add lora to training

This commit is contained in:
2025-08-20 10:13:19 +00:00
parent 228fa8c81b
commit 4110d9e12a

View File

@@ -30,10 +30,11 @@ from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
AutoConfig
)
from qwen_vl_utils import process_vision_info
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
@@ -73,10 +74,6 @@ class DistillSFTTrainer(SFTTrainer):
self.kd_ratio = kd_ratio
self.max_seq_length = max_seq_length
self.distillation_type = distillation_type
self.teacher_logits = []
with jsonlines.open(self.logits_dir) as reader:
for obj in reader:
self.teacher_logits.append(obj)
def _load_teacher_logits(
self,
@@ -88,7 +85,16 @@ class DistillSFTTrainer(SFTTrainer):
):
start_idx = dp_rank * batch_size + batch_size * it
end_idx = dp_rank * batch_size + batch_size * (it + 1)
loaded_data = self.teacher_logits[start_idx:end_idx]
loaded_data = []
# Open file and read only the specific lines needed for the current batch
with jsonlines.open(self.logits_dir) as reader:
for i, obj in enumerate(reader):
if i >= start_idx and i < end_idx:
loaded_data.append(obj)
elif i >= end_idx:
break
arr = np.zeros((batch_size, self.max_seq_length, self.teacher_vocab_size))
for i in range(len(loaded_data)):
for j in range(len(loaded_data[i])):
@@ -117,6 +123,8 @@ class DistillSFTTrainer(SFTTrainer):
else torch.ones_like(student_logits[:, :, 0])
)
mask = mask[:, : self.max_seq_length]
if self.distillation_type == "forward_kld":
# Forward KLD: student learns from teacher (original implementation)
loss = F.kl_div(
@@ -197,9 +205,23 @@ def train(config):
raw_data = json.load(f)
dataset = MMDataset(raw_data)
student_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
config["models"]["student"], trust_remote_code=True
config["models"]["student"],
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
trust_remote_code=True,
device_map="auto",
)
processor = Qwen2_5_VLProcessor.from_pretrained(config["models"]["student"])
# Creating LoRA configuration
lora_config = LoraConfig(
r=16, # Rank of the LoRA layers
lora_alpha=32, # Scaling factor for the LoRA layers
lora_dropout=0.1, # Dropout rate for the LoRA layers
bias="none", # No bias in LoRA layers
task_type="CAUSAL_LM", # Task type for the LoRA layers
target_modules=["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "o_proj"], # Target modules for LoRA
)
training_arguments = SFTConfig(**config["training"])
training_arguments.gradient_checkpointing_kwargs = dict(use_reentrant=False)
@@ -241,14 +263,18 @@ def train(config):
trainer = SFTTrainer(
model=student_model,
data_collator=collate_fn,
processing_class=processor.tokenizer,
tokenizer=processor.tokenizer,
args=training_arguments,
train_dataset=dataset,
peft_config=lora_config,
)
elif "mmkd_white_box" in job_type:
teacher_vocab_size = json.load(
open(os.path.join(config["models"]["teacher"], "config.json"))
)["vocab_size"]
teacher_config = AutoConfig.from_pretrained(
config["models"]["teacher"],
trust_remote_code=True
)
teacher_vocab_size = teacher_config.vocab_size
trainer = DistillSFTTrainer(
logits_dir=config["dataset"]["logits_path"],
data_collator=collate_fn,
@@ -259,7 +285,8 @@ def train(config):
"distillation_type", "forward_kld"
),
model=student_model,
processing_class=processor.tokenizer,
peft_config=lora_config,
tokenizer=processor.tokenizer,
args=training_arguments,
train_dataset=dataset,
)