Merge pull request #14 from wyy-code/main
add multi-teachers white box distillation
This commit is contained in:
43
configs/kd_white_box_train_only_multi.json
Normal file
43
configs/kd_white_box_train_only_multi.json
Normal file
@@ -0,0 +1,43 @@
|
||||
{
|
||||
"job_type": "kd_white_box_train_only_multi",
|
||||
"dataset": {
|
||||
"instruction_path": "./data/datasets/train_labeled_debug.json",
|
||||
"labeled_path": "./data/datasets/train_labeled_debug.json",
|
||||
"logits_path": ["./data/logits/qwen_logits.jsonl", "./data/logits/qwen2.5-14B_logits.jsonl"],
|
||||
"template" : "./chat_template/chat_template_qwen.jinja",
|
||||
"seed": 42
|
||||
},
|
||||
"inference":{
|
||||
"enable_chunked_prefill": true,
|
||||
"seed": 777,
|
||||
"gpu_memory_utilization": 0.9,
|
||||
"temperature": 0.8,
|
||||
"trust_remote_code": true,
|
||||
"enforce_eager": false,
|
||||
"max_model_len": 4096,
|
||||
"max_new_tokens": 512,
|
||||
"top_logits_num": 10
|
||||
},
|
||||
"distillation": {
|
||||
"kd_ratio": 0.1,
|
||||
"max_seq_length": 512,
|
||||
"distillation_type": "forward_kld"
|
||||
},
|
||||
"models": {
|
||||
"teacher": ["./model_hub/qwen2.5-7B/", "./model_hub/qwen2.5-14B/"],
|
||||
"student": "./model_hub/qwen2.5-0.5B/"
|
||||
},
|
||||
"training": {
|
||||
"output_dir": "./result/",
|
||||
"num_train_epochs": 5,
|
||||
"per_device_train_batch_size": 1,
|
||||
"gradient_accumulation_steps": 8,
|
||||
"max_length":512,
|
||||
"save_steps": 1000,
|
||||
"logging_steps": 1,
|
||||
"learning_rate": 2e-5,
|
||||
"weight_decay": 0.05,
|
||||
"warmup_ratio": 0.1,
|
||||
"lr_scheduler_type": "cosine"
|
||||
}
|
||||
}
|
@@ -91,6 +91,17 @@ def process(job_type, config):
|
||||
logging.info(f"Running command: {cmd_train}")
|
||||
run_cmd(cmd_train)
|
||||
|
||||
elif job_type in ['kd_black_box_train_only_multi', 'kd_white_box_train_only_multi']:
|
||||
cmd_train = [
|
||||
'accelerate', 'launch',
|
||||
'--config_file', os.path.join(parent_dir, 'configs/accelerate_config/muti_gpu.yaml'),
|
||||
os.path.join(script_dir, 'kd/multi_train.py'),
|
||||
'--config', config
|
||||
]
|
||||
cmd_train = ' '.join(cmd_train)
|
||||
logging.info(f"Running command: {cmd_train}")
|
||||
run_cmd(cmd_train)
|
||||
|
||||
elif job_type in ['kd_black_box_api', 'kd_black_box_local', 'kd_white_box']:
|
||||
cmd_infer = [
|
||||
'python', os.path.join(script_dir, 'kd/infer.py'),
|
||||
|
192
easydistill/kd/multi_train.py
Normal file
192
easydistill/kd/multi_train.py
Normal file
@@ -0,0 +1,192 @@
|
||||
#!/usr/bin/env python
|
||||
import json
|
||||
import argparse
|
||||
import os
|
||||
import logging
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
import torch
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def formatting_func(examples):
|
||||
"""
|
||||
Formats a single example for student training using the loaded template.
|
||||
"""
|
||||
try:
|
||||
message = {"content": examples["instruction"], "output": examples["output"]}
|
||||
full_text = template.render(
|
||||
message=message,
|
||||
add_generation_prompt=False,
|
||||
add_output=True
|
||||
)
|
||||
return full_text
|
||||
except Exception as e:
|
||||
logging.warning(f"Error processing sample: {str(e)}")
|
||||
return ""
|
||||
|
||||
|
||||
class MultiDistillSFTTrainer(SFTTrainer):
|
||||
"""
|
||||
Extension of SFTTrainer to support multiple teacher models in white-box distillation.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
logits_dirs: list,
|
||||
teacher_vocab_sizes: list,
|
||||
kd_ratio: float,
|
||||
max_seq_length: int,
|
||||
distillation_type: str = "forward_kld",
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.logits_dirs = logits_dirs
|
||||
self.teacher_vocab_sizes = teacher_vocab_sizes
|
||||
self.kd_ratio = kd_ratio
|
||||
self.max_seq_length = max_seq_length
|
||||
self.distillation_type = distillation_type
|
||||
# Load and cache each teacher's logits
|
||||
self.teacher_logits_list = []
|
||||
for path in self.logits_dirs:
|
||||
entries = []
|
||||
with jsonlines.open(path) as reader:
|
||||
for item in reader:
|
||||
entries.append(item)
|
||||
self.teacher_logits_list.append(entries)
|
||||
|
||||
def _load_teacher_logits_for(self, t_idx: int, batch_size: int, step: int, rank: int, device: torch.device, labels: torch.Tensor):
|
||||
"""
|
||||
Slice and shift the teacher logits for teacher index t_idx.
|
||||
"""
|
||||
data = self.teacher_logits_list[t_idx]
|
||||
vocab_size = self.teacher_vocab_sizes[t_idx]
|
||||
start = rank * batch_size + batch_size * step
|
||||
end = start + batch_size
|
||||
slice_ = data[start:end]
|
||||
arr = np.zeros((batch_size, self.max_seq_length, vocab_size), dtype=np.float32)
|
||||
for i, sample in enumerate(slice_):
|
||||
for pos, dist in enumerate(sample):
|
||||
idxs = np.fromiter(dist.keys(), dtype=int)
|
||||
vals = np.fromiter(dist.values(), dtype=float)
|
||||
arr[i, pos, idxs] = vals
|
||||
tensor = torch.tensor(arr, dtype=torch.bfloat16, device=device)
|
||||
return self._shift_tensor_right(tensor, labels, pad_value=0.0)
|
||||
|
||||
def _compute_white_box_distillation_loss(self, student_logits: torch.Tensor, teacher_logits: torch.Tensor, labels: torch.Tensor):
|
||||
student_logits = student_logits[:, :self.max_seq_length, :]
|
||||
teacher_probs = teacher_logits[:, :student_logits.size(1), :student_logits.size(-1)]
|
||||
mask = (labels != -100).float() if labels is not None else torch.ones_like(student_logits[..., 0])
|
||||
if self.distillation_type == "forward_kld":
|
||||
loss = F.kl_div(
|
||||
F.log_softmax(student_logits, dim=-1),
|
||||
teacher_probs,
|
||||
reduction='none',
|
||||
log_target=False
|
||||
).sum(dim=-1) / mask.sum()
|
||||
elif self.distillation_type == "reverse_kld":
|
||||
loss = F.kl_div(
|
||||
torch.log(teacher_probs.clamp(min=1e-10)),
|
||||
F.softmax(student_logits, dim=-1),
|
||||
reduction='none',
|
||||
log_target=False
|
||||
).sum(dim=-1) / mask.sum()
|
||||
else:
|
||||
raise ValueError(f"Unsupported distillation type: {self.distillation_type}")
|
||||
return (loss * mask).sum() / mask.sum()
|
||||
|
||||
@staticmethod
|
||||
def _shift_tensor_right(inputs: torch.Tensor, labels: torch.Tensor, pad_value: float = 0.0):
|
||||
batch, seqlen, vocab = inputs.shape
|
||||
device = inputs.device
|
||||
ne = labels != -100
|
||||
shift = torch.argmax(ne.int(), dim=1)
|
||||
idx = torch.arange(seqlen, device=device).unsqueeze(0).expand(batch, seqlen)
|
||||
shifted = idx - shift.unsqueeze(1)
|
||||
mask = shifted >= 0
|
||||
shifted = shifted.clamp(min=0)
|
||||
flat = inputs.view(batch, seqlen, vocab)
|
||||
shifted = shifted.unsqueeze(2).expand(-1, -1, vocab)
|
||||
gathered = torch.gather(flat, 1, shifted)
|
||||
mask = mask.unsqueeze(2).expand_as(gathered)
|
||||
return torch.where(mask, gathered, torch.full_like(gathered, pad_value))
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
||||
outputs = model(**inputs)
|
||||
lm = outputs.loss
|
||||
if not self.logits_dirs:
|
||||
return (lm, outputs) if return_outputs else lm
|
||||
batch = inputs['input_ids'].size(0)
|
||||
rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
|
||||
step = self.state.global_step
|
||||
labels = inputs.get('labels', None)
|
||||
dist_losses = []
|
||||
for i in range(len(self.logits_dirs)):
|
||||
t_logits = self._load_teacher_logits_for(
|
||||
i, batch, step, rank, model.device, labels
|
||||
)
|
||||
dist_losses.append(
|
||||
self._compute_white_box_distillation_loss(
|
||||
outputs.logits, t_logits, labels
|
||||
)
|
||||
)
|
||||
total_dist = sum(dist_losses)
|
||||
loss = (1 - self.kd_ratio) * lm + self.kd_ratio * total_dist
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
||||
|
||||
def train_multi(config):
|
||||
# Load data
|
||||
ds = load_dataset("json", data_files=config["dataset"]["labeled_path"])["train"]
|
||||
# Student setup
|
||||
student_tokenizer = AutoTokenizer.from_pretrained(config["models"]["student"], trust_remote_code=True)
|
||||
student_model = AutoModelForCausalLM.from_pretrained(config["models"]["student"], trust_remote_code=True)
|
||||
# Template
|
||||
global template
|
||||
tpl = config["dataset"]["template"]
|
||||
tpl_dir, tpl_file = os.path.split(tpl)
|
||||
env = Environment(loader=FileSystemLoader(tpl_dir))
|
||||
template = env.get_template(tpl_file)
|
||||
# Training args
|
||||
args = SFTConfig(**config["training"])
|
||||
# Multi-teacher config
|
||||
t_paths = config["models"]["teacher"]
|
||||
if not isinstance(t_paths, list): t_paths = [t_paths]
|
||||
l_paths = config["dataset"]["logits_path"]
|
||||
if not isinstance(l_paths, list): l_paths = [l_paths]
|
||||
assert len(t_paths) == len(l_paths), "Mismatch teachers vs logits paths"
|
||||
sizes = []
|
||||
for tp in t_paths:
|
||||
c = json.load(open(os.path.join(tp, 'config.json')))
|
||||
sizes.append(c['vocab_size'])
|
||||
# Trainer
|
||||
trainer = MultiDistillSFTTrainer(
|
||||
model=student_model,
|
||||
processing_class=student_tokenizer,
|
||||
args=args,
|
||||
train_dataset=ds,
|
||||
formatting_func=formatting_func,
|
||||
logits_dirs=l_paths,
|
||||
teacher_vocab_sizes=sizes,
|
||||
kd_ratio=config["distillation"]["kd_ratio"],
|
||||
max_seq_length=config["distillation"]["max_seq_length"],
|
||||
distillation_type=config["distillation"].get("distillation_type", "forward_kld"),
|
||||
)
|
||||
# Train and save
|
||||
trainer.train()
|
||||
trainer.save_model(config["training"]["output_dir"])
|
||||
student_tokenizer.save_pretrained(config["training"]["output_dir"])
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--config", required=True)
|
||||
opt = p.parse_args()
|
||||
conf = json.load(open(opt.config, 'r'))
|
||||
train_multi(conf)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user