Skip to content

6. Inference

Introduction

This section discusses the optimization techniques and strategies used to make inference with the LLaMA 3 405B model more efficient. Since LLaMA 3 is a massive model with over 405 billion parameters, running inference efficiently requires substantial optimization to handle the computational load. Two main techniques are highlighted: pipeline parallelism and FP8 quantization.

Pipeline Parallelism

Pipeline Parallelism is critical for managing the memory and computational requirements of large models like LLaMA 3. When using BF16 (16-bit floating-point) precision for the model parameters, the LLaMA 3 model (specifically the 405B variant) cannot fit into the GPU memory of a single machine equipped with 8 Nvidia H100 GPUs.

  • Challenges and Solutions:

    • Memory Constraints: A single machine cannot accommodate the model's parameters, which means the model needs to be spread across multiple GPUs and machines. To address this, BF16 precision is parallelized across 16 GPUs spread over two machines.

    • Within-Machine Optimization: NVLink Bandwidth: High NVLink bandwidth within each machine allows the use of tensor parallelism, meaning the model parameters are distributed across multiple GPUs in the same machine. Cross-Machine Optimization:

    • Pipeline Parallelism: For communication between machines, where connectivity has lower bandwidth and higher latency, pipeline parallelism is employed. This approach splits the model into multiple stages and allows forward passes to happen in parallel, increasing efficiency.

  • Bubbles Issue

    • During training, pipeline bubbles—idle periods while waiting for data—are a major concern, as they reduce efficiency. However, in inference, these bubbles are not as problematic since there is no backward pass.
  • Micro-Batching:

    • Micro-batching improves throughput during inference by splitting the workload into smaller batches that can be processed in parallel.
    • Evaluations were conducted with workloads of 4,096 input tokens and 256 output tokens during two stages:
      • Pre-fill Stage: Where the key-value cache is initialized.
      • Decoding Stage: During token generation.
    • Micro-batching enables concurrent execution of micro-batches, improving overall throughput, although at the cost of slightly higher latency. Despite this trade-off, it offers a better throughput-latency trade-off.

FP8 Quantization

Quantization is another critical strategy used to make LLaMA 3 more efficient during inference. FP8 (8-bit floating-point) quantization leverages the native FP8 support on Nvidia H100 GPUs, enabling low-precision inference. This significantly reduces the computational and memory requirements.

  • Quantization Strategy:

    • Matrix Multiplications:
      • FP8 quantization is applied to most matrix multiplications in the model, especially in the feedforward network layers. These layers account for approximately 50% of the inference compute time.
      • The parameters in the self-attention layers are not quantized, as these layers are more sensitive to precision loss.
    • Dynamic Scaling Factors:

      • The quantization process uses dynamic scaling factors for better accuracy. This is essential to ensure that scaling factors are adjusted dynamically based on the data being processed. Custom CUDA kernels were developed to reduce the overhead of calculating these scaling factors.
        • Avoiding Quantization in Specific Layers:
      • The first and last Transformer layers are excluded from quantization, similar to the approach by Zhang et al. (2021). This is done because these layers are more prone to errors when quantized.
  • Handling High-Perplexity Tokens:

    • Certain tokens, such as dates, can lead to large activation values, causing the dynamic scaling factors to become excessively high in FP8. To manage this, scaling factors are upper-bounded to 1200 to prevent underflows, which can lead to decoding errors.
  • Row-wise Quantization:

    • Instead of quantizing entire tensors, the model uses row-wise quantization, computing scaling factors across rows of the parameter and activation matrices. This provides finer control over quantization and helps to maintain high accuracy during inference. This approach is illustrated with results in Figure 25 in the paper, showing how row-wise quantization offers more granular control compared to tensor-wise quantization.
  • Impact on Model Output Quality:

    • Evaluations show that FP8 quantization has a negligible impact on the model’s responses. A detailed analysis of reward model scores for 100,000 responses using both BF16 and FP8 shows almost identical distributions, indicating minimal degradation in model quality. The comparison between BF16 and FP8 inference is shown in Figure 26, which confirms that the FP8 quantization approach does not significantly affect the model’s performance.
  • Efficiency Gains:

    • Experiments show that using FP8 inference provides a 50% improvement in throughput during the pre-fill stage and a substantially better throughput-latency trade-off during decoding (as shown in Figure 27). This makes FP8 quantization a highly effective technique for optimizing the inference process in LLaMA 3, especially for large models like the 405B variant.

Conclusion

Section 6 of the paper presents two key optimization techniques for efficient inference in LLaMA 3: pipeline parallelism and FP8 quantization. These strategies address the memory and computational challenges posed by large models, allowing LLaMA 3 to operate effectively on GPUs like the Nvidia H100. By using pipeline parallelism and micro-batching, the model can scale across multiple GPUs while maintaining high throughput. FP8 quantization further reduces the resource requirements without compromising model quality, making LLaMA 3 both powerful and efficient for real-world applications.