init commit

This commit is contained in:
熊兮
2025-05-27 18:55:46 +08:00
parent 6f52a67249
commit 25caa8a90a
65 changed files with 4893 additions and 1 deletions

14
easydistill/__init__.py Normal file
View File

@@ -0,0 +1,14 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

187
easydistill/cli.py Normal file
View File

@@ -0,0 +1,187 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
import subprocess
import sys
from socket import socket
import argparse
import json
import logging
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
script_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.abspath(os.path.join(script_dir, os.pardir))
def run_cmd(cmd):
try:
p = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, # Merge stderr into stdout
shell=True,
universal_newlines=True # Ensure output is in text mode
)
error_detected = False
error_keywords = [
"ERROR",
"Error",
"error"
"Unrecognized model",
"failed",
"exception",
"Traceback"
]
# Read output in real-time and detect errors
while True:
line = p.stdout.readline()
if not line:
break
logging.info(line.rstrip()) # Log normally
# Check if any error keywords are present
if any(keyword.lower() in line.lower() for keyword in error_keywords):
error_detected = True
logging.error(f"Detected error in output: {line.strip()}")
# Wait for process to finish
returncode = p.wait()
# If errors were detected or return code is non-zero, return False
if error_detected or returncode != 0:
logging.error(f"Command failed (returncode={returncode}, errors detected)")
return False
return True # Return True indicates success
except Exception as e:
logging.error(f"Unexpected error running command: {e}")
return False
def process(job_type, config):
if not os.path.isabs(config):
config = os.path.join(script_dir, config)
# Knowledge Distillation tasks
if job_type in ['kd_black_box_train_only', 'kd_white_box_train_only']:
cmd_train = [
'accelerate', 'launch',
'--config_file', os.path.join(parent_dir, 'configs/accelerate_config/muti_gpu.yaml'),
os.path.join(script_dir, 'kd/train.py'),
'--config', config
]
cmd_train = ' '.join(cmd_train)
logging.info(f"Running command: {cmd_train}")
run_cmd(cmd_train)
elif job_type in ['kd_black_box_api', 'kd_black_box_local', 'kd_white_box']:
cmd_infer = [
'python', os.path.join(script_dir, 'kd/infer.py'),
'--config', config
]
cmd_infer = ' '.join(cmd_infer)
logging.info(f"Running command: {cmd_infer}")
infer_success = run_cmd(cmd_infer)
if infer_success:
cmd_train = [
'accelerate', 'launch',
'--config_file', os.path.join(parent_dir, 'configs/accelerate_config/muti_gpu.yaml'),
os.path.join(script_dir, 'kd/train.py'),
'--config', config
]
cmd_train = ' '.join(cmd_train)
logging.info(f"Running command: {cmd_train}")
run_cmd(cmd_train)
else:
logging.error("Infer failed, skipping training")
# Reinforcement Learning tasks
elif job_type in ['rl_ppo', 'rl_grpo']:
cmd = [
'accelerate', 'launch',
'--config_file', os.path.join(parent_dir, 'configs/accelerate_config/muti_gpu.yaml'),
os.path.join(script_dir, f'rl/{job_type.split("_")[1]}.py'),
'--config', config
]
cmd = ' '.join(cmd)
logging.info(f"Running command: {cmd}")
run_cmd(cmd)
elif job_type in ['rl_reward_api', 'rl_reward_local']:
cmd = [
'python',
os.path.join(script_dir, 'rl/reward.py'),
'--config', config
]
cmd = ' '.join(cmd)
logging.info(f"Running command: {cmd}")
run_cmd(cmd)
# Instruction Processing tasks
elif job_type.startswith('instruction_'):
task_type = job_type.replace('instruction_', '')
cmd = [
'python',
os.path.join(script_dir, f'synthesis/synthesis_main.py'),
'--config', config
]
cmd = ' '.join(cmd)
logging.info(f"Running command: {cmd}")
run_cmd(cmd)
# Chain of Thought tasks
elif job_type.startswith('cot_'):
task_type = job_type.replace('cot_', '')
cmd = [
'python',
os.path.join(script_dir, f'synthesis/synthesis_main.py'),
'--config', config
]
cmd = ' '.join(cmd)
logging.info(f"Running command: {cmd}")
run_cmd(cmd)
# Ranking and DPO tasks
elif job_type.startswith('rank_'):
task_type = job_type.replace('rank_', '')
cmd = [
'python',
os.path.join(script_dir, f'rank/{task_type}.py'),
'--config', config
]
cmd = ' '.join(cmd)
logging.info(f"Running command: {cmd}")
run_cmd(cmd)
else:
logging.error(f"Unknown job type: {job_type}")
sys.exit(1)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
args = parser.parse_args()
config_path = args.config
config = json.load(open(config_path))
job_type = config["job_type"]
process(job_type, config_path)
if __name__ == '__main__':
main()

247
easydistill/kd/infer.py Normal file
View File

@@ -0,0 +1,247 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json, jsonlines
import argparse
import torch
import logging
import os
from jinja2 import Environment, FileSystemLoader
from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm import LLM, SamplingParams
from tqdm import tqdm
from openai import OpenAI
import math
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def read_json_field(filename, field_name='instruction'):
try:
with open(filename, 'r') as file:
data = json.load(file)
output_fields = []
for item in data:
if field_name in item:
output_fields.append(item[field_name])
return output_fields
except FileNotFoundError:
logging.error("The file was not found.")
except json.JSONDecodeError:
logging.error("There was an error decoding the JSON file.")
except Exception as e:
logging.error(f"An error occurred: {e}")
def write_data_to_json_file(data, file_path):
try:
with open(file_path, 'w') as file:
json.dump(data, file, ensure_ascii=False, indent=4)
logging.info(f"Data successfully written to {file_path}")
except Exception as e:
logging.error(f"An error occurred: {e}")
def load_tokenizer_and_vllm(config, eos_token=None):
teacher_model_path = config["models"]["teacher"]
logging.info(f"Loading ckpt and tokenizer: {teacher_model_path}")
tokenizer = AutoTokenizer.from_pretrained(teacher_model_path, trust_remote_code=True)
tokenizer.padding_side = "left"
if eos_token:
eos_token_id = tokenizer.convert_tokens_to_ids(eos_token)
logging.info(f"eos_token {eos_token} from user input")
elif hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id:
logging.info(f"Initial eos_token_id {tokenizer.eos_token_id} from tokenizer")
eos_token_id = tokenizer.eos_token_id
eos_token = tokenizer.convert_ids_to_tokens(eos_token_id)
else:
raise ValueError("No available eos_token or eos_token_id.")
try:
tokenizer.eos_token = eos_token
tokenizer.eos_token_id = eos_token_id
tokenizer.pad_token = eos_token
tokenizer.pad_token_id = eos_token_id
except:
logging.info(f"[WARNING] Cannot set tokenizer.eos_token")
logging.info(f"tokenizer's eos_token: {tokenizer.eos_token}, pad_token: {tokenizer.pad_token}")
logging.info(f"tokenizer's eos_token_id: {tokenizer.eos_token_id}, pad_token_id: {tokenizer.pad_token_id}")
num_gpus = torch.cuda.device_count()
llm = LLM(
model=teacher_model_path,
tensor_parallel_size=num_gpus,
enable_chunked_prefill=config["inference"]["enable_chunked_prefill"],
gpu_memory_utilization=config["inference"]["gpu_memory_utilization"],
trust_remote_code=config["inference"]["trust_remote_code"],
dtype=torch.bfloat16,
enforce_eager=config["inference"]["enforce_eager"],
max_model_len=config["inference"]["max_model_len"],
)
logging.info("vLLM model loaded successfully")
return tokenizer, llm
def generate_teacher_response_batch(tokenizer, llm, data_list, config, batch_size=32):
full_path = config["dataset"]["template"]
template_dir = os.path.dirname(full_path)
template_file = os.path.basename(full_path)
env = Environment(loader=FileSystemLoader(template_dir))
template = env.get_template(template_file)
outcomes = []
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
for batch in tqdm(batches, desc="Generating responses"):
new_batch = []
for sample in batch:
message = {"role": "user", "content": sample}
full_text = template.render(
message = message,
add_generation_prompt = True,
add_output = False
)
new_batch.append(full_text)
outputs = llm.generate(
new_batch,
SamplingParams(
n = 1,
top_k = 1,
temperature = config["inference"]["temperature"],
seed = config["inference"]["seed"],
skip_special_tokens = False,
ignore_eos = False,
max_tokens = config["inference"]["max_new_tokens"]
)
)
responses = [output.outputs[0].text for output in outputs]
gen_data = [{'instruction': batch[i], 'output': responses[i]} for i in range(len(batch))]
outcomes = outcomes + gen_data
write_data_to_json_file(outcomes, config["dataset"]["labeled_path"])
def generate_teacher_logits_batch(tokenizer, llm, data_list, config, batch_size=32):
full_path = config["dataset"]["template"]
template_dir = os.path.dirname(full_path)
template_file = os.path.basename(full_path)
env = Environment(loader=FileSystemLoader(template_dir))
template = env.get_template(template_file)
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
for batch in tqdm(batches, desc="Generating responses"):
new_batch = []
for sample in batch:
message={"role": "user", "content": sample}
full_text = template.render(
message=message,
add_generation_prompt=True,
add_output=False
)
new_batch.append(full_text)
outputs = llm.generate(
new_batch, # Pass the raw text directly
SamplingParams(
n=1,
top_k=1,
temperature=config["inference"]["temperature"],
seed=config["inference"]["seed"],
skip_special_tokens=False,
ignore_eos=True,
max_tokens=config["inference"]["max_new_tokens"],
logprobs=config["inference"]["top_logits_num"],
)
)
# Extract the generated logits
responses = [output.outputs[0].text for output in outputs]
logits=[output.outputs[0].logprobs for output in outputs]
for logit in logits:
for pos in logit:
for k,v in pos.items():
pos[k]=math.exp(v.logprob)
with jsonlines.open(config["dataset"]["logits_path"], mode='a') as writer:
for row in logits:
#for item in row:
writer.write(row)
def generate_teacher_response_api(data_list, config):
client = OpenAI(
api_key = config["inference"]["api_key"],
base_url = config["inference"]["base_url"]
)
models = client.models.list()
model = models.data[0].id
logging.info(model)
system_prompt = config["inference"]["system_prompt"]
stream = config["inference"]["stream"]
outcomes = []
for sample in tqdm(data_list, desc="Call remote model and generating responses"):
if system_prompt == "":
message = [
{'role': 'user', 'content': sample}
]
else:
message = [
{'role': 'system', 'content': system_prompt},
{'role': 'user', 'content': sample}
]
completion = client.chat.completions.create(
messages = message,
model = model,
max_completion_tokens = config["inference"]["max_new_tokens"],
stream = stream
)
if stream:
result = ""
for chunk in completion:
result += chunk.choices[0].delta.content
else:
result = completion.choices[0].message.content
outcomes.append({'instruction': sample, 'output': result})
write_data_to_json_file(outcomes, config["dataset"]["labeled_path"])
def infer_with_teacher_model(config):
logging.info('Generating distillation data from the teacher model!')
data_list = read_json_field(config["dataset"]["instruction_path"])
try:
job_type = config["job_type"]
if job_type == "kd_black_box_api":
generate_teacher_response_api(data_list, config)
elif job_type == "kd_black_box_local":
tokenizer, llm = load_tokenizer_and_vllm(config)
generate_teacher_response_batch(tokenizer, llm, data_list, config)
elif job_type == "kd_white_box":
tokenizer, llm = load_tokenizer_and_vllm(config)
generate_teacher_logits_batch(tokenizer, llm, data_list, config)
else:
logging.error(f"Invalid job type: {job_type}")
raise ValueError(f"Invalid job type: {job_type}")
except ValueError as e:
logging.error(f"Training job terminated: {e}")
return
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
args = parser.parse_args()
config = json.load(open(args.config))
infer_with_teacher_model(config)
if __name__ == "__main__":
main()

218
easydistill/kd/train.py Normal file
View File

@@ -0,0 +1,218 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json
import argparse
import logging
import os
from jinja2 import Environment, BaseLoader, FileSystemLoader
from datasets import load_dataset,Dataset
from typing import Optional, Dict, Union, List
from datasets import Dataset
from transformers import PreTrainedModel, PreTrainedTokenizerBase,AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import SFTTrainer,SFTConfig
import torch
import jsonlines
import numpy as np
import torch.nn.functional as F
class DistillSFTTrainer(SFTTrainer):
def __init__(
self,
logits_dir: str = None,
teacher_vocab_size = None,
kd_ratio: float = 0.5,
max_seq_length : int = 1024,
distillation_type: str = "forward_kld",
**kwargs
):
super().__init__(**kwargs)
self.logits_dir = logits_dir
self.teacher_vocab_size = teacher_vocab_size
self.kd_ratio = kd_ratio
self.max_seq_length = max_seq_length
self.distillation_type = distillation_type
self.teacher_logits = []
with jsonlines.open(self.logits_dir) as reader:
for obj in reader:
self.teacher_logits.append(obj)
def _load_teacher_logits(self, batch_size: int, it: int, dp_rank: int, device: torch.device, no_model_batch: Dict):
start_idx = dp_rank * batch_size + batch_size * it
end_idx = dp_rank * batch_size + batch_size * (it + 1)
loaded_data = self.teacher_logits[start_idx:end_idx]
arr = np.zeros((batch_size, self.max_seq_length, self.teacher_vocab_size))
for i in range(len(loaded_data)):
for j in range(len(loaded_data[i])):
keys = np.array(list(loaded_data[i][j].keys()), dtype=int)
values = np.array(list(loaded_data[i][j].values()))
arr[i, j, keys] = values
logits_tensor = torch.tensor(arr, dtype=torch.bfloat16, device=device)
return self._shift_tensor_right(logits_tensor, no_model_batch['label'], pad_value=0)
def _compute_white_box_distillation_loss(self, student_logits: torch.Tensor, teacher_logits: torch.Tensor, labels: Optional[torch.Tensor]):
student_logits = student_logits[:, :self.max_seq_length, :]
teacher_probs = teacher_logits[:, :student_logits.size(1), :student_logits.size(-1)]
mask = (labels != -100).float() if labels is not None else torch.ones_like(student_logits[:, :, 0])
if self.distillation_type == "forward_kld":
# Forward KLD: student learns from teacher (original implementation)
loss = F.kl_div(
F.log_softmax(student_logits, dim=-1),
teacher_probs,
reduction='none',
log_target=False
).sum(dim=-1)/torch.sum(mask.view(-1), dim=0)
elif self.distillation_type == "reverse_kld":
# Reverse KLD: teacher provides certainty to student
loss = F.kl_div(
torch.log(teacher_probs.clamp(min=1e-10)), # avoid log(0)
F.softmax(student_logits, dim=-1),
reduction='none',
log_target=False
).sum(dim=-1)/torch.sum(mask.view(-1), dim=0)
else:
raise ValueError(f"Unsupported distillation type: {self.distillation_type}. Use 'forward_kld' or 'reverse_kld'")
return (loss * mask).sum() / mask.sum()
@staticmethod
def _shift_tensor_right(inputs: torch.Tensor, labels: torch.Tensor, pad_value: float = 0.0):
batch_size, seqlen, vocab_size = inputs.shape
device = inputs.device
labels_ne = labels != -100
shift_distances = torch.argmax(labels_ne.int(), dim=1)
idx = torch.arange(seqlen, device=device).unsqueeze(0).expand(batch_size, seqlen)
shifted_idx = idx - shift_distances.unsqueeze(1)
mask = shifted_idx >= 0
shifted_idx = shifted_idx.clamp(min=0)
inputs_flat = inputs.view(batch_size, seqlen, vocab_size)
shifted_idx = shifted_idx.unsqueeze(2).expand(-1, -1, vocab_size)
gathered = torch.gather(inputs_flat, 1, shifted_idx)
mask = mask.unsqueeze(2).expand(-1, -1, vocab_size)
return torch.where(mask, gathered, torch.full_like(gathered, pad_value))
def compute_loss(self, model: PreTrainedModel, inputs: Dict[str, torch.Tensor], return_outputs=False, num_items_in_batch=None):
outputs = model(**inputs)
lm_loss = outputs.loss
if self.logits_dir:
teacher_logits = self._load_teacher_logits(
batch_size=inputs['input_ids'].size(0),
it=self.state.global_step,
dp_rank=torch.distributed.get_rank() if torch.distributed.is_initialized() else 0,
device=model.device,
no_model_batch={'label': inputs.get('labels', None)}
)
distil_loss = self._compute_white_box_distillation_loss(
student_logits=outputs.logits,
teacher_logits=teacher_logits,
labels=inputs.get('labels', None)
)
total_loss = (1 - self.kd_ratio) * lm_loss + self.kd_ratio * distil_loss
else:
total_loss = lm_loss
return (total_loss, outputs) if return_outputs else total_loss
def formatting_func(examples):
env = Environment(loader=BaseLoader())
try:
message = {"content": examples["instruction"],"output":examples["output"]}
full_text = template.render(
message=message,
add_generation_prompt=False,
add_output=True
)
return full_text
except Exception as e:
logging.warning(f"Error processing sample: {str(e)}")
return ""
def train(config):
dataset = load_dataset("json", data_files=config["dataset"]["labeled_path"])
student_tokenizer = AutoTokenizer.from_pretrained(
config["models"]["student"],
trust_remote_code=True
)
student_model = AutoModelForCausalLM.from_pretrained(
config["models"]["student"],
trust_remote_code=True
)
global template
full_path = config["dataset"]["template"]
template_dir = os.path.dirname(full_path)
template_file = os.path.basename(full_path)
env = Environment(loader=FileSystemLoader(template_dir))
template = env.get_template(template_file)
training_arguments = SFTConfig(**config["training"])
try:
job_type = config["job_type"]
if "kd_black_box" in job_type:
dataset = dataset.shuffle(seed=config["dataset"]["seed"])
trainer = SFTTrainer(
model=student_model,
processing_class=student_tokenizer,
args=training_arguments,
train_dataset=dataset["train"],
formatting_func=formatting_func
)
elif "kd_white_box" in job_type:
teacher_vocab_size=json.load(open(os.path.join(config["models"]["teacher"], 'config.json')))['vocab_size']
trainer = DistillSFTTrainer(
logits_dir=config["dataset"]["logits_path"],
teacher_vocab_size=teacher_vocab_size,
kd_ratio=config["distillation"]["kd_ratio"],
max_seq_length=config["distillation"]["max_seq_length"],
distillation_type=config["distillation"].get("distillation_type", "forward_kld"),
model=student_model,
processing_class=student_tokenizer,
args=training_arguments,
train_dataset=dataset["train"],
formatting_func=formatting_func
)
else:
logging.error(f"Invalid job type: {job_type}")
raise ValueError(f"Invalid job type: {job_type}")
except ValueError as e:
logging.error(f"Training job terminated: {e}")
return
trainer.train()
trainer.save_model(config["training"]["output_dir"])
student_tokenizer.save_pretrained(config["training"]["output_dir"])
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
args = parser.parse_args()
config = json.load(open(args.config))
train(config)
if __name__ == "__main__":
main()

262
easydistill/rank/infer.py Normal file
View File

@@ -0,0 +1,262 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json
import argparse
import logging
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm import LLM, SamplingParams
from jinja2 import Environment, FileSystemLoader
from tqdm import tqdm
from openai import OpenAI
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def read_json_field(filename, field_name='prompt'):
try:
with open(filename, 'r') as file:
data = json.load(file)
output_fields = []
for item in data:
if field_name in item:
output_fields.append(item[field_name])
return output_fields
except FileNotFoundError:
logging.error("The file was not found.")
except json.JSONDecodeError:
logging.error("There was an error decoding the JSON file.")
except Exception as e:
logging.error(f"An error occurred: {e}")
def write_data_to_json_file(data, file_path):
try:
with open(file_path, 'w') as file:
json.dump(data, file, ensure_ascii=False, indent=4)
logging.info(f"Data successfully written to {file_path}")
except Exception as e:
logging.error(f"An error occurred: {e}")
def load_tokenizer_and_vllm(config, eos_token=None, is_teacher_model=True):
if is_teacher_model:
model_path = config["models"]["teacher"]
else:
model_path = config["models"]["student"]
logging.info(f"Loading ckpt and tokenizer: {model_path}")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
tokenizer.padding_side = "left"
if eos_token:
eos_token_id = tokenizer.convert_tokens_to_ids(eos_token)
logging.info(f"eos_token {eos_token} from user input")
elif hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id:
logging.info(f"Initial eos_token_id {tokenizer.eos_token_id} from tokenizer")
eos_token_id = tokenizer.eos_token_id
eos_token = tokenizer.convert_ids_to_tokens(eos_token_id)
else:
raise ValueError("No available eos_token or eos_token_id.")
try:
tokenizer.eos_token = eos_token
tokenizer.eos_token_id = eos_token_id
tokenizer.pad_token = eos_token
tokenizer.pad_token_id = eos_token_id
except:
logging.info(f"[WARNING] Cannot set tokenizer.eos_token")
logging.info(f"tokenizer's eos_token: {tokenizer.eos_token}, pad_token: {tokenizer.pad_token}")
logging.info(f"tokenizer's eos_token_id: {tokenizer.eos_token_id}, pad_token_id: {tokenizer.pad_token_id}")
num_gpus = torch.cuda.device_count()
llm = LLM(
model=model_path,
tensor_parallel_size=num_gpus,
enable_chunked_prefill=config["inference"]["enable_chunked_prefill"],
gpu_memory_utilization=config["inference"]["gpu_memory_utilization"],
trust_remote_code=config["inference"]["trust_remote_code"],
dtype=torch.bfloat16,
enforce_eager=config["inference"]["enforce_eager"],
max_model_len=config["inference"]["max_model_len"],
)
logging.info("vLLM model loaded successfully")
return tokenizer, llm
def generate_teacher_student_response_api(data_list, config):
client = OpenAI(
api_key=config["inference"]["api_key"],
base_url=config["inference"]["base_url"]
)
models = client.models.list()
model = models.data[0].id
logging.info(model)
system_prompt = config["inference"]["system_prompt"]
stream = config["inference"]["stream"]
# load student model
student_tokenizer = AutoTokenizer.from_pretrained(
config["models"]["student"],
trust_remote_code=True
)
student_model = AutoModelForCausalLM.from_pretrained(
config["models"]["student"],
device_map="auto",
trust_remote_code=True
)
outcomes = []
for sample in tqdm(data_list, desc="Call remote model and generating responses"):
# for teacher model
if system_prompt == "":
message=[
{'role': 'user', 'content': sample}
]
else:
message=[
{'role': 'system', 'content': system_prompt},
{'role': 'user', 'content': sample}
]
completion = client.chat.completions.create(
messages=message,
model=model,
max_completion_tokens=config["inference"]["max_new_tokens"],
stream=stream,
)
if stream:
result = ""
for chunk in completion:
result += chunk.choices[0].delta.content
else:
result = completion.choices[0].message.content
# for student model
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": sample}
]
text = student_tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = student_tokenizer([text], return_tensors="pt").to(student_model.device)
generated_ids = student_model.generate(
**model_inputs,
max_new_tokens=config["inference"]["max_new_tokens"]
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
rejected = student_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
gen_data = {'prompt': sample, 'chosen': result, 'rejected': rejected}
outcomes.append(gen_data)
write_data_to_json_file(outcomes, config["dataset"]["labeled_path"])
def generate_model_response_batch(tokenizer, llm, data_list, config, batch_size=32, is_teacher_model=True):
full_path = config["dataset"]["template"]
template_dir = os.path.dirname(full_path)
template_file = os.path.basename(full_path)
env = Environment(loader=FileSystemLoader(template_dir))
template = env.get_template(template_file)
outcomes = []
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
for batch in tqdm(batches, desc="Generating responses"):
new_batch = []
for sample in batch:
message={"role": "user", "content": sample}
full_text = template.render(
message=message,
add_generation_prompt=True,
add_output=False
)
new_batch.append(full_text)
model_outputs = llm.generate(
new_batch,
SamplingParams(
n=1,
top_k=1,
temperature=config["inference"]["temperature"],
seed=config["inference"]["seed"],
skip_special_tokens=False,
ignore_eos=False,
max_tokens=config["inference"]["max_new_tokens"]
)
)
model_responses = [output.outputs[0].text for output in model_outputs]
if is_teacher_model:
gen_data = [{'prompt': batch[i], 'chosen': model_responses[i]} for i in range(len(batch))]
else:
gen_data = [{'prompt': batch[i], 'rejected': model_responses[i]} for i in range(len(batch))]
outcomes = outcomes + gen_data
return outcomes
def merge_outcomes(teacher_outcomes, student_outcomes, config):
try:
student_dict = {item['prompt']: item['rejected'] for item in student_outcomes}
merged_outcomes = []
for teacher_item in teacher_outcomes:
prompt = teacher_item['prompt']
if prompt in student_dict:
merged_outcome = {
'prompt': prompt,
'chosen': teacher_item['chosen'],
'rejected': student_dict[prompt]
}
merged_outcomes.append(merged_outcome)
with open(config["dataset"]["labeled_path"], 'w') as file:
json.dump(merged_outcomes, file, ensure_ascii=False, indent=4)
except Exception as e:
print(f"An error occurred: {e}")
def infer_with_teacher_model(config):
logging.info('Generating distillation data from the teacher model!')
data_list = read_json_field(config["dataset"]["instruction_path"])
try:
job_type = config["job_type"]
if job_type == "rank_dpo_api":
generate_teacher_student_response_api(data_list, config)
elif job_type == "rank_dpo_local":
teacher_tokenizer, teacher_llm = load_tokenizer_and_vllm(config, is_teacher_model=True)
teacher_outcomes = generate_model_response_batch(teacher_tokenizer, teacher_llm, data_list, config, is_teacher_model=True)
del teacher_llm
student_tokenizer, student_llm = load_tokenizer_and_vllm(config, is_teacher_model=False)
student_outcomes = generate_model_response_batch(student_tokenizer, student_llm, data_list, config, is_teacher_model=False)
del student_llm
merge_outcomes(teacher_outcomes, student_outcomes, config)
else:
logging.error(f"Invalid job type: {job_type}")
raise ValueError(f"Invalid job type: {job_type}")
except ValueError as e:
logging.error(f"Training job terminated: {e}")
return
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
args = parser.parse_args()
config = json.load(open(args.config))
infer_with_teacher_model(config)
if __name__ == "__main__":
main()

105
easydistill/rank/train.py Normal file
View File

@@ -0,0 +1,105 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json
import argparse
import logging
import os
from jinja2 import Environment, BaseLoader, FileSystemLoader
from datasets import load_dataset, Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOTrainer, DPOConfig
import copy
def process_dataset(dataset_path, dataset_seed, env, template):
examples = []
with open(dataset_path, 'r') as file:
examples = json.load(file)
output_text = {
"prompt": [],
"chosen": [],
"rejected": []
}
# use chat template
for i in range(len(examples)):
try:
prompt_message = {"content": examples[i]["prompt"]}
prompt = template.render(message=prompt_message, add_generation_prompt=False, add_output=False)
chosen_message = {"content": examples[i]["prompt"], "output": examples[i]["chosen"]}
chosen = template.render(message=chosen_message, add_generation_prompt=False, add_output=True)
chosen = chosen[len(prompt):]
rejected_message = {"content": examples[i]["prompt"], "output": examples[i]["rejected"]}
rejected = template.render(message=rejected_message, add_generation_prompt=False, add_output=True)
rejected = rejected[len(prompt):]
output_text["prompt"].append(prompt)
output_text["chosen"].append(chosen)
output_text["rejected"].append(rejected)
except:
logging.warning(f"Error processing sample.")
dataset = Dataset.from_dict(output_text)
dataset = dataset.shuffle(seed=dataset_seed)
return dataset
def train(config):
dataset_path = config["dataset"]["labeled_path"]
dataset_seed = config["dataset"]["seed"]
full_path = config["dataset"]["template"]
template_dir = os.path.dirname(full_path)
template_file = os.path.basename(full_path)
env = Environment(loader=FileSystemLoader(template_dir))
template = env.get_template(template_file)
dataset = process_dataset(dataset_path, dataset_seed, env, template)
student_tokenizer = AutoTokenizer.from_pretrained(
config["models"]["student"],
trust_remote_code=True
)
student_model = AutoModelForCausalLM.from_pretrained(
config["models"]["student"],
trust_remote_code=True
)
training_arguments = DPOConfig(**config["training"])
trainer = DPOTrainer(
student_model,
ref_model=copy.deepcopy(student_model),
args=training_arguments,
train_dataset=dataset,
processing_class=student_tokenizer
)
trainer.train()
trainer.save_model(config["training"]["output_dir"])
student_tokenizer.save_pretrained(config["training"]["output_dir"])
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
args = parser.parse_args()
config = json.load(open(args.config))
train(config)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,111 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json
import argparse
import logging
import os
import random
from jinja2 import Environment, BaseLoader, FileSystemLoader
from datasets import load_dataset, Dataset
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
from trl import GRPOConfig, GRPOTrainer
def process_dataset(dataset_path, dataset_seed, env, template, train_ratio):
examples = []
try:
with open(dataset_path, 'r') as file:
examples = json.load(file)
except FileNotFoundError:
print(f"Error: The file '{dataset_path}' was not found.")
except json.JSONDecodeError:
print(f"Error: The file '{dataset_path}' is not a valid JSON file.")
except Exception as e:
print(f"An unexpected error occurred: {e}")
output_dataset = []
# use chat template
for i in range(len(examples)):
try:
message = {"content": examples[i]["prompt"]}
rendered = template.render(message=message, add_generation_prompt=True, add_output=False)
sample = {"prompt": rendered}
output_dataset.append(sample)
except:
logging.warning(f"Error processing sample.")
random.shuffle(output_dataset)
random.seed(dataset_seed)
split_index = int(len(output_dataset) * train_ratio)
train_list = output_dataset[:split_index]
eval_list = output_dataset[split_index:]
return Dataset.from_list(train_list), Dataset.from_list(eval_list)
def train(config):
dataset_path = config["dataset"]["instruction_path"]
dataset_seed = config["dataset"]["seed"]
train_ratio = config["dataset"]["train_ratio"]
full_path = config["dataset"]["template"]
template_dir = os.path.dirname(full_path)
template_file = os.path.basename(full_path)
env = Environment(loader=FileSystemLoader(template_dir))
template = env.get_template(template_file)
tokenizer = AutoTokenizer.from_pretrained(
config["models"]["student"],
trust_remote_code=True
)
train_dataset, eval_dataset = process_dataset(dataset_path, dataset_seed, env, template, train_ratio)
print(train_dataset)
print(eval_dataset)
reward_model_path = config["models"]["reward"]
sft_model_path = config["models"]["student"]
reward_model = AutoModelForSequenceClassification.from_pretrained(
reward_model_path, trust_remote_code=True, num_labels=1
)
sft_model = AutoModelForCausalLM.from_pretrained(
sft_model_path, trust_remote_code=True
)
training_arguments = GRPOConfig(**config["training"])
trainer = GRPOTrainer(
args=training_arguments,
processing_class=tokenizer,
model=sft_model,
reward_funcs=reward_model,
train_dataset=train_dataset,
eval_dataset=eval_dataset
)
trainer.train()
trainer.save_model(config["training"]["output_dir"])
tokenizer.save_pretrained(config["training"]["output_dir"])
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
args = parser.parse_args()
config = json.load(open(args.config))
train(config)
if __name__ == "__main__":
main()

122
easydistill/rl/ppo_train.py Normal file
View File

@@ -0,0 +1,122 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json
import argparse
import logging
import os
import random
from jinja2 import Environment, BaseLoader, FileSystemLoader
from datasets import load_dataset, Dataset
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
from trl import PPOConfig, PPOTrainer
def process_dataset(dataset_path, dataset_seed, env, template, tokenizer, train_ratio):
examples = []
try:
with open(dataset_path, 'r') as file:
examples = json.load(file)
except FileNotFoundError:
print(f"Error: The file '{dataset_path}' was not found.")
except json.JSONDecodeError:
print(f"Error: The file '{dataset_path}' is not a valid JSON file.")
except Exception as e:
print(f"An unexpected error occurred: {e}")
output_dataset = []
# use chat template
for i in range(len(examples)):
try:
message = {"content": examples[i]["instruction"]}
rendered = template.render(message=message, add_generation_prompt=True, add_output=False)
tokens = tokenizer.encode(rendered)
sample = {"input_ids": tokens}
output_dataset.append(sample)
except:
logging.warning(f"Error processing sample.")
random.shuffle(output_dataset)
random.seed(dataset_seed)
split_index = int(len(output_dataset) * train_ratio)
train_list = output_dataset[:split_index]
eval_list = output_dataset[split_index:]
return Dataset.from_list(train_list), Dataset.from_list(eval_list)
def train(config):
dataset_path = config["dataset"]["instruction_path"]
dataset_seed = config["dataset"]["seed"]
train_ratio = config["dataset"]["train_ratio"]
full_path = config["dataset"]["template"]
template_dir = os.path.dirname(full_path)
template_file = os.path.basename(full_path)
env = Environment(loader=FileSystemLoader(template_dir))
template = env.get_template(template_file)
tokenizer = AutoTokenizer.from_pretrained(
config["models"]["student"],
trust_remote_code=True
)
train_dataset, eval_dataset = process_dataset(dataset_path, dataset_seed, env, template, tokenizer, train_ratio)
assert train_dataset[0]["input_ids"][-1] != tokenizer.eos_token_id, "The last token should not be an EOS token"
print(train_dataset)
print(eval_dataset)
reward_model_path = config["models"]["reward"]
sft_model_path = config["models"]["student"]
value_model = AutoModelForSequenceClassification.from_pretrained(
reward_model_path, trust_remote_code=True, num_labels=1
)
reward_model = AutoModelForSequenceClassification.from_pretrained(
reward_model_path, trust_remote_code=True, num_labels=1
)
ref_policy = AutoModelForCausalLM.from_pretrained(
sft_model_path, trust_remote_code=True
)
policy = AutoModelForCausalLM.from_pretrained(
sft_model_path, trust_remote_code=True
)
training_arguments = PPOConfig(**config["training"])
trainer = PPOTrainer(
config=training_arguments,
processing_class=tokenizer,
policy=policy,
ref_policy=ref_policy,
reward_model=reward_model,
value_model=value_model,
train_dataset=train_dataset,
eval_dataset=eval_dataset
)
trainer.train()
trainer.save_model(config["training"]["output_dir"])
tokenizer.save_pretrained(config["training"]["output_dir"])
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
args = parser.parse_args()
config = json.load(open(args.config))
train(config)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,258 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json
import argparse
import torch
import logging
import os
from jinja2 import Environment, FileSystemLoader
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from tqdm import tqdm
from openai import OpenAI
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def read_json_field(filename, field_name='prompt'):
try:
with open(filename, 'r') as file:
data = json.load(file)
output_fields = []
for item in data:
if field_name in item:
output_fields.append(item[field_name])
return output_fields
except FileNotFoundError:
logging.error("The file was not found.")
except json.JSONDecodeError:
logging.error("There was an error decoding the JSON file.")
except Exception as e:
logging.error(f"An error occurred: {e}")
def write_data_to_json_file(data, file_path):
try:
with open(file_path, 'w') as file:
json.dump(data, file, ensure_ascii=False, indent=4)
logging.info(f"Data successfully written to {file_path}")
except Exception as e:
logging.error(f"An error occurred: {e}")
def load_tokenizer_and_vllm(config, eos_token=None):
teacher_model_path = config["models"]["teacher"]
logging.info(f"Loading ckpt and tokenizer: {teacher_model_path}")
tokenizer = AutoTokenizer.from_pretrained(teacher_model_path, trust_remote_code=True)
tokenizer.padding_side = "left"
if eos_token:
eos_token_id = tokenizer.convert_tokens_to_ids(eos_token)
logging.info(f"eos_token {eos_token} from user input")
elif hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id:
logging.info(f"Initial eos_token_id {tokenizer.eos_token_id} from tokenizer")
eos_token_id = tokenizer.eos_token_id
eos_token = tokenizer.convert_ids_to_tokens(eos_token_id)
else:
raise ValueError("No available eos_token or eos_token_id.")
try:
tokenizer.eos_token = eos_token
tokenizer.eos_token_id = eos_token_id
tokenizer.pad_token = eos_token
tokenizer.pad_token_id = eos_token_id
except:
logging.info(f"[WARNING] Cannot set tokenizer.eos_token")
logging.info(f"tokenizer's eos_token: {tokenizer.eos_token}, pad_token: {tokenizer.pad_token}")
logging.info(f"tokenizer's eos_token_id: {tokenizer.eos_token_id}, pad_token_id: {tokenizer.pad_token_id}")
num_gpus = torch.cuda.device_count()
llm = LLM(
model=teacher_model_path,
tensor_parallel_size=num_gpus,
enable_chunked_prefill=config["inference"]["enable_chunked_prefill"],
gpu_memory_utilization=config["inference"]["gpu_memory_utilization"],
trust_remote_code=config["inference"]["trust_remote_code"],
dtype=torch.bfloat16,
enforce_eager=config["inference"]["enforce_eager"],
max_model_len=config["inference"]["max_model_len"],
)
logging.info("vLLM model loaded successfully")
return tokenizer, llm
def generate_teacher_response_for_reward_model_local(tokenizer, llm, data_list, config, batch_size=32):
full_path = config["dataset"]["template"]
template_dir = os.path.dirname(full_path)
template_file = os.path.basename(full_path)
env = Environment(loader=FileSystemLoader(template_dir))
template = env.get_template(template_file)
positive_system_prompt = config["inference"]["positive_system_prompt"]
negative_system_prompt = config["inference"]["negative_system_prompt"]
outcomes = []
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
for batch in tqdm(batches, desc="Generating responses"):
positive_new_batch = []
negative_new_batch = []
for sample in batch:
positive_message = [
{'role': 'system', 'content': positive_system_prompt},
{'role': 'user', 'content': sample}
]
positive_full_text = template.render(
message = positive_message,
add_generation_prompt = True,
add_output = False
)
positive_new_batch.append(positive_full_text)
negative_message = [
{'role': 'system', 'content': negative_system_prompt},
{'role': 'user', 'content': sample}
]
negative_full_text = template.render(
message = negative_message,
add_generation_prompt = True,
add_output = False
)
negative_new_batch.append(negative_full_text)
positive_outputs = llm.generate(
positive_new_batch,
SamplingParams(
n = 1,
top_k = 1,
temperature = config["inference"]["temperature"],
seed = config["inference"]["seed"],
skip_special_tokens = False,
ignore_eos = False,
max_tokens = config["inference"]["max_new_tokens"]
)
)
positve_responses = [output.outputs[0].text for output in positive_outputs]
positive_gen_data = [{'prompt': batch[i], 'chosen': positve_responses[i]} for i in range(len(batch))]
negative_outputs = llm.generate(
negative_new_batch,
SamplingParams(
n = 1,
top_k = 1,
temperature = config["inference"]["temperature"],
seed = config["inference"]["seed"],
skip_special_tokens = False,
ignore_eos = False,
max_tokens = config["inference"]["max_new_tokens"]
)
)
negative_responses = [output.outputs[0].text for output in negative_outputs]
negative_gen_data = [{'prompt': batch[i], 'rejected': negative_responses[i]} for i in range(len(batch))]
merged_data = merge_outcomes(positive_gen_data, negative_gen_data)
outcomes = outcomes + merged_data
write_data_to_json_file(outcomes, config["dataset"]["labeled_path"])
def merge_outcomes(positive_gen_data, negative_gen_data):
negative_dict = {item['prompt']: item['rejected'] for item in negative_gen_data}
merged_outcomes = []
for positive_item in positive_gen_data:
prompt = positive_item['prompt']
if prompt in negative_dict:
merged_outcome = {
'prompt': prompt,
'chosen': positive_item['chosen'],
'rejected': negative_dict[prompt]
}
merged_outcomes.append(merged_outcome)
return merged_outcomes
def generate_teacher_response_for_reward_model_api(data_list, config):
client = OpenAI(
api_key = config["inference"]["api_key"],
base_url = config["inference"]["base_url"]
)
models = client.models.list()
model = models.data[0].id
logging.info(model)
positive_system_prompt = config["inference"]["positive_system_prompt"]
negative_system_prompt = config["inference"]["negative_system_prompt"]
stream = config["inference"]["stream"]
outcomes = []
for sample in tqdm(data_list, desc="Call remote model and generating responses"):
positive_message = [
{'role': 'system', 'content': positive_system_prompt},
{'role': 'user', 'content': sample}
]
positive_completion = client.chat.completions.create(
messages = positive_message,
model = model,
max_completion_tokens = config["inference"]["max_new_tokens"],
stream = stream
)
if stream:
positive_result = ""
for chunk in positive_completion:
positive_result += chunk.choices[0].delta.content
else:
positive_result = positive_completion.choices[0].message.content
negative_message = [
{'role': 'system', 'content': negative_system_prompt},
{'role': 'user', 'content': sample}
]
negative_completion = client.chat.completions.create(
messages = negative_message,
model = model,
max_completion_tokens = config["inference"]["max_new_tokens"],
stream = stream
)
if stream:
negative_result = ""
for chunk in negative_completion:
negative_result += chunk.choices[0].delta.content
else:
negative_result = negative_completion.choices[0].message.content
outcomes.append({'prompt': sample, 'chosen': positive_result, 'rejected': negative_result})
write_data_to_json_file(outcomes, config["dataset"]["labeled_path"])
def infer_with_teacher_model(config):
logging.info('Generating distillation data from the teacher model!')
data_list = read_json_field(config["dataset"]["instruction_path"])
try:
job_type = config["job_type"]
if job_type == "rl_reward_api":
generate_teacher_response_for_reward_model_api(data_list, config)
elif job_type == "rl_reward_local":
tokenizer, llm = load_tokenizer_and_vllm(config)
generate_teacher_response_for_reward_model_local(tokenizer, llm, data_list, config)
else:
logging.error(f"Invalid job type: {job_type}")
raise ValueError(f"Invalid job type: {job_type}")
except ValueError as e:
logging.error(f"Training job terminated: {e}")
return
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
args = parser.parse_args()
config = json.load(open(args.config))
infer_with_teacher_model(config)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,107 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json
import argparse
import logging
import os
from jinja2 import Environment, FileSystemLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from trl import RewardTrainer, RewardConfig
from datasets import Dataset
def process_dataset(dataset_path, tokenizer, config, template):
kwargs = {"padding": "max_length", "truncation": True, "max_length": config["training"]["max_length"], "return_tensors": "pt"}
examples = []
try:
with open(dataset_path, 'r') as file:
examples = json.load(file)
except FileNotFoundError:
print(f"Error: The file '{dataset_path}' was not found.")
except json.JSONDecodeError:
print(f"Error: The file '{dataset_path}' is not a valid JSON file.")
except Exception as e:
print(f"An unexpected error occurred: {e}")
print(examples)
output_dataset = []
# use chat template
for i in range(len(examples)):
try:
chosen_message = {"content": examples[i]["prompt"], "output": examples[i]["chosen"]}
prompt_plus_chosen_response = template.render(message=chosen_message, add_generation_prompt=False, add_output=True)
rejected_message = {"content": examples[i]["prompt"], "output": examples[i]["rejected"]}
prompt_plus_rejected_response = template.render(message=rejected_message, add_generation_prompt=False, add_output=True)
tokens_chosen = tokenizer.encode_plus(prompt_plus_chosen_response, **kwargs)
tokens_rejected = tokenizer.encode_plus(prompt_plus_rejected_response, **kwargs)
sample = {
"input_ids_chosen": tokens_chosen["input_ids"][0], "attention_mask_chosen": tokens_chosen["attention_mask"][0],
"input_ids_rejected": tokens_rejected["input_ids"][0], "attention_mask_rejected": tokens_rejected["attention_mask"][0]
}
output_dataset.append(sample)
except:
logging.warning(f"Error processing sample.")
dataset = Dataset.from_list(output_dataset)
return dataset
def train(config):
dataset_path = config["dataset"]["labeled_path"]
student_tokenizer = AutoTokenizer.from_pretrained(
config["models"]["student"],
trust_remote_code=True
)
full_path = config["dataset"]["template"]
template_dir = os.path.dirname(full_path)
template_file = os.path.basename(full_path)
env = Environment(loader=FileSystemLoader(template_dir))
template = env.get_template(template_file)
dataset = process_dataset(dataset_path, student_tokenizer, config, template)
student_model = AutoModelForSequenceClassification.from_pretrained(
config["models"]["student"],
num_labels=1,
trust_remote_code=True
)
student_model.config.pad_token_id = student_tokenizer.pad_token_id
training_arguments = RewardConfig(**config["training"])
trainer = RewardTrainer(
model=student_model,
processing_class=student_tokenizer,
args=training_arguments,
train_dataset=dataset
)
trainer.train()
trainer.save_model(config["training"]["output_dir"])
student_tokenizer.save_pretrained(config["training"]["output_dir"])
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
args = parser.parse_args()
config = json.load(open(args.config))
train(config)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,274 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import jsonlines
import logging
import os
from jinja2 import Environment, FileSystemLoader
from vllm import LLM, SamplingParams
from tqdm import tqdm
from openai import OpenAI
from utils import write_data_to_json_file
# I have checked this function.
def cot_generate_api(data_list, config):
client = OpenAI(
api_key = config["inference"]["api_key"],
base_url = config["inference"]["base_url"]
)
models = client.models.list()
model = models.data[0].id
prompt = config["inference"]["prompt"]
stream = config["inference"]["stream"]
logging.info(model)
outcomes = []
for sample in tqdm(data_list, desc="Calling remote model and generating responses"):
sample = prompt + "\n" + sample
message = [
{"role": "user", "content": sample}
]
completion = client.chat.completions.create(
messages = message,
model = model,
max_completion_tokens = config["inference"]["max_new_tokens"],
stream = stream
)
if stream:
result = ""
for chunk in completion:
result += chunk.choices[0].delta.content
else:
result = completion.choices[0].message.content
if result is not None:
outcomes.append({"instruction": sample, "output": result})
write_data_to_json_file(outcomes, config["dataset"]["output_path"])
def cot_generate_batch(tokenizer, llm, data_list, config, batch_size=32):
full_path = config["dataset"]["template"]
template_dir = os.path.dirname(full_path)
template_file = os.path.basename(full_path)
env = Environment(loader=FileSystemLoader(template_dir))
template = env.get_template(template_file)
prompt = config["inference"]["prompt"]
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
for batch in tqdm(batches, desc="Generating responses"):
new_batch = []
for sample in batch:
sample = prompt + "\n" + sample
logging.info(sample)
message={"role": "user", "content": sample}
full_text = template.render(
message=message,
add_generation_prompt=True,
add_output=False
)
new_batch.append(full_text)
outputs = llm.generate(
new_batch,
SamplingParams(
n=1,
top_k=1,
temperature=config["inference"]["temperature"],
seed=config["inference"]["seed"],
skip_special_tokens=False,
ignore_eos=False,
max_tokens=config["inference"]["max_new_tokens"],
)
)
responses = [output.outputs[0].text for output in outputs]
outcomes = []
for i in range(len(batch)):
if responses[i] is not None:
outcomes.append((sample,responses[i]))
with jsonlines.open(config["dataset"]["output_path"], mode='a') as writer:
for ins,result in outcomes:
gen_data = {"instruction": ins, "output": result}
writer.write(gen_data)
def cot_long2short_api(data_list_ins, data_list_out, config):
client = OpenAI(
api_key = config["inference"]["api_key"],
base_url = config["inference"]["base_url"],
)
models = client.models.list()
model = models.data[0].id
prompt = config["inference"]["prompt"]
stream = config["inference"]["stream"]
logging.info(model)
outcomes = []
data_list=[(ins,out) for ins,out in zip(data_list_ins,data_list_out)]
for ins,out in tqdm(data_list, desc="Calling remote model and generating responses"):
sample = f"{prompt} Simplify the reasoning process for the problem below.\n\nProblem:\n{ins}\n\nAnswer:\n{out}\n\nSimplified Reasoning Process:"
logging.info(sample)
message = [
{"role": "user", "content": sample}
]
completion = client.chat.completions.create(
messages = message,
model = model,
max_completion_tokens = config["inference"]["max_new_tokens"],
stream = stream,
)
if stream:
result = ""
for chunk in completion:
result += chunk.choices[0].delta.content
else:
result = completion.choices[0].message.content
if result is not None:
outcomes.append((sample,result))
with jsonlines.open(config["dataset"]["output_path"], mode='a') as writer:
for ins,result in outcomes:
gen_data = {"instruction": ins, "output": result}
writer.write(gen_data)
def cot_long2short_batch(tokenizer, llm, data_list_ins, data_list_out, config, batch_size=32):
full_path = config["dataset"]["template"]
template_dir = os.path.dirname(full_path)
template_file = os.path.basename(full_path)
env = Environment(loader=FileSystemLoader(template_dir))
template = env.get_template(template_file)
prompt = config["inference"]["prompt"]
data_list=[(ins,out) for ins,out in zip(data_list_ins,data_list_out)]
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
for batch in tqdm(batches, desc="Generating responses"):
new_batch = []
for ins,out in batch:
sample = f"{prompt} Simplify the reasoning process for the problem below.\n\nProblem:\n{ins}\n\nAnswer:\n{out}\n\nSimplified Reasoning Process:"
logging.info(sample)
message={"role": "user", "content": sample}
full_text = template.render(
message=message,
add_generation_prompt=True,
add_output=False
)
new_batch.append(full_text)
outputs = llm.generate(
new_batch,
SamplingParams(
n=1,
top_k=1,
temperature=config["inference"]["temperature"],
seed=config["inference"]["seed"],
skip_special_tokens=False,
ignore_eos=False,
max_tokens=config["inference"]["max_new_tokens"],
)
)
responses = [output.outputs[0].text for output in outputs]
outcomes = []
for i in range(len(batch)):
if responses[i] is not None:
outcomes.append((sample,responses[i]))
with jsonlines.open(config["dataset"]["output_path"], mode='a') as writer:
for ins,result in outcomes:
gen_data = {"instruction": ins, "output": result}
writer.write(gen_data)
def cot_short2long_api(data_list_ins, data_list_out, config):
client = OpenAI(
api_key = config["inference"]["api_key"],
base_url = config["inference"]["base_url"],
)
models = client.models.list()
model = models.data[0].id
prompt = config["inference"]["prompt"]
stream = config["inference"]["stream"]
logging.info(model)
outcomes = []
data_list=[(ins,out) for ins,out in zip(data_list_ins,data_list_out)]
for ins,out in tqdm(data_list, desc="Calling remote model and generating responses"):
sample = f"{prompt} Extend the reasoning process for the problem below.\n\nProblem:\n{ins}\n\nAnswer:\n{out}\n\nExtended Reasoning Process:"
logging.info(sample)
message = [
{"role": "user", "content": sample}
]
completion = client.chat.completions.create(
messages = message,
model = model,
max_completion_tokens = config["inference"]["max_new_tokens"],
stream = stream,
)
if stream:
result = ""
for chunk in completion:
result += chunk.choices[0].delta.content
else:
result = completion.choices[0].message.content
if result is not None:
outcomes.append((sample,result))
with jsonlines.open(config["dataset"]["output_path"], mode='a') as writer:
for ins,result in outcomes:
gen_data = {"instruction": ins, "output": result}
writer.write(gen_data)
def cot_short2long_batch(tokenizer, llm, data_list_ins, data_list_out, config, batch_size=32):
full_path = config["dataset"]["template"]
template_dir = os.path.dirname(full_path)
template_file = os.path.basename(full_path)
env = Environment(loader=FileSystemLoader(template_dir))
template = env.get_template(template_file)
prompt = config["inference"]["prompt"]
data_list=[(ins,out) for ins,out in zip(data_list_ins,data_list_out)]
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
for batch in tqdm(batches, desc="Generating responses"):
new_batch = []
for ins,out in batch:
sample = f"{prompt} Extend the reasoning process for the problem below.\n\nProblem:\n{ins}\n\nAnswer:\n{out}\n\nExtended Reasoning Process:"
logging.info(sample)
message={"role": "user", "content": sample}
full_text = template.render(
message=message,
add_generation_prompt=True,
add_output=False
)
new_batch.append(full_text)
outputs = llm.generate(
new_batch,
SamplingParams(
n=1,
top_k=1,
temperature=config["inference"]["temperature"],
seed=config["inference"]["seed"],
skip_special_tokens=False,
ignore_eos=False,
max_tokens=config["inference"]["max_new_tokens"],
)
)
responses = [output.outputs[0].text for output in outputs]
outcomes = []
for i in range(len(batch)):
if responses[i] is not None:
outcomes.append((sample,responses[i]))
with jsonlines.open(config["dataset"]["output_path"], mode='a') as writer:
for ins,result in outcomes:
gen_data = {"instruction": ins, "output": result}
writer.write(gen_data)

View File

@@ -0,0 +1,293 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import logging
import os
from jinja2 import Environment, FileSystemLoader
from vllm import LLM, SamplingParams
from tqdm import tqdm
from openai import OpenAI
import random
import re
from utils import read_json_field, write_data_to_json_file, load_tokenizer_and_vllm
def extract_answer(content):
pattern = r'<answer>(.*?)</answer>'
match = re.search(pattern, content, re.DOTALL)
if match:
return match.group(1)
else:
return None
def extract_instruction_response(content):
instruction_pattern = r'<instruction>(.*?)</instruction>'
instruction_match = re.search(instruction_pattern, content, re.DOTALL)
response_pattern = r'<response>(.*?)</response>'
response_match = re.search(response_pattern, content, re.DOTALL)
if instruction_match and response_match:
return instruction_match.group(1), response_match.group(1)
else:
return None, None
def generate_prompt_list(data_list, prompt, num_in_context_samples, num_output_samples):
if num_in_context_samples > len(data_list):
raise ValueError("num_in_context_samples cannot be larger than the length of data_list")
output_list = []
for _ in range(num_output_samples):
selected_samples = random.sample(data_list, num_in_context_samples)
combined_prompts = prompt + "\n" + "".join([sample + "\n" for sample in selected_samples])
output_list.append(combined_prompts)
return output_list
def expand_instruction_api(data_list, config):
client = OpenAI(
api_key = config["inference"]["api_key"],
base_url = config["inference"]["base_url"],
)
models = client.models.list()
model = models.data[0].id
num_output_samples = config["dataset"]["num_output_samples"]
num_in_context_samples = config["dataset"]["num_in_context_samples"]
prompt = config["inference"]["prompt"]
stream = config["inference"]["stream"]
logging.info(model)
prompt_list = generate_prompt_list(data_list, prompt, num_in_context_samples, num_output_samples)
outcomes = []
for sample in tqdm(prompt_list, desc="Calling remote model and generating responses"):
logging.info(sample)
message = [
{"role": "user", "content": sample}
]
completion = client.chat.completions.create(
messages = message,
model = model,
max_completion_tokens = config["inference"]["max_new_tokens"],
stream = stream,
)
if stream:
result = ""
for chunk in completion:
result += chunk.choices[0].delta.content
else:
result = completion.choices[0].message.content
result = extract_answer(result)
if result is not None:
outcomes.append({"instruction": result})
write_data_to_json_file(outcomes, config["dataset"]["output_path"])
def expand_instruction_batch(tokenizer, llm, data_list, config, batch_size=32):
full_path = config["dataset"]["template"]
template_dir = os.path.dirname(full_path)
template_file = os.path.basename(full_path)
env = Environment(loader=FileSystemLoader(template_dir))
template = env.get_template(template_file)
num_output_samples = config["dataset"]["num_output_samples"]
num_in_context_samples = config["dataset"]["num_in_context_samples"]
prompt = config["inference"]["prompt"]
prompt_list = generate_prompt_list(data_list, prompt, num_in_context_samples, num_output_samples)
outcomes = []
batches = [prompt_list[i:i + batch_size] for i in range(0, len(prompt_list), batch_size)]
for batch in tqdm(batches, desc="Generating responses"):
new_batch = []
for sample in batch:
logging.info(sample)
message={"role": "user", "content": sample}
full_text = template.render(
message=message,
add_generation_prompt=True,
add_output=False
)
new_batch.append(full_text)
outputs = llm.generate(
new_batch,
SamplingParams(
n=1,
top_k=1,
temperature=config["inference"]["temperature"],
seed=config["inference"]["seed"],
skip_special_tokens=False,
ignore_eos=False,
max_tokens=config["inference"]["max_new_tokens"]
)
)
responses = [output.outputs[0].text for output in outputs]
for i in range(len(batch)):
result = extract_answer(responses[i])
if result is not None:
outcomes.append({"instruction": result})
write_data_to_json_file(outcomes, config["dataset"]["output_path"])
def refine_instruction_api(data_list, config):
client = OpenAI(
api_key = config["inference"]["api_key"],
base_url = config["inference"]["base_url"],
)
models = client.models.list()
model = models.data[0].id
prompt = config["inference"]["prompt"]
stream = config["inference"]["stream"]
logging.info(model)
outcomes = []
for sample in tqdm(data_list, desc="Calling remote model and generating responses"):
sample = prompt + "\n" + sample
logging.info(sample)
message = [
{"role": "user", "content": sample}
]
completion = client.chat.completions.create(
messages = message,
model = model,
max_completion_tokens = config["inference"]["max_new_tokens"],
stream = stream
)
if stream:
result = ""
for chunk in completion:
result += chunk.choices[0].delta.content
else:
result = completion.choices[0].message.content
result = extract_answer(result)
if result is not None:
outcomes.append({"instruction": result})
write_data_to_json_file(outcomes, config["dataset"]["output_path"])
def refine_instruction_batch(tokenizer, llm, data_list, config, batch_size=32):
full_path = config["dataset"]["template"]
template_dir = os.path.dirname(full_path)
template_file = os.path.basename(full_path)
env = Environment(loader=FileSystemLoader(template_dir))
template = env.get_template(template_file)
prompt = config["inference"]["prompt"]
outcomes = []
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
for batch in tqdm(batches, desc="Generating responses"):
new_batch = []
for sample in batch:
sample = prompt + "\n" + sample
logging.info(sample)
message={"role": "user", "content": sample}
full_text = template.render(
message=message,
add_generation_prompt=True,
add_output=False
)
new_batch.append(full_text)
outputs = llm.generate(
new_batch,
SamplingParams(
n=1,
top_k=1,
temperature=config["inference"]["temperature"],
seed=config["inference"]["seed"],
skip_special_tokens=False,
ignore_eos=False,
max_tokens=config["inference"]["max_new_tokens"],
)
)
responses = [output.outputs[0].text for output in outputs]
for i in range(len(batch)):
result = extract_answer(responses[i])
if result is not None:
outcomes.append({"instruction": result})
write_data_to_json_file(outcomes, config["dataset"]["output_path"])
def instruction_response_extraction_api(data_list, config):
client = OpenAI(
api_key = config["inference"]["api_key"],
base_url = config["inference"]["base_url"],
)
models = client.models.list()
model = models.data[0].id
prompt = config["inference"]["prompt"]
stream = config["inference"]["stream"]
logging.info(model)
outcomes = []
for sample in tqdm(data_list, desc="Calling remote model and generating responses"):
sample = prompt + "\n" + sample
logging.info(sample)
message = [
{"role": "user", "content": sample}
]
completion = client.chat.completions.create(
messages = message,
model = model,
max_completion_tokens = config["inference"]["max_new_tokens"],
stream= stream,
)
if stream:
result = ""
for chunk in completion:
result += chunk.choices[0].delta.content
else:
result = completion.choices[0].message.content
new_instruction, new_response = extract_instruction_response(result)
if new_instruction is not None and new_response is not None:
outcomes.append({"instruction": new_instruction, "output": new_response})
write_data_to_json_file(outcomes, config["dataset"]["output_path"])
def instruction_response_extraction_batch(tokenizer, llm, data_list, config, batch_size=32):
full_path = config["dataset"]["template"]
template_dir = os.path.dirname(full_path)
template_file = os.path.basename(full_path)
env = Environment(loader=FileSystemLoader(template_dir))
template = env.get_template(template_file)
prompt = config["inference"]["prompt"]
outcomes = []
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
for batch in tqdm(batches, desc="Generating responses"):
new_batch = []
for sample in batch:
logging.info(sample)
sample = prompt + "\n" + sample
message={"role": "user", "content": sample}
full_text = template.render(
message=message,
add_generation_prompt=True,
add_output=False
)
new_batch.append(full_text)
outputs = llm.generate(
new_batch,
SamplingParams(
n=1,
top_k=1,
temperature=config["inference"]["temperature"],
seed=config["inference"]["seed"],
skip_special_tokens=False,
ignore_eos=False,
max_tokens=config["inference"]["max_new_tokens"],
)
)
responses = [output.outputs[0].text for output in outputs]
for i in range(len(batch)):
new_instruction, new_response = extract_instruction_response(responses[i])
if new_instruction is not None and new_response is not None:
outcomes.append({"instruction": new_instruction, "output": new_response})
write_data_to_json_file(outcomes, config["dataset"]["output_path"])

View File

@@ -0,0 +1,107 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import argparse
import logging
import json
from instruct_synthesis import (
expand_instruction_api,
expand_instruction_batch,
refine_instruction_api,
refine_instruction_batch,
instruction_response_extraction_api,
instruction_response_extraction_batch
)
from cot_synthesis import (
cot_generate_api,
cot_generate_batch,
cot_long2short_api,
cot_long2short_batch,
cot_short2long_api,
cot_short2long_batch
)
from utils import read_json_field, load_tokenizer_and_vllm
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def data_synthesis_with_teacher_model(config):
logging.info('Generating distillation data from the teacher model!')
job_type = config["job_type"]
if job_type == "instruction_response_extraction_api":
data_list = read_json_field(config["dataset"]["input_path"], field_name="data")
elif job_type in ["cot_long2short_api","cot_long2short_batch","cot_short2long_api","cot_short2long_batch"]:
data_list_ins = read_json_field(config["dataset"]["input_path"])
data_list_out = read_json_field(config["dataset"]["input_path"], field_name="output")
else:
data_list = read_json_field(config["dataset"]["input_path"])
try:
if job_type == "instruction_expansion_api":
expand_instruction_api(data_list, config)
elif job_type == "instruction_expansion_batch":
tokenizer, llm = load_tokenizer_and_vllm(config)
expand_instruction_batch(tokenizer, llm, data_list, config)
elif job_type == "instruction_refinement_api":
refine_instruction_api(data_list, config)
elif job_type == "instruction_refinement_batch":
tokenizer, llm = load_tokenizer_and_vllm(config)
refine_instruction_batch(tokenizer, llm, data_list, config)
elif job_type == "instruction_response_extraction_api":
instruction_response_extraction_api(data_list, config)
elif job_type == "instruction_response_extraction_batch":
tokenizer, llm = load_tokenizer_and_vllm(config)
instruction_response_extraction_batch(tokenizer, llm, data_list, config)
elif job_type == "cot_generation_api":
cot_generate_api(data_list, config)
elif job_type == "cot_generation_batch":
tokenizer, llm = load_tokenizer_and_vllm(config)
cot_generate_batch(tokenizer, llm, data_list, config)
elif job_type == "cot_long2short_api":
cot_long2short_api(data_list_ins, data_list_out, config)
elif job_type == "cot_long2short_batch":
tokenizer, llm = load_tokenizer_and_vllm(config)
cot_long2short_batch(tokenizer, llm, data_list_ins, data_list_out, config)
elif job_type == "cot_short2long_api":
cot_short2long_api(data_list_ins, data_list_out, config)
elif job_type == "cot_short2long_batch":
tokenizer, llm = load_tokenizer_and_vllm(config)
cot_short2long_batch(tokenizer, llm, data_list_ins, data_list_out, config)
else:
logging.error(f"Invalid job type: {job_type}")
raise ValueError(f"Invalid job type: {job_type}")
except ValueError as e:
logging.error(f"Training job terminated: {e}")
return
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
args = parser.parse_args()
config = json.load(open(args.config))
data_synthesis_with_teacher_model(config)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,85 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json
import torch
import logging
from vllm import LLM
from transformers import AutoTokenizer
def read_json_field(filename, field_name='instruction'):
try:
with open(filename, 'r') as file:
data = json.load(file)
output_fields = []
for item in data:
if field_name in item:
output_fields.append(item[field_name])
return output_fields
except FileNotFoundError:
logging.error("The file was not found.")
except json.JSONDecodeError:
logging.error("There was an error decoding the JSON file.")
except Exception as e:
logging.error(f"An error occurred: {e}")
def write_data_to_json_file(data, file_path):
try:
with open(file_path, 'w') as file:
json.dump(data, file, ensure_ascii=False, indent=4)
logging.info(f"Data successfully written to {file_path}")
except Exception as e:
logging.error(f"An error occurred: {e}")
def load_tokenizer_and_vllm(config, eos_token=None):
teacher_model_path = config["models"]["teacher"]
logging.info(f"Loading ckpt and tokenizer: {teacher_model_path}")
tokenizer = AutoTokenizer.from_pretrained(teacher_model_path, trust_remote_code=True)
tokenizer.padding_side = "left"
if eos_token:
eos_token_id = tokenizer.convert_tokens_to_ids(eos_token)
logging.info(f"eos_token {eos_token} from user input")
elif hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id:
logging.info(f"Initial eos_token_id {tokenizer.eos_token_id} from tokenizer")
eos_token_id = tokenizer.eos_token_id
eos_token = tokenizer.convert_ids_to_tokens(eos_token_id)
else:
raise ValueError("No available eos_token or eos_token_id.")
try:
tokenizer.eos_token = eos_token
tokenizer.eos_token_id = eos_token_id
tokenizer.pad_token = eos_token
tokenizer.pad_token_id = eos_token_id
except:
logging.info(f"[WARNING] Cannot set tokenizer.eos_token")
logging.info(f"tokenizer's eos_token: {tokenizer.eos_token}, pad_token: {tokenizer.pad_token}")
logging.info(f"tokenizer's eos_token_id: {tokenizer.eos_token_id}, pad_token_id: {tokenizer.pad_token_id}")
num_gpus = torch.cuda.device_count()
llm = LLM(
model=teacher_model_path,
tensor_parallel_size=num_gpus,
enable_chunked_prefill=config["inference"]["enable_chunked_prefill"],
gpu_memory_utilization=config["inference"]["gpu_memory_utilization"],
trust_remote_code=config["inference"]["trust_remote_code"],
dtype=torch.bfloat16,
enforce_eager=config["inference"]["enforce_eager"],
max_model_len=config["inference"]["max_model_len"],
)
logging.info("vLLM model loaded successfully")
return tokenizer, llm