init commit
This commit is contained in:
194
recipes/distilqwen_series/distillqwen2.5-r1/cogpo.py
Normal file
194
recipes/distilqwen_series/distillqwen2.5-r1/cogpo.py
Normal file
@@ -0,0 +1,194 @@
|
||||
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import trl
|
||||
from trl.trainer.dpo_trainer import DataCollatorForPreference
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Literal, Optional, Union
|
||||
import torch, torch.nn.functional as F
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import DPOConfig, DPOTrainer, FDivergenceConstants, FDivergenceType
|
||||
from trl.trainer.utils import cap_exp
|
||||
import json
|
||||
import argparse
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForPreferenceWithBeta(DataCollatorForPreference):
|
||||
def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]:
|
||||
betas = torch.tensor([float(ex["beta"]) for ex in examples], dtype=torch.float32)
|
||||
for ex in examples:
|
||||
ex.pop("beta")
|
||||
batch = super().torch_call(examples)
|
||||
batch["betas"] = betas
|
||||
return batch
|
||||
|
||||
|
||||
class CogPOTrainer(DPOTrainer):
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model,
|
||||
batch,
|
||||
train_eval: str = "train",
|
||||
):
|
||||
metrics = {}
|
||||
|
||||
betas = batch.pop("betas").to(self.accelerator.device)
|
||||
model_output = self.concatenated_forward(model, batch)
|
||||
|
||||
if "ref_chosen_logps" in batch and "ref_rejected_logps" in batch:
|
||||
ref_chosen_logps = batch["ref_chosen_logps"]
|
||||
ref_rejected_logps = batch["ref_rejected_logps"]
|
||||
else:
|
||||
ref_chosen_logps, ref_rejected_logps = self.compute_ref_log_probs(batch)
|
||||
|
||||
losses, chosen_rewards, rejected_rewards = self._dpo_sigmoid_loss(
|
||||
model_output["chosen_logps"],
|
||||
model_output["rejected_logps"],
|
||||
ref_chosen_logps,
|
||||
ref_rejected_logps,
|
||||
betas,
|
||||
)
|
||||
|
||||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||
|
||||
if self.args.rpo_alpha is not None:
|
||||
losses = losses + self.args.rpo_alpha * model_output["nll_loss"]
|
||||
|
||||
if self.use_weighting:
|
||||
losses = losses * model_output["policy_weights"]
|
||||
|
||||
if self.aux_loss_enabled:
|
||||
losses = losses + self.aux_loss_coef * model_output["aux_loss"]
|
||||
|
||||
prefix = "eval_" if train_eval == "eval" else ""
|
||||
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
|
||||
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
|
||||
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
|
||||
metrics[f"{prefix}rewards/margins"] = (
|
||||
self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item()
|
||||
)
|
||||
metrics[f"{prefix}logps/chosen"] = (
|
||||
self.accelerator.gather_for_metrics(model_output["chosen_logps"]).detach().mean().item()
|
||||
)
|
||||
metrics[f"{prefix}logps/rejected"] = (
|
||||
self.accelerator.gather_for_metrics(model_output["rejected_logps"]).detach().mean().item()
|
||||
)
|
||||
metrics[f"{prefix}logits/chosen"] = (
|
||||
self.accelerator.gather_for_metrics(model_output["mean_chosen_logits"]).detach().mean().item()
|
||||
)
|
||||
metrics[f"{prefix}logits/rejected"] = (
|
||||
self.accelerator.gather_for_metrics(model_output["mean_rejected_logits"]).detach().mean().item()
|
||||
)
|
||||
if self.args.rpo_alpha is not None:
|
||||
metrics[f"{prefix}nll_loss"] = (
|
||||
self.accelerator.gather_for_metrics(model_output["nll_loss"]).detach().mean().item()
|
||||
)
|
||||
if self.aux_loss_enabled:
|
||||
metrics[f"{prefix}aux_loss"] = (
|
||||
self.accelerator.gather_for_metrics(model_output["aux_loss"]).detach().mean().item()
|
||||
)
|
||||
|
||||
return losses.mean(), metrics
|
||||
|
||||
def _dpo_sigmoid_loss(
|
||||
self,
|
||||
chosen_logps: torch.FloatTensor,
|
||||
rejected_logps: torch.FloatTensor,
|
||||
ref_chosen_logps: torch.FloatTensor,
|
||||
ref_rejected_logps: torch.FloatTensor,
|
||||
betas: torch.FloatTensor,
|
||||
):
|
||||
|
||||
device = self.accelerator.device
|
||||
chosen_logratios = chosen_logps.to(device) - (not self.reference_free) * ref_chosen_logps.to(device)
|
||||
rejected_logratios = rejected_logps.to(device) - (not self.reference_free) * ref_rejected_logps.to(device)
|
||||
|
||||
# 2) Δ = (log p_c - log p_r) - (log p̂_c - log p̂_r)
|
||||
if self.f_divergence_type == FDivergenceType.ALPHA_DIVERGENCE.value:
|
||||
alpha_coef = FDivergenceConstants.ALPHA_DIVERGENCE_COEF_DEFAULT
|
||||
if self.f_divergence_params and FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY in self.f_divergence_params:
|
||||
alpha_coef = float(self.f_divergence_params[FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY])
|
||||
logits = (cap_exp(rejected_logratios * -alpha_coef) - cap_exp(chosen_logratios * -alpha_coef)) / alpha_coef
|
||||
else:
|
||||
logratios = chosen_logps - rejected_logps
|
||||
if self.reference_free:
|
||||
ref_logratios = torch.tensor([0], dtype=logratios.dtype, device=logratios.device)
|
||||
else:
|
||||
ref_logratios = ref_chosen_logps - ref_rejected_logps
|
||||
logratios = logratios.to(self.accelerator.device)
|
||||
ref_logratios = ref_logratios.to(self.accelerator.device)
|
||||
logits = logratios - ref_logratios
|
||||
|
||||
if self.f_divergence_type == FDivergenceType.JS_DIVERGENCE.value:
|
||||
logits -= F.softplus(chosen_logratios) - F.softplus(rejected_logratios)
|
||||
|
||||
|
||||
losses = (
|
||||
-F.logsigmoid(betas * logits) * (1 - self.label_smoothing)
|
||||
- F.logsigmoid(-betas * logits) * self.label_smoothing
|
||||
)
|
||||
|
||||
chosen_rewards = betas * (chosen_logps.to(device) - ref_chosen_logps.to(device)).detach()
|
||||
rejected_rewards = betas * (rejected_logps.to(device) - ref_rejected_logps.to(device)).detach()
|
||||
|
||||
|
||||
return losses, chosen_rewards, rejected_rewards
|
||||
|
||||
|
||||
def train(config):
|
||||
model_name = config["models"]["student"]
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
dataset = load_dataset("json", data_files=config["dataset"]["labeled_path"], split='train')
|
||||
|
||||
dpo_args = DPOConfig(
|
||||
output_dir=config["training"]["output_dir"],
|
||||
num_train_epochs=config["training"]["num_train_epochs"],
|
||||
loss_type=config["training"]["loss_type"],
|
||||
beta=config["training"]["beta"],
|
||||
per_device_train_batch_size=config["training"]["per_device_train_batch_size"],
|
||||
remove_unused_columns=False,
|
||||
)
|
||||
|
||||
collator = DataCollatorForPreferenceWithBeta(
|
||||
pad_token_id=tokenizer.pad_token_id
|
||||
)
|
||||
|
||||
trainer = CogPOTrainer(
|
||||
model=model,
|
||||
args=dpo_args,
|
||||
train_dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=collator,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
trainer.save_model(config["training"]["output_dir"])
|
||||
tokenizer.save_pretrained(config["training"]["output_dir"])
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
|
||||
args = parser.parse_args()
|
||||
config = json.load(open(args.config))
|
||||
train(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user