diff --git a/configs/kd_white_box_train_only_multi.json b/configs/kd_white_box_train_only_multi.json new file mode 100644 index 0000000..572e18b --- /dev/null +++ b/configs/kd_white_box_train_only_multi.json @@ -0,0 +1,43 @@ +{ + "job_type": "kd_white_box_train_only_multi", + "dataset": { + "instruction_path": "./data/datasets/train_labeled_debug.json", + "labeled_path": "./data/datasets/train_labeled_debug.json", + "logits_path": ["./data/logits/qwen_logits.jsonl", "./data/logits/qwen2.5-14B_logits.jsonl"], + "template" : "./chat_template/chat_template_qwen.jinja", + "seed": 42 + }, + "inference":{ + "enable_chunked_prefill": true, + "seed": 777, + "gpu_memory_utilization": 0.9, + "temperature": 0.8, + "trust_remote_code": true, + "enforce_eager": false, + "max_model_len": 4096, + "max_new_tokens": 512, + "top_logits_num": 10 + }, + "distillation": { + "kd_ratio": 0.1, + "max_seq_length": 512, + "distillation_type": "forward_kld" + }, + "models": { + "teacher": ["./model_hub/qwen2.5-7B/", "./model_hub/qwen2.5-14B/"], + "student": "./model_hub/qwen2.5-0.5B/" + }, + "training": { + "output_dir": "./result/", + "num_train_epochs": 5, + "per_device_train_batch_size": 1, + "gradient_accumulation_steps": 8, + "max_length":512, + "save_steps": 1000, + "logging_steps": 1, + "learning_rate": 2e-5, + "weight_decay": 0.05, + "warmup_ratio": 0.1, + "lr_scheduler_type": "cosine" + } +} \ No newline at end of file diff --git a/easydistill/cli.py b/easydistill/cli.py index 474788e..96dca0e 100644 --- a/easydistill/cli.py +++ b/easydistill/cli.py @@ -91,6 +91,17 @@ def process(job_type, config): logging.info(f"Running command: {cmd_train}") run_cmd(cmd_train) + elif job_type in ['kd_black_box_train_only_multi', 'kd_white_box_train_only_multi']: + cmd_train = [ + 'accelerate', 'launch', + '--config_file', os.path.join(parent_dir, 'configs/accelerate_config/muti_gpu.yaml'), + os.path.join(script_dir, 'kd/multi_train.py'), + '--config', config + ] + cmd_train = ' '.join(cmd_train) + logging.info(f"Running command: {cmd_train}") + run_cmd(cmd_train) + elif job_type in ['kd_black_box_api', 'kd_black_box_local', 'kd_white_box']: cmd_infer = [ 'python', os.path.join(script_dir, 'kd/infer.py'), diff --git a/easydistill/kd/multi_train.py b/easydistill/kd/multi_train.py new file mode 100644 index 0000000..c6489f4 --- /dev/null +++ b/easydistill/kd/multi_train.py @@ -0,0 +1,192 @@ +#!/usr/bin/env python +import json +import argparse +import os +import logging +from jinja2 import Environment, FileSystemLoader +from datasets import load_dataset +from transformers import AutoTokenizer, AutoModelForCausalLM +from trl import SFTTrainer, SFTConfig +import torch +import jsonlines +import numpy as np +import torch.nn.functional as F + + +def formatting_func(examples): + """ + Formats a single example for student training using the loaded template. + """ + try: + message = {"content": examples["instruction"], "output": examples["output"]} + full_text = template.render( + message=message, + add_generation_prompt=False, + add_output=True + ) + return full_text + except Exception as e: + logging.warning(f"Error processing sample: {str(e)}") + return "" + + +class MultiDistillSFTTrainer(SFTTrainer): + """ + Extension of SFTTrainer to support multiple teacher models in white-box distillation. + """ + def __init__( + self, + logits_dirs: list, + teacher_vocab_sizes: list, + kd_ratio: float, + max_seq_length: int, + distillation_type: str = "forward_kld", + **kwargs + ): + super().__init__(**kwargs) + self.logits_dirs = logits_dirs + self.teacher_vocab_sizes = teacher_vocab_sizes + self.kd_ratio = kd_ratio + self.max_seq_length = max_seq_length + self.distillation_type = distillation_type + # Load and cache each teacher's logits + self.teacher_logits_list = [] + for path in self.logits_dirs: + entries = [] + with jsonlines.open(path) as reader: + for item in reader: + entries.append(item) + self.teacher_logits_list.append(entries) + + def _load_teacher_logits_for(self, t_idx: int, batch_size: int, step: int, rank: int, device: torch.device, labels: torch.Tensor): + """ + Slice and shift the teacher logits for teacher index t_idx. + """ + data = self.teacher_logits_list[t_idx] + vocab_size = self.teacher_vocab_sizes[t_idx] + start = rank * batch_size + batch_size * step + end = start + batch_size + slice_ = data[start:end] + arr = np.zeros((batch_size, self.max_seq_length, vocab_size), dtype=np.float32) + for i, sample in enumerate(slice_): + for pos, dist in enumerate(sample): + idxs = np.fromiter(dist.keys(), dtype=int) + vals = np.fromiter(dist.values(), dtype=float) + arr[i, pos, idxs] = vals + tensor = torch.tensor(arr, dtype=torch.bfloat16, device=device) + return self._shift_tensor_right(tensor, labels, pad_value=0.0) + + def _compute_white_box_distillation_loss(self, student_logits: torch.Tensor, teacher_logits: torch.Tensor, labels: torch.Tensor): + student_logits = student_logits[:, :self.max_seq_length, :] + teacher_probs = teacher_logits[:, :student_logits.size(1), :student_logits.size(-1)] + mask = (labels != -100).float() if labels is not None else torch.ones_like(student_logits[..., 0]) + if self.distillation_type == "forward_kld": + loss = F.kl_div( + F.log_softmax(student_logits, dim=-1), + teacher_probs, + reduction='none', + log_target=False + ).sum(dim=-1) / mask.sum() + elif self.distillation_type == "reverse_kld": + loss = F.kl_div( + torch.log(teacher_probs.clamp(min=1e-10)), + F.softmax(student_logits, dim=-1), + reduction='none', + log_target=False + ).sum(dim=-1) / mask.sum() + else: + raise ValueError(f"Unsupported distillation type: {self.distillation_type}") + return (loss * mask).sum() / mask.sum() + + @staticmethod + def _shift_tensor_right(inputs: torch.Tensor, labels: torch.Tensor, pad_value: float = 0.0): + batch, seqlen, vocab = inputs.shape + device = inputs.device + ne = labels != -100 + shift = torch.argmax(ne.int(), dim=1) + idx = torch.arange(seqlen, device=device).unsqueeze(0).expand(batch, seqlen) + shifted = idx - shift.unsqueeze(1) + mask = shifted >= 0 + shifted = shifted.clamp(min=0) + flat = inputs.view(batch, seqlen, vocab) + shifted = shifted.unsqueeze(2).expand(-1, -1, vocab) + gathered = torch.gather(flat, 1, shifted) + mask = mask.unsqueeze(2).expand_as(gathered) + return torch.where(mask, gathered, torch.full_like(gathered, pad_value)) + + def compute_loss(self, model, inputs, return_outputs=False, **kwargs): + outputs = model(**inputs) + lm = outputs.loss + if not self.logits_dirs: + return (lm, outputs) if return_outputs else lm + batch = inputs['input_ids'].size(0) + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + step = self.state.global_step + labels = inputs.get('labels', None) + dist_losses = [] + for i in range(len(self.logits_dirs)): + t_logits = self._load_teacher_logits_for( + i, batch, step, rank, model.device, labels + ) + dist_losses.append( + self._compute_white_box_distillation_loss( + outputs.logits, t_logits, labels + ) + ) + total_dist = sum(dist_losses) + loss = (1 - self.kd_ratio) * lm + self.kd_ratio * total_dist + return (loss, outputs) if return_outputs else loss + + +def train_multi(config): + # Load data + ds = load_dataset("json", data_files=config["dataset"]["labeled_path"])["train"] + # Student setup + student_tokenizer = AutoTokenizer.from_pretrained(config["models"]["student"], trust_remote_code=True) + student_model = AutoModelForCausalLM.from_pretrained(config["models"]["student"], trust_remote_code=True) + # Template + global template + tpl = config["dataset"]["template"] + tpl_dir, tpl_file = os.path.split(tpl) + env = Environment(loader=FileSystemLoader(tpl_dir)) + template = env.get_template(tpl_file) + # Training args + args = SFTConfig(**config["training"]) + # Multi-teacher config + t_paths = config["models"]["teacher"] + if not isinstance(t_paths, list): t_paths = [t_paths] + l_paths = config["dataset"]["logits_path"] + if not isinstance(l_paths, list): l_paths = [l_paths] + assert len(t_paths) == len(l_paths), "Mismatch teachers vs logits paths" + sizes = [] + for tp in t_paths: + c = json.load(open(os.path.join(tp, 'config.json'))) + sizes.append(c['vocab_size']) + # Trainer + trainer = MultiDistillSFTTrainer( + model=student_model, + processing_class=student_tokenizer, + args=args, + train_dataset=ds, + formatting_func=formatting_func, + logits_dirs=l_paths, + teacher_vocab_sizes=sizes, + kd_ratio=config["distillation"]["kd_ratio"], + max_seq_length=config["distillation"]["max_seq_length"], + distillation_type=config["distillation"].get("distillation_type", "forward_kld"), + ) + # Train and save + trainer.train() + trainer.save_model(config["training"]["output_dir"]) + student_tokenizer.save_pretrained(config["training"]["output_dir"]) + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--config", required=True) + opt = p.parse_args() + conf = json.load(open(opt.config, 'r')) + train_multi(conf) + +if __name__ == "__main__": + main()