Luisa Crawford
Jul 19, 2025 03:30
NVIDIA introduces advanced techniques for reducing latency in large language model inference, leveraging JAX and XLA for significant performance improvements in GPU-based workloads.
In the ongoing quest to optimize inference workloads, NVIDIA has unveiled a series of enhancements aimed at reducing latency when running large language models (LLMs) in production environments. These advancements are particularly crucial during the LLM decode phase, where reducing time-to-next-token is vital, according to NVIDIA.
Addressing Latency Challenges
NVIDIA’s approach involves partitioning inference tasks across multiple GPUs using tensor parallelism, specifically targeting the Multilayer perceptron (MLP) and projection GEMM layers within transformer blocks. This partitioning helps minimize runtime latencies, a common bottleneck in high-performance computing.
During the decode stage, static overheads such as kernel invocation and communication setup can dominate, leading to increased latency. To combat this, NVIDIA developed techniques to minimize these overheads, which significantly contribute to overall decode latency.
Innovations in All-Reduce Algorithms
NVIDIA’s research revealed that the all-reduce collective in tensor parallel layers was a significant bottleneck, consuming approximately 23% of end-to-end decode latency. Traditionally, the ring algorithm is used for all-reduce operations, which, while bandwidth optimal for larger messages, incurs high latencies for smaller message sizes.
To address this, NVIDIA implemented a custom single-shot all-reduce algorithm, which aggregates data from peers and performs reduction in a single stage. This innovation reduces communication latency by allowing simultaneous data exchanges via NVLink, despite increasing total bandwidth.
Furthermore, NVIDIA utilized cudaDeviceEnablePeerAccess to eliminate additional memory copy overheads, enabling direct access to buffers on peer GPUs. This method is particularly effective in single-node, multi-GPU setups, where a shared CUDA context simplifies memory access across devices.
Fusion and Performance Gains
The single-shot all-reduce kernel was further optimized by fusing it with layer normalization and pointwise addition operations into a single CUDA C++ kernel. This fusion minimizes kernel launch overheads and data movement, providing a ~3x speedup over standalone all-reduce kernels and a ~27% improvement in decode phase latency.
By grouping and launching these kernels as a single CUDA Graph, NVIDIA achieved an additional 5% reduction in decode latency. This comprehensive integration demonstrates the potential of custom kernels in enhancing inference efficiency.
Future Optimizations and Developments
NVIDIA continues to explore further optimizations for low-latency inference, particularly for workloads with small message sizes. Upcoming features in NCCL 2.27 and future releases aim to improve communication overheads, potentially achieving up to 4x faster communication for smaller payloads.
Additionally, NVIDIA is leveraging GPU-initiated device-side communication APIs available in the NVIDIA OpenSHMEM Library to interleave compute-communication blocks, effectively hiding communication latencies. Recent advancements in the Mosaic-GPU DSL facilitate expressing interleaved compute-communication fusion patterns, promising further enhancements in distributed fusion kernels for various parallel paradigms.
For more detailed insights, the original article by NVIDIA can be accessed here.
Image source: Shutterstock
Source: https://blockchain.news/news/enhancing-inference-efficiency-nvidias-innovations-with-jax-and-xla