193 lines
7.4 KiB
Python
193 lines
7.4 KiB
Python
#!/usr/bin/env python
|
|
import json
|
|
import argparse
|
|
import os
|
|
import logging
|
|
from jinja2 import Environment, FileSystemLoader
|
|
from datasets import load_dataset
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
from trl import SFTTrainer, SFTConfig
|
|
import torch
|
|
import jsonlines
|
|
import numpy as np
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def formatting_func(examples):
|
|
"""
|
|
Formats a single example for student training using the loaded template.
|
|
"""
|
|
try:
|
|
message = {"content": examples["instruction"], "output": examples["output"]}
|
|
full_text = template.render(
|
|
message=message,
|
|
add_generation_prompt=False,
|
|
add_output=True
|
|
)
|
|
return full_text
|
|
except Exception as e:
|
|
logging.warning(f"Error processing sample: {str(e)}")
|
|
return ""
|
|
|
|
|
|
class MultiDistillSFTTrainer(SFTTrainer):
|
|
"""
|
|
Extension of SFTTrainer to support multiple teacher models in white-box distillation.
|
|
"""
|
|
def __init__(
|
|
self,
|
|
logits_dirs: list,
|
|
teacher_vocab_sizes: list,
|
|
kd_ratio: float,
|
|
max_seq_length: int,
|
|
distillation_type: str = "forward_kld",
|
|
**kwargs
|
|
):
|
|
super().__init__(**kwargs)
|
|
self.logits_dirs = logits_dirs
|
|
self.teacher_vocab_sizes = teacher_vocab_sizes
|
|
self.kd_ratio = kd_ratio
|
|
self.max_seq_length = max_seq_length
|
|
self.distillation_type = distillation_type
|
|
# Load and cache each teacher's logits
|
|
self.teacher_logits_list = []
|
|
for path in self.logits_dirs:
|
|
entries = []
|
|
with jsonlines.open(path) as reader:
|
|
for item in reader:
|
|
entries.append(item)
|
|
self.teacher_logits_list.append(entries)
|
|
|
|
def _load_teacher_logits_for(self, t_idx: int, batch_size: int, step: int, rank: int, device: torch.device, labels: torch.Tensor):
|
|
"""
|
|
Slice and shift the teacher logits for teacher index t_idx.
|
|
"""
|
|
data = self.teacher_logits_list[t_idx]
|
|
vocab_size = self.teacher_vocab_sizes[t_idx]
|
|
start = rank * batch_size + batch_size * step
|
|
end = start + batch_size
|
|
slice_ = data[start:end]
|
|
arr = np.zeros((batch_size, self.max_seq_length, vocab_size), dtype=np.float32)
|
|
for i, sample in enumerate(slice_):
|
|
for pos, dist in enumerate(sample):
|
|
idxs = np.fromiter(dist.keys(), dtype=int)
|
|
vals = np.fromiter(dist.values(), dtype=float)
|
|
arr[i, pos, idxs] = vals
|
|
tensor = torch.tensor(arr, dtype=torch.bfloat16, device=device)
|
|
return self._shift_tensor_right(tensor, labels, pad_value=0.0)
|
|
|
|
def _compute_white_box_distillation_loss(self, student_logits: torch.Tensor, teacher_logits: torch.Tensor, labels: torch.Tensor):
|
|
student_logits = student_logits[:, :self.max_seq_length, :]
|
|
teacher_probs = teacher_logits[:, :student_logits.size(1), :student_logits.size(-1)]
|
|
mask = (labels != -100).float() if labels is not None else torch.ones_like(student_logits[..., 0])
|
|
if self.distillation_type == "forward_kld":
|
|
loss = F.kl_div(
|
|
F.log_softmax(student_logits, dim=-1),
|
|
teacher_probs,
|
|
reduction='none',
|
|
log_target=False
|
|
).sum(dim=-1) / mask.sum()
|
|
elif self.distillation_type == "reverse_kld":
|
|
loss = F.kl_div(
|
|
torch.log(teacher_probs.clamp(min=1e-10)),
|
|
F.softmax(student_logits, dim=-1),
|
|
reduction='none',
|
|
log_target=False
|
|
).sum(dim=-1) / mask.sum()
|
|
else:
|
|
raise ValueError(f"Unsupported distillation type: {self.distillation_type}")
|
|
return (loss * mask).sum() / mask.sum()
|
|
|
|
@staticmethod
|
|
def _shift_tensor_right(inputs: torch.Tensor, labels: torch.Tensor, pad_value: float = 0.0):
|
|
batch, seqlen, vocab = inputs.shape
|
|
device = inputs.device
|
|
ne = labels != -100
|
|
shift = torch.argmax(ne.int(), dim=1)
|
|
idx = torch.arange(seqlen, device=device).unsqueeze(0).expand(batch, seqlen)
|
|
shifted = idx - shift.unsqueeze(1)
|
|
mask = shifted >= 0
|
|
shifted = shifted.clamp(min=0)
|
|
flat = inputs.view(batch, seqlen, vocab)
|
|
shifted = shifted.unsqueeze(2).expand(-1, -1, vocab)
|
|
gathered = torch.gather(flat, 1, shifted)
|
|
mask = mask.unsqueeze(2).expand_as(gathered)
|
|
return torch.where(mask, gathered, torch.full_like(gathered, pad_value))
|
|
|
|
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
|
outputs = model(**inputs)
|
|
lm = outputs.loss
|
|
if not self.logits_dirs:
|
|
return (lm, outputs) if return_outputs else lm
|
|
batch = inputs['input_ids'].size(0)
|
|
rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
|
|
step = self.state.global_step
|
|
labels = inputs.get('labels', None)
|
|
dist_losses = []
|
|
for i in range(len(self.logits_dirs)):
|
|
t_logits = self._load_teacher_logits_for(
|
|
i, batch, step, rank, model.device, labels
|
|
)
|
|
dist_losses.append(
|
|
self._compute_white_box_distillation_loss(
|
|
outputs.logits, t_logits, labels
|
|
)
|
|
)
|
|
total_dist = sum(dist_losses)
|
|
loss = (1 - self.kd_ratio) * lm + self.kd_ratio * total_dist
|
|
return (loss, outputs) if return_outputs else loss
|
|
|
|
|
|
def train_multi(config):
|
|
# Load data
|
|
ds = load_dataset("json", data_files=config["dataset"]["labeled_path"])["train"]
|
|
# Student setup
|
|
student_tokenizer = AutoTokenizer.from_pretrained(config["models"]["student"], trust_remote_code=True)
|
|
student_model = AutoModelForCausalLM.from_pretrained(config["models"]["student"], trust_remote_code=True)
|
|
# Template
|
|
global template
|
|
tpl = config["dataset"]["template"]
|
|
tpl_dir, tpl_file = os.path.split(tpl)
|
|
env = Environment(loader=FileSystemLoader(tpl_dir))
|
|
template = env.get_template(tpl_file)
|
|
# Training args
|
|
args = SFTConfig(**config["training"])
|
|
# Multi-teacher config
|
|
t_paths = config["models"]["teacher"]
|
|
if not isinstance(t_paths, list): t_paths = [t_paths]
|
|
l_paths = config["dataset"]["logits_path"]
|
|
if not isinstance(l_paths, list): l_paths = [l_paths]
|
|
assert len(t_paths) == len(l_paths), "Mismatch teachers vs logits paths"
|
|
sizes = []
|
|
for tp in t_paths:
|
|
c = json.load(open(os.path.join(tp, 'config.json')))
|
|
sizes.append(c['vocab_size'])
|
|
# Trainer
|
|
trainer = MultiDistillSFTTrainer(
|
|
model=student_model,
|
|
processing_class=student_tokenizer,
|
|
args=args,
|
|
train_dataset=ds,
|
|
formatting_func=formatting_func,
|
|
logits_dirs=l_paths,
|
|
teacher_vocab_sizes=sizes,
|
|
kd_ratio=config["distillation"]["kd_ratio"],
|
|
max_seq_length=config["distillation"]["max_seq_length"],
|
|
distillation_type=config["distillation"].get("distillation_type", "forward_kld"),
|
|
)
|
|
# Train and save
|
|
trainer.train()
|
|
trainer.save_model(config["training"]["output_dir"])
|
|
student_tokenizer.save_pretrained(config["training"]["output_dir"])
|
|
|
|
|
|
def main():
|
|
p = argparse.ArgumentParser()
|
|
p.add_argument("--config", required=True)
|
|
opt = p.parse_args()
|
|
conf = json.load(open(opt.config, 'r'))
|
|
train_multi(conf)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|