So far, we have analyzed our implementations mainly by a combination of theoretical analysis and wallclock time measurements. In this section, we will run the kernels using the Nsight Compute profiler. This executes the kernel multiple times in a controlled setup (e.g., locked GPU frequency), gathering different statistics from the GPU's performance counters during each run.
This will give us detailed information on how much we utilize different parts of the memory subsystem, and what causes warps to stall (i.e., spend time not issuing an instruction).
In the previous sections, we have used K2200 as a reference GPU. While the basic principles of GPU programming have not changed since its release data in 2014, it is missing several important architectural improvements that are present in more recent GPUs. For this reason, from now on, we will instead consider Quadro RTX 4000, which, among other things, has a Unified Data Cache that can be partitioned into the L1 cache and shared memory, and independent execution units for integer and floating-point operations. More information on the Turing architecture can be found in the tuning guide and whitepaper.
Until now, we have not considered caches on the GPU at all. This is, in part, because older GPUs like K2200
can only make use of the L1 cache in a very limited capacity: only data that is guaranteed to remain unchanged
for the duration of the kernel call could be cached. In some cases, the compiler might be able to automatically
proof that such a constraint is fulfilled, otherwise, caching only happens for data accessed through the
__ldg
intrinsic.
In contrast, on the RTX 4000 (and any more recent GPU), the L1 cache can handle reads of non-constant data.
In addition to the L1 cache local to each streaming multiprocessor, there is also a single L2 cache that is shared
among all SMs.
Do these improved caching techniques relieve us from the necessity of ensuring coalescing, register reuse, and shared memory? Let's find out by profiling our kernels!
The first question we want to ask is, does switching to coalesced reads reduce stress on the memory system. Let's check how much data we are reading from DRAM: For V0, this is a huge amount, we read a total of 60 GB. Surprisingly, though, switching to V1 does not actually help; instead we end up reading a total of 127 GB, over twice more than before. Our new access pattern has reduced the hit rate in the L1 cache from 96% to 56%.
Why is it faster, then? First, it is important to realize just how much bandwidth there is between a GPU and its cores. In fact, V0 uses only 4% of the available bandwidth, and V1 increases that number to 24%; neither kernel is limited by the DRAM bandwidth, but instead by latency.
Let's take a closer look at what the warps are spending their time on. First, how much time does each warp spend just waiting for something? In V0, we measure that 103 cycles are spend, on average, between two consecutive instructions issued by a single warp. For V1, this number is only 38 cycles. Contrary to the CPU, we do not actually need each warp to issue an instruction every cycle, because each scheduler inside the SM is responsible for multiple warps. As long as at least one of those is ready to go, we can keep the cores busy.
This is not something we achieve, though. For V0, only in 8% of all cycles does the scheduler have an eligible warp to advance. This improves substantially with V1, where we get 21%. Still not great, but much improved.
What is it that prevents the warps from executing? To answer that question, we can look at how many cycles each warp spends in different Warp Stall states, explained below.
In V0, by far the most time is spent for LG Throttle. That means that the warp wanted to request to read some memory, but there were too many pending memory requests already in flight, so the current warp could not append to the queue. Because V0 does not result in coalesced reads, each thread will generate an independent memory request, and each individual entry in the queue will only fetch a small amount of memory. Thus, the queue of outstanding loads fills up quickly, and threads have to wait before they can submit a new memory request, resulting in low memory bandwidth utilization.
In V1, the amount of LG Throttle stalls is reduced drastically, from 96 cycles to 28 cycles. The reduction in LG Throttle comes at a cost, though: Compared to V0, the improved memory access in V1 leads to Long Scoreboard stalls, increasing from 1.8 cycles to 4 cycles. A Long Scoreboard stall indicates that the warp cannot progress because it has to wait for data to be retrieved from memory. Thus, LG Throttle indicates warps halting at an instruction that fetches data from memory, whereas Long Scoreboard implied waiting at an instruction dependent on fetched data itself.
The increase in Long Scoreboard stalls demonstrates a general phenomenon when optimizing GPU kernels: In fixing one bottleneck, the queue of outstanding loads running out of capacity, the overall execution runs faster up to the point where a new bottleneck becomes limiting. In this case, this turns out to be the latency of memory accesses. Only because the reduction on LG Throttle far exceeds the increase in Long Scoreboard do we get significant wallclock speed-ups.
The remaining stall reasons are Selected, Not Selected and Wait. Selected means that the warp was chosen by the scheduler in that given cycle, and thus is not an actual warp stall. Thus, by definition, this reason will always account for exactly one cycle. For every instruction issued, the warp will be selected by the scheduler for one cycle. Not Selected, on the other hand, means that the warp is ready to issue its next instruction, but the scheduler picked a different warp. This is an indicator of having sufficient parallel work available. Finally, a warp ends up in Wait state if its next instruction depends on the result of a previous instruction with fixed latency.
To further reduce stress on the memory system, in V2 we started reusing data in registers, avoiding the need to go to the memory system for a large fraction of the operations. We can see the success of this strategy in the following table:
V0 | V1 | V2 | |
---|---|---|---|
DRAM reads | 60 GB | 127 GB | 17 GB |
DRAM bandwidth utilization | 4% | 24% | 21% |
L2 reads | 155 GB | 455 GB | 27 GB |
L1 hit rate | 96% | 56% | 54% |
L1 read requests | 16 billion | 16 billion | 2.0 billion |
By making good use of the registers, we managed to reduce the amount of data transfer through the memory hierarchy by an order of magnitude across all level. Consequently, we eliminated LG Throttle entirely as a reason for warps to stall. Now we are down to only 8 cycles between issued instructions.
With up to 12 warps per scheduler, that should be enough to keep the cores busy, yet we still see significant amounts of stalling in the diagram above. In fact, the schedulers still only issue an instruction in 61% of all cycles. What went wrong?
Registers are a scarce resource on the GPU, with all threads in all blocks on one SM competing for them. If a block requires more registers than the SM has available (65k), then the kernel will not launch at all. In more benign cases, the number of blocks that can execute on SM concurrently just decreases until the register limit is reached. This means that some warp slots in the schedulers will remain unused. The kernel has low occupancy.
Going from V1 to V2, the amount of registers per thread increased from 50 to 96. Therefore, we went from 32 active warps per SM down to 20, out of a maximum of 32 warps per multiprocessor. This decrease is also reflected in a decline of Not Selected warp stalls, from 2 cycles to only 0.8 cycles.
Consequently, it becomes much more important that each warp issues instructions in rapid succession, because there are fewer warps that can hide away the latencies. There is still a substantial amount of Long Scoreboard stalls, with warps waiting on average 3 cycles between instructions for data to arrive from memory.
Let's peek into the detailed statistics of V3. First, we generate an extended version of the memory table that also covers shared memory:
V0 | V1 | V2 | V3 | |
---|---|---|---|---|
DRAM reads | 60 GB | 127 GB | 17 GB | 16 GB |
DRAM bandwidth utilization | 4% | 24% | 21% | 25% |
L2 reads | 155 GB | 455 GB | 27 GB | 28 GB |
L1 hit rate | 96% | 56% | 54% | 7% |
L1 read requests | 15638 million | 15638 million | 1976 million | 247 million |
L1 write requests | 1.2 million | 1.2 million | 1.2 million | 1.2 million |
Shared memory reads | – | – | – | 1976 million |
Shared memory writes | – | – | – | 247 million |
We can see a perfect correspondence between the number of loads from L1 being performed for V2 and the number of shared loads being performed in V3, as well as the remaining number of L1 loads and shared stores. This reflects exactly what we do in the code: everything we load from the global memory gets stored in the shared memory, and then we access it in the same way as V2.
We can also see that, now that we moved the main calculation to use shared memory, there is not much that we still gain from the L1 cache. But in V2, only 54% of our 1976 million requests actually hit the cache, whereas now, all 1976 million requests will be served from shared memory.
The warp stall statistics show a successful reduction of Long Scoreboard stalls, at the cost of introducing entirely new stall reasons to this kernel.
__syncthreads()
synchronization barriers.
If some threads reach those faster than others, they cannot do any more useful work, and instead will be in a Barrier stall.
Finally, let us take a look at how much we use the different units in the GPU now:
We can see that both the Arithmetic-Logic Unit (ALU, for minimum) and Fused-Multiply-Add Unit (FMA, for addition) pipelines are running at about 70% capacity, about 15% higher than without shared memory. However, the most utilized pipe turns out to be the load/store unit, at over 80%.
The combination warp stalls due to MIO Throttle and the LSU being the most-used pipeline suggest that we could still see improvements if we reduced the number of instructions needed for shared-memory access. The obvious way would be to increase register reuse, but that would lead to requiring too many registers to hold intermediate results. Thus, a preferable alternative is to try and load more data with a single instruction. This can be achieved using vectorized loads, which allow up to 16 bytes to be loaded with just one instruction.
Finally, let us go back to the original question, how important are the different optimizations on more modern GPU with L1 caches. To that end, the figure below shows the fraction of peak performance achieved on the K2200 and the RTX 4000, respectively.
We can see that the differences between the GPUs are not that big. While the non-coalesced V0 code sees strong benefits from the improved cache hierarchy, with about twice the relative performance, it still achieves far less than 10% of what the GPU is capable of.