add
This commit is contained in:
49
models/LLaVA/build/lib/llava/train/llava_trainer.py
Normal file
49
models/LLaVA/build/lib/llava/train/llava_trainer.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from transformers import Trainer
|
||||
from typing import Dict, Optional, Sequence
|
||||
|
||||
|
||||
def unwrap_model(model: nn.Module) -> nn.Module:
|
||||
"""
|
||||
Recursively unwraps a model from potential containers (as used in distributed training).
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`): The model to unwrap.
|
||||
"""
|
||||
# since there could be multiple levels of wrapping, unwrap recursively
|
||||
if hasattr(model, "module"):
|
||||
return unwrap_model(model.module)
|
||||
else:
|
||||
return model
|
||||
|
||||
|
||||
class LLaVATrainer(Trainer):
|
||||
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
||||
if getattr(self.args, 'tune_mm_mlp_adapter', False):
|
||||
# Save the model
|
||||
_state_dict = state_dict
|
||||
if _state_dict is None:
|
||||
# Only save the model itself if we are using distributed training
|
||||
model_to_save = unwrap_model(self.model)
|
||||
_state_dict = model_to_save.state_dict()
|
||||
|
||||
weight_to_save = {}
|
||||
keys_to_match = ['mm_projector', 'embed_tokens', 'embed_in']
|
||||
for k, v in _state_dict.items():
|
||||
if any(key_match in k for key_match in keys_to_match):
|
||||
weight_to_save[k] = v
|
||||
|
||||
current_folder = output_dir.split('/')[-1]
|
||||
parent_folder = os.path.dirname(output_dir)
|
||||
if current_folder.startswith('checkpoint-'):
|
||||
mm_projector_folder = os.path.join(parent_folder, "mm_projector")
|
||||
os.makedirs(mm_projector_folder, exist_ok=True)
|
||||
torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
|
||||
else:
|
||||
torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
|
||||
|
||||
super(LLaVATrainer, self)._save(output_dir, state_dict)
|
Reference in New Issue
Block a user