diff --git a/easydistill/mmkd/train_lora.py b/easydistill/mmkd/train_lora.py index 95609b0..b29eb22 100644 --- a/easydistill/mmkd/train_lora.py +++ b/easydistill/mmkd/train_lora.py @@ -30,10 +30,11 @@ from transformers import ( AutoModelForCausalLM, AutoTokenizer, TrainingArguments, + AutoConfig ) from qwen_vl_utils import process_vision_info from trl import SFTTrainer, SFTConfig - +from peft import LoraConfig logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" @@ -73,10 +74,6 @@ class DistillSFTTrainer(SFTTrainer): self.kd_ratio = kd_ratio self.max_seq_length = max_seq_length self.distillation_type = distillation_type - self.teacher_logits = [] - with jsonlines.open(self.logits_dir) as reader: - for obj in reader: - self.teacher_logits.append(obj) def _load_teacher_logits( self, @@ -88,7 +85,16 @@ class DistillSFTTrainer(SFTTrainer): ): start_idx = dp_rank * batch_size + batch_size * it end_idx = dp_rank * batch_size + batch_size * (it + 1) - loaded_data = self.teacher_logits[start_idx:end_idx] + + loaded_data = [] + # Open file and read only the specific lines needed for the current batch + with jsonlines.open(self.logits_dir) as reader: + for i, obj in enumerate(reader): + if i >= start_idx and i < end_idx: + loaded_data.append(obj) + elif i >= end_idx: + break + arr = np.zeros((batch_size, self.max_seq_length, self.teacher_vocab_size)) for i in range(len(loaded_data)): for j in range(len(loaded_data[i])): @@ -117,6 +123,8 @@ class DistillSFTTrainer(SFTTrainer): else torch.ones_like(student_logits[:, :, 0]) ) + mask = mask[:, : self.max_seq_length] + if self.distillation_type == "forward_kld": # Forward KLD: student learns from teacher (original implementation) loss = F.kl_div( @@ -197,9 +205,23 @@ def train(config): raw_data = json.load(f) dataset = MMDataset(raw_data) student_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( - config["models"]["student"], trust_remote_code=True + config["models"]["student"], + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + trust_remote_code=True, + device_map="auto", ) processor = Qwen2_5_VLProcessor.from_pretrained(config["models"]["student"]) + + # Creating LoRA configuration + lora_config = LoraConfig( + r=16, # Rank of the LoRA layers + lora_alpha=32, # Scaling factor for the LoRA layers + lora_dropout=0.1, # Dropout rate for the LoRA layers + bias="none", # No bias in LoRA layers + task_type="CAUSAL_LM", # Task type for the LoRA layers + target_modules=["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "o_proj"], # Target modules for LoRA + ) training_arguments = SFTConfig(**config["training"]) training_arguments.gradient_checkpointing_kwargs = dict(use_reentrant=False) @@ -241,14 +263,18 @@ def train(config): trainer = SFTTrainer( model=student_model, data_collator=collate_fn, - processing_class=processor.tokenizer, + tokenizer=processor.tokenizer, args=training_arguments, train_dataset=dataset, + peft_config=lora_config, ) elif "mmkd_white_box" in job_type: - teacher_vocab_size = json.load( - open(os.path.join(config["models"]["teacher"], "config.json")) - )["vocab_size"] + teacher_config = AutoConfig.from_pretrained( + config["models"]["teacher"], + trust_remote_code=True + ) + teacher_vocab_size = teacher_config.vocab_size + trainer = DistillSFTTrainer( logits_dir=config["dataset"]["logits_path"], data_collator=collate_fn, @@ -259,7 +285,8 @@ def train(config): "distillation_type", "forward_kld" ), model=student_model, - processing_class=processor.tokenizer, + peft_config=lora_config, + tokenizer=processor.tokenizer, args=training_arguments, train_dataset=dataset, )