init commit

This commit is contained in:
熊兮
2025-05-27 18:55:46 +08:00
parent 6f52a67249
commit 25caa8a90a
65 changed files with 4893 additions and 1 deletions

View File

@@ -0,0 +1,76 @@
# DistilQwen2.5-0324: training fast-thinking models
## Brief Introduction
In the rapid advancement of large language models, effectively balancing the trade-off between efficient inference and model thinking capabilities has been a key focus in both academia and industry. DeepSeekV3-0324, by default, does not employ deep thinking mode, which accelerates model inference while maintaining a balance between swift reasoning and handling complex tasks. The DistilQwen2.5-0324 series not only inherits the essence of the original model's chain-of-thought distillation but also introduces fast-thinking strategies, significantly boosting inference speed. This enables these models to efficiently execute complex tasks on resource-constrained devices and in edge computing scenarios.
## Detailed Steps
### Processing of Instructional Dataset
DistilQwen2.5-0324 was trained using data distilled from Deepseek-V3-0324 as well as data rewritten with long2short after distillation from Deepseek-R1. For Deepseek-V3-0324, the official recommendation is not to use a system prompt; for the long2short scenario, the following prompt was used. You can employ this method to reduce the output of Deepseek-R1 and distill your own model.
```json
{
"system": "You are a helpful assistant who is highly skilled at simplifying reasoning processes. Given a problem, its answer and its reasoning process, your task is to simplify the reasoning process so that a small language model (e.g., a 7B model) can reliably follow the steps to solve the problem. If the original reasoning process is divided into multiple steps separated by two newline characters (\n\n), your output must preserve this formatting. You must output ONLY the simplified reasoning process with no additional explanation or commentary."
}
```
```bash
python easydistill/kd/infer.py --config=distilqwen2.5-0324_stage1.json
```
The training dataset is in JSON format, exemplified by entries such as:
```json
[
{
"instruction": "The ratio of the number of molar teeth in the human upper jaw at the age of 6 is 2:1 compared to number of incisors teeth. There are total 8 incisors in the human mouth...",
"output": "Step 1: Determine the total number of incisors in the upper jaw...The final answer is: \\boxed{8}"
}
]
```
### Black-Box KD
The black-box KD process follows a supervised learning paradigm, utilizing enhanced instruction-response pairs as training samples. Through this approach, the student model can effectively absorb and understand the knowledge imparted by the larger model, even with a limited number of parameters. This method not only boosts the student model's ability to tackle tasks but also enables it to perform better in multi-task scenarios. Because we have already obtained the teacher's responses in the dataset, we can run the training job:
```bash
python easydistill/kd/train.py --config=distilqwen2.5-0324_stage2.json
```
Plese refer to the config file `distilqwen2.5-0324_stage2.json` in the current folder. If you need to run the job in a distributed mode, use `accelerate` to run the job.
## Model Download
We have open-sourced our distilled models on both HuggingFace and ModelScope. The available models are named `alibaba-pai/DistilQwen2.5-DS3-0324-7B`, `alibaba-pai/DistilQwen2.5-DS3-0324-14B`, and `alibaba-pai/DistilQwen2.5-DS3-0324-32B`.
For example, users can download these models from HuggingFace using the following code:
```python
from huggingface_hub import snapshot_download
# Download the 1.5B model
model_name = "alibaba-pai/DistilQwen2.5-DS3-0324-7B"
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2.5-DS3-0324-7B/")
# Download the 3B model
model_name = "alibaba-pai/DistilQwen2.5-DS3-0324-14B"
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2.5-DS3-0324-14B/")
# Download the 7B model
model_name = "alibaba-pai/DistilQwen2.5-DS3-0324-32B"
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2.5-DS3-0324-32B/")
```
## Performance
- **32B Model** approaches the performance of closed-source models with 10x the parameters on the GPQA Diamond benchmark
- **Significant Improvement in Reasoning Efficiency** (see comparison table below)
| Model | MMLU_PRO Tokens | AIME2024 Tokens | Speed Gain |
|--------------------------------|-----------------|-----------------|------------|
| DistilQwen2.5-R1-32B (Slow-Thinking) | 4198 | 12178 | 1x |
| DistilQwen2.5-DS3-0324-32B | 690 | 4177 | 5-8x |

View File

@@ -0,0 +1,14 @@
{
"job_type": "cot_long2short_api",
"dataset": {
"input_path": "./raw.json",
"output_path": "./raw_simplified.json"
},
"inference":{
"base_url": "ENDPOINT",
"api_key": "TOKEN",
"stream": true,
"prompt" : "You are a helpful assistant who is highly skilled at simplifying reasoning processes. Given a problem, its answer and its reasoning process, your task is to simplify the reasoning process so that a small language model (e.g., a 7B model) can reliably follow the steps to solve the problem. If the original reasoning process is divided into multiple steps separated by two newline characters (\n\n), your output must preserve this formatting. You must output ONLY the simplified reasoning process with no additional explanation or commentary.",
"max_new_tokens": 1024
}
}

View File

@@ -0,0 +1,23 @@
{
"job_type": "kd_black_box_api",
"dataset": {
"labeled_path": "distil_qwen_0324.json",
"template" : "chat_template_kd.jinja",
"seed": 42
},
"models": {
"student": "model/Qwen/Qwen2.5-1.5B-Instruct/"
},
"training": {
"output_dir": "result_stage2/",
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 8,
"save_steps": 1000,
"logging_steps": 1,
"learning_rate": 2e-5,
"weight_decay": 0.05,
"warmup_ratio": 0.1,
"lr_scheduler_type": "cosine"
}
}

View File

@@ -0,0 +1,142 @@
# DistilQwen2.5-R1: training distilled reasonin models based on CoTs generated by Deepseek-R1
## Brief Introduction
As large language models (LLMs) evolve toward deep reasoning capabilities, deploying them in resource-constrained environments (e.g., mobile devices, edge computing) remains challenging. The DistilQwen2.5-R1 series addresses this by transferring reasoning capabilities from ultra-large models (e.g., DeepSeek-R1) to compact models through innovative distillation techniques, achieving high performance while reducing computational costs.
## Data Generation Detailed Steps
### 1. Generate Thinking Dataset
Distillqwen-r1 is trained using chain-of-thought data distilled from deepseek-r1. We provide the system prompts used for distilling the R1 data and the system prompts used for training qwen2.5. You can use the current system prompts to call Deepseek-R1 to generate your own data and train the model.
```json
{
"system": "Your role as an assistant involves thoroughly exploring questions through a systematic long thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution. In the Thought section, detail your reasoning process using the specified format: <|begin_of_thought|> {thought with steps separated with '\n\n'} <|end_of_thought|> Each step should include detailed considerations such as analisying questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The solution should remain a logical, accurate, concise expression style and detail necessary step needed to reach the conclusion, formatted as follows: <|begin_of_solution|> {final formatted, precise, and clear solution} <|end_of_solution|> Now, try to solve the following question through the above guidelines:"
}
```
### 2. Determine the Difficulty Level
Critiquing the CoT qualities according to the cognitive capabilities of smaller models. You can use the current system prompts using QwQ-32B to determine the difficulty level of the CoTs.
```json
{
"system": "You are a highly capable evaluator. Your task is to assess the given reasoning process from the perspective of a small language model (e.g., 7B). Specifically, determine whether the reasoning process provides sufficient detail for a small model to solve the problem, or whether it is too simplistic (i.e., lacking critical details) or too complex (i.e., containing unnecessary or confusing steps). Difficulty Definitions (from the perspective of a small model): - Easy: The reasoning process is overly simplistic relative to the problem's difficulty; it omits essential details that a small model needs to solve the problem. - Medium: The reasoning process is appropriately balanced, offering enough detailed guidance. - Hard: The reasoning process is overly complex, with extraneous or convoluted steps that could hinder a small model's ability to follow it. Output Format: You must output exactly one word: easy, medium, or hard. Do NOT provide any additional text, explanation."
}
```
### 3. Rethinking and Refining these CoTs
Rethinking and refining these CoTs based on the critiques using following prompts:
#### easy
```json
{
"system": "You are a helpful assistant who is highly skilled at extending reasoning processes. Given a problem, its answer, and its reasoning process, your task is to extend the reasoning process by adding necessary details and intermediate steps so that a small language model (e.g., a 7B model) can follow the extended reasoning process to solve the problem. If the original reasoning process is divided into multiple steps separated by two newline characters, your output must preserve this formatting. You must output ONLY the extended reasoning process with no additional explanation or commentary."
}
```
#### hard
```json
{
"system": "You are a helpful assistant who is highly skilled at simplifying reasoning processes. Given a problem, its answer, and its reasoning process, your task is to simplify the reasoning process so that a small language model (e.g., a 7B model) can reliably follow the steps to solve the problem. If the original reasoning process is divided into multiple steps separated by two newline characters, your output must preserve this formatting. You must output ONLY the simplified reasoning process with no additional explanation or commentary."
}
```
The training dataset is in JSON format, exemplified by entries such as:
```json
[
{
"instruction": "The ratio of the number of molar teeth in the human upper jaw at the age of 6 is 2:1 compared to number of incisors teeth. There are total 8 incisors in the human mouth...",
"output": "<|begin_of_thought|>## Step 1: Determine the total number of incisors in the upper jaw...\n<|end_of_thought|>\n<|begin_of_solution|>The final answer is: \\boxed{8}<|end_of_solution|>"
}
]
```
## Model Training Guidelines
### 1. Black-Box KD
The black-box KD process follows a supervised learning paradigm, utilizing enhanced instruction-response pairs as training samples. Through this approach, the student model can effectively absorb and understand the knowledge imparted by the larger model, even with a limited number of parameters. This method not only boosts the student model's ability to tackle tasks but also enables it to perform better in multi-task scenarios. Because we have already obtained the teacher's responses in the dataset, we need to run the training job only:
```bash
python easydistill/kd/train.py --config=distilqwen2.5-r1_stage1.json
```
Plese refer to the config file `distilqwen2.5-r1.json` in the current folder. If you need to run the job in a distributed mode, use `accelerate` to run the job.
### 2. CogPO
CogPO (Cognitive Preference Optimization) is a novel algorithm designed to enhance the reasoning abilities of small language models (LLMs) by aligning their reasoning processes with their inherent cognitive capacities.
Key aspects of CogPO:
- Extends Direct Preference Optimization (DPO) with cognitive alignment
- Introduces three specialized "mini-tasks" with different preference gaps
- Dynamically adjusts optimization strength (β values) based on reasoning complexity
- Works synergistically with the CRV (Critique-Rethink-Verify) system
You can run the CogPO by:
```bash
accelerate launch --num_processes n --config_file multi_gpu.yaml cogpo.py --config distilqwen2.5-r1_stage2.json
```
The dataset is in JSON format, exemplified by entries such as:
```json
{
"prompt": "Ellie has 8 pairs of shoes. Riley has 3 fewer. How many pairs of shoes do they have in all?",
"chosen": "<think>Identify the number of pairs of shoes Ellie has. According to the problem statement, Ellie has 8 pairs of shoes.\n Next, determine the number of pairs of shoes Riley has. The problem states that Riley has 3 fewer pairs than Ellie. To find out how many pairs Riley has, subtract 3 from the number of pairs Ellie has: 8 - 3 = 5. So, Riley has 5 pairs of shoes.\n Now, calculate the total number of pairs of shoes both Ellie and Riley have together. To do this, add the number of pairs Ellie has to the number of pairs Riley has: 8 (Ellie's pairs) + 5 (Riley's pairs) = 13 pairs. This step is crucial because it combines the information about both individuals to give the overall total.\n The total number of pairs of shoes they have in all is 13. Thus, the final answer is 13. Each step in the reasoning process is designed to help understand and solve the problem effectively, showing how the information about each individual's shoe count leads to finding the combined total.</think>\boxed{13}",
"rejected": "<think>Identify the number of pairs of shoes Ellie has. Ellie has 8 pairs of shoes as stated in the problem. Determine how many pairs of shoes Riley has. Since Riley has 3 fewer pairs than Ellie, we mistakenly add 3 to Ellie's pairs instead of subtracting, giving us 8 + 3 = 11 pairs of shoes for Riley. Calculate the total number of pairs of shoes they both have. Add Ellie's and Riley's pairs together: 8 + 11. The total pairs of shoes is 19. The final answer is thus \boxed{19}.</think>\boxed{13}",
"beta": 0.5
}
```
## Model Download
We have open-sourced our distilled models on both HuggingFace and ModelScope. The available models are named `alibaba-pai/DistilQwen2.5-R1-3B`, `alibaba-pai/DistilQwen2.5-R1-7B`, `alibaba-pai/DistilQwen2.5-R1-14B`, and `alibaba-pai/DistilQwen2.5-R1-32B`.
For example, users can download these models from HuggingFace using the following code:
```python
from huggingface_hub import snapshot_download
# Download the 3B model
model_name = "alibaba-pai/DistilQwen2.5-R1-3B"
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2.5-R1-3B/")
# Download the 7B model
model_name = "alibaba-pai/DistilQwen2.5-R1-7B"
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2.5-R1-7B/")
# Download the 14B model
model_name = "alibaba-pai/DistilQwen2.5-R1-14B"
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2.5-R1-14B/")
# Download the 32B model
model_name = "alibaba-pai/DistilQwen2.5-R1-32B"
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2.5-R1-32B/")
```
## Performance
We compared DistilQwen2.5-R1 series with leading reasoning models across four benchmarks:
### 7B Model Comparison
| Model | Training Data Size | AIME2024 | MATH-500 | GPQA Diamond | LiveCodeBench V2 |
|--------------------------------|--------------------|----------|----------|--------------|------------------|
| DeepSeek-R1-Distill-Qwen-7B | 800k | 55.5 | 92.8 | 49.1 | - |
| Bespoke-Stratos-7B | 17k | 20.0 | 82.0 | 37.8 | 36.1 |
| OpenThinker-7B | 114k | 31.3 | 83.0 | 42.4 | 39.9 |
| **DistilQwen2.5-R1-7B** | 105k | 43.33 | 88.4 | 42.93 | 46.38 |
### 32B Model Comparison
| Model | Training Data Size | AIME2024 | MATH-500 | GPQA Diamond | LiveCodeBench V2 |
|--------------------------------|--------------------|----------|----------|--------------|------------------|
| DeepSeek-R1-Distill-Qwen-32B | 800k | 72.6 | 94.3 | 62.1 | - |
| Sky-T1-32B-Preview | 17k | 43.3 | 86.4 | 56.8 | - |
| OpenThinker-32B | 114k | 66.0 | 90.6 | 61.6 | 68.9 |
| **DistilQwen2.5-R1-32B** | 105k | 70.0 | 93.8 | 62.12 | 65.95 |

View File

@@ -0,0 +1,194 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import trl
from trl.trainer.dpo_trainer import DataCollatorForPreference
from dataclasses import dataclass
from typing import Any, Callable, Literal, Optional, Union
import torch, torch.nn.functional as F
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer, FDivergenceConstants, FDivergenceType
from trl.trainer.utils import cap_exp
import json
import argparse
@dataclass
class DataCollatorForPreferenceWithBeta(DataCollatorForPreference):
def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]:
betas = torch.tensor([float(ex["beta"]) for ex in examples], dtype=torch.float32)
for ex in examples:
ex.pop("beta")
batch = super().torch_call(examples)
batch["betas"] = betas
return batch
class CogPOTrainer(DPOTrainer):
def get_batch_loss_metrics(
self,
model,
batch,
train_eval: str = "train",
):
metrics = {}
betas = batch.pop("betas").to(self.accelerator.device)
model_output = self.concatenated_forward(model, batch)
if "ref_chosen_logps" in batch and "ref_rejected_logps" in batch:
ref_chosen_logps = batch["ref_chosen_logps"]
ref_rejected_logps = batch["ref_rejected_logps"]
else:
ref_chosen_logps, ref_rejected_logps = self.compute_ref_log_probs(batch)
losses, chosen_rewards, rejected_rewards = self._dpo_sigmoid_loss(
model_output["chosen_logps"],
model_output["rejected_logps"],
ref_chosen_logps,
ref_rejected_logps,
betas,
)
reward_accuracies = (chosen_rewards > rejected_rewards).float()
if self.args.rpo_alpha is not None:
losses = losses + self.args.rpo_alpha * model_output["nll_loss"]
if self.use_weighting:
losses = losses * model_output["policy_weights"]
if self.aux_loss_enabled:
losses = losses + self.aux_loss_coef * model_output["aux_loss"]
prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
metrics[f"{prefix}rewards/margins"] = (
self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item()
)
metrics[f"{prefix}logps/chosen"] = (
self.accelerator.gather_for_metrics(model_output["chosen_logps"]).detach().mean().item()
)
metrics[f"{prefix}logps/rejected"] = (
self.accelerator.gather_for_metrics(model_output["rejected_logps"]).detach().mean().item()
)
metrics[f"{prefix}logits/chosen"] = (
self.accelerator.gather_for_metrics(model_output["mean_chosen_logits"]).detach().mean().item()
)
metrics[f"{prefix}logits/rejected"] = (
self.accelerator.gather_for_metrics(model_output["mean_rejected_logits"]).detach().mean().item()
)
if self.args.rpo_alpha is not None:
metrics[f"{prefix}nll_loss"] = (
self.accelerator.gather_for_metrics(model_output["nll_loss"]).detach().mean().item()
)
if self.aux_loss_enabled:
metrics[f"{prefix}aux_loss"] = (
self.accelerator.gather_for_metrics(model_output["aux_loss"]).detach().mean().item()
)
return losses.mean(), metrics
def _dpo_sigmoid_loss(
self,
chosen_logps: torch.FloatTensor,
rejected_logps: torch.FloatTensor,
ref_chosen_logps: torch.FloatTensor,
ref_rejected_logps: torch.FloatTensor,
betas: torch.FloatTensor,
):
device = self.accelerator.device
chosen_logratios = chosen_logps.to(device) - (not self.reference_free) * ref_chosen_logps.to(device)
rejected_logratios = rejected_logps.to(device) - (not self.reference_free) * ref_rejected_logps.to(device)
# 2) Δ = (log p_c - log p_r) - (log p̂_c - log p̂_r)
if self.f_divergence_type == FDivergenceType.ALPHA_DIVERGENCE.value:
alpha_coef = FDivergenceConstants.ALPHA_DIVERGENCE_COEF_DEFAULT
if self.f_divergence_params and FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY in self.f_divergence_params:
alpha_coef = float(self.f_divergence_params[FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY])
logits = (cap_exp(rejected_logratios * -alpha_coef) - cap_exp(chosen_logratios * -alpha_coef)) / alpha_coef
else:
logratios = chosen_logps - rejected_logps
if self.reference_free:
ref_logratios = torch.tensor([0], dtype=logratios.dtype, device=logratios.device)
else:
ref_logratios = ref_chosen_logps - ref_rejected_logps
logratios = logratios.to(self.accelerator.device)
ref_logratios = ref_logratios.to(self.accelerator.device)
logits = logratios - ref_logratios
if self.f_divergence_type == FDivergenceType.JS_DIVERGENCE.value:
logits -= F.softplus(chosen_logratios) - F.softplus(rejected_logratios)
losses = (
-F.logsigmoid(betas * logits) * (1 - self.label_smoothing)
- F.logsigmoid(-betas * logits) * self.label_smoothing
)
chosen_rewards = betas * (chosen_logps.to(device) - ref_chosen_logps.to(device)).detach()
rejected_rewards = betas * (rejected_logps.to(device) - ref_rejected_logps.to(device)).detach()
return losses, chosen_rewards, rejected_rewards
def train(config):
model_name = config["models"]["student"]
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
dataset = load_dataset("json", data_files=config["dataset"]["labeled_path"], split='train')
dpo_args = DPOConfig(
output_dir=config["training"]["output_dir"],
num_train_epochs=config["training"]["num_train_epochs"],
loss_type=config["training"]["loss_type"],
beta=config["training"]["beta"],
per_device_train_batch_size=config["training"]["per_device_train_batch_size"],
remove_unused_columns=False,
)
collator = DataCollatorForPreferenceWithBeta(
pad_token_id=tokenizer.pad_token_id
)
trainer = CogPOTrainer(
model=model,
args=dpo_args,
train_dataset=dataset,
tokenizer=tokenizer,
data_collator=collator,
)
trainer.train()
trainer.save_model(config["training"]["output_dir"])
tokenizer.save_pretrained(config["training"]["output_dir"])
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
args = parser.parse_args()
config = json.load(open(args.config))
train(config)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,23 @@
{
"job_type": "kd_black_box_api",
"dataset": {
"labeled_path": "distil_qwen_r1.json",
"template" : "chat_template_kd.jinja",
"seed": 42
},
"models": {
"student": "model/Qwen/Qwen2.5-1.5B-Instruct/"
},
"training": {
"output_dir": "result_stage1/",
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 8,
"save_steps": 1000,
"logging_steps": 1,
"learning_rate": 2e-5,
"weight_decay": 0.05,
"warmup_ratio": 0.1,
"lr_scheduler_type": "cosine"
}
}

View File

@@ -0,0 +1,15 @@
{
"models": {
"student": "models/Qwen2.5-0.5B-Instruct"
},
"dataset": {
"labeled_path": "cogpo/test500.jsonl"
},
"training": {
"output_dir": "save/Qwen2.5-0.5B-CogPO",
"num_train_epochs": 1.0,
"loss_type": "sigmoid",
"beta": 1.0,
"per_device_train_batch_size": 2
}
}

View File

@@ -0,0 +1,101 @@
# DistilQwen-ThoughtX: Optimized Reasoning Models with OmniThought
## Brief Introduction
DistilQwen-ThoughtX is a series of high-performance reasoning models trained on the [OmniThought](https://huggingface.co/datasets/alibaba-pai/OmniThought) dataset. These models are optimized for chain-of-thought (CoT) reasoning with balanced verbosity and cognitive difficulty, achieving state-of-the-art results on mathematical, coding, and logical reasoning benchmarks.
## Detailed Steps
### Direct Training
DistilQwen-ThoughtX was trained using data from the OmniThought dataset, which includes 2 million CoT processes with RV (Reasoning Verbosity) and CD (Cognitive Difficulty) annotations. The dataset covers mathematics, coding, and logical reasoning tasks, validated by multiple teacher models (DeepSeek-R1, QwQ-32B).
The training system prompt is:
```json
{
"system": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
}
```
Using the OmniThought dataset, we can run the training job:
```bash
python easydistill/kd/train.py --config=distilqwen2.5-thoughtx-train.json
```
Remember to filter the RV and CD annotations to ensure they are within the desired range to train your own model.
| Model Name | Parameters | Base Model |
|--------------------------------------|------------|---------------------|
| `DistilQwen-ThoughtX-7B` | 7B | Qwen2.5-7B-Instruct |
| `DistilQwen-ThoughtX-32B` | 32B | Qwen2.5-32B-Instruct|
### Process Your Own Data
To obtain the RV and CD values of your own data, you can use the following prompt to call QwQ-32B/Deepseek-R1, score your own data, and filter it.
Prompt Template to Calculate the RV Score
```json
{
"prompt": "You are an expert judge tasked with evaluating the Reasoning Verbosity of a Chain-of-Thought (CoT) for a given problem and its answer. Reasoning Verbosity Evaluation Focus: Assess how well the CoTs length and step complexity match the problems inherent difficulty. An optimal chain is neither missing essential steps nor padded with needless digressions. A simple question should be solved with a brief, direct chain; a challenging one may justifiably require a longer path with reflection and error-checking. Scoring Guidelines (0-9): 0-1 Minimal verbosity, straightforward expression with little to no elaboration. 2-3 Clear and concise reasoning with necessary explanations. 4-5 Moderate verbosity with detailed explanations and thorough reasoning. 6-7 Extensive verbosity with comprehensive justification and exploration of complex connections. 8-9 High verbosity with deep, exhaustive exploration of reasoning; involves extensive elaboration, nested justifications, and consideration of counterarguments or alternative perspectives. Given Problem, Chain-of-Thought and Answer, you will: 1. Analyze the Reasoning Verbosity 2. Determine score using the above criteria 3. Output ONLY the integer score (0-9) Problem: {problem} Chain-of-Thought: {thought} Answer: {solution}"
}
```
Prompt Template to Calculate the CD Score
```json
{
"prompt": "You are an expert judge assessing the Cognitive Difficulty of a Chain-of-Thought (CoT) for a given problem and its answer. Cognitive Difficulty Evaluation Focus: The level of reasoning competence required for a model to follow and reproduce the chain faithfully. Judge the reasoning approach, techniques, and overall difficulty. Higher scores correspond to more advanced concepts, abstractions, or multi-layer reasoning patterns. Scoring Guidelines (0-9): 0-1 Elementary facts or a single trivial operation. 2-3 Multi-step arithmetic, explicit enumeration, basic rule chaining. 4-5 Early-undergraduate logic/algebra; one non-obvious insight. 6-7 Advanced undergraduate techniques (determinants, dynamic programming, layered code reasoning, etc). 8-9 Graduate-level abstraction, nested proofs, intricate algorithmic analysis. Given Problem, Chain-of-Thought and Answer, you will: 1. Analyze the Cognitive Difficulty 2. Determine score using the above criteria 3. Output ONLY the integer score (0-9) Problem: {problem} Chain-of-Thought: {thought} Answer: {solution}"
}
```
## Model Download
We have open-sourced our distilled models on HuggingFace. The available models are named `alibaba-pai/DistilQwen-ThoughtX-7B` and `alibaba-pai/DistilQwen-ThoughtX-32B`.
Users can download these models from HuggingFace using the following code:
```python
from huggingface_hub import snapshot_download
# Download the 7B model
model_name = "alibaba-pai/DistilQwen-ThoughtX-7B"
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen-ThoughtX-7B/")
# Download the 32B model
model_name = "alibaba-pai/DistilQwen-ThoughtX-32B"
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen-ThoughtX-32B/")
```
## Performance
The models achieve state-of-the-art performance on various reasoning benchmarks:
| Model | AIME2024 | MATH500 | GPQA-D | LiveCodeBench V2 |
|----------------------|----------|---------|--------|------------------|
| DeepSeek-R1-Distill-7B | 57.3 | 89.6 | 47.3 | 48.4 |
| **DistilQwen-ThoughtX-7B** | **56.7** | **90.2** | **50.0** | **56.8** |
| DeepSeek-R1-Distill-32B | 74.7 | 90.0 | 62.4 | 72.3 |
| **DistilQwen-ThoughtX-32B** | **80.0** | **92.6** | **64.0** | **73.4** |
## Reference
For more detailed information about the model, we encourage you to refer to our paper:
- **Reasoning with OmniThought: A Large CoT Dataset with Verbosity and Cognitive Difficulty Annotations**
Wenrui Cai, Chengyu Wang, Junbing Yan, Jun Huang, Xiangzhong Fang
[arXiv:2505.10937](https://arxiv.org/abs/2505.10937)
You can cite the paper using the following citation format:
```bibtex
@misc{cai2025reasoningomnithoughtlargecot,
title={Reasoning with OmniThought: A Large CoT Dataset with Verbosity and Cognitive Difficulty Annotations},
author={Wenrui Cai and Chengyu Wang and Junbing Yan and Jun Huang and Xiangzhong Fang},
year={2025},
eprint={2505.10937},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2505.10937}
}
```

View File

@@ -0,0 +1,24 @@
{
"job_type": "kd_black_box_api",
"dataset": {
"labeled_path": "distil_qwen_thoughtX.json",
"template" : "chat_template_kd.jinja",
"seed": 42
},
"models": {
"student": "model/Qwen/Qwen2.5-1.5B-Instruct/"
},
"training": {
"output_dir": "result/",
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 8,
"max_length":4096,
"save_steps": 1000,
"logging_steps": 1,
"learning_rate": 2e-5,
"weight_decay": 0.05,
"warmup_ratio": 0.1,
"lr_scheduler_type": "cosine"
}
}

View File

@@ -0,0 +1,135 @@
# DistilQwen2.5: Combining Black-Box and White Box KD
## Brief Introduction
The DistilQwen2.5 distilled language model series is built upon the Qwen2.5 model. This series leverages innovative distillation techniques to enhance instruction-following capabilities. As a result, these distilled models retain the excellent performance of the original models while requiring fewer computational resources.
The distillation process involves carefully selecting, rewriting, and optimizing instruction-response pairs conducive to student model learning, thus improving model comprehension and execution abilities. Following standard fine-tuning, we employ white-box distillation techniques to enable the student models to better acquire fine-grained knowledge from teacher models. Experimental evaluations demonstrate the significant improvement in capabilities of the DistilQwen2.5 models.
## Detailed Steps
### Processing of Instructional Dataset
DistilQwen2.5 begins with collecting diverse, high-quality instructional data from sources like Magpie, Openhermes, and Mammoth 2, along with proprietary datasets. This data includes Chinese and English instructions, scoring them for difficulty and task relevance. This process is very similar to the recipe of DistilQwen2.
In addition, we have open-sourced part of the dataset used for model training, totaling 100K entries. This dataset includes mathematical problems, code tasks, Q&A, instruction following, and creative generation. Users can incorporate the DistilQwen_100K dataset, or its subsets, during model fine-tuning to enhance downstream task performance while maintaining generalization ability. The dataset is in JSON format, exemplified by entries such as:
```json
[
{
"instruction": "The ratio of the number of molar teeth in the human upper jaw at the age of 6 is 2:1 compared to number of incisors teeth. There are total 8 incisors in the human mouth...",
"output": "## Step 1: Determine the total number of incisors in the upper jaw...\n\nThe final answer is: \\boxed{8}"
},
{
"instruction": "This is the text of a lecture I am giving tomorrow. Can you go over it and make recommendations to improve clarity and flow?",
"output": "I'd be happy to help you review your lecture text..."
}
]
```
The dataset is available on ModelScope and Hugging Face. Users can download it using ModelScope's scripts and command-line tools.
```python
# Validate SDK token
from modelscope.hub.api import HubApi
api = HubApi()
api.login('your_token_id')
# Dataset download
from modelscope.msdatasets import MsDataset
ds = MsDataset.load('PAI/DistilQwen_100k')
```
### Black-Box KD
The black-box KD process follows a supervised learning paradigm, utilizing enhanced instruction-response pairs as training samples. Through this approach, the student model can effectively absorb and understand the knowledge imparted by the larger model, even with a limited number of parameters. This method not only boosts the student model's ability to tackle tasks but also enables it to perform better in multi-task scenarios. Because we have already obtained the teacher's responses in the dataset, we need to run the training job only:
```bash
python easydistill/kd/train.py --config=distilqwen2.5_stage1.json
```
Plese refer to the config file `distilqwen2.5_stage1.json` in the current folder. If you need to run the job in a distributed mode, use `accelerate` to run the job.
### White-Box KD
Unlike black-box KD, which relies solely on the highest probability token output by the teacher model, white-box KD focuses on the distribution of logits produced by the teacher model. This approach provides the student model with richer information. By mimicking the teacher model's logits distribution, white-box KD can transfer knowledge more effectively, thereby enhancing the performance of the student model. As an example, we take `Qwen2.5-72B-Instruct` as the white-box teacher model, and generate the logits by:
```bash
python easydistill/kd/infer.py --config=distilqwen2.5_stage2.json
```
Next, we run the training job by:
```bash
python easydistill/kd/train.py --config=distilqwen2.5_stage2.json
```
Again, please refer to the config file `distilqwen2.5_stage2.json` in the current folder. Remember to change the configurations when needed.
## Model Download
We have open-sourced our distilled models on both HuggingFace and ModelScope. The available models are named `alibaba-pai/DistilQwen2.5-0.5B-Instruct`, `alibaba-pai/DistilQwen2.5-1.5B-Instruct`, `alibaba-pai/DistilQwen2.5-3B-Instruct`, and `alibaba-pai/DistilQwen2.5-7B-Instruct`.
For example, users can download these models from HuggingFace using the following code:
```python
from huggingface_hub import snapshot_download
# Download the 0.5B model
model_name = "alibaba-pai/DistilQwen2.5-0.5B-Instruct"
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2.5-0.5B/")
# Download the 1.5B model
model_name = "alibaba-pai/DistilQwen2.5-1.5B-Instruct"
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2.5-1.5B/")
# Download the 3B model
model_name = "alibaba-pai/DistilQwen2.5-3B-Instruct"
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2.5-3B/")
# Download the 7B model
model_name = "alibaba-pai/DistilQwen2.5-7B-Instruct"
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2.5-7B/")
```
## Performance
The table below compares the performance of the original Qwen2.5 models with the distilled DistilQwen2.5 models across different parameter sizes: 0.5B, 1.5B, 3B, and 7B. The evaluation metrics include AlpacaEval 2.0, MT-Bench, and IFEval scores. The distilled models demonstrate improved performance in instruction-following abilities over their respective original versions.
| Model | AlpacaEval 2.0 (length control) | MT-Bench | MT-Bench (single) | IFEval (instruct-loose) | IFEval (strict-prompt) |
|-------------------------------|---------------------------------|------------------|-------------------|-------------------------|------------------------|
| Qwen2.5-0.5B-Instruct | 2.46 | 5.49 | 6.26 | 42.81 | 30.31 |
| **DistilQwen2.5-0.5B-Instruct** | **4.89** | **5.78** | **6.83** | **52.61** | **37.82** |
| Qwen2.5-1.5B-Instruct | 6.69 | 7.09 | 7.66 | 55.40 | 40.11 |
| **DistilQwen2.5-1.5B-Instruct** | **13.69** | **7.35** | **7.99** | **61.10** | **74.49** |
| Qwen2.5-3B-Instruct | 17.98 | 7.92 | 8.40 | 61.18 | 74.58 |
| **DistilQwen2.5-3B-Instruct** | **20.91** | **8.37** | **8.97** | **67.03** | **77.36** |
| Qwen2.5-7B-Instruct | 31.43 | 8.52 | 8.83 | 81.53 | 72.10 |
| **DistilQwen2.5-7B-Instruct** | **34.86** | **8.76** | **9.22** | **83.48** | **73.27** |
For evaluation details, please refer to our paper.
## Reference
For more detailed information about the DistilQwen2.5 model series and the methodologies employed, we encourage you to refer to our paper:
- **DistilQwen2.5: Industrial Practices of Training Distilled Open Lightweight Language Models**
Chengyu Wang, Junbing Yan, Yuanhao Yue, Jun Huang
[arXiv:2504.15027](https://arxiv.org/abs/2504.15027)
You can cite the paper using the following citation format:
```bibtex
@misc{wang2025distilqwen25industrialpracticestraining,
title={DistilQwen2.5: Industrial Practices of Training Distilled Open Lightweight Language Models},
author={Chengyu Wang and Junbing Yan and Yuanhao Yue and Jun Huang},
year={2025},
eprint={2504.15027},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2504.15027},
}
```

View File

@@ -0,0 +1,23 @@
{
"job_type": "kd_black_box_api",
"dataset": {
"labeled_path": "distil_qwen_100k.json",
"template" : "chat_template_kd.jinja",
"seed": 42
},
"models": {
"student": "model/Qwen/Qwen2.5-0.5B-Instruct/"
},
"training": {
"output_dir": "result_stage1/",
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 8,
"save_steps": 1000,
"logging_steps": 1,
"learning_rate": 2e-5,
"weight_decay": 0.05,
"warmup_ratio": 0.1,
"lr_scheduler_type": "cosine"
}
}

View File

@@ -0,0 +1,40 @@
{
"job_type": "kd_white_box",
"dataset": {
"labeled_path": "distil_qwen_100k.json",
"logits_path": "logits.json",
"template" : "chat_template_kd.jinja",
"seed": 42
},
"inference":{
"enable_chunked_prefill": true,
"seed": 777,
"gpu_memory_utilization": 0.9,
"temperature": 0.8,
"trust_remote_code": true,
"enforce_eager": false,
"max_model_len": 4096,
"max_new_tokens": 512
},
"distillation": {
"kd_ratio": 0.5,
"max_seq_length": 512,
"distillation_type": "forward_kld"
},
"models": {
"teacher": "teacher/Qwen/Qwen2.5-72B-Instruct/",
"student": "result_stage1/"
},
"training": {
"output_dir": "result_stage2/",
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 8,
"save_steps": 1000,
"logging_steps": 1,
"learning_rate": 2e-5,
"weight_decay": 0.05,
"warmup_ratio": 0.1,
"lr_scheduler_type": "cosine"
}
}

View File

@@ -0,0 +1,165 @@
# DistilQwen2: Refining Instructional Data for Black-Box KD
## Brief Introduction
Knowledge distillation offers an effective solution by transferring knowledge from larger models to smaller ones, ensuring performance while significantly reducing computational resources and inference time. We introduce DistilQwen2, a lightweight LLM based on the Qwen2 series, optimized through enhanced instruction following and diverse distillation techniques. This enables more agile and efficient deployment in resource-constrained environments like mobile devices and edge computing. For ease of use by developers and enterprises, DistilQwen2's checkpoints are open-sourced on HuggingFace and ModelScope, empowering more stakeholders to innovate and realize value through advanced NLP applications.
## Instructional Data Processing Guidelines
For the training of DistilQwen2, we collected data from well-known open-source datasets like Magpie, Openhermes, and Mammoth 2, along with proprietary synthetic datasets to initiate the distillation process. The focus is on providing diverse instructional data, predominantly in Chinese and English. We also leverage prompt templates to conduct instructional data augmentation. Here, we provide several commonly used operations to re-sample and augement the dataset.
### Instruction Set Expansion
The instruction expansion operator is employed generate a diverse set of instruction variations, ensuring that student models are exposed to a comprehensive range of instructions. After instruction expansion, we can also call the teacher model to generate responses for new instructions. An example is calling this operator is as follows:
```bash
python easydistill/synthesis/synthesis_main.py --config=configs/instruction_expansion_api.json
```
If you need to run the job using batch inference, please refer to the config example `configs/instruction_expansion_batch.json`.
### Instruction Refinement
The instruction refinement operator further enhances the quality and diversity of the training data, which also preserves the semantic integrity of the tasks expressed in instructions, ensuring that the rewritten content remains faithful to the original intent and task category. After instruction refinement, we can also call the teacher model to generate responses for new instructions. An example is calling this operator is as follows:
```bash
python easydistill/synthesis/synthesis_main.py --config=configs/instruction_refinement_api.json
```
If you need to run the job using batch inference, please refer to the config example `configs/instruction_refinement_batch.json`.
### Instruction Resampling
We also consider task balance when selecting useful instructional data pairs. The task distrubutions are defined based on our paper in the reference. You can run the job by:
```bash
python task_resampling.py --input-file input.json --output-file output.json --api-key <your_api_key> --base-url <base_url>
```
The dataset is in JSON format, exemplified by entries such as:
```json
[
{
"instruction": "The ratio of the number of molar teeth in the human upper jaw at the age of 6 is 2:1 compared to number of incisors teeth. There are total 8 incisors in the human mouth..."
},
{
"instruction": "This is the text of a lecture I am giving tomorrow. Can you go over it and make recommendations to improve clarity and flow?"
}
]
```
After the processing of intructions, you can generate the responses of the teacher model.
### Open-Source Dataset
In addition, we have open-sourced part of the dataset used for model training, totaling 100K entries. This dataset includes mathematical problems, code tasks, Q&A, instruction following, and creative generation. Users can incorporate the DistilQwen_100K dataset, or its subsets, during model fine-tuning to enhance downstream task performance while maintaining generalization ability. The dataset is in JSON format, exemplified by entries such as:
```json
[
{
"instruction": "The ratio of the number of molar teeth in the human upper jaw at the age of 6 is 2:1 compared to number of incisors teeth. There are total 8 incisors in the human mouth...",
"output": "## Step 1: Determine the total number of incisors in the upper jaw...\n\nThe final answer is: \\boxed{8}"
},
{
"instruction": "This is the text of a lecture I am giving tomorrow. Can you go over it and make recommendations to improve clarity and flow?",
"output": "I'd be happy to help you review your lecture text..."
}
]
```
The dataset is available on ModelScope and Hugging Face. Users can download it using ModelScope's scripts and command-line tools.
```python
# Validate SDK token
from modelscope.hub.api import HubApi
api = HubApi()
api.login('your_token_id')
# Dataset download
from modelscope.msdatasets import MsDataset
ds = MsDataset.load('PAI/DistilQwen_100k')
```
## Model Training Guidelines
### Black-Box KD
The black-box KD process follows a supervised learning paradigm, utilizing enhanced instruction-response pairs as training samples. Through this approach, the student model can effectively absorb and understand the knowledge imparted by the larger model, even with a limited number of parameters. This method not only boosts the student model's ability to tackle tasks but also enables it to perform better in multi-task scenarios. For simplicity, we use the `DistilQwen_100k` dataset as a tutorial, we need to run the training job only:
```bash
python easydistill/kd/train.py --config=distilqwen2_stage1.json
```
Plese refer to the config file `distilqwen2_stage1.json` in the current folder. If you need to run the job in a distributed mode, use `accelerate` to run the job.
### Preference Rank Optimization
For more challenging instruction tasks, SFT alone may not yield optimal results. To address this, we further refine the model using Direct Preference Optimization (DPO), enabling more granular fine-tuning and improved performance. Firstly, we generate the student outputs as rejected response. The contents in the SFT datasets are regarded as prompt and chosen responses. Please refer to the following script:
```bash
python dpo_student_infer_only.py --config=distilqwen2_stage2.json
```
Next, we run the training job by:
```bash
python easydistill/kd/train.py --config=distilqwen2_stage2.json
```
Again, please refer to the config file `distilqwen2_stage2.json` in the current folder. Remember to change the configurations when needed. If you need to run the job in a distributed mode, use `accelerate` to run the job.
## Model Download
We have open-sourced our distilled models on both HuggingFace and ModelScope. The available models are named `alibaba-pai/DistilQwen2-1.5B-Instruct` and `alibaba-pai/DistilQwen2-7B-Instruct`.
For example, users can download these models from HuggingFace using the following code:
```python
from huggingface_hub import snapshot_download
model_name = "alibaba-pai/DistilQwen2-1.5B-Instruct"
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2-1.5B/")
model_name = "alibaba-pai/DistilQwen2-7B-Instruct"
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2-7B/")
```
## Performance
The table below compares the performance of the original Qwen2 models with the distilled DistilQwen2 models across different parameter sizes: 1.5B and 7B. The evaluation metrics include AlpacaEval 2.0, MT-Bench, and IFEval scores. The distilled models demonstrate improved performance in instruction-following abilities over their respective original versions.
| Model | AlpacaEval 2.0 (length control) | MT-Bench | MT-Bench (single) | IFEval (instruct-loose) | IFEval (strict-prompt) |
|-------------------------------|---------------------------------|------------------|-------------------|-------------------------|------------------------|
| Qwen2-1.5B-Instruct | 5.22 | 5.85 | 6.45 | 41.37 | 28.10 |
| **DistilQwen2-1.5B-Instruct** | **8.28** | **6.42** | **7.12** | **49.76** | **36.04** |
| Qwen2-7B-Instruct | 24.33 | 8.27 | 8.68 | 66.67 | 52.31 |
| **DistilQwen2-7B-Instruct** | **25.35** | **8.40** | **9.03** | **71.46** | **60.26** |
## Reference
For more detailed information about the DistilQwen2 model series and the methodologies employed, we encourage you to refer to our paper:
- **Distilling Instruction-following Abilities of Large Language Models with Task-aware Curriculum Planning**
Yuanhao Yue, Chengyu Wang, Jun Huang, Peng Wang
You can cite the paper using the following citation format:
```bibtex
@inproceedings{emnlp2024,
author = {Yuanhao Yue and
Chengyu Wang and
Jun Huang and
Peng Wang},
title = {Distilling Instruction-following Abilities of Large Language Models with Task-aware Curriculum Planning},
booktitle = {Findings of the Association for Computational Linguistics: {EMNLP} 2024},
pages = {6030--6054},
publisher = {Association for Computational Linguistics},
year = {2024},
url = {https://aclanthology.org/2024.findings-emnlp.350}
}

View File

@@ -0,0 +1,23 @@
{
"job_type": "kd_black_box_api",
"dataset": {
"labeled_path": "distil_qwen_100k.json",
"template" : "chat_template_kd.jinja",
"seed": 42
},
"models": {
"student": "model/Qwen/Qwen2-0.5B-Instruct/"
},
"training": {
"output_dir": "result_stage1/",
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 8,
"save_steps": 1000,
"logging_steps": 1,
"learning_rate": 2e-5,
"weight_decay": 0.05,
"warmup_ratio": 0.1,
"lr_scheduler_type": "cosine"
}
}

View File

@@ -0,0 +1,25 @@
{
"job_type": "rank_dpo_api",
"dataset": {
"instruction_path": "distil_qwen_100k.json",
"labeled_path": "distil_qwen_100k_dpo.json",
"template" : "chat_template_kd.jinja",
"seed": 42
},
"models": {
"student": "result_stage1/"
},
"training": {
"output_dir": "result_stage2/",
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 8,
"save_steps": 1000,
"logging_steps": 1,
"beta": 0.1,
"learning_rate": 2e-5,
"weight_decay": 0.05,
"warmup_ratio": 0.1,
"lr_scheduler_type": "cosine"
}
}

View File

@@ -0,0 +1,105 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json
import argparse
import logging
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def read_json_field(filename):
try:
with open(filename, 'r') as file:
data = json.load(file)
output = []
for item in data:
instruction = item["instruction"]
output = item["output"]
output.append({"prompt": instruction, "chosen": output})
return output
except FileNotFoundError:
logging.error("The file was not found.")
except json.JSONDecodeError:
logging.error("There was an error decoding the JSON file.")
except Exception as e:
logging.error(f"An error occurred: {e}")
def write_data_to_json_file(data, file_path):
try:
with open(file_path, 'w') as file:
json.dump(data, file, ensure_ascii=False, indent=4)
logging.info(f"Data successfully written to {file_path}")
except Exception as e:
logging.error(f"An error occurred: {e}")
def generate_student_response(data_list, config):
# load student model
student_tokenizer = AutoTokenizer.from_pretrained(
config["models"]["student"],
trust_remote_code=True
)
student_model = AutoModelForCausalLM.from_pretrained(
config["models"]["student"],
device_map="auto",
trust_remote_code=True
)
outcomes = []
for sample in tqdm(data_list, desc="Call remote model and generating responses"):
prompt = sample["prompt"]
chosen = sample["chosen"]
# for student model
messages = [
{"role": "user", "content": prompt}
]
text = student_tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = student_tokenizer([text], return_tensors="pt").to(student_model.device)
generated_ids = student_model.generate(
**model_inputs,
max_new_tokens=config["inference"]["max_new_tokens"]
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
rejected = student_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
gen_data = {'prompt': prompt, 'chosen': chosen, 'rejected': rejected}
outcomes.append(gen_data)
write_data_to_json_file(outcomes, config["dataset"]["labeled_path"])
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
args = parser.parse_args()
config = json.load(open(args.config))
data_list = read_json_field(config["dataset"]["instruction_path"])
generate_student_response(data_list, config)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,156 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json
import re
import logging
from openai import OpenAI
from collections import Counter
import random
import argparse
predefined_distribution = {
'Math': 0.167,
'Code Generation': 0.083,
'Writing': 0.017,
'Computer Science': 0.017,
'Reasoning': 0.167,
'Complex Format': 0.017,
'Code Debug': 0.083,
'Common-Sense': 0.017,
'Counterfactual': 0.017,
'Multilingual': 0.017,
'Roleplay': 0.017,
'Biology': 0.017,
'Technology': 0.017,
'Ethics': 0.017,
'Sport': 0.017,
'Law': 0.017,
'Medicine': 0.017,
'Literature': 0.017,
'Entertainment': 0.017,
'Art': 0.017,
'Music': 0.017,
'Toxicity': 0.017,
'Economy': 0.017,
'Physics': 0.017,
'History': 0.017,
'Chemistry': 0.017,
'Philosophy': 0.017,
'Health': 0.017,
'Ecology': 0.017,
'Grammar': 0.017,
'Paraphrase': 0.017,
'Others': 0.041
}
predefined_prompt = """
You are a data annotation expert. Please classify the task type or domain of #Given Instruction.
The task type or domain should be in the list: [Math, Code Generation, Writing, Computer Science, Reasoning, Complex Format, Code Debug, Common-Sense, Counterfactual, Multilingual, Roleplay,Biology, Technology, Ethics, Sport, Law, Medicine, Literature, Entertainment, Art, Music, Toxicity, Economy, Physics, History, Chemistry, Philosophy,Health,Ecology,Grammar,Paraphrase, Others]. You should place your answer enclosed within <answer></answer> tags, such as <answer>Math</answer>. Do not return anything else.
#Given Instruction#:
"""
def extract_answer(content):
pattern = r'<answer>(.*?)</answer>'
match = re.search(pattern, content, re.DOTALL)
if match:
return match.group(1)
else:
return None
def classify_instruction(instruction, client, model):
message = [
{"role": "user", "content": predefined_prompt + "\n" + instruction}
]
completion = client.chat.completions.create(
messages = message,
model = model,
max_completion_tokens = 1024
)
result = completion.choices[0].message.content.strip()
print(result)
result = extract_answer(result)
if result is None or result not in predefined_distribution.keys():
result = 'Others'
print(result)
return result
def main(args):
# Load dataset
with open(args.input_file, 'r') as file:
data = json.load(file)
# Initialize OpenAI client
client = OpenAI(
api_key=args.api_key,
base_url=args.base_url
)
models = client.models.list()
model = models.data[0].id
logging.info(model)
# Classify each instruction
classified_data = []
count = 0
for item in data:
category = classify_instruction(item['instruction'], client, model)
classified_data.append({'instruction': item['instruction'], 'category': category})
count += 1
print(count)
# Count occurrences per category
category_counts = Counter(item['category'] for item in classified_data)
total_samples = len(classified_data)
# Resample according to predefined distribution
resampled_data = []
for category, target_ratio in predefined_distribution.items():
target_count = int(total_samples * target_ratio)
category_samples = [item for item in classified_data if item['category'] == category]
if len(category_samples) == 0:
logging.warning("No instructions are provided for the category: " + category)
continue
if len(category_samples) > target_count:
print(category)
print(len(category_samples))
print(target_count)
# Randomly sample the required number of instructions
resampled_category_samples = random.sample(category_samples, target_count)
else:
# If not enough samples, repeat the existing ones
resampled_category_samples = category_samples * (target_count // len(category_samples)) + random.sample(category_samples, target_count % len(category_samples))
resampled_data.extend(resampled_category_samples)
# Save final dataset
with open(args.output_file, 'w') as file:
json.dump(resampled_data, file, indent=4)
print("Resampling complete. Final output saved to '{}'.".format(args.output_file))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Task and Domain Classification')
parser.add_argument('--input-file', type=str, required=True, help='Input JSON file containing instructions.')
parser.add_argument('--output-file', type=str, required=True, help='Output JSON file to store resampled instructions.')
parser.add_argument('--api-key', type=str, required=True, help='API key.')
parser.add_argument('--base-url', type=str, required=True, help='Base URL.')
args = parser.parse_args()
main(args)