add mmkd, white mmkd

This commit is contained in:
yyh
2025-07-24 11:27:11 +08:00
parent e5ff55e4e2
commit 98398d4e73
17 changed files with 655 additions and 97 deletions

BIN
.DS_Store vendored Normal file

Binary file not shown.

BIN
configs/.DS_Store vendored Normal file

Binary file not shown.

View File

@@ -1,30 +0,0 @@
{
"job_type": "mmkd_black_box_api",
"dataset": {
"instruction_path": "train.json",
"labeled_path": "train_labeled.json",
"seed": 42
},
"inference":{
"base_url": "ENDPOINT",
"api_key": "TOKEN",
"system_prompt" : "You are a helpful assistant.",
"max_new_tokens": 512
},
"models": {
"student": "student/Qwen/Qwen2.5-VL-3B-Instruct/"
},
"training": {
"output_dir": "./result/",
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 8,
"max_length": 512,
"save_steps": 1000,
"logging_steps": 1,
"learning_rate": 2e-5,
"weight_decay": 0.05,
"warmup_ratio": 0.1,
"lr_scheduler_type": "cosine"
}
}

View File

@@ -0,0 +1,35 @@
{
"job_type": "mmkd_black_box_local",
"dataset": {
"instruction_path": "/mnt/workspace/yyh/easydistill/test_data/mllm_demo_reformat.json",
"labeled_path": "/mnt/data/yyh/easydistill/test_data/mllm_demo_distill.json",
"seed": 42
},
"inference":{
"enable_chunked_prefill": true,
"seed": 777,
"gpu_memory_utilization": 0.9,
"temperature": 0.8,
"trust_remote_code": true,
"enforce_eager": false,
"max_model_len": 4096,
"max_new_tokens": 512
},
"models": {
"teacher": "/mnt/data/yyh/models/Qwen2.5-VL-3B-Instruct",
"student": "/mnt/data/yyh/models/Qwen2.5-VL-3B-Instruct"
},
"training": {
"output_dir": "./result/",
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 8,
"max_length":512,
"save_steps": 1000,
"logging_steps": 1,
"learning_rate": 2e-5,
"weight_decay": 0.05,
"warmup_ratio": 0.1,
"lr_scheduler_type": "cosine"
}
}

BIN
data/.DS_Store vendored Normal file

Binary file not shown.

1
data/alpaca_en_demo.json Normal file

File diff suppressed because one or more lines are too long

226
data/mllm_demo.json Normal file
View File

@@ -0,0 +1,226 @@
[
[
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": [
{
"type": "image",
"image": "/mnt/data/yyh/easydistill/test_data/mllm_demo_data/1.jpg"
},
{
"type": "text",
"text": "Who are they?"
}
]
},
{
"role": "assistant",
"content": "They're Kane and Gretzka from Bayern Munich."
},
{
"role": "user",
"content": [
{
"type": "image",
"image": "/mnt/data/yyh/easydistill/test_data/mllm_demo_data/1.jpg"
},
{
"type": "text",
"text": "What are they doing?"
}
]
},
{
"role": "assistant",
"content": "They are celebrating on the soccer field."
}
],
[
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": [
{
"type": "image",
"image": "/mnt/data/yyh/easydistill/test_data/mllm_demo_data/2.jpg"
},
{
"type": "text",
"text": "Who is he?"
}
]
},
{
"role": "assistant",
"content": "He's Thomas Muller from Bayern Munich."
},
{
"role": "user",
"content": [
{
"type": "text",
"text": "Why is he on the ground?"
}
]
},
{
"role": "assistant",
"content": "Because he's sliding on his knees to celebrate."
}
],
[
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": [
{
"type": "image",
"image": "/mnt/data/yyh/easydistill/test_data/mllm_demo_data/3.jpg"
},
{
"type": "text",
"text": "Please describe this image"
}
]
},
{
"role": "assistant",
"content": "Chinese astronaut Gui Haichao is giving a speech."
},
{
"role": "user",
"content": [
{
"type": "text",
"text": "What has he accomplished?"
}
]
},
{
"role": "assistant",
"content": "He was appointed to be a payload specialist on Shenzhou 16 mission in June 2022, thus becoming the first Chinese civilian of Group 3 in space on 30 May 2023. He is responsible for the on-orbit operation of space science experimental payloads."
}
],
[
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": [
{
"type": "image",
"image": "/mnt/data/yyh/easydistill/test_data/mllm_demo_data/1.jpg"
},
{
"type": "text",
"text": "\u4ed6\u4eec\u662f\u8c01\uff1f"
}
]
},
{
"role": "assistant",
"content": "\u4ed6\u4eec\u662f\u62dc\u4ec1\u6155\u5c3c\u9ed1\u7684\u51ef\u6069\u548c\u683c\u96f7\u8328\u5361\u3002"
},
{
"role": "user",
"content": [
{
"type": "image",
"image": "/mnt/data/yyh/easydistill/test_data/mllm_demo_data/1.jpg"
},
{
"type": "text",
"text": "\u4ed6\u4eec\u5728\u505a\u4ec0\u4e48\uff1f"
}
]
},
{
"role": "assistant",
"content": "\u4ed6\u4eec\u5728\u8db3\u7403\u573a\u4e0a\u5e86\u795d\u3002"
}
],
[
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": [
{
"type": "image",
"image": "/mnt/data/yyh/easydistill/test_data/mllm_demo_data/2.jpg"
},
{
"type": "text",
"text": "\u4ed6\u662f\u8c01\uff1f"
}
]
},
{
"role": "assistant",
"content": "\u4ed6\u662f\u6765\u81ea\u62dc\u4ec1\u6155\u5c3c\u9ed1\u7684\u6258\u9a6c\u65af\u00b7\u7a46\u52d2\u3002"
},
{
"role": "user",
"content": [
{
"type": "text",
"text": "\u4ed6\u4e3a\u4ec0\u4e48\u5728\u5730\u4e0a\uff1f"
}
]
},
{
"role": "assistant",
"content": "\u56e0\u4e3a\u4ed6\u6b63\u5728\u53cc\u819d\u8dea\u5730\u6ed1\u884c\u5e86\u795d\u3002"
}
],
[
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": [
{
"type": "image",
"image": "/mnt/data/yyh/easydistill/test_data/mllm_demo_data/3.jpg"
},
{
"type": "text",
"text": "\u8bf7\u63cf\u8ff0\u8fd9\u5f20\u56fe\u7247"
}
]
},
{
"role": "assistant",
"content": "\u4e2d\u56fd\u5b87\u822a\u5458\u6842\u6d77\u6f6e\u6b63\u5728\u8bb2\u8bdd\u3002"
},
{
"role": "user",
"content": [
{
"type": "text",
"text": "\u4ed6\u53d6\u5f97\u8fc7\u54ea\u4e9b\u6210\u5c31\uff1f"
}
]
},
{
"role": "assistant",
"content": "\u4ed6\u4e8e2022\u5e746\u6708\u88ab\u4efb\u547d\u4e3a\u795e\u821f\u5341\u516d\u53f7\u4efb\u52a1\u7684\u6709\u6548\u8f7d\u8377\u4e13\u5bb6\uff0c\u4ece\u800c\u6210\u4e3a2023\u5e745\u670830\u65e5\u8fdb\u5165\u592a\u7a7a\u7684\u9996\u4f4d\u5e73\u6c11\u5b87\u822a\u5458\u3002\u4ed6\u8d1f\u8d23\u5728\u8f68\u64cd\u4f5c\u7a7a\u95f4\u79d1\u5b66\u5b9e\u9a8c\u6709\u6548\u8f7d\u8377\u3002"
}
]
]

BIN
data/mllm_demo_data/1.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 12 KiB

BIN
data/mllm_demo_data/2.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

BIN
data/mllm_demo_data/3.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

BIN
easydistill/.DS_Store vendored Normal file

Binary file not shown.

View File

@@ -91,17 +91,6 @@ def process(job_type, config):
logging.info(f"Running command: {cmd_train}") logging.info(f"Running command: {cmd_train}")
run_cmd(cmd_train) run_cmd(cmd_train)
elif job_type in ['kd_black_box_train_only_multi', 'kd_white_box_train_only_multi']:
cmd_train = [
'accelerate', 'launch',
'--config_file', os.path.join(parent_dir, 'configs/accelerate_config/muti_gpu.yaml'),
os.path.join(script_dir, 'kd/multi_train.py'),
'--config', config
]
cmd_train = ' '.join(cmd_train)
logging.info(f"Running command: {cmd_train}")
run_cmd(cmd_train)
elif job_type in ['kd_black_box_api', 'kd_black_box_local', 'kd_white_box']: elif job_type in ['kd_black_box_api', 'kd_black_box_local', 'kd_white_box']:
cmd_infer = [ cmd_infer = [
'python', os.path.join(script_dir, 'kd/infer.py'), 'python', os.path.join(script_dir, 'kd/infer.py'),
@@ -110,6 +99,10 @@ def process(job_type, config):
cmd_infer = ' '.join(cmd_infer) cmd_infer = ' '.join(cmd_infer)
logging.info(f"Running command: {cmd_infer}") logging.info(f"Running command: {cmd_infer}")
infer_success = run_cmd(cmd_infer) infer_success = run_cmd(cmd_infer)
###############################
infer_success=True
###############################
if infer_success: if infer_success:
cmd_train = [ cmd_train = [
'accelerate', 'launch', 'accelerate', 'launch',
@@ -123,6 +116,32 @@ def process(job_type, config):
else: else:
logging.error("Infer failed, skipping training") logging.error("Infer failed, skipping training")
elif job_type in ['mmkd_black_box_api', 'mmkd_black_box_local', 'mmkd_white_box']:
cmd_infer = [
'python', os.path.join(script_dir, 'mmkd/infer.py'),
'--config', config
]
cmd_infer = ' '.join(cmd_infer)
logging.info(f"Running command: {cmd_infer}")
infer_success = run_cmd(cmd_infer)
###############################
infer_success=True
###############################
if infer_success:
cmd_train = [
'accelerate', 'launch',
'--config_file', os.path.join(parent_dir, 'configs/accelerate_config/muti_gpu.yaml'),
os.path.join(script_dir, 'mmkd/train.py'),
'--config', config
]
cmd_train = ' '.join(cmd_train)
logging.info(f"Running command: {cmd_train}")
run_cmd(cmd_train)
else:
logging.error("Infer failed, skipping training")
# Reinforcement Learning tasks # Reinforcement Learning tasks
elif job_type in ['rl_ppo', 'rl_grpo']: elif job_type in ['rl_ppo', 'rl_grpo']:
cmd = [ cmd = [

View File

@@ -14,11 +14,16 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import json import json, jsonlines
import math
import argparse import argparse
import logging import logging
from tqdm import tqdm from tqdm import tqdm
from openai import OpenAI from openai import OpenAI
import torch
from transformers import AutoProcessor, AutoTokenizer
from vllm import LLM, SamplingParams
from qwen_vl_utils import process_vision_info
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -28,12 +33,7 @@ def read_json_field(filename):
try: try:
with open(filename, 'r') as file: with open(filename, 'r') as file:
data = json.load(file) data = json.load(file)
outputs = [] return data
for item in data:
text = item["instruction"]
image = item["image"]
outputs.append((text, image))
return outputs
except FileNotFoundError: except FileNotFoundError:
logging.error("The file was not found.") logging.error("The file was not found.")
except json.JSONDecodeError: except json.JSONDecodeError:
@@ -50,6 +50,170 @@ def write_data_to_json_file(data, file_path):
except Exception as e: except Exception as e:
logging.error(f"An error occurred: {e}") logging.error(f"An error occurred: {e}")
def load_tokenizer_and_vllm(config, eos_token=None):
model_path = config["models"]["teacher"]
logging.info(f"Loading processor & vLLM model from {model_path}")
# 1. 统一用 AutoProcessor已整合 tokenizer + image_processor + video_processor
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
# 2. eos / pad token 处理(与官方示例保持一致,不再显式改 pad_token
if eos_token:
eos_token_id = processor.tokenizer.convert_tokens_to_ids(eos_token)
logging.info(f"eos_token {eos_token} from user input")
elif hasattr(processor.tokenizer, "eos_token_id") and processor.tokenizer.eos_token_id is not None:
eos_token_id = processor.tokenizer.eos_token_id
eos_token = processor.tokenizer.convert_ids_to_tokens(eos_token_id)
logging.info(f"Initial eos_token_id {eos_token_id} from tokenizer")
else:
raise ValueError("No available eos_token or eos_token_id.")
# 3. 设置 tokenizer 的 eos 相关字段pad_token 保持 None由 vLLM 自动处理)
try:
processor.tokenizer.eos_token = eos_token
processor.tokenizer.eos_token_id = eos_token_id
except Exception as e:
logging.warning(f"[WARNING] Cannot set eos_token: {e}")
logging.info(
f"processor.tokenizer eos_token: {processor.tokenizer.eos_token}, "
f"eos_token_id: {processor.tokenizer.eos_token_id}"
)
num_gpus = torch.cuda.device_count()
llm = LLM(
model=model_path,
tensor_parallel_size=num_gpus,
trust_remote_code=True,
limit_mm_per_prompt={"image": 10, "video": 10}, # 可按需调整
# 其余超参沿用原 config
gpu_memory_utilization=config["inference"].get("gpu_memory_utilization", 0.9),
max_model_len=config["inference"].get("max_model_len", 4096),
enforce_eager=config["inference"].get("enforce_eager", False),
)
logging.info("Qwen2.5-VL vLLM model loaded successfully")
#return processor, llm
return processor, llm
def generate_teacher_response_batch(processor, llm, data_list, config, batch_size=32):
outcomes = []
sampling_params = SamplingParams(
n = 1,
top_k = 1,
temperature=config["inference"]["temperature"],
seed = config["inference"]["seed"],
max_tokens = config["inference"]["max_new_tokens"],
)
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
for batch in tqdm(batches, desc="Generating responses"):
new_batch = []
batch_outcomes = []
for sample in batch:
batch_outcomes.append(sample)
prompt = processor.apply_chat_template(
sample,
tokenize=False,
add_generation_prompt=True,
)
image_inputs, video_inputs = process_vision_info(sample)
mm_data = {}
if image_inputs is not None:
mm_data["image"] = image_inputs
sample_inputs = {
"prompt": prompt,
"multi_modal_data": mm_data,
}
new_batch.append(sample_inputs)
outputs = llm.generate(new_batch, sampling_params=sampling_params)
for b in range(len(batch_outcomes)):
generated_text = outputs[b].outputs[0].text
out={
"role": "assistant",
"content": [
{
"type": "text",
"text": generated_text,
}
],
}
batch_outcomes[b].append(out)
outcomes.extend(batch_outcomes)
write_data_to_json_file(outcomes, config["dataset"]["labeled_path"])
def generate_teacher_logits_batch(processor, llm, data_list, config, batch_size=32):
outcomes = []
sampling_params = SamplingParams(
n = 1,
top_k = 1,
temperature=config["inference"]["temperature"],
seed = config["inference"]["seed"],
skip_special_tokens=False,
max_tokens = config["inference"]["max_new_tokens"],
logprobs=config["inference"]["top_logits_num"],
)
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
logits=[]
for batch in tqdm(batches, desc="Generating responses"):
new_batch = []
batch_outcomes = []
for sample in batch:
batch_outcomes.append(sample)
prompt = processor.apply_chat_template(
sample,
tokenize=False,
add_generation_prompt=True,
)
image_inputs, video_inputs = process_vision_info(sample)
mm_data = {}
if image_inputs is not None:
mm_data["image"] = image_inputs
sample_inputs = {
"prompt": prompt,
"multi_modal_data": mm_data,
}
new_batch.append(sample_inputs)
outputs = llm.generate(new_batch, sampling_params=sampling_params)
logits+=[output.outputs[0].logprobs for output in outputs]
for b in range(len(batch_outcomes)):
generated_text = outputs[b].outputs[0].text
out={
"role": "assistant",
"content": [
{
"type": "text",
"text": generated_text,
}
],
}
batch_outcomes[b].append(out)
outcomes.extend(batch_outcomes)
for logit in logits:
for pos in logit:
for k,v in pos.items():
pos[k]=math.exp(v.logprob)
with jsonlines.open(config["dataset"]["logits_path"], mode='a') as writer:
for row in logits:
#for item in row:
writer.write(row)
write_data_to_json_file(outcomes, config["dataset"]["labeled_path"])
def generate_teacher_response_api(data_list, config): def generate_teacher_response_api(data_list, config):
client = OpenAI( client = OpenAI(
@@ -98,10 +262,18 @@ def generate_teacher_response_api(data_list, config):
def infer_with_teacher_model(config): def infer_with_teacher_model(config):
logging.info('Generating distillation data from the teacher model!') logging.info('Generating distillation data from the teacher model!')
data_list = read_json_field(config["dataset"]["instruction_path"]) data_list = read_json_field(config["dataset"]["instruction_path"])
try: try:
job_type = config["job_type"] job_type = config["job_type"]
if job_type == "mmkd_black_box_api": if job_type == "mmkd_black_box_api":
generate_teacher_response_api(data_list, config) generate_teacher_response_api(data_list, config)
elif job_type == "mmkd_black_box_local":
tokenizer, llm = load_tokenizer_and_vllm(config)
generate_teacher_response_batch(tokenizer, llm, data_list, config)
elif job_type == "mmkd_white_box":
tokenizer, llm = load_tokenizer_and_vllm(config)
generate_teacher_logits_batch(tokenizer, llm, data_list, config)
else: else:
logging.error(f"Invalid job type: {job_type}") logging.error(f"Invalid job type: {job_type}")
raise ValueError(f"Invalid job type: {job_type}") raise ValueError(f"Invalid job type: {job_type}")

View File

@@ -15,10 +15,17 @@
# ============================================================================== # ==============================================================================
import json import json
import torch
import numpy as np
import jsonlines
import torch.nn.functional as F
import os
import argparse import argparse
import logging import logging
from datasets import load_dataset, Dataset from datasets import load_dataset, Dataset
from typing import Optional, Dict, Union, List
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor
from transformers import PreTrainedModel, PreTrainedTokenizerBase,AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from qwen_vl_utils import process_vision_info from qwen_vl_utils import process_vision_info
from trl import SFTTrainer, SFTConfig from trl import SFTTrainer, SFTConfig
@@ -26,39 +33,147 @@ from trl import SFTTrainer, SFTConfig
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
from torch.utils.data import Dataset
from PIL import Image
import os
class MMDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[int(idx)]
class DistillSFTTrainer(SFTTrainer):
def __init__(
self,
logits_dir: str = None,
teacher_vocab_size = None,
kd_ratio: float = 0.5,
max_seq_length : int = 1024,
distillation_type: str = "forward_kld",
**kwargs
):
super().__init__(**kwargs)
self.logits_dir = logits_dir
self.teacher_vocab_size = teacher_vocab_size
self.kd_ratio = kd_ratio
self.max_seq_length = max_seq_length
self.distillation_type = distillation_type
self.teacher_logits = []
with jsonlines.open(self.logits_dir) as reader:
for obj in reader:
self.teacher_logits.append(obj)
def _load_teacher_logits(self, batch_size: int, it: int, dp_rank: int, device: torch.device, no_model_batch: Dict):
start_idx = dp_rank * batch_size + batch_size * it
end_idx = dp_rank * batch_size + batch_size * (it + 1)
loaded_data = self.teacher_logits[start_idx:end_idx]
arr = np.zeros((batch_size, self.max_seq_length, self.teacher_vocab_size))
for i in range(len(loaded_data)):
for j in range(len(loaded_data[i])):
keys = np.array(list(loaded_data[i][j].keys()), dtype=int)
values = np.array(list(loaded_data[i][j].values()))
arr[i, j, keys] = values
logits_tensor = torch.tensor(arr, dtype=torch.bfloat16, device=device)
return self._shift_tensor_right(logits_tensor, no_model_batch['label'], pad_value=0)
def _compute_white_box_distillation_loss(self, student_logits: torch.Tensor, teacher_logits: torch.Tensor, labels: Optional[torch.Tensor]):
student_logits = student_logits[:, :self.max_seq_length, :]
teacher_probs = teacher_logits[:, :student_logits.size(1), :student_logits.size(-1)]
mask = (labels != -100).float() if labels is not None else torch.ones_like(student_logits[:, :, 0])
if self.distillation_type == "forward_kld":
# Forward KLD: student learns from teacher (original implementation)
loss = F.kl_div(
F.log_softmax(student_logits, dim=-1),
teacher_probs,
reduction='none',
log_target=False
).sum(dim=-1)/torch.sum(mask.view(-1), dim=0)
elif self.distillation_type == "reverse_kld":
# Reverse KLD: teacher provides certainty to student
loss = F.kl_div(
torch.log(teacher_probs.clamp(min=1e-10)), # avoid log(0)
F.softmax(student_logits, dim=-1),
reduction='none',
log_target=False
).sum(dim=-1)/torch.sum(mask.view(-1), dim=0)
else:
raise ValueError(f"Unsupported distillation type: {self.distillation_type}. Use 'forward_kld' or 'reverse_kld'")
return (loss * mask).sum() / mask.sum()
@staticmethod
def _shift_tensor_right(inputs: torch.Tensor, labels: torch.Tensor, pad_value: float = 0.0):
batch_size, seqlen, vocab_size = inputs.shape
device = inputs.device
labels_ne = labels != -100
shift_distances = torch.argmax(labels_ne.int(), dim=1)
idx = torch.arange(seqlen, device=device).unsqueeze(0).expand(batch_size, seqlen)
shifted_idx = idx - shift_distances.unsqueeze(1)
mask = shifted_idx >= 0
shifted_idx = shifted_idx.clamp(min=0)
inputs_flat = inputs.view(batch_size, seqlen, vocab_size)
shifted_idx = shifted_idx.unsqueeze(2).expand(-1, -1, vocab_size)
gathered = torch.gather(inputs_flat, 1, shifted_idx)
mask = mask.unsqueeze(2).expand(-1, -1, vocab_size)
return torch.where(mask, gathered, torch.full_like(gathered, pad_value))
def compute_loss(self, model: PreTrainedModel, inputs: Dict[str, torch.Tensor], return_outputs=False, num_items_in_batch=None):
outputs = model(**inputs)
lm_loss = outputs.loss
if self.logits_dir:
teacher_logits = self._load_teacher_logits(
batch_size=inputs['input_ids'].size(0),
it=self.state.global_step,
dp_rank=torch.distributed.get_rank() if torch.distributed.is_initialized() else 0,
device=model.device,
no_model_batch={'label': inputs.get('labels', None)}
)
distil_loss = self._compute_white_box_distillation_loss(
student_logits=outputs.logits,
teacher_logits=teacher_logits,
labels=inputs.get('labels', None)
)
total_loss = (1 - self.kd_ratio) * lm_loss + self.kd_ratio * distil_loss
else:
total_loss = lm_loss
return (total_loss, outputs) if return_outputs else total_loss
def train(config): def train(config):
dataset = load_dataset("json", data_files=config["dataset"]["labeled_path"]) with open(config["dataset"]["labeled_path"], "r") as f:
dataset = dataset.shuffle(seed=config["dataset"]["seed"])["train"] raw_data = json.load(f)
dataset = MMDataset(raw_data)
student_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( student_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
config["models"]["student"], config["models"]["student"],
trust_remote_code=True trust_remote_code=True
) )
processor = Qwen2_5_VLProcessor.from_pretrained(config["models"]["student"]) processor = Qwen2_5_VLProcessor.from_pretrained(config["models"]["student"])
training_arguments = SFTConfig(**config["training"])
training_arguments.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_arguments.remove_unused_columns = False
training_arguments.dataset_kwargs = {"skip_prepare_dataset": True}
def collate_fn(examples): def collate_fn(examples):
texts = [] texts = []
images = [] images = []
for example in examples: for example in examples:
chat = [
{ chat = example
"role": "user",
"content": [
{
"type": "image","image": example["image"]
},
{
"type": "text","text": example["instruction"]
}
]
},
{
"role": "assistant",
"content": example["output"]
}
]
text = processor.apply_chat_template(chat, tokenize=False) text = processor.apply_chat_template(chat, tokenize=False)
texts.append(text) texts.append(text)
image, _ = process_vision_info(chat)
image, _ = process_vision_info(example)
images.append(image) images.append(image)
batch = processor(text=texts, images=images, return_tensors="pt", padding=True) batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
@@ -75,10 +190,10 @@ def train(config):
batch["labels"] = labels batch["labels"] = labels
return batch return batch
training_arguments = SFTConfig(**config["training"]) try:
training_arguments.gradient_checkpointing_kwargs = dict(use_reentrant=False) job_type = config["job_type"]
training_arguments.remove_unused_columns = False if "mmkd_black_box" in job_type:
training_arguments.dataset_kwargs = {"skip_prepare_dataset": True}
trainer = SFTTrainer( trainer = SFTTrainer(
model=student_model, model=student_model,
@@ -87,6 +202,26 @@ def train(config):
args=training_arguments, args=training_arguments,
train_dataset=dataset train_dataset=dataset
) )
elif "mmkd_white_box" in job_type:
teacher_vocab_size=json.load(open(os.path.join(config["models"]["teacher"], 'config.json')))['vocab_size']
trainer = DistillSFTTrainer(
logits_dir=config["dataset"]["logits_path"],
data_collator=collate_fn,
teacher_vocab_size=teacher_vocab_size,
kd_ratio=config["distillation"]["kd_ratio"],
max_seq_length=config["distillation"]["max_seq_length"],
distillation_type=config["distillation"].get("distillation_type", "forward_kld"),
model=student_model,
processing_class=processor.tokenizer,
args=training_arguments,
train_dataset=dataset,
)
else:
logging.error(f"Invalid job type: {job_type}")
raise ValueError(f"Invalid job type: {job_type}")
except ValueError as e:
logging.error(f"Training job terminated: {e}")
return
trainer.train() trainer.train()
trainer.save_model(config["training"]["output_dir"]) trainer.save_model(config["training"]["output_dir"])

BIN
recipes/.DS_Store vendored Normal file

Binary file not shown.

BIN
recipes/distilqwen_series/.DS_Store vendored Normal file

Binary file not shown.

BIN
recipes/domain_specific/.DS_Store vendored Normal file

Binary file not shown.