Files
distillation/easydistill/kd/multi_train.py

193 lines
7.4 KiB
Python
Raw Permalink Normal View History

2025-07-16 16:30:42 +00:00
#!/usr/bin/env python
import json
import argparse
import os
import logging
from jinja2 import Environment, FileSystemLoader
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import SFTTrainer, SFTConfig
import torch
import jsonlines
import numpy as np
import torch.nn.functional as F
def formatting_func(examples):
"""
Formats a single example for student training using the loaded template.
"""
try:
message = {"content": examples["instruction"], "output": examples["output"]}
full_text = template.render(
message=message,
add_generation_prompt=False,
add_output=True
)
return full_text
except Exception as e:
logging.warning(f"Error processing sample: {str(e)}")
return ""
class MultiDistillSFTTrainer(SFTTrainer):
"""
Extension of SFTTrainer to support multiple teacher models in white-box distillation.
"""
def __init__(
self,
logits_dirs: list,
teacher_vocab_sizes: list,
kd_ratio: float,
max_seq_length: int,
distillation_type: str = "forward_kld",
**kwargs
):
super().__init__(**kwargs)
self.logits_dirs = logits_dirs
self.teacher_vocab_sizes = teacher_vocab_sizes
self.kd_ratio = kd_ratio
self.max_seq_length = max_seq_length
self.distillation_type = distillation_type
# Load and cache each teacher's logits
self.teacher_logits_list = []
for path in self.logits_dirs:
entries = []
with jsonlines.open(path) as reader:
for item in reader:
entries.append(item)
self.teacher_logits_list.append(entries)
def _load_teacher_logits_for(self, t_idx: int, batch_size: int, step: int, rank: int, device: torch.device, labels: torch.Tensor):
"""
Slice and shift the teacher logits for teacher index t_idx.
"""
data = self.teacher_logits_list[t_idx]
vocab_size = self.teacher_vocab_sizes[t_idx]
start = rank * batch_size + batch_size * step
end = start + batch_size
slice_ = data[start:end]
arr = np.zeros((batch_size, self.max_seq_length, vocab_size), dtype=np.float32)
for i, sample in enumerate(slice_):
for pos, dist in enumerate(sample):
idxs = np.fromiter(dist.keys(), dtype=int)
vals = np.fromiter(dist.values(), dtype=float)
arr[i, pos, idxs] = vals
tensor = torch.tensor(arr, dtype=torch.bfloat16, device=device)
return self._shift_tensor_right(tensor, labels, pad_value=0.0)
def _compute_white_box_distillation_loss(self, student_logits: torch.Tensor, teacher_logits: torch.Tensor, labels: torch.Tensor):
student_logits = student_logits[:, :self.max_seq_length, :]
teacher_probs = teacher_logits[:, :student_logits.size(1), :student_logits.size(-1)]
mask = (labels != -100).float() if labels is not None else torch.ones_like(student_logits[..., 0])
if self.distillation_type == "forward_kld":
loss = F.kl_div(
F.log_softmax(student_logits, dim=-1),
teacher_probs,
reduction='none',
log_target=False
).sum(dim=-1) / mask.sum()
elif self.distillation_type == "reverse_kld":
loss = F.kl_div(
torch.log(teacher_probs.clamp(min=1e-10)),
F.softmax(student_logits, dim=-1),
reduction='none',
log_target=False
).sum(dim=-1) / mask.sum()
else:
raise ValueError(f"Unsupported distillation type: {self.distillation_type}")
return (loss * mask).sum() / mask.sum()
@staticmethod
def _shift_tensor_right(inputs: torch.Tensor, labels: torch.Tensor, pad_value: float = 0.0):
batch, seqlen, vocab = inputs.shape
device = inputs.device
ne = labels != -100
shift = torch.argmax(ne.int(), dim=1)
idx = torch.arange(seqlen, device=device).unsqueeze(0).expand(batch, seqlen)
shifted = idx - shift.unsqueeze(1)
mask = shifted >= 0
shifted = shifted.clamp(min=0)
flat = inputs.view(batch, seqlen, vocab)
shifted = shifted.unsqueeze(2).expand(-1, -1, vocab)
gathered = torch.gather(flat, 1, shifted)
mask = mask.unsqueeze(2).expand_as(gathered)
return torch.where(mask, gathered, torch.full_like(gathered, pad_value))
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
outputs = model(**inputs)
lm = outputs.loss
if not self.logits_dirs:
return (lm, outputs) if return_outputs else lm
batch = inputs['input_ids'].size(0)
rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
step = self.state.global_step
labels = inputs.get('labels', None)
dist_losses = []
for i in range(len(self.logits_dirs)):
t_logits = self._load_teacher_logits_for(
i, batch, step, rank, model.device, labels
)
dist_losses.append(
self._compute_white_box_distillation_loss(
outputs.logits, t_logits, labels
)
)
total_dist = sum(dist_losses)
loss = (1 - self.kd_ratio) * lm + self.kd_ratio * total_dist
return (loss, outputs) if return_outputs else loss
def train_multi(config):
# Load data
ds = load_dataset("json", data_files=config["dataset"]["labeled_path"])["train"]
# Student setup
student_tokenizer = AutoTokenizer.from_pretrained(config["models"]["student"], trust_remote_code=True)
student_model = AutoModelForCausalLM.from_pretrained(config["models"]["student"], trust_remote_code=True)
# Template
global template
tpl = config["dataset"]["template"]
tpl_dir, tpl_file = os.path.split(tpl)
env = Environment(loader=FileSystemLoader(tpl_dir))
template = env.get_template(tpl_file)
# Training args
args = SFTConfig(**config["training"])
# Multi-teacher config
t_paths = config["models"]["teacher"]
if not isinstance(t_paths, list): t_paths = [t_paths]
l_paths = config["dataset"]["logits_path"]
if not isinstance(l_paths, list): l_paths = [l_paths]
assert len(t_paths) == len(l_paths), "Mismatch teachers vs logits paths"
sizes = []
for tp in t_paths:
c = json.load(open(os.path.join(tp, 'config.json')))
sizes.append(c['vocab_size'])
# Trainer
trainer = MultiDistillSFTTrainer(
model=student_model,
processing_class=student_tokenizer,
args=args,
train_dataset=ds,
formatting_func=formatting_func,
logits_dirs=l_paths,
teacher_vocab_sizes=sizes,
kd_ratio=config["distillation"]["kd_ratio"],
max_seq_length=config["distillation"]["max_seq_length"],
distillation_type=config["distillation"].get("distillation_type", "forward_kld"),
)
# Train and save
trainer.train()
trainer.save_model(config["training"]["output_dir"])
student_tokenizer.save_pretrained(config["training"]["output_dir"])
def main():
p = argparse.ArgumentParser()
p.add_argument("--config", required=True)
opt = p.parse_args()
conf = json.load(open(opt.config, 'r'))
train_multi(conf)
if __name__ == "__main__":
main()