add lora to training
This commit is contained in:
@@ -30,10 +30,11 @@ from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
TrainingArguments,
|
||||
AutoConfig
|
||||
)
|
||||
from qwen_vl_utils import process_vision_info
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
|
||||
from peft import LoraConfig
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
@@ -73,10 +74,6 @@ class DistillSFTTrainer(SFTTrainer):
|
||||
self.kd_ratio = kd_ratio
|
||||
self.max_seq_length = max_seq_length
|
||||
self.distillation_type = distillation_type
|
||||
self.teacher_logits = []
|
||||
with jsonlines.open(self.logits_dir) as reader:
|
||||
for obj in reader:
|
||||
self.teacher_logits.append(obj)
|
||||
|
||||
def _load_teacher_logits(
|
||||
self,
|
||||
@@ -88,7 +85,16 @@ class DistillSFTTrainer(SFTTrainer):
|
||||
):
|
||||
start_idx = dp_rank * batch_size + batch_size * it
|
||||
end_idx = dp_rank * batch_size + batch_size * (it + 1)
|
||||
loaded_data = self.teacher_logits[start_idx:end_idx]
|
||||
|
||||
loaded_data = []
|
||||
# Open file and read only the specific lines needed for the current batch
|
||||
with jsonlines.open(self.logits_dir) as reader:
|
||||
for i, obj in enumerate(reader):
|
||||
if i >= start_idx and i < end_idx:
|
||||
loaded_data.append(obj)
|
||||
elif i >= end_idx:
|
||||
break
|
||||
|
||||
arr = np.zeros((batch_size, self.max_seq_length, self.teacher_vocab_size))
|
||||
for i in range(len(loaded_data)):
|
||||
for j in range(len(loaded_data[i])):
|
||||
@@ -117,6 +123,8 @@ class DistillSFTTrainer(SFTTrainer):
|
||||
else torch.ones_like(student_logits[:, :, 0])
|
||||
)
|
||||
|
||||
mask = mask[:, : self.max_seq_length]
|
||||
|
||||
if self.distillation_type == "forward_kld":
|
||||
# Forward KLD: student learns from teacher (original implementation)
|
||||
loss = F.kl_div(
|
||||
@@ -197,9 +205,23 @@ def train(config):
|
||||
raw_data = json.load(f)
|
||||
dataset = MMDataset(raw_data)
|
||||
student_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
config["models"]["student"], trust_remote_code=True
|
||||
config["models"]["student"],
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
trust_remote_code=True,
|
||||
device_map="auto",
|
||||
)
|
||||
processor = Qwen2_5_VLProcessor.from_pretrained(config["models"]["student"])
|
||||
|
||||
# Creating LoRA configuration
|
||||
lora_config = LoraConfig(
|
||||
r=16, # Rank of the LoRA layers
|
||||
lora_alpha=32, # Scaling factor for the LoRA layers
|
||||
lora_dropout=0.1, # Dropout rate for the LoRA layers
|
||||
bias="none", # No bias in LoRA layers
|
||||
task_type="CAUSAL_LM", # Task type for the LoRA layers
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "o_proj"], # Target modules for LoRA
|
||||
)
|
||||
|
||||
training_arguments = SFTConfig(**config["training"])
|
||||
training_arguments.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||
@@ -241,14 +263,18 @@ def train(config):
|
||||
trainer = SFTTrainer(
|
||||
model=student_model,
|
||||
data_collator=collate_fn,
|
||||
processing_class=processor.tokenizer,
|
||||
tokenizer=processor.tokenizer,
|
||||
args=training_arguments,
|
||||
train_dataset=dataset,
|
||||
peft_config=lora_config,
|
||||
)
|
||||
elif "mmkd_white_box" in job_type:
|
||||
teacher_vocab_size = json.load(
|
||||
open(os.path.join(config["models"]["teacher"], "config.json"))
|
||||
)["vocab_size"]
|
||||
teacher_config = AutoConfig.from_pretrained(
|
||||
config["models"]["teacher"],
|
||||
trust_remote_code=True
|
||||
)
|
||||
teacher_vocab_size = teacher_config.vocab_size
|
||||
|
||||
trainer = DistillSFTTrainer(
|
||||
logits_dir=config["dataset"]["logits_path"],
|
||||
data_collator=collate_fn,
|
||||
@@ -259,7 +285,8 @@ def train(config):
|
||||
"distillation_type", "forward_kld"
|
||||
),
|
||||
model=student_model,
|
||||
processing_class=processor.tokenizer,
|
||||
peft_config=lora_config,
|
||||
tokenizer=processor.tokenizer,
|
||||
args=training_arguments,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
Reference in New Issue
Block a user