init commit

This commit is contained in:
熊兮
2025-05-27 18:55:46 +08:00
parent 6f52a67249
commit 25caa8a90a
65 changed files with 4893 additions and 1 deletions

View File

@@ -0,0 +1,165 @@
# DistilQwen2: Refining Instructional Data for Black-Box KD
## Brief Introduction
Knowledge distillation offers an effective solution by transferring knowledge from larger models to smaller ones, ensuring performance while significantly reducing computational resources and inference time. We introduce DistilQwen2, a lightweight LLM based on the Qwen2 series, optimized through enhanced instruction following and diverse distillation techniques. This enables more agile and efficient deployment in resource-constrained environments like mobile devices and edge computing. For ease of use by developers and enterprises, DistilQwen2's checkpoints are open-sourced on HuggingFace and ModelScope, empowering more stakeholders to innovate and realize value through advanced NLP applications.
## Instructional Data Processing Guidelines
For the training of DistilQwen2, we collected data from well-known open-source datasets like Magpie, Openhermes, and Mammoth 2, along with proprietary synthetic datasets to initiate the distillation process. The focus is on providing diverse instructional data, predominantly in Chinese and English. We also leverage prompt templates to conduct instructional data augmentation. Here, we provide several commonly used operations to re-sample and augement the dataset.
### Instruction Set Expansion
The instruction expansion operator is employed generate a diverse set of instruction variations, ensuring that student models are exposed to a comprehensive range of instructions. After instruction expansion, we can also call the teacher model to generate responses for new instructions. An example is calling this operator is as follows:
```bash
python easydistill/synthesis/synthesis_main.py --config=configs/instruction_expansion_api.json
```
If you need to run the job using batch inference, please refer to the config example `configs/instruction_expansion_batch.json`.
### Instruction Refinement
The instruction refinement operator further enhances the quality and diversity of the training data, which also preserves the semantic integrity of the tasks expressed in instructions, ensuring that the rewritten content remains faithful to the original intent and task category. After instruction refinement, we can also call the teacher model to generate responses for new instructions. An example is calling this operator is as follows:
```bash
python easydistill/synthesis/synthesis_main.py --config=configs/instruction_refinement_api.json
```
If you need to run the job using batch inference, please refer to the config example `configs/instruction_refinement_batch.json`.
### Instruction Resampling
We also consider task balance when selecting useful instructional data pairs. The task distrubutions are defined based on our paper in the reference. You can run the job by:
```bash
python task_resampling.py --input-file input.json --output-file output.json --api-key <your_api_key> --base-url <base_url>
```
The dataset is in JSON format, exemplified by entries such as:
```json
[
{
"instruction": "The ratio of the number of molar teeth in the human upper jaw at the age of 6 is 2:1 compared to number of incisors teeth. There are total 8 incisors in the human mouth..."
},
{
"instruction": "This is the text of a lecture I am giving tomorrow. Can you go over it and make recommendations to improve clarity and flow?"
}
]
```
After the processing of intructions, you can generate the responses of the teacher model.
### Open-Source Dataset
In addition, we have open-sourced part of the dataset used for model training, totaling 100K entries. This dataset includes mathematical problems, code tasks, Q&A, instruction following, and creative generation. Users can incorporate the DistilQwen_100K dataset, or its subsets, during model fine-tuning to enhance downstream task performance while maintaining generalization ability. The dataset is in JSON format, exemplified by entries such as:
```json
[
{
"instruction": "The ratio of the number of molar teeth in the human upper jaw at the age of 6 is 2:1 compared to number of incisors teeth. There are total 8 incisors in the human mouth...",
"output": "## Step 1: Determine the total number of incisors in the upper jaw...\n\nThe final answer is: \\boxed{8}"
},
{
"instruction": "This is the text of a lecture I am giving tomorrow. Can you go over it and make recommendations to improve clarity and flow?",
"output": "I'd be happy to help you review your lecture text..."
}
]
```
The dataset is available on ModelScope and Hugging Face. Users can download it using ModelScope's scripts and command-line tools.
```python
# Validate SDK token
from modelscope.hub.api import HubApi
api = HubApi()
api.login('your_token_id')
# Dataset download
from modelscope.msdatasets import MsDataset
ds = MsDataset.load('PAI/DistilQwen_100k')
```
## Model Training Guidelines
### Black-Box KD
The black-box KD process follows a supervised learning paradigm, utilizing enhanced instruction-response pairs as training samples. Through this approach, the student model can effectively absorb and understand the knowledge imparted by the larger model, even with a limited number of parameters. This method not only boosts the student model's ability to tackle tasks but also enables it to perform better in multi-task scenarios. For simplicity, we use the `DistilQwen_100k` dataset as a tutorial, we need to run the training job only:
```bash
python easydistill/kd/train.py --config=distilqwen2_stage1.json
```
Plese refer to the config file `distilqwen2_stage1.json` in the current folder. If you need to run the job in a distributed mode, use `accelerate` to run the job.
### Preference Rank Optimization
For more challenging instruction tasks, SFT alone may not yield optimal results. To address this, we further refine the model using Direct Preference Optimization (DPO), enabling more granular fine-tuning and improved performance. Firstly, we generate the student outputs as rejected response. The contents in the SFT datasets are regarded as prompt and chosen responses. Please refer to the following script:
```bash
python dpo_student_infer_only.py --config=distilqwen2_stage2.json
```
Next, we run the training job by:
```bash
python easydistill/kd/train.py --config=distilqwen2_stage2.json
```
Again, please refer to the config file `distilqwen2_stage2.json` in the current folder. Remember to change the configurations when needed. If you need to run the job in a distributed mode, use `accelerate` to run the job.
## Model Download
We have open-sourced our distilled models on both HuggingFace and ModelScope. The available models are named `alibaba-pai/DistilQwen2-1.5B-Instruct` and `alibaba-pai/DistilQwen2-7B-Instruct`.
For example, users can download these models from HuggingFace using the following code:
```python
from huggingface_hub import snapshot_download
model_name = "alibaba-pai/DistilQwen2-1.5B-Instruct"
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2-1.5B/")
model_name = "alibaba-pai/DistilQwen2-7B-Instruct"
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2-7B/")
```
## Performance
The table below compares the performance of the original Qwen2 models with the distilled DistilQwen2 models across different parameter sizes: 1.5B and 7B. The evaluation metrics include AlpacaEval 2.0, MT-Bench, and IFEval scores. The distilled models demonstrate improved performance in instruction-following abilities over their respective original versions.
| Model | AlpacaEval 2.0 (length control) | MT-Bench | MT-Bench (single) | IFEval (instruct-loose) | IFEval (strict-prompt) |
|-------------------------------|---------------------------------|------------------|-------------------|-------------------------|------------------------|
| Qwen2-1.5B-Instruct | 5.22 | 5.85 | 6.45 | 41.37 | 28.10 |
| **DistilQwen2-1.5B-Instruct** | **8.28** | **6.42** | **7.12** | **49.76** | **36.04** |
| Qwen2-7B-Instruct | 24.33 | 8.27 | 8.68 | 66.67 | 52.31 |
| **DistilQwen2-7B-Instruct** | **25.35** | **8.40** | **9.03** | **71.46** | **60.26** |
## Reference
For more detailed information about the DistilQwen2 model series and the methodologies employed, we encourage you to refer to our paper:
- **Distilling Instruction-following Abilities of Large Language Models with Task-aware Curriculum Planning**
Yuanhao Yue, Chengyu Wang, Jun Huang, Peng Wang
You can cite the paper using the following citation format:
```bibtex
@inproceedings{emnlp2024,
author = {Yuanhao Yue and
Chengyu Wang and
Jun Huang and
Peng Wang},
title = {Distilling Instruction-following Abilities of Large Language Models with Task-aware Curriculum Planning},
booktitle = {Findings of the Association for Computational Linguistics: {EMNLP} 2024},
pages = {6030--6054},
publisher = {Association for Computational Linguistics},
year = {2024},
url = {https://aclanthology.org/2024.findings-emnlp.350}
}

View File

@@ -0,0 +1,23 @@
{
"job_type": "kd_black_box_api",
"dataset": {
"labeled_path": "distil_qwen_100k.json",
"template" : "chat_template_kd.jinja",
"seed": 42
},
"models": {
"student": "model/Qwen/Qwen2-0.5B-Instruct/"
},
"training": {
"output_dir": "result_stage1/",
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 8,
"save_steps": 1000,
"logging_steps": 1,
"learning_rate": 2e-5,
"weight_decay": 0.05,
"warmup_ratio": 0.1,
"lr_scheduler_type": "cosine"
}
}

View File

@@ -0,0 +1,25 @@
{
"job_type": "rank_dpo_api",
"dataset": {
"instruction_path": "distil_qwen_100k.json",
"labeled_path": "distil_qwen_100k_dpo.json",
"template" : "chat_template_kd.jinja",
"seed": 42
},
"models": {
"student": "result_stage1/"
},
"training": {
"output_dir": "result_stage2/",
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 8,
"save_steps": 1000,
"logging_steps": 1,
"beta": 0.1,
"learning_rate": 2e-5,
"weight_decay": 0.05,
"warmup_ratio": 0.1,
"lr_scheduler_type": "cosine"
}
}

View File

@@ -0,0 +1,105 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json
import argparse
import logging
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def read_json_field(filename):
try:
with open(filename, 'r') as file:
data = json.load(file)
output = []
for item in data:
instruction = item["instruction"]
output = item["output"]
output.append({"prompt": instruction, "chosen": output})
return output
except FileNotFoundError:
logging.error("The file was not found.")
except json.JSONDecodeError:
logging.error("There was an error decoding the JSON file.")
except Exception as e:
logging.error(f"An error occurred: {e}")
def write_data_to_json_file(data, file_path):
try:
with open(file_path, 'w') as file:
json.dump(data, file, ensure_ascii=False, indent=4)
logging.info(f"Data successfully written to {file_path}")
except Exception as e:
logging.error(f"An error occurred: {e}")
def generate_student_response(data_list, config):
# load student model
student_tokenizer = AutoTokenizer.from_pretrained(
config["models"]["student"],
trust_remote_code=True
)
student_model = AutoModelForCausalLM.from_pretrained(
config["models"]["student"],
device_map="auto",
trust_remote_code=True
)
outcomes = []
for sample in tqdm(data_list, desc="Call remote model and generating responses"):
prompt = sample["prompt"]
chosen = sample["chosen"]
# for student model
messages = [
{"role": "user", "content": prompt}
]
text = student_tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = student_tokenizer([text], return_tensors="pt").to(student_model.device)
generated_ids = student_model.generate(
**model_inputs,
max_new_tokens=config["inference"]["max_new_tokens"]
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
rejected = student_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
gen_data = {'prompt': prompt, 'chosen': chosen, 'rejected': rejected}
outcomes.append(gen_data)
write_data_to_json_file(outcomes, config["dataset"]["labeled_path"])
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
args = parser.parse_args()
config = json.load(open(args.config))
data_list = read_json_field(config["dataset"]["instruction_path"])
generate_student_response(data_list, config)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,156 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json
import re
import logging
from openai import OpenAI
from collections import Counter
import random
import argparse
predefined_distribution = {
'Math': 0.167,
'Code Generation': 0.083,
'Writing': 0.017,
'Computer Science': 0.017,
'Reasoning': 0.167,
'Complex Format': 0.017,
'Code Debug': 0.083,
'Common-Sense': 0.017,
'Counterfactual': 0.017,
'Multilingual': 0.017,
'Roleplay': 0.017,
'Biology': 0.017,
'Technology': 0.017,
'Ethics': 0.017,
'Sport': 0.017,
'Law': 0.017,
'Medicine': 0.017,
'Literature': 0.017,
'Entertainment': 0.017,
'Art': 0.017,
'Music': 0.017,
'Toxicity': 0.017,
'Economy': 0.017,
'Physics': 0.017,
'History': 0.017,
'Chemistry': 0.017,
'Philosophy': 0.017,
'Health': 0.017,
'Ecology': 0.017,
'Grammar': 0.017,
'Paraphrase': 0.017,
'Others': 0.041
}
predefined_prompt = """
You are a data annotation expert. Please classify the task type or domain of #Given Instruction.
The task type or domain should be in the list: [Math, Code Generation, Writing, Computer Science, Reasoning, Complex Format, Code Debug, Common-Sense, Counterfactual, Multilingual, Roleplay,Biology, Technology, Ethics, Sport, Law, Medicine, Literature, Entertainment, Art, Music, Toxicity, Economy, Physics, History, Chemistry, Philosophy,Health,Ecology,Grammar,Paraphrase, Others]. You should place your answer enclosed within <answer></answer> tags, such as <answer>Math</answer>. Do not return anything else.
#Given Instruction#:
"""
def extract_answer(content):
pattern = r'<answer>(.*?)</answer>'
match = re.search(pattern, content, re.DOTALL)
if match:
return match.group(1)
else:
return None
def classify_instruction(instruction, client, model):
message = [
{"role": "user", "content": predefined_prompt + "\n" + instruction}
]
completion = client.chat.completions.create(
messages = message,
model = model,
max_completion_tokens = 1024
)
result = completion.choices[0].message.content.strip()
print(result)
result = extract_answer(result)
if result is None or result not in predefined_distribution.keys():
result = 'Others'
print(result)
return result
def main(args):
# Load dataset
with open(args.input_file, 'r') as file:
data = json.load(file)
# Initialize OpenAI client
client = OpenAI(
api_key=args.api_key,
base_url=args.base_url
)
models = client.models.list()
model = models.data[0].id
logging.info(model)
# Classify each instruction
classified_data = []
count = 0
for item in data:
category = classify_instruction(item['instruction'], client, model)
classified_data.append({'instruction': item['instruction'], 'category': category})
count += 1
print(count)
# Count occurrences per category
category_counts = Counter(item['category'] for item in classified_data)
total_samples = len(classified_data)
# Resample according to predefined distribution
resampled_data = []
for category, target_ratio in predefined_distribution.items():
target_count = int(total_samples * target_ratio)
category_samples = [item for item in classified_data if item['category'] == category]
if len(category_samples) == 0:
logging.warning("No instructions are provided for the category: " + category)
continue
if len(category_samples) > target_count:
print(category)
print(len(category_samples))
print(target_count)
# Randomly sample the required number of instructions
resampled_category_samples = random.sample(category_samples, target_count)
else:
# If not enough samples, repeat the existing ones
resampled_category_samples = category_samples * (target_count // len(category_samples)) + random.sample(category_samples, target_count % len(category_samples))
resampled_data.extend(resampled_category_samples)
# Save final dataset
with open(args.output_file, 'w') as file:
json.dump(resampled_data, file, indent=4)
print("Resampling complete. Final output saved to '{}'.".format(args.output_file))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Task and Domain Classification')
parser.add_argument('--input-file', type=str, required=True, help='Input JSON file containing instructions.')
parser.add_argument('--output-file', type=str, required=True, help='Output JSON file to store resampled instructions.')
parser.add_argument('--api-key', type=str, required=True, help='API key.')
parser.add_argument('--base-url', type=str, required=True, help='Base URL.')
args = parser.parse_args()
main(args)