<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" ><generator uri="https://jekyllrb.com/" version="3.10.0">Jekyll</generator><link href="https://www.spatters.ca/feed.xml" rel="self" type="application/atom+xml" /><link href="https://www.spatters.ca/" rel="alternate" type="text/html" /><updated>2025-09-03T00:08:05+00:00</updated><id>https://www.spatters.ca/feed.xml</id><title type="html">spatters.ca</title><subtitle>Sam&apos;s blog.
</subtitle><author><name>Sam Patterson</name></author><entry><title type="html">Improving FP16/16 matmul accuracy with two-stage accumulation</title><link href="https://www.spatters.ca/two-stage-fp16-mma" rel="alternate" type="text/html" title="Improving FP16/16 matmul accuracy with two-stage accumulation" /><published>2025-03-31T00:00:00+00:00</published><updated>2025-03-31T00:00:00+00:00</updated><id>https://www.spatters.ca/fp16-mma</id><content type="html" xml:base="https://www.spatters.ca/two-stage-fp16-mma"><![CDATA[<p>On Nvidia consumer GPUs such as the RTX 4090, FP16/32 matrix multiplication is limited to run at half the speed of FP16/16, meaning users need to choose between either using tensor core operations that accumulate in FP16 precision or only getting 50% of the GPU’s peak performance.</p>

<p>We can improve the accuracy of FP16/16 matrix multiplication with a two-stage accumulation strategy: use FP16/16 tensor core <code class="language-plaintext highlighter-rouge">mma</code> instructions but accumulate the results outside the <code class="language-plaintext highlighter-rouge">mma</code> in separate FP32 registers.</p>

<p>This is done by changing the main loop of the matmul kernel from</p>
<div class="language-c++ highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">k</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">k</span> <span class="o">&lt;</span> <span class="n">K</span><span class="p">;</span> <span class="n">K</span> <span class="o">+=</span> <span class="n">K_BLOCK</span><span class="p">)</span> <span class="p">{</span>
  <span class="c1">// load global-&gt;shared-&gt;reg etc.</span>
  <span class="c1">// ...</span>
  <span class="n">mma_m16n8k16</span><span class="p">(</span><span class="n">aReg</span><span class="p">,</span> <span class="n">bReg</span><span class="p">,</span> <span class="n">dReg</span><span class="p">,</span> <span class="n">dReg</span><span class="p">);</span>
  <span class="n">__syncthreads</span><span class="p">();</span>
<span class="p">}</span>
</code></pre></div></div>
<p>to (simplified for clarity)</p>

<div class="language-c++ highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">unsigned</span> <span class="n">cReg</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="p">{</span><span class="mi">0</span><span class="p">};</span>
<span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">k</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">k</span> <span class="o">&lt;</span> <span class="n">K</span><span class="p">;</span> <span class="n">K</span> <span class="o">+=</span> <span class="n">K_BLOCK</span><span class="p">)</span> <span class="p">{</span>
  <span class="c1">// load global-&gt;shared-&gt;reg etc.</span>
  <span class="c1">// ...</span>
  <span class="n">dRegPtr</span> <span class="o">=</span> <span class="k">reinterpret_cast</span><span class="o">&lt;</span><span class="n">half</span> <span class="o">*&gt;</span><span class="p">(</span><span class="n">dReg</span><span class="p">);</span>
  <span class="n">mma_m16n8k16</span><span class="p">(</span><span class="n">aReg</span><span class="p">,</span> <span class="n">bReg</span><span class="p">,</span> <span class="n">cReg</span><span class="p">,</span> <span class="n">dReg</span><span class="p">);</span>
  <span class="n">dRegAcc</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">+=</span> <span class="n">__half2float</span><span class="p">(</span><span class="n">dRegPtr</span><span class="p">[</span><span class="mi">0</span><span class="p">]);</span>
  <span class="n">dRegAcc</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">+=</span> <span class="n">__half2float</span><span class="p">(</span><span class="n">dRegPtr</span><span class="p">[</span><span class="mi">1</span><span class="p">]);</span>
  <span class="n">dRegAcc</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">+=</span> <span class="n">__half2float</span><span class="p">(</span><span class="n">dRegPtr</span><span class="p">[</span><span class="mi">2</span><span class="p">]);</span>
  <span class="n">dRegAcc</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span> <span class="o">+=</span> <span class="n">__half2float</span><span class="p">(</span><span class="n">dRegPtr</span><span class="p">[</span><span class="mi">3</span><span class="p">]);</span>
  <span class="n">__syncthreads</span><span class="p">();</span>
<span class="p">}</span>
</code></pre></div></div>
<p>The full code is available in Kernel 3.2 <a href="https://github.com/spatters/mma-matmul/blob/5e730a1f931b3caeca3164f3777f7c5593bd9577/kernel_3.cu#L274">here</a>.</p>

<p>An alternative approach to maintaining separate FP32 accumulator registers in the main loop would be to use Split/Stream-K and convert to FP32 when accumulating partial results.</p>

<h2 id="performance">Performance</h2>
<p>We look at the performance impact on one problem shape: M=N=K=4096, using normally distributed inputs <sup id="fnref:1" role="doc-noteref"><a href="#fn:1" class="footnote" rel="footnote">1</a></sup>. Benchmarking setup as described in my previous post on <a href="https://www.spatters.ca/mma-matmul#benchmarking-setup">Ada matmuls</a>.</p>

<p>On this problem shape, the two-stage accumulation kernel achieves 209.1 TFLOP/s, which is 79% of cuBLAS FP16/16 performance.</p>

<table>
  <thead>
    <tr>
      <th>Kernel</th>
      <th>Execution Time</th>
      <th>TFLOP/s    </th>
      <th>% 4090 peak FP16/16          </th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>cublasGemmEx FP16/32</td>
      <td>895 us</td>
      <td>153.6</td>
      <td>47.5%</td>
    </tr>
    <tr>
      <td>Two-stage accumulation</td>
      <td>657 us</td>
      <td>209.1</td>
      <td>63.3%</td>
    </tr>
    <tr>
      <td>cublasGemmEx FP16/16</td>
      <td>520 us</td>
      <td>264.2</td>
      <td>80.0%</td>
    </tr>
  </tbody>
</table>

<h2 id="accuracy">Accuracy</h2>
<p>We compare the results of each kernel to a reference kernel that computes the matmul using FP32 operations on CUDA cores. Percentiles of absolute error of each kernel compared to this reference are shown in the plot below.<br />
<img src="/assets/images/abs-error-percentiles.png" alt="abs-error-perc" /></p>

<p>Roughly speaking the two-stage accumulation kernel has ~100x larger absolute error than cuBLAS FP16/32, and ~10x smaller absoluter error than cuBLAS FP16/16.</p>

<p>So the two-stage kernel is 36% faster than cuBlAS FP16/32 but with ~100x larger absolute error, as compared to cuBLAS FP16/16 which is 72% faster with ~1000x the absolute error.</p>

<h3 id="references">References</h3>

<div class="footnotes" role="doc-endnotes">
  <ol>
    <li id="fn:1" role="doc-endnote">
      <p>As discussed in great detail on <a href="https://www.thonking.ai/p/strangely-matrix-multiplications">Horace He’s blog</a>, the distribution of input data has a noticeable impact on performance. <a href="#fnref:1" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
  </ol>
</div>]]></content><author><name>Sam Patterson</name></author><summary type="html"><![CDATA[On Nvidia consumer GPUs such as the RTX 4090, FP16/32 matrix multiplication is limited to run at half the speed of FP16/16, meaning users need to choose between either using tensor core operations that accumulate in FP16 precision or only getting 50% of the GPU’s peak performance.]]></summary></entry><entry><title type="html">Implementing a fast Tensor Core matmul on the Ada Architecture</title><link href="https://www.spatters.ca/mma-matmul" rel="alternate" type="text/html" title="Implementing a fast Tensor Core matmul on the Ada Architecture" /><published>2024-11-15T00:00:00+00:00</published><updated>2024-11-15T00:00:00+00:00</updated><id>https://www.spatters.ca/mma-matmul</id><content type="html" xml:base="https://www.spatters.ca/mma-matmul"><![CDATA[<p>Using Tensor Cores is now a prerequisite to get anywhere near peak performance on NVIDIA GPUs. In this post we work through the process of developing an efficient Tensor Core matrix multiplication kernel targeting the Ada architecture.</p>

<p>We start with a naive implementation and by incorporating techniques used in CUTLASS<sup id="fnref:3" role="doc-noteref"><a href="#fn:3" class="footnote" rel="footnote">1</a></sup>, finish with a kernel that matches cuBLAS performance (on one particular problem specification):</p>

<table>
  <thead>
    <tr>
      <th>Kernel</th>
      <th>Execution Time</th>
      <th>TFLOP/s    </th>
      <th>% cuBLAS    </th>
      <th>% 4090 peak          </th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>cublasGemmEx</td>
      <td>895 us</td>
      <td>153.6</td>
      <td>100%</td>
      <td>93.0%</td>
    </tr>
    <tr>
      <td>Kernel 1.0: Naive mma</td>
      <td>4680 us</td>
      <td>29.4</td>
      <td>19.1%</td>
      <td>17.8%</td>
    </tr>
    <tr>
      <td>Kernel 1.1: Naive + 2x tiling</td>
      <td>2400 us</td>
      <td>57.3</td>
      <td>37.3%</td>
      <td>34.7%</td>
    </tr>
    <tr>
      <td>Kernel 2.0: Permuted shmem</td>
      <td>1080 us</td>
      <td>127.3</td>
      <td>82.9%</td>
      <td>77.0%</td>
    </tr>
    <tr>
      <td>Kernel 2.1: Permuted shmem + register tweak</td>
      <td>1030 us</td>
      <td>133.4</td>
      <td>86.9%</td>
      <td>80.8%</td>
    </tr>
    <tr>
      <td>Kernel 3.0: N-stage async pipeline</td>
      <td>1000 us</td>
      <td>137.4</td>
      <td>89.5%</td>
      <td>83.2%</td>
    </tr>
    <tr>
      <td>Kernel 3.1: N-stage + 4x tiling</td>
      <td>895 us</td>
      <td>153.6</td>
      <td>100%</td>
      <td>93.0%</td>
    </tr>
  </tbody>
</table>

<p>In the process we’ll learn about the <code class="language-plaintext highlighter-rouge">mma</code>, <code class="language-plaintext highlighter-rouge">ldmatrix</code> and <code class="language-plaintext highlighter-rouge">cp.async</code> PTX instructions, how CUTLASS’s permuted shared memory layout avoids bank conflicts and how to set up an n-stage global to shared memory pipeline. The code is written as simply as possible: the aim is ease of understanding rather than generality or robustness.</p>

<p>As may be clear already, this post was heavily inspired by Simon Boehm’s great worklog on optimizing a CUDA matmul kernel<sup id="fnref:8" role="doc-noteref"><a href="#fn:8" class="footnote" rel="footnote">2</a></sup>.</p>

<h2 id="problem-definition">Problem Definition</h2>
<p>We’ll focus on one particular problem shape: M=N=K=4096, for <code class="language-plaintext highlighter-rouge">fp16</code> A/B and <code class="language-plaintext highlighter-rouge">fp32</code> C/D. This operation is <code class="language-plaintext highlighter-rouge">2*4096^3 = 137.4 GFLOP</code><sup id="fnref:2" role="doc-noteref"><a href="#fn:2" class="footnote" rel="footnote">3</a></sup> (conventional to count one FMA as 2 FLOP) and the peak <code class="language-plaintext highlighter-rouge">fp16/32</code> throughput of the RTX 4090 is 165.2 TFLOP/s<sup id="fnref:6" role="doc-noteref"><a href="#fn:6" class="footnote" rel="footnote">4</a></sup>, so the lower bound on kernel execution time is ~830 us.</p>

<p>We can use the peak throughput number to deduce how many cycles one Tensor Core instruction takes to complete (latency). All our kernels will use the PTX <code class="language-plaintext highlighter-rouge">m16n8k16</code> <code class="language-plaintext highlighter-rouge">mma</code> instruction, this is the largest Tensor Core matmul supported on Ada so it’s reasonable to assume the peak throughput is obtained using this instruction.</p>

<p>The m16n8k16 operation is <code class="language-plaintext highlighter-rouge">2*16*8*16=4096</code> FLOP, and there are 512 Tensor Cores on the RTX 4090, hence computing one mma on all Tensor Cores gives 2,097,152 FLOP. Given the peak throughput of 165.2 TFLOP/s at the boost clock of 2520 MHz, it must take 12.7 ns = 32 cycles for the <code class="language-plaintext highlighter-rouge">m16n8k16</code> <code class="language-plaintext highlighter-rouge">mma</code> operation to complete. This is roughly consistent with empirical benchmarks<sup id="fnref:7" role="doc-noteref"><a href="#fn:7" class="footnote" rel="footnote">5</a></sup>.</p>

<p>Our problem shape of M=N=K=4096 requires 256x512x256 = 33,554,432 individual m16n8k16 <code class="language-plaintext highlighter-rouge">mma</code> instructions, which is 65,536 card-wide waves of <code class="language-plaintext highlighter-rouge">mma</code>s. Hence in the best case, with no cycles stalled waiting for input, the minimum number of cycles this will take is 65,636 * 32 = 2,097,152, which is 832 us at the boost clock of 2520 MHz. Note this agrees with the number computed using peak throughput by definition as we computed the 32 cycle latency from the throughput.</p>

<h2 id="benchmarking-setup">Benchmarking Setup</h2>
<p>As a baseline for performance we use the cuBLAS <code class="language-plaintext highlighter-rouge">cublasGemmEx</code> API with <code class="language-plaintext highlighter-rouge">fp16</code> inputs and<code class="language-plaintext highlighter-rouge">fp32</code> accumulation. This performs a <code class="language-plaintext highlighter-rouge">M=N=K=4096</code> matrix multiply in 895 us which is a throughput of 153.6 TFLOP/s, 93.0% of the RTX 4090’s peak.</p>

<p>How to accurately time CUDA kernel execution could fill an entire post but in summary either CUDA events or nsight-compute give broadly consistent results if you first lock the gpu and memory clocks. I used nsight-compute as it measures kernel execution more precisely than possible using events <sup id="fnref:5" role="doc-noteref"><a href="#fn:5" class="footnote" rel="footnote">6</a></sup>.</p>

<p>By default nsight-compute locks to the GPU’s base clock, but as I wanted to compare to the RTX 4090’s stated peak throughput I locked at the boost clock of 2520 MHz. Kernels were run 55 times, the first 5 runs discarded and average results on the remaining 50 reported.</p>
<div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="nb">sudo </span>nvidia-smi <span class="nt">-pm</span> ENABLED
<span class="nb">sudo </span>nvidia-smi <span class="nt">--lock-gpu-clocks</span><span class="o">=</span>2550     <span class="c"># lock at boost clock</span>
<span class="nb">sudo </span>nvidia-smi <span class="nt">--lock-memory-clocks</span><span class="o">=</span>10501 <span class="c"># max for RTX 4090</span>
ncu <span class="nt">-s</span> 5 <span class="nt">-k</span> <span class="nv">$my_kernel_name</span> <span class="nt">--clock-control</span> none <span class="nt">--print-summary</span> per-gpu <span class="nv">$my_executable</span>
</code></pre></div></div>
<p>Benchmarks were run on Pop!_OS 22.04 LTS, CUDA Toolkit Version 12.4, CUDA Driver Version 550.67.</p>

<h3 id="aside-tensor-core-matrix-multiply-apis">Aside: Tensor Core Matrix Multiply APIs</h3>
<p>There are three separate Tensor Core matmul APIs in CUDA/PTX:</p>
<ul>
  <li>WMMA: High level API available in both <a href="https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#warp-matrix-functions">CUDA</a> and <a href="https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-wmma-instructions">PTX</a></li>
  <li>MMA: Lower level API just available in <a href="https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction">PTX</a></li>
  <li>WGMMA: sm_90 only API that operates on warp-groups (consecutive groups of 4 warps). Just available in <a href="https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-multiply-accumulate-instructions">PTX</a></li>
</ul>

<p>All kernels in this post use the PTX mma API. wgmma is not an option as I am using an Ada architecture GPU. I chose mma over wmma as mma is a lower level API and my aim is to build an understanding of the underlying Tensor Core operations. Using mma also reportedly delivers higher performance than wmma though that comparison is old<sup id="fnref:1" role="doc-noteref"><a href="#fn:1" class="footnote" rel="footnote">7</a></sup>.</p>

<h2 id="kernel-10-naive-mma-kernel">Kernel 1.0: Naive mma kernel</h2>
<p>The first kernel is a naive implementation resulting from reading the <code class="language-plaintext highlighter-rouge">mma</code> instruction documentation and handling data movement from global memory to registers in the simplest way possible.</p>

<p>In the Ada architecture there are 4 warp schedulers per SM, each with their own Tensor Core. Hence we want at least 4 warps per thread block (not strictly required as multiple thread blocks can run concurrently on one SM). In this kernel we use a 16x16 thread block, containing 8 warps. Each warp computes one 16x8 output tile and we arrange the warps in a 2 row x 4 column grid, so that each thread block computes a 32x32 output tile.</p>

<div class="language-c highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1">// arrangement of warps in output tile</span>
<span class="c1">// (warp_0 | warp_1 | warp_2 | warp_3)</span>
<span class="c1">// (warp_4 | warp_5 | warp_6 | warp_7)</span>
</code></pre></div></div>
<p>There are multiple <code class="language-plaintext highlighter-rouge">mma</code> instructions for different data types and matrix shapes. As mentioned previously, in this and all subsequent kernels we’ll use</p>
<ul>
  <li><code class="language-plaintext highlighter-rouge">mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32</code></li>
</ul>

<p>which performs (per warp) the matrix multiplication <code class="language-plaintext highlighter-rouge">D = A * B + C</code> where A is a <code class="language-plaintext highlighter-rouge">16x16</code> <code class="language-plaintext highlighter-rouge">fp16</code> matrix, <code class="language-plaintext highlighter-rouge">B</code> is <code class="language-plaintext highlighter-rouge">16x8</code> <code class="language-plaintext highlighter-rouge">fp16</code> matrix and C/D are <code class="language-plaintext highlighter-rouge">16x8</code> <code class="language-plaintext highlighter-rouge">fp32</code> matrices.</p>

<p>As <code class="language-plaintext highlighter-rouge">mma</code> is a PTX instruction, calling it from CUDA code requires using <a href="https://docs.nvidia.com/cuda/inline-ptx-assembly/index.html">inline PTX</a> which we wrap in a helper function:</p>

<div class="language-c++ highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">__device__</span> <span class="kt">void</span> <span class="nf">mma_m16n8k16</span><span class="p">(</span><span class="k">const</span> <span class="kt">unsigned</span> <span class="o">*</span><span class="n">A</span><span class="p">,</span> <span class="k">const</span> <span class="kt">unsigned</span> <span class="o">*</span><span class="n">B</span><span class="p">,</span> <span class="kt">float</span> <span class="o">*</span><span class="n">C</span><span class="p">,</span> <span class="kt">float</span> <span class="o">*</span><span class="n">D</span><span class="p">)</span> <span class="p">{</span>
  <span class="k">asm</span><span class="p">(</span>
      <span class="s">"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "</span>
      <span class="s">"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};</span><span class="se">\n</span><span class="s">"</span>
      <span class="o">:</span> <span class="s">"=f"</span><span class="p">(</span><span class="n">D</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="s">"=f"</span><span class="p">(</span><span class="n">D</span><span class="p">[</span><span class="mi">1</span><span class="p">]),</span> <span class="s">"=f"</span><span class="p">(</span><span class="n">D</span><span class="p">[</span><span class="mi">2</span><span class="p">]),</span> <span class="s">"=f"</span><span class="p">(</span><span class="n">D</span><span class="p">[</span><span class="mi">3</span><span class="p">])</span>
      <span class="o">:</span>
      <span class="s">"r"</span><span class="p">(</span><span class="n">A</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="s">"r"</span><span class="p">(</span><span class="n">A</span><span class="p">[</span><span class="mi">1</span><span class="p">]),</span> <span class="s">"r"</span><span class="p">(</span><span class="n">A</span><span class="p">[</span><span class="mi">2</span><span class="p">]),</span> <span class="s">"r"</span><span class="p">(</span><span class="n">A</span><span class="p">[</span><span class="mi">3</span><span class="p">]),</span>
      <span class="s">"r"</span><span class="p">(</span><span class="n">B</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="s">"r"</span><span class="p">(</span><span class="n">B</span><span class="p">[</span><span class="mi">1</span><span class="p">]),</span>
      <span class="s">"f"</span><span class="p">(</span><span class="n">C</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="s">"f"</span><span class="p">(</span><span class="n">C</span><span class="p">[</span><span class="mi">1</span><span class="p">]),</span> <span class="s">"f"</span><span class="p">(</span><span class="n">C</span><span class="p">[</span><span class="mi">2</span><span class="p">]),</span> <span class="s">"f"</span><span class="p">(</span><span class="n">C</span><span class="p">[</span><span class="mi">3</span><span class="p">])</span>
      <span class="p">);</span>
<span class="p">}</span>
</code></pre></div></div>

<p>The <code class="language-plaintext highlighter-rouge">mma</code> instruction is warp-wide, each of the 32 threads provides 8 <code class="language-plaintext highlighter-rouge">fp16</code> elements from A, 4 <code class="language-plaintext highlighter-rouge">fp16</code> elements from B and 4 <code class="language-plaintext highlighter-rouge">fp32</code> elements from C, and recieves 4 output <code class="language-plaintext highlighter-rouge">fp32</code> elements from D. The 8 <code class="language-plaintext highlighter-rouge">fp16</code> elements of A are packed into 4 32 bit registers, and similarly the 4 elements of B into 2 32 bit registers.</p>

<p>The matrix elements held by each thread in its registers are called a matrix fragment, and the required mapping from thread ID to fragments for A is shown below:
<img src="/assets/images/a-fragment.png" alt="a-fragment" />
A is split into 4 <code class="language-plaintext highlighter-rouge">8x8</code> submatrices, and each submatrix is split across the warp in a row major fashion which each thread holding two <code class="language-plaintext highlighter-rouge">fp16</code> values in one of its 32 bit registers. Mappings for <code class="language-plaintext highlighter-rouge">B, C &amp; D</code> are defined similarly and can be found in the <a href="https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-16816-float">PTX docs</a>.</p>

<p>We will later use the <code class="language-plaintext highlighter-rouge">ldmatrix</code> instruction to load fragements to registers, but for now we’ll do this per thread to demostrate the mapping. The main loop of Kernel 1.0 contains the code to load matrix fragments to registers and call the <code class="language-plaintext highlighter-rouge">mma</code> instruction.</p>

<div class="language-c++ highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">kStart</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">kStart</span> <span class="o">&lt;</span> <span class="n">K</span><span class="p">;</span> <span class="n">kStart</span> <span class="o">+=</span> <span class="n">K_BLOCK</span><span class="p">)</span> <span class="p">{</span>
  <span class="c1">// load from global to shared memory</span>
  <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">m</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">m</span> <span class="o">&lt;</span> <span class="mi">2</span><span class="p">;</span> <span class="o">++</span><span class="n">m</span><span class="p">)</span> <span class="p">{</span>
    <span class="n">As</span><span class="p">[</span><span class="n">m</span><span class="o">*</span><span class="n">K_BLOCK</span> <span class="o">+</span> <span class="n">ty</span><span class="p">][</span><span class="n">tx</span><span class="p">]</span> <span class="o">=</span> <span class="n">A</span><span class="p">[(</span><span class="n">mBlock</span> <span class="o">+</span> <span class="n">ty</span> <span class="o">+</span> <span class="n">m</span><span class="o">*</span><span class="n">K_BLOCK</span><span class="p">)</span><span class="o">*</span><span class="n">K</span> <span class="o">+</span> <span class="n">kStart</span> <span class="o">+</span> <span class="n">tx</span><span class="p">];</span>
  <span class="p">}</span>
  <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">n</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">n</span> <span class="o">&lt;</span> <span class="mi">2</span><span class="p">;</span> <span class="o">++</span><span class="n">n</span><span class="p">)</span> <span class="p">{</span>
    <span class="n">Bs</span><span class="p">[</span><span class="n">ty</span><span class="p">][</span><span class="n">n</span><span class="o">*</span><span class="n">K_BLOCK</span> <span class="o">+</span> <span class="n">tx</span><span class="p">]</span> <span class="o">=</span> <span class="n">B</span><span class="p">[(</span><span class="n">kStart</span> <span class="o">+</span> <span class="n">ty</span><span class="p">)</span> <span class="o">*</span> <span class="n">K</span> <span class="o">+</span> <span class="n">nBlock</span> <span class="o">+</span> <span class="n">n</span><span class="o">*</span><span class="n">K_BLOCK</span> <span class="o">+</span> <span class="n">tx</span><span class="p">];</span>
  <span class="p">}</span>
  <span class="n">__syncthreads</span><span class="p">();</span>

  <span class="c1">// load from shmem to fp16 registers</span>
  <span class="n">aReg</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">As</span><span class="p">[</span><span class="n">mWarp</span> <span class="o">+</span> <span class="n">groupID</span>    <span class="p">][</span><span class="n">groupLaneID</span><span class="o">*</span><span class="mi">2</span>    <span class="p">];</span>
  <span class="n">aReg</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">As</span><span class="p">[</span><span class="n">mWarp</span> <span class="o">+</span> <span class="n">groupID</span>    <span class="p">][</span><span class="n">groupLaneID</span><span class="o">*</span><span class="mi">2</span> <span class="o">+</span> <span class="mi">1</span><span class="p">];</span>
  <span class="n">aReg</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">As</span><span class="p">[</span><span class="n">mWarp</span> <span class="o">+</span> <span class="n">groupID</span> <span class="o">+</span> <span class="mi">8</span><span class="p">][</span><span class="n">groupLaneID</span><span class="o">*</span><span class="mi">2</span>    <span class="p">];</span>
  <span class="n">aReg</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span> <span class="o">=</span> <span class="n">As</span><span class="p">[</span><span class="n">mWarp</span> <span class="o">+</span> <span class="n">groupID</span> <span class="o">+</span> <span class="mi">8</span><span class="p">][</span><span class="n">groupLaneID</span><span class="o">*</span><span class="mi">2</span> <span class="o">+</span> <span class="mi">1</span><span class="p">];</span>
  <span class="n">aReg</span><span class="p">[</span><span class="mi">4</span><span class="p">]</span> <span class="o">=</span> <span class="n">As</span><span class="p">[</span><span class="n">mWarp</span> <span class="o">+</span> <span class="n">groupID</span>    <span class="p">][</span><span class="n">groupLaneID</span><span class="o">*</span><span class="mi">2</span> <span class="o">+</span> <span class="mi">8</span><span class="p">];</span>
  <span class="n">aReg</span><span class="p">[</span><span class="mi">5</span><span class="p">]</span> <span class="o">=</span> <span class="n">As</span><span class="p">[</span><span class="n">mWarp</span> <span class="o">+</span> <span class="n">groupID</span>    <span class="p">][</span><span class="n">groupLaneID</span><span class="o">*</span><span class="mi">2</span> <span class="o">+</span> <span class="mi">9</span><span class="p">];</span>
  <span class="n">aReg</span><span class="p">[</span><span class="mi">6</span><span class="p">]</span> <span class="o">=</span> <span class="n">As</span><span class="p">[</span><span class="n">mWarp</span> <span class="o">+</span> <span class="n">groupID</span> <span class="o">+</span> <span class="mi">8</span><span class="p">][</span><span class="n">groupLaneID</span><span class="o">*</span><span class="mi">2</span> <span class="o">+</span> <span class="mi">8</span><span class="p">];</span>
  <span class="n">aReg</span><span class="p">[</span><span class="mi">7</span><span class="p">]</span> <span class="o">=</span> <span class="n">As</span><span class="p">[</span><span class="n">mWarp</span> <span class="o">+</span> <span class="n">groupID</span> <span class="o">+</span> <span class="mi">8</span><span class="p">][</span><span class="n">groupLaneID</span><span class="o">*</span><span class="mi">2</span> <span class="o">+</span> <span class="mi">9</span><span class="p">];</span>

  <span class="n">bReg</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">Bs</span><span class="p">[</span><span class="n">groupLaneID</span><span class="o">*</span><span class="mi">2</span> <span class="o">+</span> <span class="mi">0</span><span class="p">][</span><span class="n">nWarp</span> <span class="o">+</span> <span class="n">groupID</span><span class="p">];</span>
  <span class="n">bReg</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">Bs</span><span class="p">[</span><span class="n">groupLaneID</span><span class="o">*</span><span class="mi">2</span> <span class="o">+</span> <span class="mi">1</span><span class="p">][</span><span class="n">nWarp</span> <span class="o">+</span> <span class="n">groupID</span><span class="p">];</span>
  <span class="n">bReg</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">Bs</span><span class="p">[</span><span class="n">groupLaneID</span><span class="o">*</span><span class="mi">2</span> <span class="o">+</span> <span class="mi">8</span><span class="p">][</span><span class="n">nWarp</span> <span class="o">+</span> <span class="n">groupID</span><span class="p">];</span>
  <span class="n">bReg</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span> <span class="o">=</span> <span class="n">Bs</span><span class="p">[</span><span class="n">groupLaneID</span><span class="o">*</span><span class="mi">2</span> <span class="o">+</span> <span class="mi">9</span><span class="p">][</span><span class="n">nWarp</span> <span class="o">+</span> <span class="n">groupID</span><span class="p">];</span>
  <span class="c1">// pack fp16 registers to u32 and call mma</span>
  <span class="kt">unsigned</span> <span class="k">const</span> <span class="o">*</span><span class="n">aPtr</span> <span class="o">=</span> <span class="k">reinterpret_cast</span><span class="o">&lt;</span><span class="kt">unsigned</span> <span class="k">const</span> <span class="o">*&gt;</span><span class="p">(</span><span class="o">&amp;</span><span class="n">aReg</span><span class="p">);</span>
  <span class="kt">unsigned</span> <span class="k">const</span> <span class="o">*</span><span class="n">bPtr</span> <span class="o">=</span> <span class="k">reinterpret_cast</span><span class="o">&lt;</span><span class="kt">unsigned</span> <span class="k">const</span> <span class="o">*&gt;</span><span class="p">(</span><span class="o">&amp;</span><span class="n">bReg</span><span class="p">);</span>
  <span class="n">mma_m16n8k16</span><span class="p">(</span><span class="n">aPtr</span><span class="p">,</span> <span class="n">bPtr</span><span class="p">,</span> <span class="n">dReg</span><span class="p">,</span> <span class="n">dReg</span><span class="p">);</span>
  <span class="n">__syncthreads</span><span class="p">();</span>
<span class="p">}</span>
</code></pre></div></div>
<h3 id="performance">Performance</h3>
<p>Kernel 1.0 has an execution time of 4.67 ms, giving a throughput of 29.4 TFLOP/s, 19.1% of cuBLAS and 17.8% of peak RTX 4090 performance. In fact it only achieves 35.6% of the RTX 4090’s peak FP32 performance, so a reasonably optimized non Tensor Core kernel would be faster. The reasons for the poor performance are:</p>
<ol>
  <li>Each thread loads individual 16b values in an uncoalesced load pattern</li>
  <li>The loads from shared memory to registers have multiple bank conflicts</li>
  <li>Each element loaded is only used in the input to one <code class="language-plaintext highlighter-rouge">mma</code> instruction, so the ratio of memory access to computation is low</li>
</ol>

<p>The Warp State Statistics chart in nsight-compute shows the impact of these problems: on average per instruction executed a warp spends 31 cycles stalled on shared memory throttles (MIO), 15 cycles stalled on barrier waits and 11 stalled on long scoreboard (global load) dependencies.</p>

<p><img src="/assets/images/kernel-1-warp-stats-1.png" alt="kernel-1-warp-stats" /></p>

<p>We can also use the profiler to  query the count of <code class="language-plaintext highlighter-rouge">mma</code> instructions executed and elapsed cycles:</p>
<div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="nt">-------------------------------------------</span> <span class="nt">-----------</span> <span class="nt">-------------</span>
Metric Name                                 Metric Unit  Metric Value
<span class="nt">-------------------------------------------</span> <span class="nt">-----------</span> <span class="nt">-------------</span>
sm__cycles_elapsed.avg                            cycle 11,787,459.33
sm__cycles_elapsed.max                            cycle    11,822,127
sm__cycles_elapsed.min                            cycle    11,739,016
sm__cycles_elapsed.sum                            cycle 1,508,794,794
smsp__inst_executed_pipe_tensor_op_hmma.avg        inst        65,536
smsp__inst_executed_pipe_tensor_op_hmma.max        inst        66,048
smsp__inst_executed_pipe_tensor_op_hmma.min        inst        65,024
smsp__inst_executed_pipe_tensor_op_hmma.sum        inst    33,554,432
<span class="nt">-------------------------------------------</span> <span class="nt">-----------</span> <span class="nt">-------------</span>
</code></pre></div></div>
<p>The total number of <code class="language-plaintext highlighter-rouge">mma</code> instructions is 33,554,432 as calculated earlier, with 65,536 being computed on each Tensor Core. The number of cycles elapsed per <code class="language-plaintext highlighter-rouge">mma</code> was 11,787,459 / 65,536 = 179.9, so we are far from the 32 cycles best case.</p>

<p>The three problems described above will be addressed in Kernel 2: Point 1 by using vectorized and coalesced loads, Point 2 by using a permuted shared memory layout and Point 3 as each warp will compute multiple output tiles.</p>

<p>To isolate the impact made just by tiling vs the other changes, we add 2x tiling in the M and N dimensions in Kernel 1.1. In this kernel each warp executes 4 <code class="language-plaintext highlighter-rouge">mma</code> instructions meaning each thread block computes a 64x64 output tile. This reduces execution time to 2.40 ms, increasing throughput to 57.3 TFLOP/s, 37.3% cuBLAS, 34.7% peak performance.</p>

<h2 id="kernel-20-vectorized-loads--permuted-shared-memory-layout">Kernel 2.0: Vectorized Loads &amp; Permuted Shared Memory Layout</h2>
<p>In this kernel we use some of the techniques (vectorized loads and permuted shared memory layout) discussed in the GTC 2020 CUTLASS presentation<sup id="fnref:3:1" role="doc-noteref"><a href="#fn:3" class="footnote" rel="footnote">1</a></sup> to resolve the performance issues of Kernel 1. The memory layout diagrams in this section are taken from that presentation. The majority of the performance of the final kernel comes from the permuted shared memory layout introduced in this section.</p>

<p>Throughout this kernel we operate on <code class="language-plaintext highlighter-rouge">uint4</code> 128b vectors containing 8 consecutive <code class="language-plaintext highlighter-rouge">fp16</code> elements in the K dimension of A and B. Working with 128b vectors is natural when using Tensor Cores as the fundamental Tensor Core operation is an 8 by 8 by 128b matrix multiply, i.e. each 128b vector forms one row of the input matrices. Using 128b vectors also means we can vectorize memory operations.</p>

<p>We keep the 16x16 thread block dimensions from Kernel 1. The main loop of the kernel is shown below:</p>

<div class="language-c++ highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1">// row / column indices when storing to shared memory</span>
<span class="kt">int</span> <span class="n">storeRow</span> <span class="o">=</span> <span class="n">warpID</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">+</span> <span class="n">laneID</span> <span class="o">/</span> <span class="mi">8</span><span class="p">;</span>
<span class="kt">int</span> <span class="n">storeCol</span> <span class="o">=</span> <span class="p">(</span><span class="n">laneID</span> <span class="o">%</span> <span class="mi">8</span><span class="p">)</span> <span class="o">^</span> <span class="p">(</span><span class="n">laneID</span> <span class="o">/</span> <span class="mi">8</span><span class="p">);</span>

<span class="c1">// row/column indices when loading from permuted shmem layout to registers</span>
<span class="kt">int</span> <span class="n">loadRowA</span> <span class="o">=</span> <span class="p">(</span><span class="n">laneID</span> <span class="o">%</span> <span class="mi">16</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span><span class="p">;</span>
<span class="kt">int</span> <span class="n">loadColA</span> <span class="o">=</span> <span class="p">(</span><span class="n">laneID</span> <span class="o">/</span> <span class="mi">16</span> <span class="o">+</span> <span class="mi">4</span> <span class="o">*</span> <span class="p">(</span><span class="n">laneID</span> <span class="o">%</span> <span class="mi">2</span><span class="p">))</span> <span class="o">^</span> <span class="p">(</span><span class="n">loadRowA</span> <span class="o">%</span> <span class="mi">4</span><span class="p">);</span>
<span class="kt">int</span> <span class="n">loadRowB</span> <span class="o">=</span> <span class="p">(</span><span class="n">laneID</span> <span class="o">%</span> <span class="mi">8</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span><span class="p">;</span>
<span class="kt">int</span> <span class="n">loadColB</span> <span class="o">=</span> <span class="p">(</span><span class="n">laneID</span> <span class="o">/</span> <span class="mi">8</span> <span class="o">+</span> <span class="mi">4</span> <span class="o">*</span> <span class="p">(</span><span class="n">laneID</span> <span class="o">%</span> <span class="mi">2</span><span class="p">))</span> <span class="o">^</span> <span class="p">(</span><span class="n">loadRowB</span> <span class="o">%</span> <span class="mi">4</span><span class="p">);</span>

<span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">k</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">k</span> <span class="o">&lt;</span> <span class="n">K</span><span class="o">/</span><span class="mi">8</span><span class="p">;</span> <span class="n">k</span> <span class="o">+=</span> <span class="mi">4</span><span class="p">)</span> <span class="p">{</span>
  <span class="n">As</span><span class="p">[</span><span class="n">storeRow</span><span class="p">][</span><span class="n">storeCol</span><span class="p">]</span> <span class="o">=</span> <span class="n">globalTileA</span><span class="p">[(</span><span class="n">warpID</span><span class="o">*</span><span class="mi">8</span> <span class="o">+</span> <span class="n">laneID</span><span class="o">/</span><span class="mi">4</span><span class="p">)</span><span class="o">*</span><span class="n">K</span><span class="o">/</span><span class="mi">8</span> <span class="o">+</span> <span class="n">k</span> <span class="o">+</span> <span class="n">laneID</span><span class="o">%</span><span class="mi">4</span><span class="p">];</span>
  <span class="n">Bs</span><span class="p">[</span><span class="n">storeRow</span><span class="p">][</span><span class="n">storeCol</span><span class="p">]</span> <span class="o">=</span> <span class="n">globalTileB</span><span class="p">[(</span><span class="n">warpID</span><span class="o">*</span><span class="mi">8</span> <span class="o">+</span> <span class="n">laneID</span><span class="o">/</span><span class="mi">4</span><span class="p">)</span><span class="o">*</span><span class="n">K</span><span class="o">/</span><span class="mi">8</span> <span class="o">+</span> <span class="n">k</span> <span class="o">+</span> <span class="n">laneID</span><span class="o">%</span><span class="mi">4</span><span class="p">];</span>
  <span class="n">__syncthreads</span><span class="p">();</span>

  <span class="c1">// loop over the two (M/N=16, K=4) tiles of a and b</span>
  <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">m</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">m</span> <span class="o">&lt;</span> <span class="mi">2</span><span class="p">;</span> <span class="n">m</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
    <span class="kt">int</span> <span class="n">mTile</span> <span class="o">=</span> <span class="n">m</span> <span class="o">*</span> <span class="mi">8</span><span class="p">;</span>
    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">n</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">n</span> <span class="o">&lt;</span> <span class="mi">2</span><span class="p">;</span> <span class="n">n</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
      <span class="kt">int</span> <span class="n">nTile</span> <span class="o">=</span> <span class="n">n</span> <span class="o">*</span> <span class="mi">4</span><span class="p">;</span>
      <span class="n">load_matrix_x4</span><span class="p">(</span><span class="n">aReg</span><span class="p">,</span> <span class="p">(</span><span class="n">As</span><span class="p">[</span><span class="n">mWarp</span> <span class="o">+</span> <span class="n">mTile</span> <span class="o">+</span> <span class="n">loadRowA</span><span class="p">]</span> <span class="o">+</span> <span class="n">loadColA</span><span class="p">));</span>
      <span class="n">load_matrix_x2</span><span class="p">(</span><span class="n">bReg</span><span class="p">,</span> <span class="p">(</span><span class="n">Bs</span><span class="p">[</span><span class="n">nWarp</span> <span class="o">+</span> <span class="n">nTile</span> <span class="o">+</span> <span class="n">loadRowB</span><span class="p">]</span> <span class="o">+</span> <span class="n">loadColB</span><span class="p">));</span>
      <span class="n">mma_m16n8k16</span><span class="p">(</span><span class="n">aReg</span><span class="p">,</span> <span class="n">bReg</span><span class="p">,</span> <span class="n">dReg</span><span class="p">[</span><span class="n">m</span><span class="p">][</span><span class="n">n</span><span class="p">],</span> <span class="n">dReg</span><span class="p">[</span><span class="n">m</span><span class="p">][</span><span class="n">n</span><span class="p">]);</span>
      <span class="n">load_matrix_x4</span><span class="p">(</span><span class="n">aReg</span><span class="p">,</span> <span class="p">(</span><span class="n">As</span><span class="p">[</span><span class="n">mWarp</span> <span class="o">+</span> <span class="n">mTile</span> <span class="o">+</span> <span class="n">loadRowA</span><span class="p">]</span> <span class="o">+</span> <span class="p">(</span><span class="n">loadColA</span><span class="o">^</span><span class="mi">2</span><span class="p">)));</span>
      <span class="n">load_matrix_x2</span><span class="p">(</span><span class="n">bReg</span><span class="p">,</span> <span class="p">(</span><span class="n">Bs</span><span class="p">[</span><span class="n">nWarp</span> <span class="o">+</span> <span class="n">nTile</span> <span class="o">+</span> <span class="n">loadRowB</span><span class="p">]</span> <span class="o">+</span> <span class="p">(</span><span class="n">loadColB</span><span class="o">^</span><span class="mi">2</span><span class="p">)));</span>
      <span class="n">mma_m16n8k16</span><span class="p">(</span><span class="n">aReg</span><span class="p">,</span> <span class="n">bReg</span><span class="p">,</span> <span class="n">dReg</span><span class="p">[</span><span class="n">m</span><span class="p">][</span><span class="n">n</span><span class="p">],</span> <span class="n">dReg</span><span class="p">[</span><span class="n">m</span><span class="p">][</span><span class="n">n</span><span class="p">]);</span>
    <span class="p">}</span>
  <span class="p">}</span>
  <span class="n">__syncthreads</span><span class="p">();</span>
<span class="p">}</span>
</code></pre></div></div>
<p>Looking first at the load from global to shared, each thread block loads A/B tiles of shape <code class="language-plaintext highlighter-rouge">(M/N=64, K=4)</code> <code class="language-plaintext highlighter-rouge">uint4</code> values from global memory to shared memory in a K-major fashion (i.e. row-major for A and column-major for B), with consecutive threads loading consecutive <code class="language-plaintext highlighter-rouge">uint4</code> values in the K-dimension using vectorized 128b loads. To coalesce these loads, the kernel requires A to be stored row-major in global memory and to be B stored column-major.</p>

<p>At the warp level, we load <code class="language-plaintext highlighter-rouge">uint4</code> tiles of shape <code class="language-plaintext highlighter-rouge">(M/N=8, K=4)</code> containing 8 rows/columns of A/B each containing 4 <code class="language-plaintext highlighter-rouge">uint4</code> values. This results in eight 64B memory transactions, each transaction reading two 32B sectors out of a 128B cache line containing four sectors in total.</p>

<p>This tile is stored in a <code class="language-plaintext highlighter-rouge">uint4</code> shared memory array of shape <code class="language-plaintext highlighter-rouge">(4, 8)</code> with two K=4 row/column slices stored per shared memory row. This shared memory shape is used as shared memory has 32 banks which are each 4 bytes wide, hence a row of 8 <code class="language-plaintext highlighter-rouge">uint4</code> values spans the 32 shared memory banks.</p>

<p>To avoid bank conflicts, threads which are part of the same memory request must not access addresses which map to the same bank. When each thread requests a 16B (128b) value, the warp level 512B request is split into 4 phases each consisting of 8 consecutive threads, as the max shared memory bandwidth is 32 banks * 4B = 128B. This means that it is sufficient to avoid bank conflicts within the 8 threads in each phase, rather within the full warp of 32 threads.</p>

<p>When storing to shared memory, the column indices for each row are permuted by XORing them with the row index: <code class="language-plaintext highlighter-rouge">storeCol = (laneID % 8) ^ (laneID / 8)</code>. The store from global to shared would be bank conflict free without this permutation, but it is required to avoid bank conflicts when loading data to registers from shared memory.</p>

<p>This diagram from <sup id="fnref:3:2" role="doc-noteref"><a href="#fn:3" class="footnote" rel="footnote">1</a></sup> illustrates how one warp loads from global to shared using the permuted layout:</p>

<p><img src="/assets/images/load-global-store-shared.png" alt="load-global-store-shared" /></p>

<p>Once data is loaded to shared memory, each warp computes a matmul on a <code class="language-plaintext highlighter-rouge">(M=32, K=4)</code> tile of A and a <code class="language-plaintext highlighter-rouge">(N=16, K=4)</code> tile of B. As the <code class="language-plaintext highlighter-rouge">mma</code> instruction computes a M=16, N=8, K=16 matmul we split these tiles into two <code class="language-plaintext highlighter-rouge">(M=16, K=4)</code> tiles of A / <code class="language-plaintext highlighter-rouge">(N=8, K=4)</code> tiles of B and compute their products in a nested loop. At the innermost level of this loop, we first load the <code class="language-plaintext highlighter-rouge">k=0..1</code> subtiles of the current A and B tiles into registers and compute their product using the <code class="language-plaintext highlighter-rouge">mma</code> instruction. We then load the <code class="language-plaintext highlighter-rouge">k=2..3</code> subtiles and perform a second <code class="language-plaintext highlighter-rouge">mma</code>.</p>

<p>We use the <code class="language-plaintext highlighter-rouge">ldmatrix</code> PTX instruction to load these tiles from shared memory to registers. This warp-wide instruction loads 1, 2 or 4 <code class="language-plaintext highlighter-rouge">8x128b</code> matrices and stores each matrix in one 32b register per thread in the fragment layout discussed previously. Each 128b row of these matrices is stored in one <code class="language-plaintext highlighter-rouge">uint4</code> vector in shared memory and each thread in the warp provides the address of one of these rows as described in the docs:</p>

<p><img src="/assets/images/ldmatrix-ptx-docs.png" alt="ldmatrix-docs" /></p>

<p>This means that to load a <code class="language-plaintext highlighter-rouge">(M=16, k=0..1)</code> subtile of <code class="language-plaintext highlighter-rouge">A</code>, we use the <code class="language-plaintext highlighter-rouge">.x4</code> variant of <code class="language-plaintext highlighter-rouge">ldmatrix</code>, with threads <code class="language-plaintext highlighter-rouge">0..15</code> providing the addresses of the elements with indices <code class="language-plaintext highlighter-rouge">m=0..15, k=0</code> and threads <code class="language-plaintext highlighter-rouge">16..31</code> providing the addresses of elements with indices <code class="language-plaintext highlighter-rouge">m=0..15, k=1</code>. Crucially, we permuted the layout of the tiles of A when storing to shared memory, and hence each thread needs to compute the address of its required element in the permuted layout.</p>

<p>Each <code class="language-plaintext highlighter-rouge">(M=16, K=4)</code> tile of A is stored in a 8 consecutive row subarray of the <code class="language-plaintext highlighter-rouge">As</code> shared memory array and each <code class="language-plaintext highlighter-rouge">(N=8, K=4)</code> tile of B is stored in a 4 consecutive row subarray of <code class="language-plaintext highlighter-rouge">Bs</code>. The <code class="language-plaintext highlighter-rouge">mWarp, nWarp</code> and <code class="language-plaintext highlighter-rouge">mTile, nTile</code> variables specify the start row of the subarrays of <code class="language-plaintext highlighter-rouge">As</code>, <code class="language-plaintext highlighter-rouge">Bs</code> for each warp / each iteration of the tile loop. Within each subarray the <code class="language-plaintext highlighter-rouge">loadRowA/B, loadColA/B</code> variables specify the location of the required element in the permuted layout.</p>

<p>The following diagram from <sup id="fnref:3:3" role="doc-noteref"><a href="#fn:3" class="footnote" rel="footnote">1</a></sup> illustrates the locations in shared memory provided by each thread to <code class="language-plaintext highlighter-rouge">ldmatrix</code> when loading a <code class="language-plaintext highlighter-rouge">(M=16, K=4)</code> tile of A:</p>

<p><img src="/assets/images/shared-register.png" alt="shared-register" /></p>

<p>The elements of the <code class="language-plaintext highlighter-rouge">k=0</code> slice of the subtile, loaded by threads <code class="language-plaintext highlighter-rouge">0..15</code> are shaded in blue. The elements loaded by threads <code class="language-plaintext highlighter-rouge">0..7</code> are all in distinct shared memory banks due to the permuted layout, as are those loaded by threads <code class="language-plaintext highlighter-rouge">8..15</code> and hence there are no bank conflicts.</p>

<p>This is also true for the <code class="language-plaintext highlighter-rouge">k=1</code> slice which is shaded in green. If the permutation had not been applied, threads <code class="language-plaintext highlighter-rouge">0,2,4,6</code> would all access banks <code class="language-plaintext highlighter-rouge">0..3</code> and threads <code class="language-plaintext highlighter-rouge">1,3,5,7</code> would all access banks <code class="language-plaintext highlighter-rouge">16..19</code>, causing multiple bank conflicts.</p>

<p>The elements shaded in yellow/gray belong to the <code class="language-plaintext highlighter-rouge">k=2..3</code> slices, which are inputs to the second <code class="language-plaintext highlighter-rouge">mma</code>. The column indices for these slices can be computed efficiently from the column indices of the <code class="language-plaintext highlighter-rouge">k=0..1</code> slices by applying <code class="language-plaintext highlighter-rouge">xor 2</code> to those indices.</p>

<p>Loading the <code class="language-plaintext highlighter-rouge">k=0..1</code> and <code class="language-plaintext highlighter-rouge">k=2..3</code> subtiles of B is similar except that as the subtile dimension is <code class="language-plaintext highlighter-rouge">(N=8, k=0..1)</code> there are only 16 128b matrix rows to load. Hence we use <code class="language-plaintext highlighter-rouge">ldmatrix.x2</code> which loads 2 8x128b matrices, using only the addresses in threads <code class="language-plaintext highlighter-rouge">0..15</code>.</p>

<p>As with the <code class="language-plaintext highlighter-rouge">mma</code> instruction, we define helper functions <code class="language-plaintext highlighter-rouge">load_matrix_x4</code>, <code class="language-plaintext highlighter-rouge">load_matrix_x2</code> to wrap the inline PTX. Looking at <code class="language-plaintext highlighter-rouge">load_matrix_x4</code> as an example:</p>

<div class="language-c++ highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">__device__</span> <span class="kt">void</span> <span class="nf">load_matrix_x4</span><span class="p">(</span><span class="kt">unsigned</span> <span class="o">*</span><span class="n">destReg</span><span class="p">,</span> <span class="n">uint4</span> <span class="o">*</span><span class="n">srcAddr</span><span class="p">)</span> <span class="p">{</span>
  <span class="kt">unsigned</span> <span class="n">ptxSrcAddr</span> <span class="o">=</span> <span class="n">__cvta_generic_to_shared</span><span class="p">(</span><span class="n">srcAddr</span><span class="p">);</span>
  <span class="k">asm</span> <span class="k">volatile</span><span class="p">(</span>
      <span class="s">"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];</span><span class="se">\n</span><span class="s">"</span>
      <span class="o">:</span> <span class="s">"=r"</span><span class="p">(</span><span class="n">destReg</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="s">"=r"</span><span class="p">(</span><span class="n">destReg</span><span class="p">[</span><span class="mi">1</span><span class="p">]),</span> <span class="s">"=r"</span><span class="p">(</span><span class="n">destReg</span><span class="p">[</span><span class="mi">2</span><span class="p">]),</span> <span class="s">"=r"</span><span class="p">(</span><span class="n">destReg</span><span class="p">[</span><span class="mi">3</span><span class="p">])</span>
      <span class="o">:</span>  <span class="s">"r"</span><span class="p">(</span><span class="n">ptxSrcAddr</span><span class="p">)</span>
      <span class="p">);</span>
<span class="p">}</span>
</code></pre></div></div>

<p>Two things to note</p>
<ol>
  <li><code class="language-plaintext highlighter-rouge">__cvta_generic_to_shared</code> is a CUDA function that takes a standard C/C++ pointer, which is 64b, and converts to a 32b pointer as shared memory is a 32b address space</li>
  <li>The <code class="language-plaintext highlighter-rouge">volatile</code> qualifier is needed for this instruction: without it the loads do not get synchronized properly and threads end up with incorrect data, as I discovered after much painful debugging.</li>
</ol>

<p>Once the main loop has finished, the output tile for each warp is contained in <code class="language-plaintext highlighter-rouge">dReg</code>, the output registers of the <code class="language-plaintext highlighter-rouge">mma</code> instructions. There is a <code class="language-plaintext highlighter-rouge">stmatrix</code> instruction to copy back from registers to shared memory but this requires <code class="language-plaintext highlighter-rouge">sm_90</code> so we need to handle this ourselves. We write directly from registers to global memory, it may be possible to optimize this by writing first to shared and then writing to global in a coalesced pattern but that requires more shared memory and could reduce occupancy. I experimented with this but did not see a performance improvement.</p>

<h3 id="performance-1">Performance</h3>
<p>Kernel 2.0 has greatly increased performance. Execution time is 1080 us, a throughput of 127.3 TFLOP/s which is 82.9% cuBLAS and 77.0% of RTX 4090 peak performance. We can make one minor tweak to the kernel to improve performance further. Currently we reload each tile of A for each tile of B, this reduces register usage but introduces redundant loads from shared memory to registers.</p>

<p>In Kernel 2.1 we only load each tile of A once, this improves performance to 1030 us, 133.4 TFLOP/s, 86.9% of cuBLAS, 80.8% peak. The elapsed cycles per mma for Kernel 2.1 is 38, much closer to the minimum of 32.</p>

<p>The permuted shared memory layout should make these kernels bank-conflict free and we verify this for Kernel 2.1:</p>

<p><img src="/assets/images/kernel-2b-conflict.png" alt="kernel-2b-conflict" /></p>

<p>Looking at the warp stats shows that the most frequent cause of stalls is now waiting for the Tensor Cores to be free - this is good!</p>

<p><img src="/assets/images/kernel-2b-warp-stats-1.png" alt="2b-warp-stats" /></p>

<p>There are still considerable number of barrier and long scoreboard stalls, which we’ll address in Kernel 3.0 by introducing an n-stage pipeline from global to shared memory using the <code class="language-plaintext highlighter-rouge">cp.async</code> instruction.</p>

<h2 id="kernel-30-n-stage-global-to-shared-pipeline">Kernel 3.0: N-stage global to shared pipeline</h2>
<p>There are asynchronous copy APIs both in CUDA (<code class="language-plaintext highlighter-rouge">cuda::memcpy_async</code>) and PTX (<code class="language-plaintext highlighter-rouge">cp.async</code>). The <code class="language-plaintext highlighter-rouge">cuda::memcp_async</code> API does not support copying with a permuted layout and hence we use the PTX <code class="language-plaintext highlighter-rouge">cp.async</code> API. As before we define a wrapper function for the inline PTX call:</p>

<div class="language-c++ highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">__device__</span> <span class="kt">void</span> <span class="nf">cp_async</span><span class="p">(</span><span class="n">uint4</span> <span class="o">*</span><span class="n">dstAddr</span><span class="p">,</span> <span class="k">const</span> <span class="n">uint4</span> <span class="o">*</span><span class="n">srcAddr</span><span class="p">)</span> <span class="p">{</span>
  <span class="kt">unsigned</span> <span class="n">ptxDstAddr</span> <span class="o">=</span> <span class="n">__cvta_generic_to_shared</span><span class="p">(</span><span class="n">dstAddr</span><span class="p">);</span>
  <span class="k">asm</span> <span class="k">volatile</span><span class="p">(</span><span class="s">"cp.async.cg.shared.global.L2::128B [%0], [%1], %2;</span><span class="se">\n</span><span class="s">"</span>
      <span class="o">::</span> <span class="s">"r"</span><span class="p">(</span><span class="n">ptxDstAddr</span><span class="p">),</span>
      <span class="s">"l"</span><span class="p">(</span><span class="n">srcAddr</span><span class="p">),</span>
      <span class="s">"n"</span><span class="p">(</span><span class="mi">16</span><span class="p">));</span>
<span class="p">}</span>
</code></pre></div></div>

<p>The final <code class="language-plaintext highlighter-rouge">"n"(16)</code> input is the number of bytes to copy, and needs to be a compile time constant.</p>

<p>We can then use this function to replace the global to shared load:</p>
<div class="language-c++ highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1">// Replace this</span>
<span class="n">As</span><span class="p">[</span><span class="n">storeRow</span><span class="p">][</span><span class="n">storeCol</span><span class="p">]</span> <span class="o">=</span> <span class="n">globalTileA</span><span class="p">[(</span><span class="n">warpID</span><span class="o">*</span><span class="mi">8</span> <span class="o">+</span> <span class="n">laneID</span><span class="o">/</span><span class="mi">4</span><span class="p">)</span><span class="o">*</span><span class="n">K</span><span class="o">/</span><span class="mi">8</span> <span class="o">+</span> <span class="n">k</span> <span class="o">+</span> <span class="n">laneID</span><span class="o">%</span><span class="mi">4</span><span class="p">];</span>
<span class="n">Bs</span><span class="p">[</span><span class="n">storeRow</span><span class="p">][</span><span class="n">storeCol</span><span class="p">]</span> <span class="o">=</span> <span class="n">globalTileB</span><span class="p">[(</span><span class="n">warpID</span><span class="o">*</span><span class="mi">8</span> <span class="o">+</span> <span class="n">laneID</span><span class="o">/</span><span class="mi">4</span><span class="p">)</span><span class="o">*</span><span class="n">K</span><span class="o">/</span><span class="mi">8</span> <span class="o">+</span> <span class="n">k</span> <span class="o">+</span> <span class="n">laneID</span><span class="o">%</span><span class="mi">4</span><span class="p">];</span>
<span class="c1">// With</span>
<span class="n">cp_async</span><span class="p">(</span><span class="n">As</span><span class="p">[</span><span class="n">storeRow</span><span class="p">]</span> <span class="o">+</span> <span class="n">storeCol</span><span class="p">,</span> <span class="n">globalTileA</span><span class="p">[(</span><span class="n">warpID</span><span class="o">*</span><span class="mi">8</span> <span class="o">+</span> <span class="n">laneID</span><span class="o">/</span><span class="mi">4</span><span class="p">)</span><span class="o">*</span><span class="n">K</span><span class="o">/</span><span class="mi">8</span> <span class="o">+</span> <span class="n">k</span> <span class="o">+</span> <span class="n">laneID</span><span class="o">%</span><span class="mi">4</span><span class="p">]);</span>
<span class="n">cp_async</span><span class="p">(</span><span class="n">Bs</span><span class="p">[</span><span class="n">storeRow</span><span class="p">]</span> <span class="o">+</span> <span class="n">storeCol</span><span class="p">,</span> <span class="n">globalTileB</span><span class="p">[(</span><span class="n">warpID</span><span class="o">*</span><span class="mi">8</span> <span class="o">+</span> <span class="n">laneID</span><span class="o">/</span><span class="mi">4</span><span class="p">)</span><span class="o">*</span><span class="n">K</span><span class="o">/</span><span class="mi">8</span> <span class="o">+</span> <span class="n">k</span> <span class="o">+</span> <span class="n">laneID</span><span class="o">%</span><span class="mi">4</span><span class="p">]);</span>
<span class="k">asm</span> <span class="k">volatile</span><span class="p">(</span><span class="s">"cp.async.commit_group;</span><span class="se">\n</span><span class="s">"</span> <span class="o">::</span><span class="p">);</span>
</code></pre></div></div>
<p>The <code class="language-plaintext highlighter-rouge">cp.async.commit_group</code> instruction groups these copies together in a <code class="language-plaintext highlighter-rouge">cp.async-group</code> which can later be waited on using <code class="language-plaintext highlighter-rouge">cp.async.wait_group</code>.</p>

<p>We now use <code class="language-plaintext highlighter-rouge">cp.async</code> to set up an n-stage pipeline from global to shared memory. We create circular buffers of size <code class="language-plaintext highlighter-rouge">N_STAGES</code> for A and B in shared memory. Before the main loop of the kernel we preload the first <code class="language-plaintext highlighter-rouge">N_STAGES - 1</code> stages into these shared memory buffers using <code class="language-plaintext highlighter-rouge">cp.async</code>:</p>

<div class="language-c++ highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">__shared__</span> <span class="n">uint4</span> <span class="n">As</span><span class="p">[</span><span class="n">N_STAGES</span><span class="o">*</span><span class="mi">32</span><span class="p">][</span><span class="mi">8</span><span class="p">];</span>
<span class="n">__shared__</span> <span class="n">uint4</span> <span class="n">Bs</span><span class="p">[</span><span class="n">N_STAGES</span><span class="o">*</span><span class="mi">32</span><span class="p">][</span><span class="mi">8</span><span class="p">];</span>
<span class="c1">// PRELUDE: load first (N_STAGES - 1) into shared memory</span>
<span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">nStage</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">nStage</span> <span class="o">&lt;</span> <span class="n">N_STAGES</span> <span class="o">-</span> <span class="mi">1</span><span class="p">;</span> <span class="n">nStage</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
  <span class="kt">int</span> <span class="n">kStart</span> <span class="o">=</span> <span class="n">nStage</span> <span class="o">*</span> <span class="mi">4</span><span class="p">;</span>
  <span class="n">aStorePtr</span> <span class="o">=</span> <span class="n">As</span> <span class="o">+</span> <span class="mi">32</span> <span class="o">*</span> <span class="n">nStage</span><span class="p">;</span>
  <span class="n">bStorePtr</span> <span class="o">=</span> <span class="n">Bs</span> <span class="o">+</span> <span class="mi">32</span> <span class="o">*</span> <span class="n">nStage</span><span class="p">;</span>
  <span class="n">cp_async</span><span class="p">(</span><span class="n">aStorePtr</span><span class="p">[</span><span class="n">storeRow</span><span class="p">]</span> <span class="o">+</span> <span class="n">storeCol</span><span class="p">,</span> <span class="n">aGlobalAddress</span> <span class="o">+</span> <span class="n">kStart</span><span class="p">);</span>
  <span class="n">cp_async</span><span class="p">(</span><span class="n">bStorePtr</span><span class="p">[</span><span class="n">storeRow</span><span class="p">]</span> <span class="o">+</span> <span class="n">storeCol</span><span class="p">,</span> <span class="n">bGlobalAddress</span> <span class="o">+</span> <span class="n">kStart</span><span class="p">);</span>
  <span class="k">asm</span> <span class="k">volatile</span><span class="p">(</span><span class="s">"cp.async.commit_group;</span><span class="se">\n</span><span class="s">"</span> <span class="o">::</span><span class="p">);</span>
<span class="p">}</span>
</code></pre></div></div>
<p>At the start of the main loop there are at most <code class="language-plaintext highlighter-rouge">N_STAGES-1</code> <code class="language-plaintext highlighter-rouge">cp.async</code> operations pending, this is an invariant that will be maintained at each loop iteration. We initialize shared memory load and store pointers to stages <code class="language-plaintext highlighter-rouge">0</code> and <code class="language-plaintext highlighter-rouge">N_STAGES-1</code> respectively and then wait for the first copy to complete, i.e. until there are at most <code class="language-plaintext highlighter-rouge">N_STAGES-2</code> cp.async operations pending. Note that a <code class="language-plaintext highlighter-rouge">__syncthreads</code> is required after <code class="language-plaintext highlighter-rouge">wait_group</code> as <code class="language-plaintext highlighter-rouge">wait_group</code> just synchronizes copy operations within each thread, not across threads.</p>

<div class="language-c++ highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1">//  MAIN LOOP OVER K BLOCKS</span>
<span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">nStage</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">nStage</span> <span class="o">&lt;</span> <span class="n">K</span><span class="o">/</span><span class="mi">32</span><span class="p">;</span> <span class="n">nStage</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
  <span class="kt">int</span> <span class="n">kStart</span> <span class="o">=</span> <span class="p">(</span><span class="n">N_STAGES</span><span class="o">-</span><span class="mi">1</span><span class="o">+</span><span class="n">nStage</span><span class="p">)</span> <span class="o">*</span> <span class="mi">4</span><span class="p">;</span>
  <span class="n">aStorePtr</span> <span class="o">=</span> <span class="n">As</span> <span class="o">+</span> <span class="mi">32</span> <span class="o">*</span> <span class="p">((</span><span class="n">nStage</span> <span class="o">+</span> <span class="n">N_STAGES</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="n">N_STAGES</span><span class="p">);</span>
  <span class="n">bStorePtr</span> <span class="o">=</span> <span class="n">Bs</span> <span class="o">+</span> <span class="mi">32</span> <span class="o">*</span> <span class="p">((</span><span class="n">nStage</span> <span class="o">+</span> <span class="n">N_STAGES</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="n">N_STAGES</span><span class="p">);</span>
  <span class="n">aLoadPtr</span> <span class="o">=</span> <span class="n">As</span> <span class="o">+</span> <span class="mi">32</span> <span class="o">*</span> <span class="p">(</span><span class="n">nStage</span> <span class="o">%</span> <span class="n">N_STAGES</span><span class="p">);</span>
  <span class="n">bLoadPtr</span> <span class="o">=</span> <span class="n">Bs</span> <span class="o">+</span> <span class="mi">32</span> <span class="o">*</span> <span class="p">(</span><span class="n">nStage</span> <span class="o">%</span> <span class="n">N_STAGES</span><span class="p">);</span>
  
  <span class="k">asm</span> <span class="k">volatile</span><span class="p">(</span><span class="s">"cp.async.wait_group %0;</span><span class="se">\n</span><span class="s">"</span> <span class="o">::</span> <span class="s">"n"</span><span class="p">(</span><span class="n">N_STAGES</span><span class="o">-</span><span class="mi">2</span><span class="p">));</span>
  <span class="n">__syncthreads</span><span class="p">();</span>

  <span class="c1">// Preload the fragments for k=0..1, k=2..3 for both A/B tiles </span>
  <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">m</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">m</span><span class="o">&lt;</span><span class="mi">2</span><span class="p">;</span> <span class="n">m</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
    <span class="n">load_matrix_x4</span><span class="p">(</span><span class="n">aReg</span><span class="p">[</span><span class="n">m</span><span class="p">]</span>    <span class="p">,</span> <span class="n">aLoadPtr</span><span class="p">[</span><span class="n">m</span><span class="o">*</span><span class="mi">8</span> <span class="o">+</span> <span class="n">warpOffsetA</span> <span class="o">+</span> <span class="n">loadRowA</span><span class="p">]</span> <span class="o">+</span> <span class="n">loadColA</span><span class="p">);</span>
    <span class="n">load_matrix_x4</span><span class="p">(</span><span class="n">aReg</span><span class="p">[</span><span class="n">m</span><span class="p">]</span> <span class="o">+</span> <span class="mi">4</span><span class="p">,</span> <span class="n">aLoadPtr</span><span class="p">[</span><span class="n">m</span><span class="o">*</span><span class="mi">8</span> <span class="o">+</span> <span class="n">warpOffsetA</span> <span class="o">+</span> <span class="n">loadRowA</span><span class="p">]</span> <span class="o">+</span> <span class="p">(</span><span class="n">loadColA</span><span class="o">^</span><span class="mi">2</span><span class="p">));</span>
  <span class="p">}</span>
  <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">n</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">n</span><span class="o">&lt;</span><span class="mi">2</span><span class="p">;</span> <span class="n">n</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
    <span class="n">load_matrix_x2</span><span class="p">(</span><span class="n">bReg</span><span class="p">[</span><span class="n">n</span><span class="p">]</span>   <span class="p">,</span> <span class="n">bLoadPtr</span><span class="p">[</span><span class="n">n</span><span class="o">*</span><span class="mi">4</span> <span class="o">+</span> <span class="n">warpOffsetB</span> <span class="o">+</span> <span class="n">loadRowB</span><span class="p">]</span> <span class="o">+</span> <span class="n">loadColB</span><span class="p">);</span>
    <span class="n">load_matrix_x2</span><span class="p">(</span><span class="n">bReg</span><span class="p">[</span><span class="n">n</span><span class="p">]</span><span class="o">+</span> <span class="mi">2</span><span class="p">,</span> <span class="n">bLoadPtr</span><span class="p">[</span><span class="n">n</span><span class="o">*</span><span class="mi">4</span> <span class="o">+</span> <span class="n">warpOffsetB</span> <span class="o">+</span> <span class="n">loadRowB</span><span class="p">]</span> <span class="o">+</span> <span class="p">(</span><span class="n">loadColB</span><span class="o">^</span><span class="mi">2</span><span class="p">));</span>
  <span class="p">}</span>

  <span class="c1">// Start next cp.async: on last N_STAGES-1 iterations the results of </span>
  <span class="c1">// these copies are not used. The copies are done solely to allow</span>
  <span class="c1">// us to keep the argument to `wait_group` fixed at N_STAGES-2</span>
  <span class="n">kStart</span> <span class="o">=</span> <span class="p">(</span><span class="n">kStart</span> <span class="o">&gt;</span> <span class="mi">512</span><span class="o">-</span><span class="mi">4</span><span class="p">)</span> <span class="o">?</span> <span class="mi">512</span><span class="o">-</span><span class="mi">4</span> <span class="o">:</span> <span class="n">kStart</span><span class="p">;</span>
  <span class="n">cp_async</span><span class="p">(</span><span class="n">aStorePtr</span><span class="p">[</span><span class="n">storeRow</span><span class="p">]</span> <span class="o">+</span> <span class="n">storeCol</span><span class="p">,</span> <span class="n">aGlobalAddress</span> <span class="o">+</span> <span class="n">kStart</span><span class="p">);</span>
  <span class="n">cp_async</span><span class="p">(</span><span class="n">bStorePtr</span><span class="p">[</span><span class="n">storeRow</span><span class="p">]</span> <span class="o">+</span> <span class="n">storeCol</span><span class="p">,</span> <span class="n">bGlobalAddress</span> <span class="o">+</span> <span class="n">kStart</span><span class="p">);</span>
  <span class="k">asm</span> <span class="k">volatile</span><span class="p">(</span><span class="s">"cp.async.commit_group;</span><span class="se">\n</span><span class="s">"</span> <span class="o">::</span><span class="p">);</span>

  <span class="c1">// Compute the mmas</span>
  <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">m</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">m</span><span class="o">&lt;</span><span class="mi">2</span><span class="p">;</span> <span class="n">m</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
    <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">n</span><span class="o">=</span><span class="mi">0</span><span class="p">;</span> <span class="n">n</span><span class="o">&lt;</span><span class="mi">2</span><span class="p">;</span> <span class="n">n</span><span class="o">++</span><span class="p">)</span> <span class="p">{</span>
      <span class="n">mma_m16n8k16</span><span class="p">(</span><span class="n">aReg</span><span class="p">[</span><span class="n">m</span><span class="p">]</span>    <span class="p">,</span> <span class="n">bReg</span><span class="p">[</span><span class="n">n</span><span class="p">]</span>    <span class="p">,</span> <span class="n">dReg</span><span class="p">[</span><span class="n">m</span><span class="p">][</span><span class="n">n</span><span class="p">]);</span>
      <span class="n">mma_m16n8k16</span><span class="p">(</span><span class="n">aReg</span><span class="p">[</span><span class="n">m</span><span class="p">]</span> <span class="o">+</span> <span class="mi">4</span><span class="p">,</span> <span class="n">bReg</span><span class="p">[</span><span class="n">n</span><span class="p">]</span> <span class="o">+</span> <span class="mi">2</span><span class="p">,</span> <span class="n">dReg</span><span class="p">[</span><span class="n">m</span><span class="p">][</span><span class="n">n</span><span class="p">]);</span>
    <span class="p">}</span>
  <span class="p">}</span>
<span class="p">}</span>
</code></pre></div></div>
<p>Next we load the current shared memory stage to registers. In this kernel we preload the entire <code class="language-plaintext highlighter-rouge">(M/N=64, K=4)</code> tile into registers, requiring 16 registers for <code class="language-plaintext highlighter-rouge">A</code> and 8 for <code class="language-plaintext highlighter-rouge">B</code>. The extra shared memory required by the <code class="language-plaintext highlighter-rouge">N_STAGES</code> shared memory buffers is the occupancy bottleneck so using these extra registers makes sense to parallelize the loads as much as possible. After the starting the loads to registers, we submit the next <code class="language-plaintext highlighter-rouge">cp.async</code> instruction, and finally we perform the <code class="language-plaintext highlighter-rouge">mma</code> instructions and increment the load and store pointers modulo N_STAGES.</p>

<p>As <code class="language-plaintext highlighter-rouge">N_STAGES-1</code> K blocks were loaded before the main loop, on the last <code class="language-plaintext highlighter-rouge">N_STAGES-1</code> iterations through the main loop we don’t need to load any more data from global memory. However the argument to <code class="language-plaintext highlighter-rouge">cp.async.wait_group</code> needs to be a compile time constant and submitting superfluous copies is a hacky way to keep the argument to <code class="language-plaintext highlighter-rouge">wait_group</code> fixed at <code class="language-plaintext highlighter-rouge">N_STAGES-2</code>. Without these copies the kernel would be incorrect unless we decreased this argument on each of the last <code class="language-plaintext highlighter-rouge">N_STAGES-1</code> iterations.</p>

<h3 id="performance-2">Performance</h3>
<p>Sadly after all that effort Kernel 3.0 is a very minor improvement over Kernel 2.1. For <code class="language-plaintext highlighter-rouge">N_STAGES=3</code>, the execution time is 1000 us, giving 137.4 TFLOP/s, 89.5% cuBLAS, 83.2% 4090 peak performance. Setting <code class="language-plaintext highlighter-rouge">N_STAGES=4</code> has similar performance and higher than this reduces performance. Looking at the warp state stats shows that overall stalls are lower than in Kernel 2.1:</p>

<p><img src="/assets/images/kernel-3-warp-stats.png" alt="3-warp-stats" /></p>

<p>This is partially due to reduced occupancy: Kernel 2.1 has 32 warps per SM while Kernel 3.0 has 24 due to the extra shared memory requirements.</p>

<p>As stalls due to barrier synchronization are still high, a reasonable optimization is to try increasing the work each warp does within a main loop iteration. We do this in Kernel 3.1 by increasing the tiling in the M/N dimensions from 2 to 4. This doubles the thread block tile size to <code class="language-plaintext highlighter-rouge">(M/N=128, K=4)</code> meaning that each warp performs 4x4x2=32 <code class="language-plaintext highlighter-rouge">mma</code> instructions per main loop iteration.</p>

<p>Kernel 3.1 has an execution time of 895 us, giving throughput of 153.6 TFLOP/s, 100% cuBLAS, 93.0% of RTX 4090 peak performance. Looking at the warp state stats shows that the vast majority of stalls are now due to waiting for Tensor Cores, in fact each warp now waits on average 36 cycles for a Tenor Core to be available:</p>

<p><img src="/assets/images/kernel-3b-warp-stats.png" alt="3b-warp-stats" /></p>

<p>The ratio of elapsed cycles to mma instructions for Kernel 3.1 is 34.2, consistent with the figure of 93.5% peak performance.</p>

<p>Surprisingly nsight-compute shows the Tensor Core utilization as only 47.3% so what is going on?</p>

<p><img src="/assets/images/tensor-core-util.png" alt="tc-util" /></p>

<p>It seems that nsight uses a fixed latency of 16 cycles when computing <code class="language-plaintext highlighter-rouge">smsp__pipe_tensor_op_hmma_cycles_active</code> as the metric value is consistently 16 times the value of <code class="language-plaintext highlighter-rouge">smsp__inst_executed_pipe_tensor_op_hmma</code>. This seems to be an error as the latency for the <code class="language-plaintext highlighter-rouge">m16n8k16</code> <code class="language-plaintext highlighter-rouge">mma</code> instruction should be 32, so the utilization should be 94.6%.</p>

<p>One final thing I noticed is that both Kernels 3.0 &amp; 3.1 have bank conflicts, for 3.1 nsight shows:</p>

<p><img src="/assets/images/kernel-3b-conflict.png" alt="kernel-3b-conflict" /></p>

<p>Confusingly in this view (Memory Tables) the conflicts appear only in the shared loads, whereas in the source metrics they appear both when copying from global to shared and when loading from shared to registers. The shared loads in particular use the same <code class="language-plaintext highlighter-rouge">ldmatrix</code> instruction as in Kernel 2, so I’m not sure how moving to <code class="language-plaintext highlighter-rouge">cp.async</code> introduces a conflict there.</p>

<p>It’s possible these conflicts are not real, nsight-compute reports erroneous conflicts in some cases as described <a href="https://forums.developer.nvidia.com/t/shared-memory-bank-conflicts-and-nsight-metric/115731/12">here</a>. I need to look into this further and will update the post if/when I find out what’s going on.</p>

<h2 id="conclusion">Conclusion</h2>
<p>We’ve gone from a naive implementation with correspondingly poor performance, to a kernel that is on par with cuBLAS, at least for this extremely specific problem formulation. In the process we’ve developed an understanding of <code class="language-plaintext highlighter-rouge">mma</code> and related PTX instructions, along with the techniques needed to feed data to Tensor Cores efficiently.</p>

<h2 id="code">Code</h2>
<p>The code for all Kernels is available here: <a href="https://github.com/spatters/mma-matmul">https://github.com/spatters/mma-matmul</a>.</p>

<h3 id="appendix-floating-point-accuracy">Appendix: Floating Point Accuracy</h3>
<p>NVIDIA does not fully document the exact numerical behavior of the Tensor Core <code class="language-plaintext highlighter-rouge">mma</code> instruction. The PTX ISA states: 
<img src="/assets/images/mma-numeric.png" alt="mma-numeric" />
Getting into these details is not the focus of this post, but one example of rounding error is worth noting. Kernel 1.0 accumulates the results of the main loop over K directly in <code class="language-plaintext highlighter-rouge">dReg</code> meaning at each iteration the accumulation <code class="language-plaintext highlighter-rouge">dReg = dReg + aReg * bReg</code> happens within the <code class="language-plaintext highlighter-rouge">mma</code> operation, which can cause loss of precision if <code class="language-plaintext highlighter-rouge">dReg</code> is large compared to <code class="language-plaintext highlighter-rouge">aReg * bReg</code>.</p>

<p>When testing correctness of the implementation I initialize inputs with <code class="language-plaintext highlighter-rouge">U[0,1)</code> values. This means <code class="language-plaintext highlighter-rouge">dReg</code> grows monotonically as we loop over K, and performing the accumulation directly in the <code class="language-plaintext highlighter-rouge">mma</code> operation causes round off such that the result using <code class="language-plaintext highlighter-rouge">mma</code> is consistently lower than a reference implementation using <code class="language-plaintext highlighter-rouge">fp16/fp32</code> operations on CUDA cores (relative difference around 1e-5). This issue can be avoided by instead performing an mma without accumulation, and accumulating the results outside, i.e.</p>
<div class="language-c++ highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">float4</span> <span class="n">dRegAcc</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
<span class="kt">float</span> <span class="n">cReg</span><span class="p">[</span><span class="mi">4</span><span class="p">]</span> <span class="o">=</span> <span class="p">{</span><span class="mf">0.</span><span class="p">};</span>
<span class="n">mma_m16n8k16</span><span class="p">(</span><span class="n">aPtr</span><span class="p">,</span> <span class="n">bPtr</span><span class="p">,</span> <span class="n">cReg</span><span class="p">,</span> <span class="n">dReg</span><span class="p">);</span>
<span class="n">float4</span> <span class="o">*</span><span class="n">dRegPtr</span> <span class="o">=</span> <span class="k">reinterpret_cast</span><span class="o">&lt;</span><span class="n">float4</span> <span class="o">*&gt;</span><span class="p">(</span><span class="n">dReg</span><span class="p">);</span>
<span class="n">dRegAcc</span><span class="p">.</span><span class="n">x</span> <span class="o">+=</span> <span class="n">dRegPtr</span><span class="o">-&gt;</span><span class="n">x</span><span class="p">;</span>
<span class="n">dRegAcc</span><span class="p">.</span><span class="n">y</span> <span class="o">+=</span> <span class="n">dRegPtr</span><span class="o">-&gt;</span><span class="n">y</span><span class="p">;</span>
<span class="n">dRegAcc</span><span class="p">.</span><span class="n">z</span> <span class="o">+=</span> <span class="n">dRegPtr</span><span class="o">-&gt;</span><span class="n">z</span><span class="p">;</span>
<span class="n">dRegAcc</span><span class="p">.</span><span class="n">w</span> <span class="o">+=</span> <span class="n">dRegPtr</span><span class="o">-&gt;</span><span class="n">w</span><span class="p">;</span>
</code></pre></div></div>
<p>Applied to Kernel 3.1, this incurs a performance penalty of around 10 us, reduces the difference to the reference kernel by two orders of magnitude and centers it. Detailed investigation into the numerical behavior of Tensor Cores in general can be found in <sup id="fnref:4" role="doc-noteref"><a href="#fn:4" class="footnote" rel="footnote">8</a></sup>.</p>

<h3 id="references">References</h3>

<div class="footnotes" role="doc-endnotes">
  <ol>
    <li id="fn:3" role="doc-endnote">
      <p><a href="https://www.nvidia.com/en-us/on-demand/session/gtcsj20-s21745">GTC 2020 Developing CUDA Kernels to Push Tensor Cores to the Absolute Limite on NVIDIA A100</a> <a href="#fnref:3" class="reversefootnote" role="doc-backlink">&#8617;</a> <a href="#fnref:3:1" class="reversefootnote" role="doc-backlink">&#8617;<sup>2</sup></a> <a href="#fnref:3:2" class="reversefootnote" role="doc-backlink">&#8617;<sup>3</sup></a> <a href="#fnref:3:3" class="reversefootnote" role="doc-backlink">&#8617;<sup>4</sup></a></p>
    </li>
    <li id="fn:8" role="doc-endnote">
      <p><a href="https://siboehm.com/articles/22/CUDA-MMM">How to Optimize a CUDA Matmul Kernel for cuBLAS-like Performance: a Worklog</a> <a href="#fnref:8" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:2" role="doc-endnote">
      <p>Originally FLOPS stood for Floating point Operations Per Second. However In deep learning it is also used a measure of quantity i.e. to mean Floating point Operations. To prevent confusion I am using FLOP/s for rates and FLOP for quantities, as suggested <a href="https://blog.heim.xyz/flop-for-quantity-flop-s-for-performance">here</a>. <a href="#fnref:2" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:6" role="doc-endnote">
      <p><a href="https://images.nvidia.com/aem-dam/Solutions/Data-Center/l4/nvidia-ada-gpu-architecture-whitepaper-v2.1.pdf">Ada Architecture White Paper</a> <a href="#fnref:6" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:7" role="doc-endnote">
      <p><a href="https://arxiv.org/pdf/2402.13499v1">Benchmarking and Dissecting the Nvidia Hopper GPU Architecture</a> <a href="#fnref:7" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:5" role="doc-endnote">
      <p><a href="https://forums.developer.nvidia.com/t/why-would-code-run-1-7x-faster-when-run-with-nvprof-than-without/56406/7">Why would code run 1.7x faster when run with nvprof than without</a> <a href="#fnref:5" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:1" role="doc-endnote">
      <p><a href="https://developer.download.nvidia.com/video/gputechconf/gtc/2019/presentation/s9593-cutensor-high-performance-tensor-operations-in-cuda-v2.pdf">GTC 2019 Programming Tensor Cores: Navtive Volta Tensor Cores With CUTLASS</a> <a href="#fnref:1" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:4" role="doc-endnote">
      <p><a href="https://eprints.maths.manchester.ac.uk/2774/1/fhmp20.pdf">Numerical Behavior of NVIDIA Tensor Cores</a> <a href="#fnref:4" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
  </ol>
</div>]]></content><author><name>Sam Patterson</name></author><summary type="html"><![CDATA[Using Tensor Cores is now a prerequisite to get anywhere near peak performance on NVIDIA GPUs. In this post we work through the process of developing an efficient Tensor Core matrix multiplication kernel targeting the Ada architecture.]]></summary></entry></feed>