Date:

Data-Efficient Knowledge Distillation for Supervised Fine-Tuning with NeMo-Aligner

Knowledge Distillation in NeMo-Aligner

Knowledge distillation is an approach for transferring the knowledge of a much larger teacher model to a smaller student model, ideally yielding a compact, easily deployable student with comparable accuracy to the teacher. Knowledge distillation has gained popularity in pretraining settings, but there are fewer resources available for performing knowledge distillation during supervised fine-tuning (SFT). 

Knowledge Distillation in NeMo-Aligner

NeMo-Aligner has open-sourced an implementation for using knowledge distillation during SFT that is more data-efficient and yields higher accuracy than its standard SFT counterpart (Table 1).

Training  Objective Train Steps MMLU (5-shot) MMLU (0-shot) HumanEval (0-shot) MBPP (0-shot) GSM8K (0-shot) MATH (0-shot)
SFT loss 600,000 65.3 56.9 64.6 71.7 84.2 30.12
KD + SFT loss 420,000 65.3 57.3 70.1 73.3 85.2 35.84
KD + SFT loss 600,000 65.3 57.6 72 73.8 84.8 36.6

Knowledge Distillation in NeMo-Aligner

There are a number of approaches to transfer knowledge from a large model during SFT. The most common approach involves using the teacher model for synthetic data generation, which we refer to as KD-SDG. The synthetically generated data is then used to fine-tune the student model.

There is also a seminal approach in which the student is trained to match the teacher’s output logits. This approach was introduced in Distilling the Knowledge in a Neural Network. We refer to this as KD-logit.

This method enables a more informative gradient signal, using knowledge of the similarities and dissimilarities across classes, termed dark knowledge. For more information, see Dark Knowledge in Neural Networks.

In this post and in NeMo-Aligner, we focus on applying KD-logit during SFT.

NeMo-Aligner’s offline KD-logit pipeline consists of these key steps:

  1. A preprocessing step in which the teacher model makes predictions on the training data. The logits from the teacher model are added to the training data.
  2. A training step in which the student is trained to match its logits with the teacher’s logits.

Results

Table 1 shows that fine-tuning a model using the knowledge distillation objective yields higher accuracy and requires fewer training tokens than vanilla SFT. We conducted experiments using a base Nemotron-4 15B student model and a fine-tuned Nemotron-4 340B teacher model.

The dataset used for SFT is a combination generated using the techniques described in the following papers:

Both the math and code portions of the dataset were generated using synthetic data generation. These experiments set K=100 and λ=0.1.

With the same number of training steps, the model fine-tuned using the joint knowledge distillation and SFT objective performs better than the SFT baseline on six of the seven evaluation metrics. In particular, we saw significant improvement in the HumanEval, MBPP, and MATH benchmarks, which measure coding and mathematical reasoning skills. On MMLU, which evaluates a diverse range of language understanding tasks, the KD-finetuned model performs at least as well as the baseline in the zero-shot setting and outperforms the baseline in the 5-shot setting.

With only 70% of the training tokens, the KD-finetuned Nemotron-4 still outperforms the vanilla SFT model on the same six evaluation metrics.

Conclusion

These results have two important implications. First, we’ve shown that knowledge distillation can be used to improve the accuracy of fine-tuned models. This is especially useful in settings where data is scarce, as fewer training tokens are needed to achieve good accuracy.

Second, we’ve demonstrated that KD-logit can be used in conjunction with your SDG data to achieve compounding benefits.

Frequently Asked Questions

Q: What is knowledge distillation? A: Knowledge distillation is an approach for transferring the knowledge of a much larger teacher model to a smaller student model, ideally yielding a compact, easily deployable student with comparable accuracy to the teacher.

Q: What is NeMo-Aligner? A: NeMo-Aligner is an implementation for using knowledge distillation during SFT that is more data-efficient and yields higher accuracy than its standard SFT counterpart.

Q: How does NeMo-Aligner’s offline KD-logit pipeline work? A: NeMo-Aligner’s offline KD-logit pipeline consists of two key steps: preprocessing and training. During preprocessing, the teacher model makes predictions on the training data, and the logits from the teacher model are added to the training data. During training, the student is trained to match its logits with the teacher’s logits.

Q: What are the benefits of using knowledge distillation during SFT? A: Using knowledge distillation during SFT can improve the accuracy of fine-tuned models and reduce the number of training tokens needed to achieve good accuracy.

Latest stories

Read More

LEAVE A REPLY

Please enter your comment!
Please enter your name here