init commit
This commit is contained in:
274
easydistill/synthesis/cot_synthesis.py
Normal file
274
easydistill/synthesis/cot_synthesis.py
Normal file
@@ -0,0 +1,274 @@
|
||||
|
||||
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import jsonlines
|
||||
import logging
|
||||
import os
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from vllm import LLM, SamplingParams
|
||||
from tqdm import tqdm
|
||||
from openai import OpenAI
|
||||
|
||||
from utils import write_data_to_json_file
|
||||
|
||||
|
||||
# I have checked this function.
|
||||
def cot_generate_api(data_list, config):
|
||||
client = OpenAI(
|
||||
api_key = config["inference"]["api_key"],
|
||||
base_url = config["inference"]["base_url"]
|
||||
)
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
prompt = config["inference"]["prompt"]
|
||||
stream = config["inference"]["stream"]
|
||||
logging.info(model)
|
||||
outcomes = []
|
||||
for sample in tqdm(data_list, desc="Calling remote model and generating responses"):
|
||||
sample = prompt + "\n" + sample
|
||||
message = [
|
||||
{"role": "user", "content": sample}
|
||||
]
|
||||
completion = client.chat.completions.create(
|
||||
messages = message,
|
||||
model = model,
|
||||
max_completion_tokens = config["inference"]["max_new_tokens"],
|
||||
stream = stream
|
||||
)
|
||||
if stream:
|
||||
result = ""
|
||||
for chunk in completion:
|
||||
result += chunk.choices[0].delta.content
|
||||
else:
|
||||
result = completion.choices[0].message.content
|
||||
if result is not None:
|
||||
outcomes.append({"instruction": sample, "output": result})
|
||||
write_data_to_json_file(outcomes, config["dataset"]["output_path"])
|
||||
|
||||
|
||||
def cot_generate_batch(tokenizer, llm, data_list, config, batch_size=32):
|
||||
full_path = config["dataset"]["template"]
|
||||
template_dir = os.path.dirname(full_path)
|
||||
template_file = os.path.basename(full_path)
|
||||
env = Environment(loader=FileSystemLoader(template_dir))
|
||||
template = env.get_template(template_file)
|
||||
prompt = config["inference"]["prompt"]
|
||||
|
||||
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
|
||||
for batch in tqdm(batches, desc="Generating responses"):
|
||||
new_batch = []
|
||||
for sample in batch:
|
||||
sample = prompt + "\n" + sample
|
||||
logging.info(sample)
|
||||
message={"role": "user", "content": sample}
|
||||
full_text = template.render(
|
||||
message=message,
|
||||
add_generation_prompt=True,
|
||||
add_output=False
|
||||
)
|
||||
new_batch.append(full_text)
|
||||
outputs = llm.generate(
|
||||
new_batch,
|
||||
SamplingParams(
|
||||
n=1,
|
||||
top_k=1,
|
||||
temperature=config["inference"]["temperature"],
|
||||
seed=config["inference"]["seed"],
|
||||
skip_special_tokens=False,
|
||||
ignore_eos=False,
|
||||
max_tokens=config["inference"]["max_new_tokens"],
|
||||
)
|
||||
)
|
||||
responses = [output.outputs[0].text for output in outputs]
|
||||
outcomes = []
|
||||
for i in range(len(batch)):
|
||||
if responses[i] is not None:
|
||||
outcomes.append((sample,responses[i]))
|
||||
|
||||
with jsonlines.open(config["dataset"]["output_path"], mode='a') as writer:
|
||||
for ins,result in outcomes:
|
||||
gen_data = {"instruction": ins, "output": result}
|
||||
writer.write(gen_data)
|
||||
|
||||
|
||||
def cot_long2short_api(data_list_ins, data_list_out, config):
|
||||
client = OpenAI(
|
||||
api_key = config["inference"]["api_key"],
|
||||
base_url = config["inference"]["base_url"],
|
||||
)
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
prompt = config["inference"]["prompt"]
|
||||
stream = config["inference"]["stream"]
|
||||
logging.info(model)
|
||||
outcomes = []
|
||||
data_list=[(ins,out) for ins,out in zip(data_list_ins,data_list_out)]
|
||||
for ins,out in tqdm(data_list, desc="Calling remote model and generating responses"):
|
||||
sample = f"{prompt} Simplify the reasoning process for the problem below.\n\nProblem:\n{ins}\n\nAnswer:\n{out}\n\nSimplified Reasoning Process:"
|
||||
logging.info(sample)
|
||||
message = [
|
||||
{"role": "user", "content": sample}
|
||||
]
|
||||
completion = client.chat.completions.create(
|
||||
messages = message,
|
||||
model = model,
|
||||
max_completion_tokens = config["inference"]["max_new_tokens"],
|
||||
stream = stream,
|
||||
)
|
||||
if stream:
|
||||
result = ""
|
||||
for chunk in completion:
|
||||
result += chunk.choices[0].delta.content
|
||||
else:
|
||||
result = completion.choices[0].message.content
|
||||
|
||||
if result is not None:
|
||||
outcomes.append((sample,result))
|
||||
|
||||
with jsonlines.open(config["dataset"]["output_path"], mode='a') as writer:
|
||||
for ins,result in outcomes:
|
||||
gen_data = {"instruction": ins, "output": result}
|
||||
writer.write(gen_data)
|
||||
|
||||
|
||||
def cot_long2short_batch(tokenizer, llm, data_list_ins, data_list_out, config, batch_size=32):
|
||||
full_path = config["dataset"]["template"]
|
||||
template_dir = os.path.dirname(full_path)
|
||||
template_file = os.path.basename(full_path)
|
||||
env = Environment(loader=FileSystemLoader(template_dir))
|
||||
template = env.get_template(template_file)
|
||||
prompt = config["inference"]["prompt"]
|
||||
data_list=[(ins,out) for ins,out in zip(data_list_ins,data_list_out)]
|
||||
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
|
||||
for batch in tqdm(batches, desc="Generating responses"):
|
||||
new_batch = []
|
||||
for ins,out in batch:
|
||||
sample = f"{prompt} Simplify the reasoning process for the problem below.\n\nProblem:\n{ins}\n\nAnswer:\n{out}\n\nSimplified Reasoning Process:"
|
||||
logging.info(sample)
|
||||
message={"role": "user", "content": sample}
|
||||
full_text = template.render(
|
||||
message=message,
|
||||
add_generation_prompt=True,
|
||||
add_output=False
|
||||
)
|
||||
new_batch.append(full_text)
|
||||
outputs = llm.generate(
|
||||
new_batch,
|
||||
SamplingParams(
|
||||
n=1,
|
||||
top_k=1,
|
||||
temperature=config["inference"]["temperature"],
|
||||
seed=config["inference"]["seed"],
|
||||
skip_special_tokens=False,
|
||||
ignore_eos=False,
|
||||
max_tokens=config["inference"]["max_new_tokens"],
|
||||
)
|
||||
)
|
||||
responses = [output.outputs[0].text for output in outputs]
|
||||
outcomes = []
|
||||
for i in range(len(batch)):
|
||||
if responses[i] is not None:
|
||||
outcomes.append((sample,responses[i]))
|
||||
|
||||
with jsonlines.open(config["dataset"]["output_path"], mode='a') as writer:
|
||||
for ins,result in outcomes:
|
||||
gen_data = {"instruction": ins, "output": result}
|
||||
writer.write(gen_data)
|
||||
|
||||
|
||||
def cot_short2long_api(data_list_ins, data_list_out, config):
|
||||
client = OpenAI(
|
||||
api_key = config["inference"]["api_key"],
|
||||
base_url = config["inference"]["base_url"],
|
||||
)
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
prompt = config["inference"]["prompt"]
|
||||
stream = config["inference"]["stream"]
|
||||
logging.info(model)
|
||||
outcomes = []
|
||||
data_list=[(ins,out) for ins,out in zip(data_list_ins,data_list_out)]
|
||||
for ins,out in tqdm(data_list, desc="Calling remote model and generating responses"):
|
||||
sample = f"{prompt} Extend the reasoning process for the problem below.\n\nProblem:\n{ins}\n\nAnswer:\n{out}\n\nExtended Reasoning Process:"
|
||||
logging.info(sample)
|
||||
message = [
|
||||
{"role": "user", "content": sample}
|
||||
]
|
||||
completion = client.chat.completions.create(
|
||||
messages = message,
|
||||
model = model,
|
||||
max_completion_tokens = config["inference"]["max_new_tokens"],
|
||||
stream = stream,
|
||||
)
|
||||
if stream:
|
||||
result = ""
|
||||
for chunk in completion:
|
||||
result += chunk.choices[0].delta.content
|
||||
else:
|
||||
result = completion.choices[0].message.content
|
||||
|
||||
if result is not None:
|
||||
outcomes.append((sample,result))
|
||||
|
||||
with jsonlines.open(config["dataset"]["output_path"], mode='a') as writer:
|
||||
for ins,result in outcomes:
|
||||
gen_data = {"instruction": ins, "output": result}
|
||||
writer.write(gen_data)
|
||||
|
||||
|
||||
def cot_short2long_batch(tokenizer, llm, data_list_ins, data_list_out, config, batch_size=32):
|
||||
full_path = config["dataset"]["template"]
|
||||
template_dir = os.path.dirname(full_path)
|
||||
template_file = os.path.basename(full_path)
|
||||
env = Environment(loader=FileSystemLoader(template_dir))
|
||||
template = env.get_template(template_file)
|
||||
prompt = config["inference"]["prompt"]
|
||||
data_list=[(ins,out) for ins,out in zip(data_list_ins,data_list_out)]
|
||||
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
|
||||
for batch in tqdm(batches, desc="Generating responses"):
|
||||
new_batch = []
|
||||
for ins,out in batch:
|
||||
sample = f"{prompt} Extend the reasoning process for the problem below.\n\nProblem:\n{ins}\n\nAnswer:\n{out}\n\nExtended Reasoning Process:"
|
||||
logging.info(sample)
|
||||
message={"role": "user", "content": sample}
|
||||
full_text = template.render(
|
||||
message=message,
|
||||
add_generation_prompt=True,
|
||||
add_output=False
|
||||
)
|
||||
new_batch.append(full_text)
|
||||
outputs = llm.generate(
|
||||
new_batch,
|
||||
SamplingParams(
|
||||
n=1,
|
||||
top_k=1,
|
||||
temperature=config["inference"]["temperature"],
|
||||
seed=config["inference"]["seed"],
|
||||
skip_special_tokens=False,
|
||||
ignore_eos=False,
|
||||
max_tokens=config["inference"]["max_new_tokens"],
|
||||
)
|
||||
)
|
||||
responses = [output.outputs[0].text for output in outputs]
|
||||
outcomes = []
|
||||
for i in range(len(batch)):
|
||||
if responses[i] is not None:
|
||||
outcomes.append((sample,responses[i]))
|
||||
|
||||
with jsonlines.open(config["dataset"]["output_path"], mode='a') as writer:
|
||||
for ins,result in outcomes:
|
||||
gen_data = {"instruction": ins, "output": result}
|
||||
writer.write(gen_data)
|
293
easydistill/synthesis/instruct_synthesis.py
Normal file
293
easydistill/synthesis/instruct_synthesis.py
Normal file
@@ -0,0 +1,293 @@
|
||||
|
||||
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import logging
|
||||
import os
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from vllm import LLM, SamplingParams
|
||||
from tqdm import tqdm
|
||||
from openai import OpenAI
|
||||
import random
|
||||
import re
|
||||
|
||||
from utils import read_json_field, write_data_to_json_file, load_tokenizer_and_vllm
|
||||
|
||||
|
||||
def extract_answer(content):
|
||||
pattern = r'<answer>(.*?)</answer>'
|
||||
match = re.search(pattern, content, re.DOTALL)
|
||||
if match:
|
||||
return match.group(1)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def extract_instruction_response(content):
|
||||
instruction_pattern = r'<instruction>(.*?)</instruction>'
|
||||
instruction_match = re.search(instruction_pattern, content, re.DOTALL)
|
||||
response_pattern = r'<response>(.*?)</response>'
|
||||
response_match = re.search(response_pattern, content, re.DOTALL)
|
||||
if instruction_match and response_match:
|
||||
return instruction_match.group(1), response_match.group(1)
|
||||
else:
|
||||
return None, None
|
||||
|
||||
|
||||
def generate_prompt_list(data_list, prompt, num_in_context_samples, num_output_samples):
|
||||
if num_in_context_samples > len(data_list):
|
||||
raise ValueError("num_in_context_samples cannot be larger than the length of data_list")
|
||||
output_list = []
|
||||
for _ in range(num_output_samples):
|
||||
selected_samples = random.sample(data_list, num_in_context_samples)
|
||||
combined_prompts = prompt + "\n" + "".join([sample + "\n" for sample in selected_samples])
|
||||
output_list.append(combined_prompts)
|
||||
return output_list
|
||||
|
||||
|
||||
def expand_instruction_api(data_list, config):
|
||||
client = OpenAI(
|
||||
api_key = config["inference"]["api_key"],
|
||||
base_url = config["inference"]["base_url"],
|
||||
)
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
num_output_samples = config["dataset"]["num_output_samples"]
|
||||
num_in_context_samples = config["dataset"]["num_in_context_samples"]
|
||||
prompt = config["inference"]["prompt"]
|
||||
stream = config["inference"]["stream"]
|
||||
logging.info(model)
|
||||
prompt_list = generate_prompt_list(data_list, prompt, num_in_context_samples, num_output_samples)
|
||||
outcomes = []
|
||||
for sample in tqdm(prompt_list, desc="Calling remote model and generating responses"):
|
||||
logging.info(sample)
|
||||
message = [
|
||||
{"role": "user", "content": sample}
|
||||
]
|
||||
completion = client.chat.completions.create(
|
||||
messages = message,
|
||||
model = model,
|
||||
max_completion_tokens = config["inference"]["max_new_tokens"],
|
||||
stream = stream,
|
||||
)
|
||||
if stream:
|
||||
result = ""
|
||||
for chunk in completion:
|
||||
result += chunk.choices[0].delta.content
|
||||
else:
|
||||
result = completion.choices[0].message.content
|
||||
result = extract_answer(result)
|
||||
if result is not None:
|
||||
outcomes.append({"instruction": result})
|
||||
write_data_to_json_file(outcomes, config["dataset"]["output_path"])
|
||||
|
||||
|
||||
def expand_instruction_batch(tokenizer, llm, data_list, config, batch_size=32):
|
||||
full_path = config["dataset"]["template"]
|
||||
template_dir = os.path.dirname(full_path)
|
||||
template_file = os.path.basename(full_path)
|
||||
env = Environment(loader=FileSystemLoader(template_dir))
|
||||
template = env.get_template(template_file)
|
||||
|
||||
num_output_samples = config["dataset"]["num_output_samples"]
|
||||
num_in_context_samples = config["dataset"]["num_in_context_samples"]
|
||||
prompt = config["inference"]["prompt"]
|
||||
prompt_list = generate_prompt_list(data_list, prompt, num_in_context_samples, num_output_samples)
|
||||
|
||||
outcomes = []
|
||||
batches = [prompt_list[i:i + batch_size] for i in range(0, len(prompt_list), batch_size)]
|
||||
for batch in tqdm(batches, desc="Generating responses"):
|
||||
new_batch = []
|
||||
for sample in batch:
|
||||
logging.info(sample)
|
||||
message={"role": "user", "content": sample}
|
||||
full_text = template.render(
|
||||
message=message,
|
||||
add_generation_prompt=True,
|
||||
add_output=False
|
||||
)
|
||||
new_batch.append(full_text)
|
||||
outputs = llm.generate(
|
||||
new_batch,
|
||||
SamplingParams(
|
||||
n=1,
|
||||
top_k=1,
|
||||
temperature=config["inference"]["temperature"],
|
||||
seed=config["inference"]["seed"],
|
||||
skip_special_tokens=False,
|
||||
ignore_eos=False,
|
||||
max_tokens=config["inference"]["max_new_tokens"]
|
||||
)
|
||||
)
|
||||
responses = [output.outputs[0].text for output in outputs]
|
||||
for i in range(len(batch)):
|
||||
result = extract_answer(responses[i])
|
||||
if result is not None:
|
||||
outcomes.append({"instruction": result})
|
||||
write_data_to_json_file(outcomes, config["dataset"]["output_path"])
|
||||
|
||||
|
||||
def refine_instruction_api(data_list, config):
|
||||
client = OpenAI(
|
||||
api_key = config["inference"]["api_key"],
|
||||
base_url = config["inference"]["base_url"],
|
||||
)
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
prompt = config["inference"]["prompt"]
|
||||
stream = config["inference"]["stream"]
|
||||
logging.info(model)
|
||||
outcomes = []
|
||||
for sample in tqdm(data_list, desc="Calling remote model and generating responses"):
|
||||
sample = prompt + "\n" + sample
|
||||
logging.info(sample)
|
||||
message = [
|
||||
{"role": "user", "content": sample}
|
||||
]
|
||||
completion = client.chat.completions.create(
|
||||
messages = message,
|
||||
model = model,
|
||||
max_completion_tokens = config["inference"]["max_new_tokens"],
|
||||
stream = stream
|
||||
)
|
||||
if stream:
|
||||
result = ""
|
||||
for chunk in completion:
|
||||
result += chunk.choices[0].delta.content
|
||||
else:
|
||||
result = completion.choices[0].message.content
|
||||
result = extract_answer(result)
|
||||
if result is not None:
|
||||
outcomes.append({"instruction": result})
|
||||
write_data_to_json_file(outcomes, config["dataset"]["output_path"])
|
||||
|
||||
|
||||
def refine_instruction_batch(tokenizer, llm, data_list, config, batch_size=32):
|
||||
full_path = config["dataset"]["template"]
|
||||
template_dir = os.path.dirname(full_path)
|
||||
template_file = os.path.basename(full_path)
|
||||
env = Environment(loader=FileSystemLoader(template_dir))
|
||||
template = env.get_template(template_file)
|
||||
prompt = config["inference"]["prompt"]
|
||||
|
||||
outcomes = []
|
||||
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
|
||||
for batch in tqdm(batches, desc="Generating responses"):
|
||||
new_batch = []
|
||||
for sample in batch:
|
||||
sample = prompt + "\n" + sample
|
||||
logging.info(sample)
|
||||
message={"role": "user", "content": sample}
|
||||
full_text = template.render(
|
||||
message=message,
|
||||
add_generation_prompt=True,
|
||||
add_output=False
|
||||
)
|
||||
new_batch.append(full_text)
|
||||
outputs = llm.generate(
|
||||
new_batch,
|
||||
SamplingParams(
|
||||
n=1,
|
||||
top_k=1,
|
||||
temperature=config["inference"]["temperature"],
|
||||
seed=config["inference"]["seed"],
|
||||
skip_special_tokens=False,
|
||||
ignore_eos=False,
|
||||
max_tokens=config["inference"]["max_new_tokens"],
|
||||
)
|
||||
)
|
||||
responses = [output.outputs[0].text for output in outputs]
|
||||
for i in range(len(batch)):
|
||||
result = extract_answer(responses[i])
|
||||
if result is not None:
|
||||
outcomes.append({"instruction": result})
|
||||
write_data_to_json_file(outcomes, config["dataset"]["output_path"])
|
||||
|
||||
|
||||
def instruction_response_extraction_api(data_list, config):
|
||||
client = OpenAI(
|
||||
api_key = config["inference"]["api_key"],
|
||||
base_url = config["inference"]["base_url"],
|
||||
)
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
prompt = config["inference"]["prompt"]
|
||||
stream = config["inference"]["stream"]
|
||||
logging.info(model)
|
||||
outcomes = []
|
||||
for sample in tqdm(data_list, desc="Calling remote model and generating responses"):
|
||||
sample = prompt + "\n" + sample
|
||||
logging.info(sample)
|
||||
message = [
|
||||
{"role": "user", "content": sample}
|
||||
]
|
||||
completion = client.chat.completions.create(
|
||||
messages = message,
|
||||
model = model,
|
||||
max_completion_tokens = config["inference"]["max_new_tokens"],
|
||||
stream= stream,
|
||||
)
|
||||
if stream:
|
||||
result = ""
|
||||
for chunk in completion:
|
||||
result += chunk.choices[0].delta.content
|
||||
else:
|
||||
result = completion.choices[0].message.content
|
||||
new_instruction, new_response = extract_instruction_response(result)
|
||||
if new_instruction is not None and new_response is not None:
|
||||
outcomes.append({"instruction": new_instruction, "output": new_response})
|
||||
write_data_to_json_file(outcomes, config["dataset"]["output_path"])
|
||||
|
||||
|
||||
def instruction_response_extraction_batch(tokenizer, llm, data_list, config, batch_size=32):
|
||||
full_path = config["dataset"]["template"]
|
||||
template_dir = os.path.dirname(full_path)
|
||||
template_file = os.path.basename(full_path)
|
||||
env = Environment(loader=FileSystemLoader(template_dir))
|
||||
template = env.get_template(template_file)
|
||||
prompt = config["inference"]["prompt"]
|
||||
|
||||
outcomes = []
|
||||
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
|
||||
for batch in tqdm(batches, desc="Generating responses"):
|
||||
new_batch = []
|
||||
for sample in batch:
|
||||
logging.info(sample)
|
||||
sample = prompt + "\n" + sample
|
||||
message={"role": "user", "content": sample}
|
||||
full_text = template.render(
|
||||
message=message,
|
||||
add_generation_prompt=True,
|
||||
add_output=False
|
||||
)
|
||||
new_batch.append(full_text)
|
||||
outputs = llm.generate(
|
||||
new_batch,
|
||||
SamplingParams(
|
||||
n=1,
|
||||
top_k=1,
|
||||
temperature=config["inference"]["temperature"],
|
||||
seed=config["inference"]["seed"],
|
||||
skip_special_tokens=False,
|
||||
ignore_eos=False,
|
||||
max_tokens=config["inference"]["max_new_tokens"],
|
||||
)
|
||||
)
|
||||
responses = [output.outputs[0].text for output in outputs]
|
||||
for i in range(len(batch)):
|
||||
new_instruction, new_response = extract_instruction_response(responses[i])
|
||||
if new_instruction is not None and new_response is not None:
|
||||
outcomes.append({"instruction": new_instruction, "output": new_response})
|
||||
write_data_to_json_file(outcomes, config["dataset"]["output_path"])
|
107
easydistill/synthesis/synthesis_main.py
Normal file
107
easydistill/synthesis/synthesis_main.py
Normal file
@@ -0,0 +1,107 @@
|
||||
|
||||
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import json
|
||||
|
||||
from instruct_synthesis import (
|
||||
expand_instruction_api,
|
||||
expand_instruction_batch,
|
||||
refine_instruction_api,
|
||||
refine_instruction_batch,
|
||||
instruction_response_extraction_api,
|
||||
instruction_response_extraction_batch
|
||||
)
|
||||
from cot_synthesis import (
|
||||
cot_generate_api,
|
||||
cot_generate_batch,
|
||||
cot_long2short_api,
|
||||
cot_long2short_batch,
|
||||
cot_short2long_api,
|
||||
cot_short2long_batch
|
||||
)
|
||||
from utils import read_json_field, load_tokenizer_and_vllm
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
|
||||
def data_synthesis_with_teacher_model(config):
|
||||
logging.info('Generating distillation data from the teacher model!')
|
||||
job_type = config["job_type"]
|
||||
if job_type == "instruction_response_extraction_api":
|
||||
data_list = read_json_field(config["dataset"]["input_path"], field_name="data")
|
||||
elif job_type in ["cot_long2short_api","cot_long2short_batch","cot_short2long_api","cot_short2long_batch"]:
|
||||
data_list_ins = read_json_field(config["dataset"]["input_path"])
|
||||
data_list_out = read_json_field(config["dataset"]["input_path"], field_name="output")
|
||||
else:
|
||||
data_list = read_json_field(config["dataset"]["input_path"])
|
||||
|
||||
try:
|
||||
if job_type == "instruction_expansion_api":
|
||||
expand_instruction_api(data_list, config)
|
||||
elif job_type == "instruction_expansion_batch":
|
||||
tokenizer, llm = load_tokenizer_and_vllm(config)
|
||||
expand_instruction_batch(tokenizer, llm, data_list, config)
|
||||
|
||||
elif job_type == "instruction_refinement_api":
|
||||
refine_instruction_api(data_list, config)
|
||||
elif job_type == "instruction_refinement_batch":
|
||||
tokenizer, llm = load_tokenizer_and_vllm(config)
|
||||
refine_instruction_batch(tokenizer, llm, data_list, config)
|
||||
|
||||
elif job_type == "instruction_response_extraction_api":
|
||||
instruction_response_extraction_api(data_list, config)
|
||||
elif job_type == "instruction_response_extraction_batch":
|
||||
tokenizer, llm = load_tokenizer_and_vllm(config)
|
||||
instruction_response_extraction_batch(tokenizer, llm, data_list, config)
|
||||
|
||||
elif job_type == "cot_generation_api":
|
||||
cot_generate_api(data_list, config)
|
||||
elif job_type == "cot_generation_batch":
|
||||
tokenizer, llm = load_tokenizer_and_vllm(config)
|
||||
cot_generate_batch(tokenizer, llm, data_list, config)
|
||||
|
||||
elif job_type == "cot_long2short_api":
|
||||
cot_long2short_api(data_list_ins, data_list_out, config)
|
||||
elif job_type == "cot_long2short_batch":
|
||||
tokenizer, llm = load_tokenizer_and_vllm(config)
|
||||
cot_long2short_batch(tokenizer, llm, data_list_ins, data_list_out, config)
|
||||
|
||||
elif job_type == "cot_short2long_api":
|
||||
cot_short2long_api(data_list_ins, data_list_out, config)
|
||||
elif job_type == "cot_short2long_batch":
|
||||
tokenizer, llm = load_tokenizer_and_vllm(config)
|
||||
cot_short2long_batch(tokenizer, llm, data_list_ins, data_list_out, config)
|
||||
else:
|
||||
logging.error(f"Invalid job type: {job_type}")
|
||||
raise ValueError(f"Invalid job type: {job_type}")
|
||||
except ValueError as e:
|
||||
logging.error(f"Training job terminated: {e}")
|
||||
return
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
|
||||
args = parser.parse_args()
|
||||
config = json.load(open(args.config))
|
||||
data_synthesis_with_teacher_model(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
85
easydistill/synthesis/utils.py
Normal file
85
easydistill/synthesis/utils.py
Normal file
@@ -0,0 +1,85 @@
|
||||
|
||||
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import json
|
||||
import torch
|
||||
import logging
|
||||
from vllm import LLM
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
def read_json_field(filename, field_name='instruction'):
|
||||
try:
|
||||
with open(filename, 'r') as file:
|
||||
data = json.load(file)
|
||||
output_fields = []
|
||||
for item in data:
|
||||
if field_name in item:
|
||||
output_fields.append(item[field_name])
|
||||
return output_fields
|
||||
except FileNotFoundError:
|
||||
logging.error("The file was not found.")
|
||||
except json.JSONDecodeError:
|
||||
logging.error("There was an error decoding the JSON file.")
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred: {e}")
|
||||
|
||||
|
||||
def write_data_to_json_file(data, file_path):
|
||||
try:
|
||||
with open(file_path, 'w') as file:
|
||||
json.dump(data, file, ensure_ascii=False, indent=4)
|
||||
logging.info(f"Data successfully written to {file_path}")
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred: {e}")
|
||||
|
||||
|
||||
def load_tokenizer_and_vllm(config, eos_token=None):
|
||||
teacher_model_path = config["models"]["teacher"]
|
||||
logging.info(f"Loading ckpt and tokenizer: {teacher_model_path}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(teacher_model_path, trust_remote_code=True)
|
||||
tokenizer.padding_side = "left"
|
||||
if eos_token:
|
||||
eos_token_id = tokenizer.convert_tokens_to_ids(eos_token)
|
||||
logging.info(f"eos_token {eos_token} from user input")
|
||||
elif hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id:
|
||||
logging.info(f"Initial eos_token_id {tokenizer.eos_token_id} from tokenizer")
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
eos_token = tokenizer.convert_ids_to_tokens(eos_token_id)
|
||||
else:
|
||||
raise ValueError("No available eos_token or eos_token_id.")
|
||||
try:
|
||||
tokenizer.eos_token = eos_token
|
||||
tokenizer.eos_token_id = eos_token_id
|
||||
tokenizer.pad_token = eos_token
|
||||
tokenizer.pad_token_id = eos_token_id
|
||||
except:
|
||||
logging.info(f"[WARNING] Cannot set tokenizer.eos_token")
|
||||
logging.info(f"tokenizer's eos_token: {tokenizer.eos_token}, pad_token: {tokenizer.pad_token}")
|
||||
logging.info(f"tokenizer's eos_token_id: {tokenizer.eos_token_id}, pad_token_id: {tokenizer.pad_token_id}")
|
||||
num_gpus = torch.cuda.device_count()
|
||||
llm = LLM(
|
||||
model=teacher_model_path,
|
||||
tensor_parallel_size=num_gpus,
|
||||
enable_chunked_prefill=config["inference"]["enable_chunked_prefill"],
|
||||
gpu_memory_utilization=config["inference"]["gpu_memory_utilization"],
|
||||
trust_remote_code=config["inference"]["trust_remote_code"],
|
||||
dtype=torch.bfloat16,
|
||||
enforce_eager=config["inference"]["enforce_eager"],
|
||||
max_model_len=config["inference"]["max_model_len"],
|
||||
)
|
||||
logging.info("vLLM model loaded successfully")
|
||||
return tokenizer, llm
|
Reference in New Issue
Block a user