add lora to training
This commit is contained in:
@@ -30,10 +30,11 @@ from transformers import (
|
|||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
|
AutoConfig
|
||||||
)
|
)
|
||||||
from qwen_vl_utils import process_vision_info
|
from qwen_vl_utils import process_vision_info
|
||||||
from trl import SFTTrainer, SFTConfig
|
from trl import SFTTrainer, SFTConfig
|
||||||
|
from peft import LoraConfig
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||||
@@ -73,10 +74,6 @@ class DistillSFTTrainer(SFTTrainer):
|
|||||||
self.kd_ratio = kd_ratio
|
self.kd_ratio = kd_ratio
|
||||||
self.max_seq_length = max_seq_length
|
self.max_seq_length = max_seq_length
|
||||||
self.distillation_type = distillation_type
|
self.distillation_type = distillation_type
|
||||||
self.teacher_logits = []
|
|
||||||
with jsonlines.open(self.logits_dir) as reader:
|
|
||||||
for obj in reader:
|
|
||||||
self.teacher_logits.append(obj)
|
|
||||||
|
|
||||||
def _load_teacher_logits(
|
def _load_teacher_logits(
|
||||||
self,
|
self,
|
||||||
@@ -88,7 +85,16 @@ class DistillSFTTrainer(SFTTrainer):
|
|||||||
):
|
):
|
||||||
start_idx = dp_rank * batch_size + batch_size * it
|
start_idx = dp_rank * batch_size + batch_size * it
|
||||||
end_idx = dp_rank * batch_size + batch_size * (it + 1)
|
end_idx = dp_rank * batch_size + batch_size * (it + 1)
|
||||||
loaded_data = self.teacher_logits[start_idx:end_idx]
|
|
||||||
|
loaded_data = []
|
||||||
|
# Open file and read only the specific lines needed for the current batch
|
||||||
|
with jsonlines.open(self.logits_dir) as reader:
|
||||||
|
for i, obj in enumerate(reader):
|
||||||
|
if i >= start_idx and i < end_idx:
|
||||||
|
loaded_data.append(obj)
|
||||||
|
elif i >= end_idx:
|
||||||
|
break
|
||||||
|
|
||||||
arr = np.zeros((batch_size, self.max_seq_length, self.teacher_vocab_size))
|
arr = np.zeros((batch_size, self.max_seq_length, self.teacher_vocab_size))
|
||||||
for i in range(len(loaded_data)):
|
for i in range(len(loaded_data)):
|
||||||
for j in range(len(loaded_data[i])):
|
for j in range(len(loaded_data[i])):
|
||||||
@@ -117,6 +123,8 @@ class DistillSFTTrainer(SFTTrainer):
|
|||||||
else torch.ones_like(student_logits[:, :, 0])
|
else torch.ones_like(student_logits[:, :, 0])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mask = mask[:, : self.max_seq_length]
|
||||||
|
|
||||||
if self.distillation_type == "forward_kld":
|
if self.distillation_type == "forward_kld":
|
||||||
# Forward KLD: student learns from teacher (original implementation)
|
# Forward KLD: student learns from teacher (original implementation)
|
||||||
loss = F.kl_div(
|
loss = F.kl_div(
|
||||||
@@ -197,9 +205,23 @@ def train(config):
|
|||||||
raw_data = json.load(f)
|
raw_data = json.load(f)
|
||||||
dataset = MMDataset(raw_data)
|
dataset = MMDataset(raw_data)
|
||||||
student_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
student_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||||
config["models"]["student"], trust_remote_code=True
|
config["models"]["student"],
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
attn_implementation="flash_attention_2",
|
||||||
|
trust_remote_code=True,
|
||||||
|
device_map="auto",
|
||||||
)
|
)
|
||||||
processor = Qwen2_5_VLProcessor.from_pretrained(config["models"]["student"])
|
processor = Qwen2_5_VLProcessor.from_pretrained(config["models"]["student"])
|
||||||
|
|
||||||
|
# Creating LoRA configuration
|
||||||
|
lora_config = LoraConfig(
|
||||||
|
r=16, # Rank of the LoRA layers
|
||||||
|
lora_alpha=32, # Scaling factor for the LoRA layers
|
||||||
|
lora_dropout=0.1, # Dropout rate for the LoRA layers
|
||||||
|
bias="none", # No bias in LoRA layers
|
||||||
|
task_type="CAUSAL_LM", # Task type for the LoRA layers
|
||||||
|
target_modules=["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "o_proj"], # Target modules for LoRA
|
||||||
|
)
|
||||||
|
|
||||||
training_arguments = SFTConfig(**config["training"])
|
training_arguments = SFTConfig(**config["training"])
|
||||||
training_arguments.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
training_arguments.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||||
@@ -241,14 +263,18 @@ def train(config):
|
|||||||
trainer = SFTTrainer(
|
trainer = SFTTrainer(
|
||||||
model=student_model,
|
model=student_model,
|
||||||
data_collator=collate_fn,
|
data_collator=collate_fn,
|
||||||
processing_class=processor.tokenizer,
|
tokenizer=processor.tokenizer,
|
||||||
args=training_arguments,
|
args=training_arguments,
|
||||||
train_dataset=dataset,
|
train_dataset=dataset,
|
||||||
|
peft_config=lora_config,
|
||||||
)
|
)
|
||||||
elif "mmkd_white_box" in job_type:
|
elif "mmkd_white_box" in job_type:
|
||||||
teacher_vocab_size = json.load(
|
teacher_config = AutoConfig.from_pretrained(
|
||||||
open(os.path.join(config["models"]["teacher"], "config.json"))
|
config["models"]["teacher"],
|
||||||
)["vocab_size"]
|
trust_remote_code=True
|
||||||
|
)
|
||||||
|
teacher_vocab_size = teacher_config.vocab_size
|
||||||
|
|
||||||
trainer = DistillSFTTrainer(
|
trainer = DistillSFTTrainer(
|
||||||
logits_dir=config["dataset"]["logits_path"],
|
logits_dir=config["dataset"]["logits_path"],
|
||||||
data_collator=collate_fn,
|
data_collator=collate_fn,
|
||||||
@@ -259,7 +285,8 @@ def train(config):
|
|||||||
"distillation_type", "forward_kld"
|
"distillation_type", "forward_kld"
|
||||||
),
|
),
|
||||||
model=student_model,
|
model=student_model,
|
||||||
processing_class=processor.tokenizer,
|
peft_config=lora_config,
|
||||||
|
tokenizer=processor.tokenizer,
|
||||||
args=training_arguments,
|
args=training_arguments,
|
||||||
train_dataset=dataset,
|
train_dataset=dataset,
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user