init commit

This commit is contained in:
熊兮
2025-05-27 18:55:46 +08:00
parent 6f52a67249
commit 25caa8a90a
65 changed files with 4893 additions and 1 deletions

201
LICENSE Normal file
View File

@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

13
NOTICE Normal file
View File

@@ -0,0 +1,13 @@
=============================================================
EasyDistill is a open-source tool developed by Alibaba PAI Team
Licensed under the Apache License, Version 2.0
=============================================================
This toolkit implements some modules referring to some repositories under
the same/different open source licenses.
-----------------------------
Apache License, Version 2.0
The HuggingFace Inc. team
The OpenAI Team
The vLLM Team

188
README.md
View File

@@ -1 +1,187 @@
# easydistill
# EasyDistill: Easy Knowledge Distillation for Large Language Models
Introducing **EasyDistill**, a pioneering toolkit on knowledge distillation (KD) for large language models (LLMs). With the growing complexity and size of LLMs, **EasyDistill** offers a versatile and user-friendly platform to streamline the KD process, supporting both black-box and white-box methodologies. It facilitates efficient model training, enabling smaller models to emulate the performance of larger ones without compromising accuracy. **EasyDistill** boasts an extensive range of features, including data synthesis, supervised fine-tuning, ranking optimization, and reinforcement learning, all tailored for various KD scenarios. Designed to accommodate both System 1 (fast, intuitive) and System 2 (slow, analytical) cognitive models, the toolkit is modular and easy to use, with a simple command-line interface guiding users. Beyond academic exploration, **EasyDistill** anchors practical industrial solutions, offering robust distilled models and open-source datasets, while also showcasing seamless integration with Alibaba Clouds AI platform, PAI. Committed to bridging theoretical advancements with practical needs, **EasyDistill** empowers the NLP community, making state-of-the-art KD strategies accessible to researchers and industry practitioners alike.
# Technical Articles
We have a series of technical articles on the functionalities of EasyDistill.
- [人工智能平台 PAI DistilQwen2.5-DS3-0324发布知识蒸馏+快思考=更高效解决推理难题](https://developer.aliyun.com/article/1661734)
- [DistilQwen2.5-R1发布知识蒸馏助推小模型深度思考](https://developer.aliyun.com/article/1659288)
- [DistilQwen2.5发布:通义千问蒸馏小模型再升级](https://developer.aliyun.com/article/1653842)
- [DistilQwen2通义千问大模型的知识蒸馏实践](https://developer.aliyun.com/article/1633882)
- [基于多轮课程学习的大语言模型蒸馏算法TAPIR](https://developer.aliyun.com/article/1635146)
## Overview
![EasyDistill Framework](resources/framework.png)
- **Toolkit Features**: EasyDistill provides versatile functionalities, including data synthesis, supervised fine-tuning, logits distillation, ranking optimization, and reinforcement learning techniques tailored for KD scenarios.
- **Compatibility**: It supports both System 1 (fast, intuitive) and System 2 (slow, analytical) models.
- **User-Friendly**: With its modular design and simple command-line interface, EasyDistill makes experimentation and implementation of KD strategies straightforward.
- **Industrial Integration**: Incorporates KD-based solutions and supports integration with platforms such as Alibaba Clouds Platform for AI (PAI).
## Getting Started
1. Clone the repository:
```bash
git clone <repository-url>
cd EasyDistill
```
2. Install the required dependencies:
```bash
python setup.py install
```
3. Explore the usage of EasyDistill through the command-line interface:
```bash
easydistill --config <config-file-path>
```
The config file expresses the detailed settings of any knowledge distillation jobs that **EasyDistill** supports. A sample of black-box distillation config can be shown below:
```json
{
"job_type": "kd_black_box_local",
"dataset": {
"instruction_path": "train.json",
"labeled_path": "train_labeled.json",
"template" : "chat_template/chat_template_kd.jinja",
"seed": 42
},
"inference":{
"enable_chunked_prefill": true,
"seed": 777,
"gpu_memory_utilization": 0.9,
"temperature": 0.8,
"trust_remote_code": true,
"enforce_eager": false,
"max_model_len": 4096,
"max_new_tokens": 512
},
"models": {
"teacher": "teacher/Qwen/Qwen2.5-7B-Instruct/",
"student": "student/Qwen/Qwen2.5-0.5B-Instruct/"
},
"training": {
"output_dir": "./result/",
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 8,
"save_steps": 1000,
"logging_steps": 1,
"learning_rate": 2e-5,
"weight_decay": 0.05,
"warmup_ratio": 0.1,
"lr_scheduler_type": "cosine"
}
}
```
## DistilQWen Series
The **DistilQwen** models represent a robust suite of distilled language models derived from the **EasyDistill** toolkit. Designed to capitalize on the principles of knowledge distillation, DistilQwen models offer a significant reduction in model size while maintaining high performance, making them ideal for resource-constrained environments. Whether you're aiming for efficient deployment in industrial scenarios or seeking to explore advanced KD methodologies, **DistilQwen** models are poised to meet diverse application needs with agility and precision.
### What's New: Adaptive Thinking Models
The most recent **DistilQwen** series is **DistilQwen-ThoughtX**, which exhibits improved reasoning abilities and generates CoTs with more optimal lengths compared to its predecessors. This model series is developed from the innovative **OmniThought** dataset by utilizing the novel Reasoning Verbosity (RV) and Cognitive Difficulty (CD) scores, which ensure that models receive rich, high-quality training data reflecting optimal CoT output length and difficulty. **DistilQwen-ThoughtX** outperforms other KD models in the open-source community. The performance of **DistilQwen-ThoughtX** is shown below.
| **Model** | **AIME2024** | **MATH500** | **GPQA-D** | **LCB V2** | **Avg.** |
|-----------------------------------------------|--------------|-------------|------------|------------|-----------|
| OpenThinker-7B | 31.3 | 83.0 | 42.4 | 39.9 | 49.1 |
| DeepSeek-R1-Distill-Qwen-7B | **57.3** | _89.6_ | 47.3 | 48.4 | 60.6 |
| OpenThinker2-7B | 50.0 | 88.4 | _49.3_ | _55.6_ | _60.8_ |
| **DistilQwen-ThoughtX-7B** | _56.7_ | **90.2** | **50.0** | **56.8** | **63.4** |
| LIMO-32B | 56.7 | 86.6 | 58.1 | 60.0 | 65.3 |
| OpenThinker-32B | 66.0 | 90.6 | 61.6 | 68.9 | 71.7 |
| DeepSeek-R1-Distill-Qwen-32B | 74.7 | 90.0 | 62.4 | 72.3 | 74.8 |
| OpenThinker2-32B | _76.7_ | _90.8_ | **64.1** | _72.5_ | _76.0_ |
| Light-R1-32B | 74.7 | 90.4 | 62.0 | 56.0 | 70.7 |
| s1.1-32B | 59.3 | 87.4 | 62.0 | 58.7 | 66.8 |
| **DistilQwen-ThoughtX-32B** | **80.0** | **92.6** | _64.0_ | **73.4** | **77.5** |
The **OmniThought** datasets are also publicly available. Refer to the Datasets section.
### System 1 Models
**DistilQwen2** is an enhanced version of the Qwen2 models, equipped with improved instruction-following capabilities for various NLP tasks. We employ GPT-4 and Qwen-max as teacher models to generate high-quality responses, with the balance on the task distributions of input instructions. Following SFT, a rank optimization process is performed using the DPO algorithm to enhance alignment between the student models and the teacher models. **DistilQwen2.5** models are trained using a combination of black-box and white-box KD algorithms. We adhere to the same instruction data processing and black-box SFT procedure as employed in the production of **DistilQwen2**. Subsequently, white-box training is applied to refine the students' acquisition of intricate knowledge from the teacher models, specifically utilizing Qwen2.5-72B-Instruct as open-source teacher models. The performance of **DistilQwen2** and **DistilQwen2.5** is shown below.
| **Model** | **AlpacaEval 2.0 (length control)** | **MT-Bench** | **MT-Bench (single)** | **IFEval (instruct-loose)** | **IFEval (strict-prompt)** |
|------------------------------------|-------------------------------------|--------------|-----------------------|-----------------------------|----------------------------|
| Qwen2.5-0.5B-Instruct | 2.46 | 5.49 | 6.26 | 42.81 | 30.31 |
| **DistilQwen2.5-0.5B-Instruct** | **4.89** | **5.78** | **6.83** | **52.61** | **37.82** |
| Qwen2-1.5B-Instruct | 5.22 | 5.85 | 6.45 | 41.37 | 28.10 |
| **DistilQwen2-1.5B-Instruct** | **8.28** | **6.42** | **7.12** | **49.76** | **36.04** |
| Qwen2.5-1.5B-Instruct | 6.69 | 7.09 | 7.66 | 55.40 | 40.11 |
| **DistilQwen2.5-1.5B-Instruct** | **13.69** | **7.35** | **7.99** | **61.10** | **74.49** |
| Qwen2.5-3B-Instruct | 17.98 | 7.92 | 8.40 | 61.18 | 74.58 |
| **DistilQwen2.5-3B-Instruct** | **20.91** | **8.37** | **8.97** | **67.03** | **77.36** |
| Qwen2-7B-Instruct | 24.33 | 8.27 | 8.68 | 66.67 | 52.31 |
| **DistilQwen2-7B-Instruct** | **25.35** | **8.40** | **9.03** | **71.46** | **60.26** |
| Qwen2.5-7B-Instruct | 31.43 | 8.52 | 8.83 | 81.53 | 72.10 |
| **DistilQwen2.5-7B-Instruct** | **34.86** | **8.76** | **9.22** | **83.48** | **73.27** |
We have released two instruction following datasets to public. Refer to the Datasets section.
### System 2 Models
The **DistilQwen2.5-R1** model series utilizes DeepSeek-R1 as the teacher model. To align the reasoning abilities of smaller distilled models with their intrinsic cognitive capacities, the models are further refined using our CogPO algorithm, which outperforms other training methods. Additionally, we transfer the fast-thinking reasoning capabilities from DeepSeek-V3-0324 to the **DistilQwen2.5-DS3-0324** models. To shorten the reasoning process, the CoT simplification operator are employed to reduce the number of tokens in the training data for **DistilQwen2.5-R1**. Combined with a rewritten dataset comprising DeepSeek-V3-0324's CoT distillation data, we develop the **DistilQwen2.5-DS3-0324** models. The performance of **DistilQwen2.5-R1** and **DistilQwen2.5-DS3-0324** is shown below.
| **Model** | **AIME2024** | **MATH-500** | **GPQA Diamond** | **LiveCodeBench V2** |
|---------------------------------------|--------------|--------------|------------------|----------------------|
| Qwen2.5-3B-Instruct | 6.67 | 62.6 | 32.83 | 11.35 |
| **DistilQwen2.5-DS3-0324-3B** | **16.67** | **70.0** | **34.34** | **18.00** |
| Qwen2.5-7B-Instruct | 10.0 | 73.6 | 33.30 | 30.72 |
| **DistilQwen2.5-7B-R1** | **23.33** | **77.8** | **37.88** | **36.40** |
| **DistilQwen2.5-DS3-0324-7B** | **43.33** | **88.4** | **42.93** | **46.38** |
| Qwen2.5-14B-Instruct | 16.7 | 78.2 | 43.43 | 37.38 |
| **DistilQwen2.5-14B-R1** | **26.67** | **82.6** | **45.45** | **41.49** |
| **DistilQwen2.5-DS3-0324-14B** | **46.67** | **90.8** | **51.52** | **54.40** |
| Qwen2.5-32B-Instruct | 16.67 | 81.4 | 45.50 | 47.36 |
| **DistilQwen2.5-32B-R1** | **46.67** | **87.0** | **48.99** | **55.97** |
| **DistilQwen2.5-DS3-0324-32B** | **70.00** | **93.8** | **62.12** | **65.95** |
All the **DistilQwen** models are publicly available in HuggingFace and ModelScope.
## Released Datasets
We have also released several datasets based on the **EasyDistill** framework.
### Instruction Following Datasets
To assist community developers in avoiding catastrophic forgetting when fine-tuning the **DistilQwen** model, we have open-sourced two datasets: **DistilQwen_100K** and **DistilQwen_1M**. These datasets are intended to provide a solid foundation for model fine-tuning, enhancing adaptability to new tasks while retaining performance on previous tasks. Additionally, it can be utilized to improve instruction-following capabilities when fine-tuning other similar large language models. These datasets cover a range of contents, including mathematics, code, knowledge-based Q&A, instruction following, and creative generation, with a total dataset size of 100K and 1M entries. Users can integrate **DistilQwen_100K** and **DistilQwen_1M**, or its subsets, with their own data during model fine-tuning to ensure excellent downstream task performance while maintaining the model's general capabilities, thus preserving its ability to generalize.
### Chain-of-Thought Reasoning Datasets
**OmniThought** is a large-scale dataset featuring **2 million** Chain-of-Thought (CoT) processes generated and validated by DeepSeek-R1 and QwQ-32B. Each CoT process in **OmniThought** is annotated with novel Reasoning Verbosity (RV) and Cognitive Difficulty (CD) scores, which describe the appropriateness of CoT verbosity and cognitive difficulty level for models to comprehend these reasoning processes. Based on our **OmniThought** dataset, we further train and release a series of high-performing models (**DistilQwen-ThoughtX-7B** and **DistilQwen-ThoughtX-32B**), specifically equipped with stronger reasoning abilities and optimal CoT output length and difficulty level. Refer to `recipes/open_datasets` for details.
All the datasets are publicly available in HuggingFace and ModelScope.
## Reference
We have [an arxiv paper](TBD) for you to cite for the EasyDistill library. Below are other papers related to our project.
- Wenrui Cai, Chengyu Wang, Junbing Yan, Jun Huang, Xiangzhong Fang. Reasoning with OmniThought: A Large CoT Dataset with Verbosity and Cognitive Difficulty Annotations. arXiv preprint
- Wenrui Cai, Chengyu Wang, Junbing Yan, Jun Huang, Xiangzhong Fang. Training Small Reasoning LLMs with Cognitive Preference Alignment. arXiv preprint
- Chengyu Wang, Junbing Yan, Yuanhao Yue, Jun Huang. DistilQwen2.5: Industrial Practices of Training Distilled Open Lightweight Language Models. **ACL 2025**
- Yuanhao Yue, Chengyu Wang, Jun Huang, Peng Wang. Building a Family of Data Augmentation Models for Low-cost LLM Fine-tuning on the Cloud. **COLING 2025**
- Yuanhao Yue, Chengyu Wang, Jun Huang, Peng Wang. Distilling Instruction-following Abilities of Large Language Models with Task-aware Curriculum Planning. **EMNLP 2024**
## License
This project is licensed under the [Apache License (Version 2.0)](LICENSE). This toolkit also contains some code modified from other repos under other open-source licenses. See the [NOTICE](NOTICE) file for more information.
## Join in the Discussion
We welcome community partners to collaborate and contribute to the development, and welcome to join the DingTalk group: 117440002081 to participate in the discussion.

View File

@@ -0,0 +1,19 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_clipping: 1.0
offload_optimizer_device: cpu
offload_param_device: cpu
zero_stage: 2
distributed_type: DEEPSPEED
gpu_ids: all
machine_rank: 0
main_training_function: main
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@@ -0,0 +1,8 @@
{{'<|im_start|>system\nYou are a helpful assistant.<|im_end|>'}}
{{'<|im_start|>user\n' + message['content'] + '<|im_end|>'-}}
{% if add_generation_prompt %}
{{'<|im_start|>assistant'-}}
{% endif %}
{% if add_output %}
{{'<|im_start|>assistant\n' + message['output'] + '<|im_end|>-'}}
{% endif %}

View File

@@ -0,0 +1,14 @@
{
"job_type": "cot_generation_api",
"dataset": {
"input_path": "./cot_question.json",
"output_path": "./cot_question_with_answer.json"
},
"inference":{
"base_url": "ENDPOINT",
"api_key": "TOKEN",
"stream": true,
"prompt" : "Your role as an assistant involves thoroughly exploring questions through a systematic long thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution. In the Thought section, detail your reasoning process using the specified format: <|begin_of_thought|> {thought with steps separated with '\n\n'} <|end_of_thought|> Each step should include detailed considerations such as analisying questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The solution should remain a logical, accurate, concise expression style and detail necessary step needed to reach the conclusion, formatted as follows: <|begin_of_solution|> {final formatted, precise, and clear solution} <|end_of_solution|> Now, try to solve the following question through the above guidelines:",
"max_new_tokens": 1024
}
}

View File

@@ -0,0 +1,22 @@
{
"job_type": "cot_generation_batch",
"dataset": {
"input_path": "./cot_question.json",
"output_path": "./cot_question_with_answer.json",
"template" : "./chat_template/chat_template_kd.jinja"
},
"models": {
"teacher": "teacher/Qwen/Qwen2.5-7B-Instruct/"
},
"inference":{
"prompt" : "Your role as an assistant involves thoroughly exploring questions through a systematic long thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution. In the Thought section, detail your reasoning process using the specified format: <|begin_of_thought|> {thought with steps separated with '\n\n'} <|end_of_thought|> Each step should include detailed considerations such as analisying questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The solution should remain a logical, accurate, concise expression style and detail necessary step needed to reach the conclusion, formatted as follows: <|begin_of_solution|> {final formatted, precise, and clear solution} <|end_of_solution|> Now, try to solve the following question through the above guidelines:",
"enable_chunked_prefill": true,
"seed": 777,
"gpu_memory_utilization": 0.9,
"temperature": 0.8,
"trust_remote_code": true,
"enforce_eager": false,
"max_model_len": 4096,
"max_new_tokens": 512
}
}

View File

@@ -0,0 +1,14 @@
{
"job_type": "cot_long2short_api",
"dataset": {
"input_path": "./raw.json",
"output_path": "./raw_simplified.json"
},
"inference":{
"base_url": "ENDPOINT",
"api_key": "TOKEN",
"stream": true,
"prompt" : "You are a helpful assistant who is highly skilled at simplifying reasoning processes. Given a problem, its answer and its reasoning process, your task is to simplify the reasoning process so that a small language model (e.g., a 7B model) can reliably follow the steps to solve the problem. If the original reasoning process is divided into multiple steps separated by two newline characters (\n\n), your output must preserve this formatting. You must output ONLY the simplified reasoning process with no additional explanation or commentary.",
"max_new_tokens": 1024
}
}

View File

@@ -0,0 +1,22 @@
{
"job_type": "cot_long2short_batch",
"dataset": {
"input_path": "./train.json",
"output_path": "./train_simplified.json",
"template" : "./chat_template/chat_template_kd.jinja"
},
"models": {
"teacher": "teacher/Qwen/Qwen2.5-7B-Instruct/"
},
"inference":{
"prompt" : "You are a helpful assistant who is highly skilled at simplifying reasoning processes. Given a problem, its answer and its reasoning process, your task is to simplify the reasoning process so that a small language model (e.g., a 7B model) can reliably follow the steps to solve the problem. If the original reasoning process is divided into multiple steps separated by two newline characters (\n\n), your output must preserve this formatting. You must output ONLY the simplified reasoning process with no additional explanation or commentary.",
"enable_chunked_prefill": true,
"seed": 777,
"gpu_memory_utilization": 0.9,
"temperature": 0.8,
"trust_remote_code": true,
"enforce_eager": false,
"max_model_len": 4096,
"max_new_tokens": 512
}
}

View File

@@ -0,0 +1,14 @@
{
"job_type": "cot_short2long_api",
"dataset": {
"input_path": "./raw.json",
"output_path": "./raw_extended.json"
},
"inference":{
"base_url": "ENDPOINT",
"api_key": "TOKEN",
"stream": true,
"prompt" : "You are a helpful assistant who is highly skilled at extending reasoning processes. Given a problem ,its answer and its reasoning process, your task is to extend the reasoning process by adding necessary details and intermediate steps, so that a small language model (e.g., a 7B model) can follow the extended reasoning process to solve the problem. If the original reasoning process is divided into multiple steps separated by two newline characters (\\n\\n), your output must preserve this formatting. You must output ONLY the extended reasoning process with no additional explanation or commentary.",
"max_new_tokens": 1024
}
}

View File

@@ -0,0 +1,22 @@
{
"job_type": "cot_short2long_batch",
"dataset": {
"input_path": "./train.json",
"output_path": "./train_extended.json",
"template" : "./chat_template/chat_template_kd.jinja"
},
"models": {
"teacher": "teacher/Qwen/Qwen2.5-7B-Instruct/"
},
"inference":{
"prompt" : "You are a helpful assistant who is highly skilled at extending reasoning processes. Given a problem ,its answer and its reasoning process, your task is to extend the reasoning process by adding necessary details and intermediate steps, so that a small language model (e.g., a 7B model) can follow the extended reasoning process to solve the problem. If the original reasoning process is divided into multiple steps separated by two newline characters (\\n\\n), your output must preserve this formatting. You must output ONLY the extended reasoning process with no additional explanation or commentary.",
"enable_chunked_prefill": true,
"seed": 777,
"gpu_memory_utilization": 0.9,
"temperature": 0.8,
"trust_remote_code": true,
"enforce_eager": false,
"max_model_len": 4096,
"max_new_tokens": 512
}
}

View File

@@ -0,0 +1,16 @@
{
"job_type": "instruction_expansion_api",
"dataset": {
"input_path": "./train.json",
"output_path": "./train_extended.json",
"num_in_context_samples": 3,
"num_output_samples": 10
},
"inference":{
"base_url": "ENDPOINT",
"api_key": "TOKEN",
"stream": true,
"prompt" : "Assume you are a data synthesis expert. Given a few instructions as in-context examples, you should generate a new instruction similar to the examples to support the training of large language models. You should place your answer enclosed within <answer></answer> tags. The examples are as follows:",
"max_new_tokens": 512
}
}

View File

@@ -0,0 +1,24 @@
{
"job_type": "instruction_expansion_batch",
"dataset": {
"input_path": "./train.json",
"output_path": "./train_extended.json",
"template" : "./chat_template/chat_template_kd.jinja",
"num_in_context_samples": 3,
"num_output_samples": 10
},
"models": {
"teacher": "teacher/Qwen/Qwen2.5-7B-Instruct/"
},
"inference":{
"prompt" : "Assume you are a data synthesis expert. Given a few instructions as in-context examples, you should generate a new instruction similar to the examples to support the training of large language models. You should place your answer enclosed within <answer></answer> tags. The examples are as follows:",
"enable_chunked_prefill": true,
"seed": 777,
"gpu_memory_utilization": 0.9,
"temperature": 0.8,
"trust_remote_code": true,
"enforce_eager": false,
"max_model_len": 4096,
"max_new_tokens": 512
}
}

View File

@@ -0,0 +1,14 @@
{
"job_type": "instruction_refinement_api",
"dataset": {
"input_path": "./train.json",
"output_path": "./train_refined.json"
},
"inference":{
"base_url": "ENDPOINT",
"api_key": "TOKEN",
"stream": true,
"prompt" : "Assume you are a prompt re-writing expert. Given an instruction as input, you should generate a new instruction semantically similar to the input to support the training of large language models. Transform the input raw prompt into a detailed prompt that comprehensively captures the users request. Make sure to maintain the original intent while significantly enhancing clarity and depth. You should place your answer enclosed within <answer></answer> tags. The input prompt is as follows:",
"max_new_tokens": 512
}
}

View File

@@ -0,0 +1,22 @@
{
"job_type": "instruction_refinement_batch",
"dataset": {
"input_path": "./train.json",
"output_path": "./train_refined.json",
"template" : "./chat_template/chat_template_kd.jinja"
},
"models": {
"teacher": "teacher/Qwen/Qwen2.5-7B-Instruct/"
},
"inference": {
"prompt" : "Assume you are a prompt re-writing expert. Given an instruction as input, you should generate a new instruction semantically similar to the input to support the training of large language models. Transform the input raw prompt into a detailed prompt that comprehensively captures the users request. Make sure to maintain the original intent while significantly enhancing clarity and depth. You should place your answer enclosed within <answer></answer> tags. The input prompt is as follows:",
"enable_chunked_prefill": true,
"seed": 777,
"gpu_memory_utilization": 0.9,
"temperature": 0.8,
"trust_remote_code": true,
"enforce_eager": false,
"max_model_len": 4096,
"max_new_tokens": 512
}
}

View File

@@ -0,0 +1,14 @@
{
"job_type": "instruction_response_extraction_api",
"dataset": {
"input_path": "./raw.json",
"output_path": "./raw_extracted.json"
},
"inference":{
"base_url": "ENDPOINT",
"api_key": "TOKEN",
"stream": true,
"prompt" : "Assume you are a data synthesis expert. Given plain text as input, you should generate an instruction-response pair where the instruction and the response are derived from the knowledge of the plain text to support the training of large language models. The response should properly answer the instruction. You should place your instruction enclosed within <instruction></instruction> tags, and place your response enclosed within <response></response> tags. The input plain text is as follows:",
"max_new_tokens": 1024
}
}

View File

@@ -0,0 +1,22 @@
{
"job_type": "instruction_response_extraction_batch",
"dataset": {
"input_path": "./train.json",
"output_path": "./train_extended.json",
"template" : "./chat_template/chat_template_kd.jinja"
},
"models": {
"teacher": "teacher/Qwen/Qwen2.5-7B-Instruct/"
},
"inference":{
"prompt" : "Assume you are a data synthesis expert. Given plain text as input, you should generate an instruction-response pair where the instruction and the response are derived from the knowledge of the plain text to support the training of large language models. The response should properly answer the instruction. You should place your instruction enclosed within <instruction></instruction> tags, and place your response enclosed within <response></response> tags. The input plain text is as follows:",
"enable_chunked_prefill": true,
"seed": 777,
"gpu_memory_utilization": 0.9,
"temperature": 0.8,
"trust_remote_code": true,
"enforce_eager": false,
"max_model_len": 4096,
"max_new_tokens": 512
}
}

View File

@@ -0,0 +1,32 @@
{
"job_type": "kd_black_box_api",
"dataset": {
"instruction_path": "train.json",
"labeled_path": "train_labeled.json",
"template" : "./chat_template/chat_template_kd.jinja",
"seed": 42
},
"inference":{
"base_url": "ENDPOINT",
"api_key": "TOKEN",
"stream": true,
"system_prompt" : "You are a helpful assistant.",
"max_new_tokens": 512
},
"models": {
"student": "student/Qwen/Qwen2.5-0.5B-Instruct/"
},
"training": {
"output_dir": "./result/",
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 8,
"max_length":512,
"save_steps": 1000,
"logging_steps": 1,
"learning_rate": 2e-5,
"weight_decay": 0.05,
"warmup_ratio": 0.1,
"lr_scheduler_type": "cosine"
}
}

View File

@@ -0,0 +1,36 @@
{
"job_type": "kd_black_box_local",
"dataset": {
"instruction_path": "train.json",
"labeled_path": "train_labeled.json",
"template" : "./chat_template/chat_template_kd.jinja",
"seed": 42
},
"inference":{
"enable_chunked_prefill": true,
"seed": 777,
"gpu_memory_utilization": 0.9,
"temperature": 0.8,
"trust_remote_code": true,
"enforce_eager": false,
"max_model_len": 4096,
"max_new_tokens": 512
},
"models": {
"teacher": "teacher/Qwen/Qwen2.5-7B-Instruct/",
"student": "student/Qwen/Qwen2.5-0.5B-Instruct/"
},
"training": {
"output_dir": "./result/",
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 8,
"max_length":512,
"save_steps": 1000,
"logging_steps": 1,
"learning_rate": 2e-5,
"weight_decay": 0.05,
"warmup_ratio": 0.1,
"lr_scheduler_type": "cosine"
}
}

42
configs/kd_white_box.json Normal file
View File

@@ -0,0 +1,42 @@
{
"job_type": "kd_white_box",
"dataset": {
"instruction_path": "train.json",
"labeled_path": "train_labeled.json",
"logits_path": "./logits.json",
"template" : "./chat_template/chat_template_kd.jinja",
"seed": 42
},
"inference":{
"enable_chunked_prefill": true,
"seed": 777,
"gpu_memory_utilization": 0.9,
"temperature": 0.8,
"trust_remote_code": true,
"enforce_eager": false,
"max_model_len": 4096,
"max_new_tokens": 512,
"top_logits_num": 10
},
"distillation": {
"kd_ratio": 0.5,
"max_seq_length": 512,
"distillation_type": "forward_kld"
},
"models": {
"teacher": "teacher/Qwen/Qwen2.5-7B-Instruct/",
"student": "student/Qwen/Qwen2.5-0.5B-Instruct/"
},
"training": {
"output_dir": "./result/",
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 8,
"save_steps": 1000,
"logging_steps": 1,
"learning_rate": 2e-5,
"weight_decay": 0.05,
"warmup_ratio": 0.1,
"lr_scheduler_type": "cosine"
}
}

32
configs/rank_dpo_api.json Normal file
View File

@@ -0,0 +1,32 @@
{
"job_type": "rank_dpo_api",
"dataset": {
"instruction_path": "train.json",
"labeled_path": "train_labeled.json",
"template" : "chat_template/chat_template_kd.jinja",
"seed": 42
},
"inference":{
"base_url": "ENDPOINT",
"api_key": "TOKEN",
"stream": true,
"system_prompt" : "You are a helpful assistant.",
"max_new_tokens": 512
},
"models": {
"student": "student/Qwen/Qwen2.5-0.5B-Instruct/"
},
"training": {
"output_dir": "./result/",
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 8,
"save_steps": 1000,
"logging_steps": 1,
"beta": 0.1,
"learning_rate": 2e-5,
"weight_decay": 0.05,
"warmup_ratio": 0.1,
"lr_scheduler_type": "cosine"
}
}

View File

@@ -0,0 +1,37 @@
{
"job_type": "rank_dpo_api",
"dataset": {
"instruction_path": "train.json",
"labeled_path": "train_labeled.json",
"template" : "chat_template/chat_template_kd.jinja",
"seed": 42
},
"inference":{
"system_prompt" : "You are a helpful assistant.",
"enable_chunked_prefill": true,
"seed": 777,
"gpu_memory_utilization": 0.9,
"temperature": 0.8,
"trust_remote_code": true,
"enforce_eager": false,
"max_model_len": 4096,
"max_new_tokens": 512
},
"models": {
"teacher": "teacher/Qwen/Qwen2.5-7B-Instruct/",
"student": "student/Qwen/Qwen2.5-0.5B-Instruct/"
},
"training": {
"output_dir": "./result/",
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 8,
"save_steps": 1000,
"logging_steps": 1,
"beta": 0.1,
"learning_rate": 2e-5,
"weight_decay": 0.05,
"warmup_ratio": 0.1,
"lr_scheduler_type": "cosine"
}
}

25
configs/rl_grpo.json Normal file
View File

@@ -0,0 +1,25 @@
{
"job_type": "rl_grpo",
"dataset": {
"instruction_path": "sample.json",
"template" : "chat_template_kd.jinja",
"train_ratio": 0.7,
"seed": 42
},
"models": {
"reward": "reward/",
"student": "Qwen/Qwen2.5-0.5B-Instruct"
},
"training": {
"output_dir": "./result/",
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 8,
"num_train_epochs": 3,
"save_steps": 100,
"logging_steps": 1,
"learning_rate": 2e-5,
"weight_decay": 0.05,
"warmup_ratio": 0.1,
"lr_scheduler_type": "cosine"
}
}

28
configs/rl_ppo.json Normal file
View File

@@ -0,0 +1,28 @@
{
"job_type": "rl_ppo",
"dataset": {
"instruction_path": "sample.json",
"template" : "chat_template_kd.jinja",
"train_ratio": 0.7,
"seed": 42
},
"models": {
"reward": "reward/",
"student": "Qwen/Qwen2.5-0.5B-Instruct"
},
"training": {
"output_dir": "./result/",
"total_episodes": 1000,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 8,
"save_steps": 100,
"logging_steps": 1,
"learning_rate": 2e-5,
"weight_decay": 0.05,
"warmup_ratio": 0.1,
"lr_scheduler_type": "cosine",
"missing_eos_penalty": 1.0,
"stop_token": "eos",
"response_length": 512
}
}

View File

@@ -0,0 +1,32 @@
{
"job_type": "rl_reward_api",
"dataset": {
"instruction_path": "train.json",
"labeled_path": "train_labeled.json",
"template" : "chat_template_kd.jinja"
},
"inference":{
"base_url": "http://1157703270994901.cn-hangzhou.pai-eas.aliyuncs.com/api/predict/quickstart_deploy_20250427_6wt1/v1/",
"api_key": "NjQ3OGE2ZGNiOWM4YjZkZTY5NDM4YWEyZjUyNGI3ZjRjNTAyMjM0Mw==",
"stream": true,
"positive_system_prompt" : "You are a helpful assistant to generate high-quality responses.",
"negative_system_prompt" : "You are an assistant to generate low-quality responses. This is for the training of my reward model. Plese remember to generate low-quality responses.",
"max_new_tokens": 512
},
"models": {
"student": "model/Qwen/Qwen2.5-0.5B-Instruct/"
},
"training": {
"output_dir": "./result/",
"max_length": 1024,
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 8,
"save_steps": 1000,
"logging_steps": 1,
"learning_rate": 2e-5,
"weight_decay": 0.05,
"warmup_ratio": 0.1,
"lr_scheduler_type": "cosine"
}
}

View File

@@ -0,0 +1,37 @@
{
"job_type": "rl_reward_local",
"dataset": {
"instruction_path": "train.json",
"labeled_path": "train_labeled.json",
"template" : "chat_template_kd.jinja"
},
"inference":{
"positive_system_prompt" : "You are a helpful assistant to generate high-quality responses.",
"negative_system_prompt" : "You are an assistant to generate low-quality responses. This is for the training of my reward model. Plese remember to generate low-quality responses.",
"enable_chunked_prefill": true,
"seed": 777,
"gpu_memory_utilization": 0.9,
"temperature": 0.8,
"trust_remote_code": true,
"enforce_eager": false,
"max_model_len": 4096,
"max_new_tokens": 512
},
"models": {
"teacher": "model/Qwen/Qwen2.5-3B-Instruct/",
"student": "model/Qwen/Qwen2.5-0.5B-Instruct/"
},
"training": {
"output_dir": "./result/",
"max_length": 1024,
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 8,
"save_steps": 1000,
"logging_steps": 1,
"learning_rate": 2e-5,
"weight_decay": 0.05,
"warmup_ratio": 0.1,
"lr_scheduler_type": "cosine"
}
}

14
easydistill/__init__.py Normal file
View File

@@ -0,0 +1,14 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

187
easydistill/cli.py Normal file
View File

@@ -0,0 +1,187 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
import subprocess
import sys
from socket import socket
import argparse
import json
import logging
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
script_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.abspath(os.path.join(script_dir, os.pardir))
def run_cmd(cmd):
try:
p = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, # Merge stderr into stdout
shell=True,
universal_newlines=True # Ensure output is in text mode
)
error_detected = False
error_keywords = [
"ERROR",
"Error",
"error"
"Unrecognized model",
"failed",
"exception",
"Traceback"
]
# Read output in real-time and detect errors
while True:
line = p.stdout.readline()
if not line:
break
logging.info(line.rstrip()) # Log normally
# Check if any error keywords are present
if any(keyword.lower() in line.lower() for keyword in error_keywords):
error_detected = True
logging.error(f"Detected error in output: {line.strip()}")
# Wait for process to finish
returncode = p.wait()
# If errors were detected or return code is non-zero, return False
if error_detected or returncode != 0:
logging.error(f"Command failed (returncode={returncode}, errors detected)")
return False
return True # Return True indicates success
except Exception as e:
logging.error(f"Unexpected error running command: {e}")
return False
def process(job_type, config):
if not os.path.isabs(config):
config = os.path.join(script_dir, config)
# Knowledge Distillation tasks
if job_type in ['kd_black_box_train_only', 'kd_white_box_train_only']:
cmd_train = [
'accelerate', 'launch',
'--config_file', os.path.join(parent_dir, 'configs/accelerate_config/muti_gpu.yaml'),
os.path.join(script_dir, 'kd/train.py'),
'--config', config
]
cmd_train = ' '.join(cmd_train)
logging.info(f"Running command: {cmd_train}")
run_cmd(cmd_train)
elif job_type in ['kd_black_box_api', 'kd_black_box_local', 'kd_white_box']:
cmd_infer = [
'python', os.path.join(script_dir, 'kd/infer.py'),
'--config', config
]
cmd_infer = ' '.join(cmd_infer)
logging.info(f"Running command: {cmd_infer}")
infer_success = run_cmd(cmd_infer)
if infer_success:
cmd_train = [
'accelerate', 'launch',
'--config_file', os.path.join(parent_dir, 'configs/accelerate_config/muti_gpu.yaml'),
os.path.join(script_dir, 'kd/train.py'),
'--config', config
]
cmd_train = ' '.join(cmd_train)
logging.info(f"Running command: {cmd_train}")
run_cmd(cmd_train)
else:
logging.error("Infer failed, skipping training")
# Reinforcement Learning tasks
elif job_type in ['rl_ppo', 'rl_grpo']:
cmd = [
'accelerate', 'launch',
'--config_file', os.path.join(parent_dir, 'configs/accelerate_config/muti_gpu.yaml'),
os.path.join(script_dir, f'rl/{job_type.split("_")[1]}.py'),
'--config', config
]
cmd = ' '.join(cmd)
logging.info(f"Running command: {cmd}")
run_cmd(cmd)
elif job_type in ['rl_reward_api', 'rl_reward_local']:
cmd = [
'python',
os.path.join(script_dir, 'rl/reward.py'),
'--config', config
]
cmd = ' '.join(cmd)
logging.info(f"Running command: {cmd}")
run_cmd(cmd)
# Instruction Processing tasks
elif job_type.startswith('instruction_'):
task_type = job_type.replace('instruction_', '')
cmd = [
'python',
os.path.join(script_dir, f'synthesis/synthesis_main.py'),
'--config', config
]
cmd = ' '.join(cmd)
logging.info(f"Running command: {cmd}")
run_cmd(cmd)
# Chain of Thought tasks
elif job_type.startswith('cot_'):
task_type = job_type.replace('cot_', '')
cmd = [
'python',
os.path.join(script_dir, f'synthesis/synthesis_main.py'),
'--config', config
]
cmd = ' '.join(cmd)
logging.info(f"Running command: {cmd}")
run_cmd(cmd)
# Ranking and DPO tasks
elif job_type.startswith('rank_'):
task_type = job_type.replace('rank_', '')
cmd = [
'python',
os.path.join(script_dir, f'rank/{task_type}.py'),
'--config', config
]
cmd = ' '.join(cmd)
logging.info(f"Running command: {cmd}")
run_cmd(cmd)
else:
logging.error(f"Unknown job type: {job_type}")
sys.exit(1)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
args = parser.parse_args()
config_path = args.config
config = json.load(open(config_path))
job_type = config["job_type"]
process(job_type, config_path)
if __name__ == '__main__':
main()

247
easydistill/kd/infer.py Normal file
View File

@@ -0,0 +1,247 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json, jsonlines
import argparse
import torch
import logging
import os
from jinja2 import Environment, FileSystemLoader
from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm import LLM, SamplingParams
from tqdm import tqdm
from openai import OpenAI
import math
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def read_json_field(filename, field_name='instruction'):
try:
with open(filename, 'r') as file:
data = json.load(file)
output_fields = []
for item in data:
if field_name in item:
output_fields.append(item[field_name])
return output_fields
except FileNotFoundError:
logging.error("The file was not found.")
except json.JSONDecodeError:
logging.error("There was an error decoding the JSON file.")
except Exception as e:
logging.error(f"An error occurred: {e}")
def write_data_to_json_file(data, file_path):
try:
with open(file_path, 'w') as file:
json.dump(data, file, ensure_ascii=False, indent=4)
logging.info(f"Data successfully written to {file_path}")
except Exception as e:
logging.error(f"An error occurred: {e}")
def load_tokenizer_and_vllm(config, eos_token=None):
teacher_model_path = config["models"]["teacher"]
logging.info(f"Loading ckpt and tokenizer: {teacher_model_path}")
tokenizer = AutoTokenizer.from_pretrained(teacher_model_path, trust_remote_code=True)
tokenizer.padding_side = "left"
if eos_token:
eos_token_id = tokenizer.convert_tokens_to_ids(eos_token)
logging.info(f"eos_token {eos_token} from user input")
elif hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id:
logging.info(f"Initial eos_token_id {tokenizer.eos_token_id} from tokenizer")
eos_token_id = tokenizer.eos_token_id
eos_token = tokenizer.convert_ids_to_tokens(eos_token_id)
else:
raise ValueError("No available eos_token or eos_token_id.")
try:
tokenizer.eos_token = eos_token
tokenizer.eos_token_id = eos_token_id
tokenizer.pad_token = eos_token
tokenizer.pad_token_id = eos_token_id
except:
logging.info(f"[WARNING] Cannot set tokenizer.eos_token")
logging.info(f"tokenizer's eos_token: {tokenizer.eos_token}, pad_token: {tokenizer.pad_token}")
logging.info(f"tokenizer's eos_token_id: {tokenizer.eos_token_id}, pad_token_id: {tokenizer.pad_token_id}")
num_gpus = torch.cuda.device_count()
llm = LLM(
model=teacher_model_path,
tensor_parallel_size=num_gpus,
enable_chunked_prefill=config["inference"]["enable_chunked_prefill"],
gpu_memory_utilization=config["inference"]["gpu_memory_utilization"],
trust_remote_code=config["inference"]["trust_remote_code"],
dtype=torch.bfloat16,
enforce_eager=config["inference"]["enforce_eager"],
max_model_len=config["inference"]["max_model_len"],
)
logging.info("vLLM model loaded successfully")
return tokenizer, llm
def generate_teacher_response_batch(tokenizer, llm, data_list, config, batch_size=32):
full_path = config["dataset"]["template"]
template_dir = os.path.dirname(full_path)
template_file = os.path.basename(full_path)
env = Environment(loader=FileSystemLoader(template_dir))
template = env.get_template(template_file)
outcomes = []
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
for batch in tqdm(batches, desc="Generating responses"):
new_batch = []
for sample in batch:
message = {"role": "user", "content": sample}
full_text = template.render(
message = message,
add_generation_prompt = True,
add_output = False
)
new_batch.append(full_text)
outputs = llm.generate(
new_batch,
SamplingParams(
n = 1,
top_k = 1,
temperature = config["inference"]["temperature"],
seed = config["inference"]["seed"],
skip_special_tokens = False,
ignore_eos = False,
max_tokens = config["inference"]["max_new_tokens"]
)
)
responses = [output.outputs[0].text for output in outputs]
gen_data = [{'instruction': batch[i], 'output': responses[i]} for i in range(len(batch))]
outcomes = outcomes + gen_data
write_data_to_json_file(outcomes, config["dataset"]["labeled_path"])
def generate_teacher_logits_batch(tokenizer, llm, data_list, config, batch_size=32):
full_path = config["dataset"]["template"]
template_dir = os.path.dirname(full_path)
template_file = os.path.basename(full_path)
env = Environment(loader=FileSystemLoader(template_dir))
template = env.get_template(template_file)
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
for batch in tqdm(batches, desc="Generating responses"):
new_batch = []
for sample in batch:
message={"role": "user", "content": sample}
full_text = template.render(
message=message,
add_generation_prompt=True,
add_output=False
)
new_batch.append(full_text)
outputs = llm.generate(
new_batch, # Pass the raw text directly
SamplingParams(
n=1,
top_k=1,
temperature=config["inference"]["temperature"],
seed=config["inference"]["seed"],
skip_special_tokens=False,
ignore_eos=True,
max_tokens=config["inference"]["max_new_tokens"],
logprobs=config["inference"]["top_logits_num"],
)
)
# Extract the generated logits
responses = [output.outputs[0].text for output in outputs]
logits=[output.outputs[0].logprobs for output in outputs]
for logit in logits:
for pos in logit:
for k,v in pos.items():
pos[k]=math.exp(v.logprob)
with jsonlines.open(config["dataset"]["logits_path"], mode='a') as writer:
for row in logits:
#for item in row:
writer.write(row)
def generate_teacher_response_api(data_list, config):
client = OpenAI(
api_key = config["inference"]["api_key"],
base_url = config["inference"]["base_url"]
)
models = client.models.list()
model = models.data[0].id
logging.info(model)
system_prompt = config["inference"]["system_prompt"]
stream = config["inference"]["stream"]
outcomes = []
for sample in tqdm(data_list, desc="Call remote model and generating responses"):
if system_prompt == "":
message = [
{'role': 'user', 'content': sample}
]
else:
message = [
{'role': 'system', 'content': system_prompt},
{'role': 'user', 'content': sample}
]
completion = client.chat.completions.create(
messages = message,
model = model,
max_completion_tokens = config["inference"]["max_new_tokens"],
stream = stream
)
if stream:
result = ""
for chunk in completion:
result += chunk.choices[0].delta.content
else:
result = completion.choices[0].message.content
outcomes.append({'instruction': sample, 'output': result})
write_data_to_json_file(outcomes, config["dataset"]["labeled_path"])
def infer_with_teacher_model(config):
logging.info('Generating distillation data from the teacher model!')
data_list = read_json_field(config["dataset"]["instruction_path"])
try:
job_type = config["job_type"]
if job_type == "kd_black_box_api":
generate_teacher_response_api(data_list, config)
elif job_type == "kd_black_box_local":
tokenizer, llm = load_tokenizer_and_vllm(config)
generate_teacher_response_batch(tokenizer, llm, data_list, config)
elif job_type == "kd_white_box":
tokenizer, llm = load_tokenizer_and_vllm(config)
generate_teacher_logits_batch(tokenizer, llm, data_list, config)
else:
logging.error(f"Invalid job type: {job_type}")
raise ValueError(f"Invalid job type: {job_type}")
except ValueError as e:
logging.error(f"Training job terminated: {e}")
return
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
args = parser.parse_args()
config = json.load(open(args.config))
infer_with_teacher_model(config)
if __name__ == "__main__":
main()

218
easydistill/kd/train.py Normal file
View File

@@ -0,0 +1,218 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json
import argparse
import logging
import os
from jinja2 import Environment, BaseLoader, FileSystemLoader
from datasets import load_dataset,Dataset
from typing import Optional, Dict, Union, List
from datasets import Dataset
from transformers import PreTrainedModel, PreTrainedTokenizerBase,AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import SFTTrainer,SFTConfig
import torch
import jsonlines
import numpy as np
import torch.nn.functional as F
class DistillSFTTrainer(SFTTrainer):
def __init__(
self,
logits_dir: str = None,
teacher_vocab_size = None,
kd_ratio: float = 0.5,
max_seq_length : int = 1024,
distillation_type: str = "forward_kld",
**kwargs
):
super().__init__(**kwargs)
self.logits_dir = logits_dir
self.teacher_vocab_size = teacher_vocab_size
self.kd_ratio = kd_ratio
self.max_seq_length = max_seq_length
self.distillation_type = distillation_type
self.teacher_logits = []
with jsonlines.open(self.logits_dir) as reader:
for obj in reader:
self.teacher_logits.append(obj)
def _load_teacher_logits(self, batch_size: int, it: int, dp_rank: int, device: torch.device, no_model_batch: Dict):
start_idx = dp_rank * batch_size + batch_size * it
end_idx = dp_rank * batch_size + batch_size * (it + 1)
loaded_data = self.teacher_logits[start_idx:end_idx]
arr = np.zeros((batch_size, self.max_seq_length, self.teacher_vocab_size))
for i in range(len(loaded_data)):
for j in range(len(loaded_data[i])):
keys = np.array(list(loaded_data[i][j].keys()), dtype=int)
values = np.array(list(loaded_data[i][j].values()))
arr[i, j, keys] = values
logits_tensor = torch.tensor(arr, dtype=torch.bfloat16, device=device)
return self._shift_tensor_right(logits_tensor, no_model_batch['label'], pad_value=0)
def _compute_white_box_distillation_loss(self, student_logits: torch.Tensor, teacher_logits: torch.Tensor, labels: Optional[torch.Tensor]):
student_logits = student_logits[:, :self.max_seq_length, :]
teacher_probs = teacher_logits[:, :student_logits.size(1), :student_logits.size(-1)]
mask = (labels != -100).float() if labels is not None else torch.ones_like(student_logits[:, :, 0])
if self.distillation_type == "forward_kld":
# Forward KLD: student learns from teacher (original implementation)
loss = F.kl_div(
F.log_softmax(student_logits, dim=-1),
teacher_probs,
reduction='none',
log_target=False
).sum(dim=-1)/torch.sum(mask.view(-1), dim=0)
elif self.distillation_type == "reverse_kld":
# Reverse KLD: teacher provides certainty to student
loss = F.kl_div(
torch.log(teacher_probs.clamp(min=1e-10)), # avoid log(0)
F.softmax(student_logits, dim=-1),
reduction='none',
log_target=False
).sum(dim=-1)/torch.sum(mask.view(-1), dim=0)
else:
raise ValueError(f"Unsupported distillation type: {self.distillation_type}. Use 'forward_kld' or 'reverse_kld'")
return (loss * mask).sum() / mask.sum()
@staticmethod
def _shift_tensor_right(inputs: torch.Tensor, labels: torch.Tensor, pad_value: float = 0.0):
batch_size, seqlen, vocab_size = inputs.shape
device = inputs.device
labels_ne = labels != -100
shift_distances = torch.argmax(labels_ne.int(), dim=1)
idx = torch.arange(seqlen, device=device).unsqueeze(0).expand(batch_size, seqlen)
shifted_idx = idx - shift_distances.unsqueeze(1)
mask = shifted_idx >= 0
shifted_idx = shifted_idx.clamp(min=0)
inputs_flat = inputs.view(batch_size, seqlen, vocab_size)
shifted_idx = shifted_idx.unsqueeze(2).expand(-1, -1, vocab_size)
gathered = torch.gather(inputs_flat, 1, shifted_idx)
mask = mask.unsqueeze(2).expand(-1, -1, vocab_size)
return torch.where(mask, gathered, torch.full_like(gathered, pad_value))
def compute_loss(self, model: PreTrainedModel, inputs: Dict[str, torch.Tensor], return_outputs=False, num_items_in_batch=None):
outputs = model(**inputs)
lm_loss = outputs.loss
if self.logits_dir:
teacher_logits = self._load_teacher_logits(
batch_size=inputs['input_ids'].size(0),
it=self.state.global_step,
dp_rank=torch.distributed.get_rank() if torch.distributed.is_initialized() else 0,
device=model.device,
no_model_batch={'label': inputs.get('labels', None)}
)
distil_loss = self._compute_white_box_distillation_loss(
student_logits=outputs.logits,
teacher_logits=teacher_logits,
labels=inputs.get('labels', None)
)
total_loss = (1 - self.kd_ratio) * lm_loss + self.kd_ratio * distil_loss
else:
total_loss = lm_loss
return (total_loss, outputs) if return_outputs else total_loss
def formatting_func(examples):
env = Environment(loader=BaseLoader())
try:
message = {"content": examples["instruction"],"output":examples["output"]}
full_text = template.render(
message=message,
add_generation_prompt=False,
add_output=True
)
return full_text
except Exception as e:
logging.warning(f"Error processing sample: {str(e)}")
return ""
def train(config):
dataset = load_dataset("json", data_files=config["dataset"]["labeled_path"])
student_tokenizer = AutoTokenizer.from_pretrained(
config["models"]["student"],
trust_remote_code=True
)
student_model = AutoModelForCausalLM.from_pretrained(
config["models"]["student"],
trust_remote_code=True
)
global template
full_path = config["dataset"]["template"]
template_dir = os.path.dirname(full_path)
template_file = os.path.basename(full_path)
env = Environment(loader=FileSystemLoader(template_dir))
template = env.get_template(template_file)
training_arguments = SFTConfig(**config["training"])
try:
job_type = config["job_type"]
if "kd_black_box" in job_type:
dataset = dataset.shuffle(seed=config["dataset"]["seed"])
trainer = SFTTrainer(
model=student_model,
processing_class=student_tokenizer,
args=training_arguments,
train_dataset=dataset["train"],
formatting_func=formatting_func
)
elif "kd_white_box" in job_type:
teacher_vocab_size=json.load(open(os.path.join(config["models"]["teacher"], 'config.json')))['vocab_size']
trainer = DistillSFTTrainer(
logits_dir=config["dataset"]["logits_path"],
teacher_vocab_size=teacher_vocab_size,
kd_ratio=config["distillation"]["kd_ratio"],
max_seq_length=config["distillation"]["max_seq_length"],
distillation_type=config["distillation"].get("distillation_type", "forward_kld"),
model=student_model,
processing_class=student_tokenizer,
args=training_arguments,
train_dataset=dataset["train"],
formatting_func=formatting_func
)
else:
logging.error(f"Invalid job type: {job_type}")
raise ValueError(f"Invalid job type: {job_type}")
except ValueError as e:
logging.error(f"Training job terminated: {e}")
return
trainer.train()
trainer.save_model(config["training"]["output_dir"])
student_tokenizer.save_pretrained(config["training"]["output_dir"])
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
args = parser.parse_args()
config = json.load(open(args.config))
train(config)
if __name__ == "__main__":
main()

262
easydistill/rank/infer.py Normal file
View File

@@ -0,0 +1,262 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json
import argparse
import logging
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm import LLM, SamplingParams
from jinja2 import Environment, FileSystemLoader
from tqdm import tqdm
from openai import OpenAI
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def read_json_field(filename, field_name='prompt'):
try:
with open(filename, 'r') as file:
data = json.load(file)
output_fields = []
for item in data:
if field_name in item:
output_fields.append(item[field_name])
return output_fields
except FileNotFoundError:
logging.error("The file was not found.")
except json.JSONDecodeError:
logging.error("There was an error decoding the JSON file.")
except Exception as e:
logging.error(f"An error occurred: {e}")
def write_data_to_json_file(data, file_path):
try:
with open(file_path, 'w') as file:
json.dump(data, file, ensure_ascii=False, indent=4)
logging.info(f"Data successfully written to {file_path}")
except Exception as e:
logging.error(f"An error occurred: {e}")
def load_tokenizer_and_vllm(config, eos_token=None, is_teacher_model=True):
if is_teacher_model:
model_path = config["models"]["teacher"]
else:
model_path = config["models"]["student"]
logging.info(f"Loading ckpt and tokenizer: {model_path}")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
tokenizer.padding_side = "left"
if eos_token:
eos_token_id = tokenizer.convert_tokens_to_ids(eos_token)
logging.info(f"eos_token {eos_token} from user input")
elif hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id:
logging.info(f"Initial eos_token_id {tokenizer.eos_token_id} from tokenizer")
eos_token_id = tokenizer.eos_token_id
eos_token = tokenizer.convert_ids_to_tokens(eos_token_id)
else:
raise ValueError("No available eos_token or eos_token_id.")
try:
tokenizer.eos_token = eos_token
tokenizer.eos_token_id = eos_token_id
tokenizer.pad_token = eos_token
tokenizer.pad_token_id = eos_token_id
except:
logging.info(f"[WARNING] Cannot set tokenizer.eos_token")
logging.info(f"tokenizer's eos_token: {tokenizer.eos_token}, pad_token: {tokenizer.pad_token}")
logging.info(f"tokenizer's eos_token_id: {tokenizer.eos_token_id}, pad_token_id: {tokenizer.pad_token_id}")
num_gpus = torch.cuda.device_count()
llm = LLM(
model=model_path,
tensor_parallel_size=num_gpus,
enable_chunked_prefill=config["inference"]["enable_chunked_prefill"],
gpu_memory_utilization=config["inference"]["gpu_memory_utilization"],
trust_remote_code=config["inference"]["trust_remote_code"],
dtype=torch.bfloat16,
enforce_eager=config["inference"]["enforce_eager"],
max_model_len=config["inference"]["max_model_len"],
)
logging.info("vLLM model loaded successfully")
return tokenizer, llm
def generate_teacher_student_response_api(data_list, config):
client = OpenAI(
api_key=config["inference"]["api_key"],
base_url=config["inference"]["base_url"]
)
models = client.models.list()
model = models.data[0].id
logging.info(model)
system_prompt = config["inference"]["system_prompt"]
stream = config["inference"]["stream"]
# load student model
student_tokenizer = AutoTokenizer.from_pretrained(
config["models"]["student"],
trust_remote_code=True
)
student_model = AutoModelForCausalLM.from_pretrained(
config["models"]["student"],
device_map="auto",
trust_remote_code=True
)
outcomes = []
for sample in tqdm(data_list, desc="Call remote model and generating responses"):
# for teacher model
if system_prompt == "":
message=[
{'role': 'user', 'content': sample}
]
else:
message=[
{'role': 'system', 'content': system_prompt},
{'role': 'user', 'content': sample}
]
completion = client.chat.completions.create(
messages=message,
model=model,
max_completion_tokens=config["inference"]["max_new_tokens"],
stream=stream,
)
if stream:
result = ""
for chunk in completion:
result += chunk.choices[0].delta.content
else:
result = completion.choices[0].message.content
# for student model
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": sample}
]
text = student_tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = student_tokenizer([text], return_tensors="pt").to(student_model.device)
generated_ids = student_model.generate(
**model_inputs,
max_new_tokens=config["inference"]["max_new_tokens"]
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
rejected = student_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
gen_data = {'prompt': sample, 'chosen': result, 'rejected': rejected}
outcomes.append(gen_data)
write_data_to_json_file(outcomes, config["dataset"]["labeled_path"])
def generate_model_response_batch(tokenizer, llm, data_list, config, batch_size=32, is_teacher_model=True):
full_path = config["dataset"]["template"]
template_dir = os.path.dirname(full_path)
template_file = os.path.basename(full_path)
env = Environment(loader=FileSystemLoader(template_dir))
template = env.get_template(template_file)
outcomes = []
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
for batch in tqdm(batches, desc="Generating responses"):
new_batch = []
for sample in batch:
message={"role": "user", "content": sample}
full_text = template.render(
message=message,
add_generation_prompt=True,
add_output=False
)
new_batch.append(full_text)
model_outputs = llm.generate(
new_batch,
SamplingParams(
n=1,
top_k=1,
temperature=config["inference"]["temperature"],
seed=config["inference"]["seed"],
skip_special_tokens=False,
ignore_eos=False,
max_tokens=config["inference"]["max_new_tokens"]
)
)
model_responses = [output.outputs[0].text for output in model_outputs]
if is_teacher_model:
gen_data = [{'prompt': batch[i], 'chosen': model_responses[i]} for i in range(len(batch))]
else:
gen_data = [{'prompt': batch[i], 'rejected': model_responses[i]} for i in range(len(batch))]
outcomes = outcomes + gen_data
return outcomes
def merge_outcomes(teacher_outcomes, student_outcomes, config):
try:
student_dict = {item['prompt']: item['rejected'] for item in student_outcomes}
merged_outcomes = []
for teacher_item in teacher_outcomes:
prompt = teacher_item['prompt']
if prompt in student_dict:
merged_outcome = {
'prompt': prompt,
'chosen': teacher_item['chosen'],
'rejected': student_dict[prompt]
}
merged_outcomes.append(merged_outcome)
with open(config["dataset"]["labeled_path"], 'w') as file:
json.dump(merged_outcomes, file, ensure_ascii=False, indent=4)
except Exception as e:
print(f"An error occurred: {e}")
def infer_with_teacher_model(config):
logging.info('Generating distillation data from the teacher model!')
data_list = read_json_field(config["dataset"]["instruction_path"])
try:
job_type = config["job_type"]
if job_type == "rank_dpo_api":
generate_teacher_student_response_api(data_list, config)
elif job_type == "rank_dpo_local":
teacher_tokenizer, teacher_llm = load_tokenizer_and_vllm(config, is_teacher_model=True)
teacher_outcomes = generate_model_response_batch(teacher_tokenizer, teacher_llm, data_list, config, is_teacher_model=True)
del teacher_llm
student_tokenizer, student_llm = load_tokenizer_and_vllm(config, is_teacher_model=False)
student_outcomes = generate_model_response_batch(student_tokenizer, student_llm, data_list, config, is_teacher_model=False)
del student_llm
merge_outcomes(teacher_outcomes, student_outcomes, config)
else:
logging.error(f"Invalid job type: {job_type}")
raise ValueError(f"Invalid job type: {job_type}")
except ValueError as e:
logging.error(f"Training job terminated: {e}")
return
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
args = parser.parse_args()
config = json.load(open(args.config))
infer_with_teacher_model(config)
if __name__ == "__main__":
main()

105
easydistill/rank/train.py Normal file
View File

@@ -0,0 +1,105 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json
import argparse
import logging
import os
from jinja2 import Environment, BaseLoader, FileSystemLoader
from datasets import load_dataset, Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOTrainer, DPOConfig
import copy
def process_dataset(dataset_path, dataset_seed, env, template):
examples = []
with open(dataset_path, 'r') as file:
examples = json.load(file)
output_text = {
"prompt": [],
"chosen": [],
"rejected": []
}
# use chat template
for i in range(len(examples)):
try:
prompt_message = {"content": examples[i]["prompt"]}
prompt = template.render(message=prompt_message, add_generation_prompt=False, add_output=False)
chosen_message = {"content": examples[i]["prompt"], "output": examples[i]["chosen"]}
chosen = template.render(message=chosen_message, add_generation_prompt=False, add_output=True)
chosen = chosen[len(prompt):]
rejected_message = {"content": examples[i]["prompt"], "output": examples[i]["rejected"]}
rejected = template.render(message=rejected_message, add_generation_prompt=False, add_output=True)
rejected = rejected[len(prompt):]
output_text["prompt"].append(prompt)
output_text["chosen"].append(chosen)
output_text["rejected"].append(rejected)
except:
logging.warning(f"Error processing sample.")
dataset = Dataset.from_dict(output_text)
dataset = dataset.shuffle(seed=dataset_seed)
return dataset
def train(config):
dataset_path = config["dataset"]["labeled_path"]
dataset_seed = config["dataset"]["seed"]
full_path = config["dataset"]["template"]
template_dir = os.path.dirname(full_path)
template_file = os.path.basename(full_path)
env = Environment(loader=FileSystemLoader(template_dir))
template = env.get_template(template_file)
dataset = process_dataset(dataset_path, dataset_seed, env, template)
student_tokenizer = AutoTokenizer.from_pretrained(
config["models"]["student"],
trust_remote_code=True
)
student_model = AutoModelForCausalLM.from_pretrained(
config["models"]["student"],
trust_remote_code=True
)
training_arguments = DPOConfig(**config["training"])
trainer = DPOTrainer(
student_model,
ref_model=copy.deepcopy(student_model),
args=training_arguments,
train_dataset=dataset,
processing_class=student_tokenizer
)
trainer.train()
trainer.save_model(config["training"]["output_dir"])
student_tokenizer.save_pretrained(config["training"]["output_dir"])
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
args = parser.parse_args()
config = json.load(open(args.config))
train(config)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,111 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json
import argparse
import logging
import os
import random
from jinja2 import Environment, BaseLoader, FileSystemLoader
from datasets import load_dataset, Dataset
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
from trl import GRPOConfig, GRPOTrainer
def process_dataset(dataset_path, dataset_seed, env, template, train_ratio):
examples = []
try:
with open(dataset_path, 'r') as file:
examples = json.load(file)
except FileNotFoundError:
print(f"Error: The file '{dataset_path}' was not found.")
except json.JSONDecodeError:
print(f"Error: The file '{dataset_path}' is not a valid JSON file.")
except Exception as e:
print(f"An unexpected error occurred: {e}")
output_dataset = []
# use chat template
for i in range(len(examples)):
try:
message = {"content": examples[i]["prompt"]}
rendered = template.render(message=message, add_generation_prompt=True, add_output=False)
sample = {"prompt": rendered}
output_dataset.append(sample)
except:
logging.warning(f"Error processing sample.")
random.shuffle(output_dataset)
random.seed(dataset_seed)
split_index = int(len(output_dataset) * train_ratio)
train_list = output_dataset[:split_index]
eval_list = output_dataset[split_index:]
return Dataset.from_list(train_list), Dataset.from_list(eval_list)
def train(config):
dataset_path = config["dataset"]["instruction_path"]
dataset_seed = config["dataset"]["seed"]
train_ratio = config["dataset"]["train_ratio"]
full_path = config["dataset"]["template"]
template_dir = os.path.dirname(full_path)
template_file = os.path.basename(full_path)
env = Environment(loader=FileSystemLoader(template_dir))
template = env.get_template(template_file)
tokenizer = AutoTokenizer.from_pretrained(
config["models"]["student"],
trust_remote_code=True
)
train_dataset, eval_dataset = process_dataset(dataset_path, dataset_seed, env, template, train_ratio)
print(train_dataset)
print(eval_dataset)
reward_model_path = config["models"]["reward"]
sft_model_path = config["models"]["student"]
reward_model = AutoModelForSequenceClassification.from_pretrained(
reward_model_path, trust_remote_code=True, num_labels=1
)
sft_model = AutoModelForCausalLM.from_pretrained(
sft_model_path, trust_remote_code=True
)
training_arguments = GRPOConfig(**config["training"])
trainer = GRPOTrainer(
args=training_arguments,
processing_class=tokenizer,
model=sft_model,
reward_funcs=reward_model,
train_dataset=train_dataset,
eval_dataset=eval_dataset
)
trainer.train()
trainer.save_model(config["training"]["output_dir"])
tokenizer.save_pretrained(config["training"]["output_dir"])
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
args = parser.parse_args()
config = json.load(open(args.config))
train(config)
if __name__ == "__main__":
main()

122
easydistill/rl/ppo_train.py Normal file
View File

@@ -0,0 +1,122 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json
import argparse
import logging
import os
import random
from jinja2 import Environment, BaseLoader, FileSystemLoader
from datasets import load_dataset, Dataset
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
from trl import PPOConfig, PPOTrainer
def process_dataset(dataset_path, dataset_seed, env, template, tokenizer, train_ratio):
examples = []
try:
with open(dataset_path, 'r') as file:
examples = json.load(file)
except FileNotFoundError:
print(f"Error: The file '{dataset_path}' was not found.")
except json.JSONDecodeError:
print(f"Error: The file '{dataset_path}' is not a valid JSON file.")
except Exception as e:
print(f"An unexpected error occurred: {e}")
output_dataset = []
# use chat template
for i in range(len(examples)):
try:
message = {"content": examples[i]["instruction"]}
rendered = template.render(message=message, add_generation_prompt=True, add_output=False)
tokens = tokenizer.encode(rendered)
sample = {"input_ids": tokens}
output_dataset.append(sample)
except:
logging.warning(f"Error processing sample.")
random.shuffle(output_dataset)
random.seed(dataset_seed)
split_index = int(len(output_dataset) * train_ratio)
train_list = output_dataset[:split_index]
eval_list = output_dataset[split_index:]
return Dataset.from_list(train_list), Dataset.from_list(eval_list)
def train(config):
dataset_path = config["dataset"]["instruction_path"]
dataset_seed = config["dataset"]["seed"]
train_ratio = config["dataset"]["train_ratio"]
full_path = config["dataset"]["template"]
template_dir = os.path.dirname(full_path)
template_file = os.path.basename(full_path)
env = Environment(loader=FileSystemLoader(template_dir))
template = env.get_template(template_file)
tokenizer = AutoTokenizer.from_pretrained(
config["models"]["student"],
trust_remote_code=True
)
train_dataset, eval_dataset = process_dataset(dataset_path, dataset_seed, env, template, tokenizer, train_ratio)
assert train_dataset[0]["input_ids"][-1] != tokenizer.eos_token_id, "The last token should not be an EOS token"
print(train_dataset)
print(eval_dataset)
reward_model_path = config["models"]["reward"]
sft_model_path = config["models"]["student"]
value_model = AutoModelForSequenceClassification.from_pretrained(
reward_model_path, trust_remote_code=True, num_labels=1
)
reward_model = AutoModelForSequenceClassification.from_pretrained(
reward_model_path, trust_remote_code=True, num_labels=1
)
ref_policy = AutoModelForCausalLM.from_pretrained(
sft_model_path, trust_remote_code=True
)
policy = AutoModelForCausalLM.from_pretrained(
sft_model_path, trust_remote_code=True
)
training_arguments = PPOConfig(**config["training"])
trainer = PPOTrainer(
config=training_arguments,
processing_class=tokenizer,
policy=policy,
ref_policy=ref_policy,
reward_model=reward_model,
value_model=value_model,
train_dataset=train_dataset,
eval_dataset=eval_dataset
)
trainer.train()
trainer.save_model(config["training"]["output_dir"])
tokenizer.save_pretrained(config["training"]["output_dir"])
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
args = parser.parse_args()
config = json.load(open(args.config))
train(config)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,258 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json
import argparse
import torch
import logging
import os
from jinja2 import Environment, FileSystemLoader
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from tqdm import tqdm
from openai import OpenAI
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def read_json_field(filename, field_name='prompt'):
try:
with open(filename, 'r') as file:
data = json.load(file)
output_fields = []
for item in data:
if field_name in item:
output_fields.append(item[field_name])
return output_fields
except FileNotFoundError:
logging.error("The file was not found.")
except json.JSONDecodeError:
logging.error("There was an error decoding the JSON file.")
except Exception as e:
logging.error(f"An error occurred: {e}")
def write_data_to_json_file(data, file_path):
try:
with open(file_path, 'w') as file:
json.dump(data, file, ensure_ascii=False, indent=4)
logging.info(f"Data successfully written to {file_path}")
except Exception as e:
logging.error(f"An error occurred: {e}")
def load_tokenizer_and_vllm(config, eos_token=None):
teacher_model_path = config["models"]["teacher"]
logging.info(f"Loading ckpt and tokenizer: {teacher_model_path}")
tokenizer = AutoTokenizer.from_pretrained(teacher_model_path, trust_remote_code=True)
tokenizer.padding_side = "left"
if eos_token:
eos_token_id = tokenizer.convert_tokens_to_ids(eos_token)
logging.info(f"eos_token {eos_token} from user input")
elif hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id:
logging.info(f"Initial eos_token_id {tokenizer.eos_token_id} from tokenizer")
eos_token_id = tokenizer.eos_token_id
eos_token = tokenizer.convert_ids_to_tokens(eos_token_id)
else:
raise ValueError("No available eos_token or eos_token_id.")
try:
tokenizer.eos_token = eos_token
tokenizer.eos_token_id = eos_token_id
tokenizer.pad_token = eos_token
tokenizer.pad_token_id = eos_token_id
except:
logging.info(f"[WARNING] Cannot set tokenizer.eos_token")
logging.info(f"tokenizer's eos_token: {tokenizer.eos_token}, pad_token: {tokenizer.pad_token}")
logging.info(f"tokenizer's eos_token_id: {tokenizer.eos_token_id}, pad_token_id: {tokenizer.pad_token_id}")
num_gpus = torch.cuda.device_count()
llm = LLM(
model=teacher_model_path,
tensor_parallel_size=num_gpus,
enable_chunked_prefill=config["inference"]["enable_chunked_prefill"],
gpu_memory_utilization=config["inference"]["gpu_memory_utilization"],
trust_remote_code=config["inference"]["trust_remote_code"],
dtype=torch.bfloat16,
enforce_eager=config["inference"]["enforce_eager"],
max_model_len=config["inference"]["max_model_len"],
)
logging.info("vLLM model loaded successfully")
return tokenizer, llm
def generate_teacher_response_for_reward_model_local(tokenizer, llm, data_list, config, batch_size=32):
full_path = config["dataset"]["template"]
template_dir = os.path.dirname(full_path)
template_file = os.path.basename(full_path)
env = Environment(loader=FileSystemLoader(template_dir))
template = env.get_template(template_file)
positive_system_prompt = config["inference"]["positive_system_prompt"]
negative_system_prompt = config["inference"]["negative_system_prompt"]
outcomes = []
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
for batch in tqdm(batches, desc="Generating responses"):
positive_new_batch = []
negative_new_batch = []
for sample in batch:
positive_message = [
{'role': 'system', 'content': positive_system_prompt},
{'role': 'user', 'content': sample}
]
positive_full_text = template.render(
message = positive_message,
add_generation_prompt = True,
add_output = False
)
positive_new_batch.append(positive_full_text)
negative_message = [
{'role': 'system', 'content': negative_system_prompt},
{'role': 'user', 'content': sample}
]
negative_full_text = template.render(
message = negative_message,
add_generation_prompt = True,
add_output = False
)
negative_new_batch.append(negative_full_text)
positive_outputs = llm.generate(
positive_new_batch,
SamplingParams(
n = 1,
top_k = 1,
temperature = config["inference"]["temperature"],
seed = config["inference"]["seed"],
skip_special_tokens = False,
ignore_eos = False,
max_tokens = config["inference"]["max_new_tokens"]
)
)
positve_responses = [output.outputs[0].text for output in positive_outputs]
positive_gen_data = [{'prompt': batch[i], 'chosen': positve_responses[i]} for i in range(len(batch))]
negative_outputs = llm.generate(
negative_new_batch,
SamplingParams(
n = 1,
top_k = 1,
temperature = config["inference"]["temperature"],
seed = config["inference"]["seed"],
skip_special_tokens = False,
ignore_eos = False,
max_tokens = config["inference"]["max_new_tokens"]
)
)
negative_responses = [output.outputs[0].text for output in negative_outputs]
negative_gen_data = [{'prompt': batch[i], 'rejected': negative_responses[i]} for i in range(len(batch))]
merged_data = merge_outcomes(positive_gen_data, negative_gen_data)
outcomes = outcomes + merged_data
write_data_to_json_file(outcomes, config["dataset"]["labeled_path"])
def merge_outcomes(positive_gen_data, negative_gen_data):
negative_dict = {item['prompt']: item['rejected'] for item in negative_gen_data}
merged_outcomes = []
for positive_item in positive_gen_data:
prompt = positive_item['prompt']
if prompt in negative_dict:
merged_outcome = {
'prompt': prompt,
'chosen': positive_item['chosen'],
'rejected': negative_dict[prompt]
}
merged_outcomes.append(merged_outcome)
return merged_outcomes
def generate_teacher_response_for_reward_model_api(data_list, config):
client = OpenAI(
api_key = config["inference"]["api_key"],
base_url = config["inference"]["base_url"]
)
models = client.models.list()
model = models.data[0].id
logging.info(model)
positive_system_prompt = config["inference"]["positive_system_prompt"]
negative_system_prompt = config["inference"]["negative_system_prompt"]
stream = config["inference"]["stream"]
outcomes = []
for sample in tqdm(data_list, desc="Call remote model and generating responses"):
positive_message = [
{'role': 'system', 'content': positive_system_prompt},
{'role': 'user', 'content': sample}
]
positive_completion = client.chat.completions.create(
messages = positive_message,
model = model,
max_completion_tokens = config["inference"]["max_new_tokens"],
stream = stream
)
if stream:
positive_result = ""
for chunk in positive_completion:
positive_result += chunk.choices[0].delta.content
else:
positive_result = positive_completion.choices[0].message.content
negative_message = [
{'role': 'system', 'content': negative_system_prompt},
{'role': 'user', 'content': sample}
]
negative_completion = client.chat.completions.create(
messages = negative_message,
model = model,
max_completion_tokens = config["inference"]["max_new_tokens"],
stream = stream
)
if stream:
negative_result = ""
for chunk in negative_completion:
negative_result += chunk.choices[0].delta.content
else:
negative_result = negative_completion.choices[0].message.content
outcomes.append({'prompt': sample, 'chosen': positive_result, 'rejected': negative_result})
write_data_to_json_file(outcomes, config["dataset"]["labeled_path"])
def infer_with_teacher_model(config):
logging.info('Generating distillation data from the teacher model!')
data_list = read_json_field(config["dataset"]["instruction_path"])
try:
job_type = config["job_type"]
if job_type == "rl_reward_api":
generate_teacher_response_for_reward_model_api(data_list, config)
elif job_type == "rl_reward_local":
tokenizer, llm = load_tokenizer_and_vllm(config)
generate_teacher_response_for_reward_model_local(tokenizer, llm, data_list, config)
else:
logging.error(f"Invalid job type: {job_type}")
raise ValueError(f"Invalid job type: {job_type}")
except ValueError as e:
logging.error(f"Training job terminated: {e}")
return
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
args = parser.parse_args()
config = json.load(open(args.config))
infer_with_teacher_model(config)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,107 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json
import argparse
import logging
import os
from jinja2 import Environment, FileSystemLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from trl import RewardTrainer, RewardConfig
from datasets import Dataset
def process_dataset(dataset_path, tokenizer, config, template):
kwargs = {"padding": "max_length", "truncation": True, "max_length": config["training"]["max_length"], "return_tensors": "pt"}
examples = []
try:
with open(dataset_path, 'r') as file:
examples = json.load(file)
except FileNotFoundError:
print(f"Error: The file '{dataset_path}' was not found.")
except json.JSONDecodeError:
print(f"Error: The file '{dataset_path}' is not a valid JSON file.")
except Exception as e:
print(f"An unexpected error occurred: {e}")
print(examples)
output_dataset = []
# use chat template
for i in range(len(examples)):
try:
chosen_message = {"content": examples[i]["prompt"], "output": examples[i]["chosen"]}
prompt_plus_chosen_response = template.render(message=chosen_message, add_generation_prompt=False, add_output=True)
rejected_message = {"content": examples[i]["prompt"], "output": examples[i]["rejected"]}
prompt_plus_rejected_response = template.render(message=rejected_message, add_generation_prompt=False, add_output=True)
tokens_chosen = tokenizer.encode_plus(prompt_plus_chosen_response, **kwargs)
tokens_rejected = tokenizer.encode_plus(prompt_plus_rejected_response, **kwargs)
sample = {
"input_ids_chosen": tokens_chosen["input_ids"][0], "attention_mask_chosen": tokens_chosen["attention_mask"][0],
"input_ids_rejected": tokens_rejected["input_ids"][0], "attention_mask_rejected": tokens_rejected["attention_mask"][0]
}
output_dataset.append(sample)
except:
logging.warning(f"Error processing sample.")
dataset = Dataset.from_list(output_dataset)
return dataset
def train(config):
dataset_path = config["dataset"]["labeled_path"]
student_tokenizer = AutoTokenizer.from_pretrained(
config["models"]["student"],
trust_remote_code=True
)
full_path = config["dataset"]["template"]
template_dir = os.path.dirname(full_path)
template_file = os.path.basename(full_path)
env = Environment(loader=FileSystemLoader(template_dir))
template = env.get_template(template_file)
dataset = process_dataset(dataset_path, student_tokenizer, config, template)
student_model = AutoModelForSequenceClassification.from_pretrained(
config["models"]["student"],
num_labels=1,
trust_remote_code=True
)
student_model.config.pad_token_id = student_tokenizer.pad_token_id
training_arguments = RewardConfig(**config["training"])
trainer = RewardTrainer(
model=student_model,
processing_class=student_tokenizer,
args=training_arguments,
train_dataset=dataset
)
trainer.train()
trainer.save_model(config["training"]["output_dir"])
student_tokenizer.save_pretrained(config["training"]["output_dir"])
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
args = parser.parse_args()
config = json.load(open(args.config))
train(config)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,274 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import jsonlines
import logging
import os
from jinja2 import Environment, FileSystemLoader
from vllm import LLM, SamplingParams
from tqdm import tqdm
from openai import OpenAI
from utils import write_data_to_json_file
# I have checked this function.
def cot_generate_api(data_list, config):
client = OpenAI(
api_key = config["inference"]["api_key"],
base_url = config["inference"]["base_url"]
)
models = client.models.list()
model = models.data[0].id
prompt = config["inference"]["prompt"]
stream = config["inference"]["stream"]
logging.info(model)
outcomes = []
for sample in tqdm(data_list, desc="Calling remote model and generating responses"):
sample = prompt + "\n" + sample
message = [
{"role": "user", "content": sample}
]
completion = client.chat.completions.create(
messages = message,
model = model,
max_completion_tokens = config["inference"]["max_new_tokens"],
stream = stream
)
if stream:
result = ""
for chunk in completion:
result += chunk.choices[0].delta.content
else:
result = completion.choices[0].message.content
if result is not None:
outcomes.append({"instruction": sample, "output": result})
write_data_to_json_file(outcomes, config["dataset"]["output_path"])
def cot_generate_batch(tokenizer, llm, data_list, config, batch_size=32):
full_path = config["dataset"]["template"]
template_dir = os.path.dirname(full_path)
template_file = os.path.basename(full_path)
env = Environment(loader=FileSystemLoader(template_dir))
template = env.get_template(template_file)
prompt = config["inference"]["prompt"]
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
for batch in tqdm(batches, desc="Generating responses"):
new_batch = []
for sample in batch:
sample = prompt + "\n" + sample
logging.info(sample)
message={"role": "user", "content": sample}
full_text = template.render(
message=message,
add_generation_prompt=True,
add_output=False
)
new_batch.append(full_text)
outputs = llm.generate(
new_batch,
SamplingParams(
n=1,
top_k=1,
temperature=config["inference"]["temperature"],
seed=config["inference"]["seed"],
skip_special_tokens=False,
ignore_eos=False,
max_tokens=config["inference"]["max_new_tokens"],
)
)
responses = [output.outputs[0].text for output in outputs]
outcomes = []
for i in range(len(batch)):
if responses[i] is not None:
outcomes.append((sample,responses[i]))
with jsonlines.open(config["dataset"]["output_path"], mode='a') as writer:
for ins,result in outcomes:
gen_data = {"instruction": ins, "output": result}
writer.write(gen_data)
def cot_long2short_api(data_list_ins, data_list_out, config):
client = OpenAI(
api_key = config["inference"]["api_key"],
base_url = config["inference"]["base_url"],
)
models = client.models.list()
model = models.data[0].id
prompt = config["inference"]["prompt"]
stream = config["inference"]["stream"]
logging.info(model)
outcomes = []
data_list=[(ins,out) for ins,out in zip(data_list_ins,data_list_out)]
for ins,out in tqdm(data_list, desc="Calling remote model and generating responses"):
sample = f"{prompt} Simplify the reasoning process for the problem below.\n\nProblem:\n{ins}\n\nAnswer:\n{out}\n\nSimplified Reasoning Process:"
logging.info(sample)
message = [
{"role": "user", "content": sample}
]
completion = client.chat.completions.create(
messages = message,
model = model,
max_completion_tokens = config["inference"]["max_new_tokens"],
stream = stream,
)
if stream:
result = ""
for chunk in completion:
result += chunk.choices[0].delta.content
else:
result = completion.choices[0].message.content
if result is not None:
outcomes.append((sample,result))
with jsonlines.open(config["dataset"]["output_path"], mode='a') as writer:
for ins,result in outcomes:
gen_data = {"instruction": ins, "output": result}
writer.write(gen_data)
def cot_long2short_batch(tokenizer, llm, data_list_ins, data_list_out, config, batch_size=32):
full_path = config["dataset"]["template"]
template_dir = os.path.dirname(full_path)
template_file = os.path.basename(full_path)
env = Environment(loader=FileSystemLoader(template_dir))
template = env.get_template(template_file)
prompt = config["inference"]["prompt"]
data_list=[(ins,out) for ins,out in zip(data_list_ins,data_list_out)]
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
for batch in tqdm(batches, desc="Generating responses"):
new_batch = []
for ins,out in batch:
sample = f"{prompt} Simplify the reasoning process for the problem below.\n\nProblem:\n{ins}\n\nAnswer:\n{out}\n\nSimplified Reasoning Process:"
logging.info(sample)
message={"role": "user", "content": sample}
full_text = template.render(
message=message,
add_generation_prompt=True,
add_output=False
)
new_batch.append(full_text)
outputs = llm.generate(
new_batch,
SamplingParams(
n=1,
top_k=1,
temperature=config["inference"]["temperature"],
seed=config["inference"]["seed"],
skip_special_tokens=False,
ignore_eos=False,
max_tokens=config["inference"]["max_new_tokens"],
)
)
responses = [output.outputs[0].text for output in outputs]
outcomes = []
for i in range(len(batch)):
if responses[i] is not None:
outcomes.append((sample,responses[i]))
with jsonlines.open(config["dataset"]["output_path"], mode='a') as writer:
for ins,result in outcomes:
gen_data = {"instruction": ins, "output": result}
writer.write(gen_data)
def cot_short2long_api(data_list_ins, data_list_out, config):
client = OpenAI(
api_key = config["inference"]["api_key"],
base_url = config["inference"]["base_url"],
)
models = client.models.list()
model = models.data[0].id
prompt = config["inference"]["prompt"]
stream = config["inference"]["stream"]
logging.info(model)
outcomes = []
data_list=[(ins,out) for ins,out in zip(data_list_ins,data_list_out)]
for ins,out in tqdm(data_list, desc="Calling remote model and generating responses"):
sample = f"{prompt} Extend the reasoning process for the problem below.\n\nProblem:\n{ins}\n\nAnswer:\n{out}\n\nExtended Reasoning Process:"
logging.info(sample)
message = [
{"role": "user", "content": sample}
]
completion = client.chat.completions.create(
messages = message,
model = model,
max_completion_tokens = config["inference"]["max_new_tokens"],
stream = stream,
)
if stream:
result = ""
for chunk in completion:
result += chunk.choices[0].delta.content
else:
result = completion.choices[0].message.content
if result is not None:
outcomes.append((sample,result))
with jsonlines.open(config["dataset"]["output_path"], mode='a') as writer:
for ins,result in outcomes:
gen_data = {"instruction": ins, "output": result}
writer.write(gen_data)
def cot_short2long_batch(tokenizer, llm, data_list_ins, data_list_out, config, batch_size=32):
full_path = config["dataset"]["template"]
template_dir = os.path.dirname(full_path)
template_file = os.path.basename(full_path)
env = Environment(loader=FileSystemLoader(template_dir))
template = env.get_template(template_file)
prompt = config["inference"]["prompt"]
data_list=[(ins,out) for ins,out in zip(data_list_ins,data_list_out)]
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
for batch in tqdm(batches, desc="Generating responses"):
new_batch = []
for ins,out in batch:
sample = f"{prompt} Extend the reasoning process for the problem below.\n\nProblem:\n{ins}\n\nAnswer:\n{out}\n\nExtended Reasoning Process:"
logging.info(sample)
message={"role": "user", "content": sample}
full_text = template.render(
message=message,
add_generation_prompt=True,
add_output=False
)
new_batch.append(full_text)
outputs = llm.generate(
new_batch,
SamplingParams(
n=1,
top_k=1,
temperature=config["inference"]["temperature"],
seed=config["inference"]["seed"],
skip_special_tokens=False,
ignore_eos=False,
max_tokens=config["inference"]["max_new_tokens"],
)
)
responses = [output.outputs[0].text for output in outputs]
outcomes = []
for i in range(len(batch)):
if responses[i] is not None:
outcomes.append((sample,responses[i]))
with jsonlines.open(config["dataset"]["output_path"], mode='a') as writer:
for ins,result in outcomes:
gen_data = {"instruction": ins, "output": result}
writer.write(gen_data)

View File

@@ -0,0 +1,293 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import logging
import os
from jinja2 import Environment, FileSystemLoader
from vllm import LLM, SamplingParams
from tqdm import tqdm
from openai import OpenAI
import random
import re
from utils import read_json_field, write_data_to_json_file, load_tokenizer_and_vllm
def extract_answer(content):
pattern = r'<answer>(.*?)</answer>'
match = re.search(pattern, content, re.DOTALL)
if match:
return match.group(1)
else:
return None
def extract_instruction_response(content):
instruction_pattern = r'<instruction>(.*?)</instruction>'
instruction_match = re.search(instruction_pattern, content, re.DOTALL)
response_pattern = r'<response>(.*?)</response>'
response_match = re.search(response_pattern, content, re.DOTALL)
if instruction_match and response_match:
return instruction_match.group(1), response_match.group(1)
else:
return None, None
def generate_prompt_list(data_list, prompt, num_in_context_samples, num_output_samples):
if num_in_context_samples > len(data_list):
raise ValueError("num_in_context_samples cannot be larger than the length of data_list")
output_list = []
for _ in range(num_output_samples):
selected_samples = random.sample(data_list, num_in_context_samples)
combined_prompts = prompt + "\n" + "".join([sample + "\n" for sample in selected_samples])
output_list.append(combined_prompts)
return output_list
def expand_instruction_api(data_list, config):
client = OpenAI(
api_key = config["inference"]["api_key"],
base_url = config["inference"]["base_url"],
)
models = client.models.list()
model = models.data[0].id
num_output_samples = config["dataset"]["num_output_samples"]
num_in_context_samples = config["dataset"]["num_in_context_samples"]
prompt = config["inference"]["prompt"]
stream = config["inference"]["stream"]
logging.info(model)
prompt_list = generate_prompt_list(data_list, prompt, num_in_context_samples, num_output_samples)
outcomes = []
for sample in tqdm(prompt_list, desc="Calling remote model and generating responses"):
logging.info(sample)
message = [
{"role": "user", "content": sample}
]
completion = client.chat.completions.create(
messages = message,
model = model,
max_completion_tokens = config["inference"]["max_new_tokens"],
stream = stream,
)
if stream:
result = ""
for chunk in completion:
result += chunk.choices[0].delta.content
else:
result = completion.choices[0].message.content
result = extract_answer(result)
if result is not None:
outcomes.append({"instruction": result})
write_data_to_json_file(outcomes, config["dataset"]["output_path"])
def expand_instruction_batch(tokenizer, llm, data_list, config, batch_size=32):
full_path = config["dataset"]["template"]
template_dir = os.path.dirname(full_path)
template_file = os.path.basename(full_path)
env = Environment(loader=FileSystemLoader(template_dir))
template = env.get_template(template_file)
num_output_samples = config["dataset"]["num_output_samples"]
num_in_context_samples = config["dataset"]["num_in_context_samples"]
prompt = config["inference"]["prompt"]
prompt_list = generate_prompt_list(data_list, prompt, num_in_context_samples, num_output_samples)
outcomes = []
batches = [prompt_list[i:i + batch_size] for i in range(0, len(prompt_list), batch_size)]
for batch in tqdm(batches, desc="Generating responses"):
new_batch = []
for sample in batch:
logging.info(sample)
message={"role": "user", "content": sample}
full_text = template.render(
message=message,
add_generation_prompt=True,
add_output=False
)
new_batch.append(full_text)
outputs = llm.generate(
new_batch,
SamplingParams(
n=1,
top_k=1,
temperature=config["inference"]["temperature"],
seed=config["inference"]["seed"],
skip_special_tokens=False,
ignore_eos=False,
max_tokens=config["inference"]["max_new_tokens"]
)
)
responses = [output.outputs[0].text for output in outputs]
for i in range(len(batch)):
result = extract_answer(responses[i])
if result is not None:
outcomes.append({"instruction": result})
write_data_to_json_file(outcomes, config["dataset"]["output_path"])
def refine_instruction_api(data_list, config):
client = OpenAI(
api_key = config["inference"]["api_key"],
base_url = config["inference"]["base_url"],
)
models = client.models.list()
model = models.data[0].id
prompt = config["inference"]["prompt"]
stream = config["inference"]["stream"]
logging.info(model)
outcomes = []
for sample in tqdm(data_list, desc="Calling remote model and generating responses"):
sample = prompt + "\n" + sample
logging.info(sample)
message = [
{"role": "user", "content": sample}
]
completion = client.chat.completions.create(
messages = message,
model = model,
max_completion_tokens = config["inference"]["max_new_tokens"],
stream = stream
)
if stream:
result = ""
for chunk in completion:
result += chunk.choices[0].delta.content
else:
result = completion.choices[0].message.content
result = extract_answer(result)
if result is not None:
outcomes.append({"instruction": result})
write_data_to_json_file(outcomes, config["dataset"]["output_path"])
def refine_instruction_batch(tokenizer, llm, data_list, config, batch_size=32):
full_path = config["dataset"]["template"]
template_dir = os.path.dirname(full_path)
template_file = os.path.basename(full_path)
env = Environment(loader=FileSystemLoader(template_dir))
template = env.get_template(template_file)
prompt = config["inference"]["prompt"]
outcomes = []
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
for batch in tqdm(batches, desc="Generating responses"):
new_batch = []
for sample in batch:
sample = prompt + "\n" + sample
logging.info(sample)
message={"role": "user", "content": sample}
full_text = template.render(
message=message,
add_generation_prompt=True,
add_output=False
)
new_batch.append(full_text)
outputs = llm.generate(
new_batch,
SamplingParams(
n=1,
top_k=1,
temperature=config["inference"]["temperature"],
seed=config["inference"]["seed"],
skip_special_tokens=False,
ignore_eos=False,
max_tokens=config["inference"]["max_new_tokens"],
)
)
responses = [output.outputs[0].text for output in outputs]
for i in range(len(batch)):
result = extract_answer(responses[i])
if result is not None:
outcomes.append({"instruction": result})
write_data_to_json_file(outcomes, config["dataset"]["output_path"])
def instruction_response_extraction_api(data_list, config):
client = OpenAI(
api_key = config["inference"]["api_key"],
base_url = config["inference"]["base_url"],
)
models = client.models.list()
model = models.data[0].id
prompt = config["inference"]["prompt"]
stream = config["inference"]["stream"]
logging.info(model)
outcomes = []
for sample in tqdm(data_list, desc="Calling remote model and generating responses"):
sample = prompt + "\n" + sample
logging.info(sample)
message = [
{"role": "user", "content": sample}
]
completion = client.chat.completions.create(
messages = message,
model = model,
max_completion_tokens = config["inference"]["max_new_tokens"],
stream= stream,
)
if stream:
result = ""
for chunk in completion:
result += chunk.choices[0].delta.content
else:
result = completion.choices[0].message.content
new_instruction, new_response = extract_instruction_response(result)
if new_instruction is not None and new_response is not None:
outcomes.append({"instruction": new_instruction, "output": new_response})
write_data_to_json_file(outcomes, config["dataset"]["output_path"])
def instruction_response_extraction_batch(tokenizer, llm, data_list, config, batch_size=32):
full_path = config["dataset"]["template"]
template_dir = os.path.dirname(full_path)
template_file = os.path.basename(full_path)
env = Environment(loader=FileSystemLoader(template_dir))
template = env.get_template(template_file)
prompt = config["inference"]["prompt"]
outcomes = []
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
for batch in tqdm(batches, desc="Generating responses"):
new_batch = []
for sample in batch:
logging.info(sample)
sample = prompt + "\n" + sample
message={"role": "user", "content": sample}
full_text = template.render(
message=message,
add_generation_prompt=True,
add_output=False
)
new_batch.append(full_text)
outputs = llm.generate(
new_batch,
SamplingParams(
n=1,
top_k=1,
temperature=config["inference"]["temperature"],
seed=config["inference"]["seed"],
skip_special_tokens=False,
ignore_eos=False,
max_tokens=config["inference"]["max_new_tokens"],
)
)
responses = [output.outputs[0].text for output in outputs]
for i in range(len(batch)):
new_instruction, new_response = extract_instruction_response(responses[i])
if new_instruction is not None and new_response is not None:
outcomes.append({"instruction": new_instruction, "output": new_response})
write_data_to_json_file(outcomes, config["dataset"]["output_path"])

View File

@@ -0,0 +1,107 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import argparse
import logging
import json
from instruct_synthesis import (
expand_instruction_api,
expand_instruction_batch,
refine_instruction_api,
refine_instruction_batch,
instruction_response_extraction_api,
instruction_response_extraction_batch
)
from cot_synthesis import (
cot_generate_api,
cot_generate_batch,
cot_long2short_api,
cot_long2short_batch,
cot_short2long_api,
cot_short2long_batch
)
from utils import read_json_field, load_tokenizer_and_vllm
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def data_synthesis_with_teacher_model(config):
logging.info('Generating distillation data from the teacher model!')
job_type = config["job_type"]
if job_type == "instruction_response_extraction_api":
data_list = read_json_field(config["dataset"]["input_path"], field_name="data")
elif job_type in ["cot_long2short_api","cot_long2short_batch","cot_short2long_api","cot_short2long_batch"]:
data_list_ins = read_json_field(config["dataset"]["input_path"])
data_list_out = read_json_field(config["dataset"]["input_path"], field_name="output")
else:
data_list = read_json_field(config["dataset"]["input_path"])
try:
if job_type == "instruction_expansion_api":
expand_instruction_api(data_list, config)
elif job_type == "instruction_expansion_batch":
tokenizer, llm = load_tokenizer_and_vllm(config)
expand_instruction_batch(tokenizer, llm, data_list, config)
elif job_type == "instruction_refinement_api":
refine_instruction_api(data_list, config)
elif job_type == "instruction_refinement_batch":
tokenizer, llm = load_tokenizer_and_vllm(config)
refine_instruction_batch(tokenizer, llm, data_list, config)
elif job_type == "instruction_response_extraction_api":
instruction_response_extraction_api(data_list, config)
elif job_type == "instruction_response_extraction_batch":
tokenizer, llm = load_tokenizer_and_vllm(config)
instruction_response_extraction_batch(tokenizer, llm, data_list, config)
elif job_type == "cot_generation_api":
cot_generate_api(data_list, config)
elif job_type == "cot_generation_batch":
tokenizer, llm = load_tokenizer_and_vllm(config)
cot_generate_batch(tokenizer, llm, data_list, config)
elif job_type == "cot_long2short_api":
cot_long2short_api(data_list_ins, data_list_out, config)
elif job_type == "cot_long2short_batch":
tokenizer, llm = load_tokenizer_and_vllm(config)
cot_long2short_batch(tokenizer, llm, data_list_ins, data_list_out, config)
elif job_type == "cot_short2long_api":
cot_short2long_api(data_list_ins, data_list_out, config)
elif job_type == "cot_short2long_batch":
tokenizer, llm = load_tokenizer_and_vllm(config)
cot_short2long_batch(tokenizer, llm, data_list_ins, data_list_out, config)
else:
logging.error(f"Invalid job type: {job_type}")
raise ValueError(f"Invalid job type: {job_type}")
except ValueError as e:
logging.error(f"Training job terminated: {e}")
return
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
args = parser.parse_args()
config = json.load(open(args.config))
data_synthesis_with_teacher_model(config)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,85 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json
import torch
import logging
from vllm import LLM
from transformers import AutoTokenizer
def read_json_field(filename, field_name='instruction'):
try:
with open(filename, 'r') as file:
data = json.load(file)
output_fields = []
for item in data:
if field_name in item:
output_fields.append(item[field_name])
return output_fields
except FileNotFoundError:
logging.error("The file was not found.")
except json.JSONDecodeError:
logging.error("There was an error decoding the JSON file.")
except Exception as e:
logging.error(f"An error occurred: {e}")
def write_data_to_json_file(data, file_path):
try:
with open(file_path, 'w') as file:
json.dump(data, file, ensure_ascii=False, indent=4)
logging.info(f"Data successfully written to {file_path}")
except Exception as e:
logging.error(f"An error occurred: {e}")
def load_tokenizer_and_vllm(config, eos_token=None):
teacher_model_path = config["models"]["teacher"]
logging.info(f"Loading ckpt and tokenizer: {teacher_model_path}")
tokenizer = AutoTokenizer.from_pretrained(teacher_model_path, trust_remote_code=True)
tokenizer.padding_side = "left"
if eos_token:
eos_token_id = tokenizer.convert_tokens_to_ids(eos_token)
logging.info(f"eos_token {eos_token} from user input")
elif hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id:
logging.info(f"Initial eos_token_id {tokenizer.eos_token_id} from tokenizer")
eos_token_id = tokenizer.eos_token_id
eos_token = tokenizer.convert_ids_to_tokens(eos_token_id)
else:
raise ValueError("No available eos_token or eos_token_id.")
try:
tokenizer.eos_token = eos_token
tokenizer.eos_token_id = eos_token_id
tokenizer.pad_token = eos_token
tokenizer.pad_token_id = eos_token_id
except:
logging.info(f"[WARNING] Cannot set tokenizer.eos_token")
logging.info(f"tokenizer's eos_token: {tokenizer.eos_token}, pad_token: {tokenizer.pad_token}")
logging.info(f"tokenizer's eos_token_id: {tokenizer.eos_token_id}, pad_token_id: {tokenizer.pad_token_id}")
num_gpus = torch.cuda.device_count()
llm = LLM(
model=teacher_model_path,
tensor_parallel_size=num_gpus,
enable_chunked_prefill=config["inference"]["enable_chunked_prefill"],
gpu_memory_utilization=config["inference"]["gpu_memory_utilization"],
trust_remote_code=config["inference"]["trust_remote_code"],
dtype=torch.bfloat16,
enforce_eager=config["inference"]["enforce_eager"],
max_model_len=config["inference"]["max_model_len"],
)
logging.info("vLLM model loaded successfully")
return tokenizer, llm

View File

@@ -0,0 +1,76 @@
# DistilQwen2.5-0324: training fast-thinking models
## Brief Introduction
In the rapid advancement of large language models, effectively balancing the trade-off between efficient inference and model thinking capabilities has been a key focus in both academia and industry. DeepSeekV3-0324, by default, does not employ deep thinking mode, which accelerates model inference while maintaining a balance between swift reasoning and handling complex tasks. The DistilQwen2.5-0324 series not only inherits the essence of the original model's chain-of-thought distillation but also introduces fast-thinking strategies, significantly boosting inference speed. This enables these models to efficiently execute complex tasks on resource-constrained devices and in edge computing scenarios.
## Detailed Steps
### Processing of Instructional Dataset
DistilQwen2.5-0324 was trained using data distilled from Deepseek-V3-0324 as well as data rewritten with long2short after distillation from Deepseek-R1. For Deepseek-V3-0324, the official recommendation is not to use a system prompt; for the long2short scenario, the following prompt was used. You can employ this method to reduce the output of Deepseek-R1 and distill your own model.
```json
{
"system": "You are a helpful assistant who is highly skilled at simplifying reasoning processes. Given a problem, its answer and its reasoning process, your task is to simplify the reasoning process so that a small language model (e.g., a 7B model) can reliably follow the steps to solve the problem. If the original reasoning process is divided into multiple steps separated by two newline characters (\n\n), your output must preserve this formatting. You must output ONLY the simplified reasoning process with no additional explanation or commentary."
}
```
```bash
python easydistill/kd/infer.py --config=distilqwen2.5-0324_stage1.json
```
The training dataset is in JSON format, exemplified by entries such as:
```json
[
{
"instruction": "The ratio of the number of molar teeth in the human upper jaw at the age of 6 is 2:1 compared to number of incisors teeth. There are total 8 incisors in the human mouth...",
"output": "Step 1: Determine the total number of incisors in the upper jaw...The final answer is: \\boxed{8}"
}
]
```
### Black-Box KD
The black-box KD process follows a supervised learning paradigm, utilizing enhanced instruction-response pairs as training samples. Through this approach, the student model can effectively absorb and understand the knowledge imparted by the larger model, even with a limited number of parameters. This method not only boosts the student model's ability to tackle tasks but also enables it to perform better in multi-task scenarios. Because we have already obtained the teacher's responses in the dataset, we can run the training job:
```bash
python easydistill/kd/train.py --config=distilqwen2.5-0324_stage2.json
```
Plese refer to the config file `distilqwen2.5-0324_stage2.json` in the current folder. If you need to run the job in a distributed mode, use `accelerate` to run the job.
## Model Download
We have open-sourced our distilled models on both HuggingFace and ModelScope. The available models are named `alibaba-pai/DistilQwen2.5-DS3-0324-7B`, `alibaba-pai/DistilQwen2.5-DS3-0324-14B`, and `alibaba-pai/DistilQwen2.5-DS3-0324-32B`.
For example, users can download these models from HuggingFace using the following code:
```python
from huggingface_hub import snapshot_download
# Download the 1.5B model
model_name = "alibaba-pai/DistilQwen2.5-DS3-0324-7B"
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2.5-DS3-0324-7B/")
# Download the 3B model
model_name = "alibaba-pai/DistilQwen2.5-DS3-0324-14B"
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2.5-DS3-0324-14B/")
# Download the 7B model
model_name = "alibaba-pai/DistilQwen2.5-DS3-0324-32B"
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2.5-DS3-0324-32B/")
```
## Performance
- **32B Model** approaches the performance of closed-source models with 10x the parameters on the GPQA Diamond benchmark
- **Significant Improvement in Reasoning Efficiency** (see comparison table below)
| Model | MMLU_PRO Tokens | AIME2024 Tokens | Speed Gain |
|--------------------------------|-----------------|-----------------|------------|
| DistilQwen2.5-R1-32B (Slow-Thinking) | 4198 | 12178 | 1x |
| DistilQwen2.5-DS3-0324-32B | 690 | 4177 | 5-8x |

View File

@@ -0,0 +1,14 @@
{
"job_type": "cot_long2short_api",
"dataset": {
"input_path": "./raw.json",
"output_path": "./raw_simplified.json"
},
"inference":{
"base_url": "ENDPOINT",
"api_key": "TOKEN",
"stream": true,
"prompt" : "You are a helpful assistant who is highly skilled at simplifying reasoning processes. Given a problem, its answer and its reasoning process, your task is to simplify the reasoning process so that a small language model (e.g., a 7B model) can reliably follow the steps to solve the problem. If the original reasoning process is divided into multiple steps separated by two newline characters (\n\n), your output must preserve this formatting. You must output ONLY the simplified reasoning process with no additional explanation or commentary.",
"max_new_tokens": 1024
}
}

View File

@@ -0,0 +1,23 @@
{
"job_type": "kd_black_box_api",
"dataset": {
"labeled_path": "distil_qwen_0324.json",
"template" : "chat_template_kd.jinja",
"seed": 42
},
"models": {
"student": "model/Qwen/Qwen2.5-1.5B-Instruct/"
},
"training": {
"output_dir": "result_stage2/",
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 8,
"save_steps": 1000,
"logging_steps": 1,
"learning_rate": 2e-5,
"weight_decay": 0.05,
"warmup_ratio": 0.1,
"lr_scheduler_type": "cosine"
}
}

View File

@@ -0,0 +1,142 @@
# DistilQwen2.5-R1: training distilled reasonin models based on CoTs generated by Deepseek-R1
## Brief Introduction
As large language models (LLMs) evolve toward deep reasoning capabilities, deploying them in resource-constrained environments (e.g., mobile devices, edge computing) remains challenging. The DistilQwen2.5-R1 series addresses this by transferring reasoning capabilities from ultra-large models (e.g., DeepSeek-R1) to compact models through innovative distillation techniques, achieving high performance while reducing computational costs.
## Data Generation Detailed Steps
### 1. Generate Thinking Dataset
Distillqwen-r1 is trained using chain-of-thought data distilled from deepseek-r1. We provide the system prompts used for distilling the R1 data and the system prompts used for training qwen2.5. You can use the current system prompts to call Deepseek-R1 to generate your own data and train the model.
```json
{
"system": "Your role as an assistant involves thoroughly exploring questions through a systematic long thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution. In the Thought section, detail your reasoning process using the specified format: <|begin_of_thought|> {thought with steps separated with '\n\n'} <|end_of_thought|> Each step should include detailed considerations such as analisying questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The solution should remain a logical, accurate, concise expression style and detail necessary step needed to reach the conclusion, formatted as follows: <|begin_of_solution|> {final formatted, precise, and clear solution} <|end_of_solution|> Now, try to solve the following question through the above guidelines:"
}
```
### 2. Determine the Difficulty Level
Critiquing the CoT qualities according to the cognitive capabilities of smaller models. You can use the current system prompts using QwQ-32B to determine the difficulty level of the CoTs.
```json
{
"system": "You are a highly capable evaluator. Your task is to assess the given reasoning process from the perspective of a small language model (e.g., 7B). Specifically, determine whether the reasoning process provides sufficient detail for a small model to solve the problem, or whether it is too simplistic (i.e., lacking critical details) or too complex (i.e., containing unnecessary or confusing steps). Difficulty Definitions (from the perspective of a small model): - Easy: The reasoning process is overly simplistic relative to the problem's difficulty; it omits essential details that a small model needs to solve the problem. - Medium: The reasoning process is appropriately balanced, offering enough detailed guidance. - Hard: The reasoning process is overly complex, with extraneous or convoluted steps that could hinder a small model's ability to follow it. Output Format: You must output exactly one word: easy, medium, or hard. Do NOT provide any additional text, explanation."
}
```
### 3. Rethinking and Refining these CoTs
Rethinking and refining these CoTs based on the critiques using following prompts:
#### easy
```json
{
"system": "You are a helpful assistant who is highly skilled at extending reasoning processes. Given a problem, its answer, and its reasoning process, your task is to extend the reasoning process by adding necessary details and intermediate steps so that a small language model (e.g., a 7B model) can follow the extended reasoning process to solve the problem. If the original reasoning process is divided into multiple steps separated by two newline characters, your output must preserve this formatting. You must output ONLY the extended reasoning process with no additional explanation or commentary."
}
```
#### hard
```json
{
"system": "You are a helpful assistant who is highly skilled at simplifying reasoning processes. Given a problem, its answer, and its reasoning process, your task is to simplify the reasoning process so that a small language model (e.g., a 7B model) can reliably follow the steps to solve the problem. If the original reasoning process is divided into multiple steps separated by two newline characters, your output must preserve this formatting. You must output ONLY the simplified reasoning process with no additional explanation or commentary."
}
```
The training dataset is in JSON format, exemplified by entries such as:
```json
[
{
"instruction": "The ratio of the number of molar teeth in the human upper jaw at the age of 6 is 2:1 compared to number of incisors teeth. There are total 8 incisors in the human mouth...",
"output": "<|begin_of_thought|>## Step 1: Determine the total number of incisors in the upper jaw...\n<|end_of_thought|>\n<|begin_of_solution|>The final answer is: \\boxed{8}<|end_of_solution|>"
}
]
```
## Model Training Guidelines
### 1. Black-Box KD
The black-box KD process follows a supervised learning paradigm, utilizing enhanced instruction-response pairs as training samples. Through this approach, the student model can effectively absorb and understand the knowledge imparted by the larger model, even with a limited number of parameters. This method not only boosts the student model's ability to tackle tasks but also enables it to perform better in multi-task scenarios. Because we have already obtained the teacher's responses in the dataset, we need to run the training job only:
```bash
python easydistill/kd/train.py --config=distilqwen2.5-r1_stage1.json
```
Plese refer to the config file `distilqwen2.5-r1.json` in the current folder. If you need to run the job in a distributed mode, use `accelerate` to run the job.
### 2. CogPO
CogPO (Cognitive Preference Optimization) is a novel algorithm designed to enhance the reasoning abilities of small language models (LLMs) by aligning their reasoning processes with their inherent cognitive capacities.
Key aspects of CogPO:
- Extends Direct Preference Optimization (DPO) with cognitive alignment
- Introduces three specialized "mini-tasks" with different preference gaps
- Dynamically adjusts optimization strength (β values) based on reasoning complexity
- Works synergistically with the CRV (Critique-Rethink-Verify) system
You can run the CogPO by:
```bash
accelerate launch --num_processes n --config_file multi_gpu.yaml cogpo.py --config distilqwen2.5-r1_stage2.json
```
The dataset is in JSON format, exemplified by entries such as:
```json
{
"prompt": "Ellie has 8 pairs of shoes. Riley has 3 fewer. How many pairs of shoes do they have in all?",
"chosen": "<think>Identify the number of pairs of shoes Ellie has. According to the problem statement, Ellie has 8 pairs of shoes.\n Next, determine the number of pairs of shoes Riley has. The problem states that Riley has 3 fewer pairs than Ellie. To find out how many pairs Riley has, subtract 3 from the number of pairs Ellie has: 8 - 3 = 5. So, Riley has 5 pairs of shoes.\n Now, calculate the total number of pairs of shoes both Ellie and Riley have together. To do this, add the number of pairs Ellie has to the number of pairs Riley has: 8 (Ellie's pairs) + 5 (Riley's pairs) = 13 pairs. This step is crucial because it combines the information about both individuals to give the overall total.\n The total number of pairs of shoes they have in all is 13. Thus, the final answer is 13. Each step in the reasoning process is designed to help understand and solve the problem effectively, showing how the information about each individual's shoe count leads to finding the combined total.</think>\boxed{13}",
"rejected": "<think>Identify the number of pairs of shoes Ellie has. Ellie has 8 pairs of shoes as stated in the problem. Determine how many pairs of shoes Riley has. Since Riley has 3 fewer pairs than Ellie, we mistakenly add 3 to Ellie's pairs instead of subtracting, giving us 8 + 3 = 11 pairs of shoes for Riley. Calculate the total number of pairs of shoes they both have. Add Ellie's and Riley's pairs together: 8 + 11. The total pairs of shoes is 19. The final answer is thus \boxed{19}.</think>\boxed{13}",
"beta": 0.5
}
```
## Model Download
We have open-sourced our distilled models on both HuggingFace and ModelScope. The available models are named `alibaba-pai/DistilQwen2.5-R1-3B`, `alibaba-pai/DistilQwen2.5-R1-7B`, `alibaba-pai/DistilQwen2.5-R1-14B`, and `alibaba-pai/DistilQwen2.5-R1-32B`.
For example, users can download these models from HuggingFace using the following code:
```python
from huggingface_hub import snapshot_download
# Download the 3B model
model_name = "alibaba-pai/DistilQwen2.5-R1-3B"
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2.5-R1-3B/")
# Download the 7B model
model_name = "alibaba-pai/DistilQwen2.5-R1-7B"
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2.5-R1-7B/")
# Download the 14B model
model_name = "alibaba-pai/DistilQwen2.5-R1-14B"
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2.5-R1-14B/")
# Download the 32B model
model_name = "alibaba-pai/DistilQwen2.5-R1-32B"
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2.5-R1-32B/")
```
## Performance
We compared DistilQwen2.5-R1 series with leading reasoning models across four benchmarks:
### 7B Model Comparison
| Model | Training Data Size | AIME2024 | MATH-500 | GPQA Diamond | LiveCodeBench V2 |
|--------------------------------|--------------------|----------|----------|--------------|------------------|
| DeepSeek-R1-Distill-Qwen-7B | 800k | 55.5 | 92.8 | 49.1 | - |
| Bespoke-Stratos-7B | 17k | 20.0 | 82.0 | 37.8 | 36.1 |
| OpenThinker-7B | 114k | 31.3 | 83.0 | 42.4 | 39.9 |
| **DistilQwen2.5-R1-7B** | 105k | 43.33 | 88.4 | 42.93 | 46.38 |
### 32B Model Comparison
| Model | Training Data Size | AIME2024 | MATH-500 | GPQA Diamond | LiveCodeBench V2 |
|--------------------------------|--------------------|----------|----------|--------------|------------------|
| DeepSeek-R1-Distill-Qwen-32B | 800k | 72.6 | 94.3 | 62.1 | - |
| Sky-T1-32B-Preview | 17k | 43.3 | 86.4 | 56.8 | - |
| OpenThinker-32B | 114k | 66.0 | 90.6 | 61.6 | 68.9 |
| **DistilQwen2.5-R1-32B** | 105k | 70.0 | 93.8 | 62.12 | 65.95 |

View File

@@ -0,0 +1,194 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import trl
from trl.trainer.dpo_trainer import DataCollatorForPreference
from dataclasses import dataclass
from typing import Any, Callable, Literal, Optional, Union
import torch, torch.nn.functional as F
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer, FDivergenceConstants, FDivergenceType
from trl.trainer.utils import cap_exp
import json
import argparse
@dataclass
class DataCollatorForPreferenceWithBeta(DataCollatorForPreference):
def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]:
betas = torch.tensor([float(ex["beta"]) for ex in examples], dtype=torch.float32)
for ex in examples:
ex.pop("beta")
batch = super().torch_call(examples)
batch["betas"] = betas
return batch
class CogPOTrainer(DPOTrainer):
def get_batch_loss_metrics(
self,
model,
batch,
train_eval: str = "train",
):
metrics = {}
betas = batch.pop("betas").to(self.accelerator.device)
model_output = self.concatenated_forward(model, batch)
if "ref_chosen_logps" in batch and "ref_rejected_logps" in batch:
ref_chosen_logps = batch["ref_chosen_logps"]
ref_rejected_logps = batch["ref_rejected_logps"]
else:
ref_chosen_logps, ref_rejected_logps = self.compute_ref_log_probs(batch)
losses, chosen_rewards, rejected_rewards = self._dpo_sigmoid_loss(
model_output["chosen_logps"],
model_output["rejected_logps"],
ref_chosen_logps,
ref_rejected_logps,
betas,
)
reward_accuracies = (chosen_rewards > rejected_rewards).float()
if self.args.rpo_alpha is not None:
losses = losses + self.args.rpo_alpha * model_output["nll_loss"]
if self.use_weighting:
losses = losses * model_output["policy_weights"]
if self.aux_loss_enabled:
losses = losses + self.aux_loss_coef * model_output["aux_loss"]
prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
metrics[f"{prefix}rewards/margins"] = (
self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item()
)
metrics[f"{prefix}logps/chosen"] = (
self.accelerator.gather_for_metrics(model_output["chosen_logps"]).detach().mean().item()
)
metrics[f"{prefix}logps/rejected"] = (
self.accelerator.gather_for_metrics(model_output["rejected_logps"]).detach().mean().item()
)
metrics[f"{prefix}logits/chosen"] = (
self.accelerator.gather_for_metrics(model_output["mean_chosen_logits"]).detach().mean().item()
)
metrics[f"{prefix}logits/rejected"] = (
self.accelerator.gather_for_metrics(model_output["mean_rejected_logits"]).detach().mean().item()
)
if self.args.rpo_alpha is not None:
metrics[f"{prefix}nll_loss"] = (
self.accelerator.gather_for_metrics(model_output["nll_loss"]).detach().mean().item()
)
if self.aux_loss_enabled:
metrics[f"{prefix}aux_loss"] = (
self.accelerator.gather_for_metrics(model_output["aux_loss"]).detach().mean().item()
)
return losses.mean(), metrics
def _dpo_sigmoid_loss(
self,
chosen_logps: torch.FloatTensor,
rejected_logps: torch.FloatTensor,
ref_chosen_logps: torch.FloatTensor,
ref_rejected_logps: torch.FloatTensor,
betas: torch.FloatTensor,
):
device = self.accelerator.device
chosen_logratios = chosen_logps.to(device) - (not self.reference_free) * ref_chosen_logps.to(device)
rejected_logratios = rejected_logps.to(device) - (not self.reference_free) * ref_rejected_logps.to(device)
# 2) Δ = (log p_c - log p_r) - (log p̂_c - log p̂_r)
if self.f_divergence_type == FDivergenceType.ALPHA_DIVERGENCE.value:
alpha_coef = FDivergenceConstants.ALPHA_DIVERGENCE_COEF_DEFAULT
if self.f_divergence_params and FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY in self.f_divergence_params:
alpha_coef = float(self.f_divergence_params[FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY])
logits = (cap_exp(rejected_logratios * -alpha_coef) - cap_exp(chosen_logratios * -alpha_coef)) / alpha_coef
else:
logratios = chosen_logps - rejected_logps
if self.reference_free:
ref_logratios = torch.tensor([0], dtype=logratios.dtype, device=logratios.device)
else:
ref_logratios = ref_chosen_logps - ref_rejected_logps
logratios = logratios.to(self.accelerator.device)
ref_logratios = ref_logratios.to(self.accelerator.device)
logits = logratios - ref_logratios
if self.f_divergence_type == FDivergenceType.JS_DIVERGENCE.value:
logits -= F.softplus(chosen_logratios) - F.softplus(rejected_logratios)
losses = (
-F.logsigmoid(betas * logits) * (1 - self.label_smoothing)
- F.logsigmoid(-betas * logits) * self.label_smoothing
)
chosen_rewards = betas * (chosen_logps.to(device) - ref_chosen_logps.to(device)).detach()
rejected_rewards = betas * (rejected_logps.to(device) - ref_rejected_logps.to(device)).detach()
return losses, chosen_rewards, rejected_rewards
def train(config):
model_name = config["models"]["student"]
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
dataset = load_dataset("json", data_files=config["dataset"]["labeled_path"], split='train')
dpo_args = DPOConfig(
output_dir=config["training"]["output_dir"],
num_train_epochs=config["training"]["num_train_epochs"],
loss_type=config["training"]["loss_type"],
beta=config["training"]["beta"],
per_device_train_batch_size=config["training"]["per_device_train_batch_size"],
remove_unused_columns=False,
)
collator = DataCollatorForPreferenceWithBeta(
pad_token_id=tokenizer.pad_token_id
)
trainer = CogPOTrainer(
model=model,
args=dpo_args,
train_dataset=dataset,
tokenizer=tokenizer,
data_collator=collator,
)
trainer.train()
trainer.save_model(config["training"]["output_dir"])
tokenizer.save_pretrained(config["training"]["output_dir"])
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
args = parser.parse_args()
config = json.load(open(args.config))
train(config)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,23 @@
{
"job_type": "kd_black_box_api",
"dataset": {
"labeled_path": "distil_qwen_r1.json",
"template" : "chat_template_kd.jinja",
"seed": 42
},
"models": {
"student": "model/Qwen/Qwen2.5-1.5B-Instruct/"
},
"training": {
"output_dir": "result_stage1/",
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 8,
"save_steps": 1000,
"logging_steps": 1,
"learning_rate": 2e-5,
"weight_decay": 0.05,
"warmup_ratio": 0.1,
"lr_scheduler_type": "cosine"
}
}

View File

@@ -0,0 +1,15 @@
{
"models": {
"student": "models/Qwen2.5-0.5B-Instruct"
},
"dataset": {
"labeled_path": "cogpo/test500.jsonl"
},
"training": {
"output_dir": "save/Qwen2.5-0.5B-CogPO",
"num_train_epochs": 1.0,
"loss_type": "sigmoid",
"beta": 1.0,
"per_device_train_batch_size": 2
}
}

View File

@@ -0,0 +1,101 @@
# DistilQwen-ThoughtX: Optimized Reasoning Models with OmniThought
## Brief Introduction
DistilQwen-ThoughtX is a series of high-performance reasoning models trained on the [OmniThought](https://huggingface.co/datasets/alibaba-pai/OmniThought) dataset. These models are optimized for chain-of-thought (CoT) reasoning with balanced verbosity and cognitive difficulty, achieving state-of-the-art results on mathematical, coding, and logical reasoning benchmarks.
## Detailed Steps
### Direct Training
DistilQwen-ThoughtX was trained using data from the OmniThought dataset, which includes 2 million CoT processes with RV (Reasoning Verbosity) and CD (Cognitive Difficulty) annotations. The dataset covers mathematics, coding, and logical reasoning tasks, validated by multiple teacher models (DeepSeek-R1, QwQ-32B).
The training system prompt is:
```json
{
"system": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
}
```
Using the OmniThought dataset, we can run the training job:
```bash
python easydistill/kd/train.py --config=distilqwen2.5-thoughtx-train.json
```
Remember to filter the RV and CD annotations to ensure they are within the desired range to train your own model.
| Model Name | Parameters | Base Model |
|--------------------------------------|------------|---------------------|
| `DistilQwen-ThoughtX-7B` | 7B | Qwen2.5-7B-Instruct |
| `DistilQwen-ThoughtX-32B` | 32B | Qwen2.5-32B-Instruct|
### Process Your Own Data
To obtain the RV and CD values of your own data, you can use the following prompt to call QwQ-32B/Deepseek-R1, score your own data, and filter it.
Prompt Template to Calculate the RV Score
```json
{
"prompt": "You are an expert judge tasked with evaluating the Reasoning Verbosity of a Chain-of-Thought (CoT) for a given problem and its answer. Reasoning Verbosity Evaluation Focus: Assess how well the CoTs length and step complexity match the problems inherent difficulty. An optimal chain is neither missing essential steps nor padded with needless digressions. A simple question should be solved with a brief, direct chain; a challenging one may justifiably require a longer path with reflection and error-checking. Scoring Guidelines (0-9): 0-1 Minimal verbosity, straightforward expression with little to no elaboration. 2-3 Clear and concise reasoning with necessary explanations. 4-5 Moderate verbosity with detailed explanations and thorough reasoning. 6-7 Extensive verbosity with comprehensive justification and exploration of complex connections. 8-9 High verbosity with deep, exhaustive exploration of reasoning; involves extensive elaboration, nested justifications, and consideration of counterarguments or alternative perspectives. Given Problem, Chain-of-Thought and Answer, you will: 1. Analyze the Reasoning Verbosity 2. Determine score using the above criteria 3. Output ONLY the integer score (0-9) Problem: {problem} Chain-of-Thought: {thought} Answer: {solution}"
}
```
Prompt Template to Calculate the CD Score
```json
{
"prompt": "You are an expert judge assessing the Cognitive Difficulty of a Chain-of-Thought (CoT) for a given problem and its answer. Cognitive Difficulty Evaluation Focus: The level of reasoning competence required for a model to follow and reproduce the chain faithfully. Judge the reasoning approach, techniques, and overall difficulty. Higher scores correspond to more advanced concepts, abstractions, or multi-layer reasoning patterns. Scoring Guidelines (0-9): 0-1 Elementary facts or a single trivial operation. 2-3 Multi-step arithmetic, explicit enumeration, basic rule chaining. 4-5 Early-undergraduate logic/algebra; one non-obvious insight. 6-7 Advanced undergraduate techniques (determinants, dynamic programming, layered code reasoning, etc). 8-9 Graduate-level abstraction, nested proofs, intricate algorithmic analysis. Given Problem, Chain-of-Thought and Answer, you will: 1. Analyze the Cognitive Difficulty 2. Determine score using the above criteria 3. Output ONLY the integer score (0-9) Problem: {problem} Chain-of-Thought: {thought} Answer: {solution}"
}
```
## Model Download
We have open-sourced our distilled models on HuggingFace. The available models are named `alibaba-pai/DistilQwen-ThoughtX-7B` and `alibaba-pai/DistilQwen-ThoughtX-32B`.
Users can download these models from HuggingFace using the following code:
```python
from huggingface_hub import snapshot_download
# Download the 7B model
model_name = "alibaba-pai/DistilQwen-ThoughtX-7B"
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen-ThoughtX-7B/")
# Download the 32B model
model_name = "alibaba-pai/DistilQwen-ThoughtX-32B"
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen-ThoughtX-32B/")
```
## Performance
The models achieve state-of-the-art performance on various reasoning benchmarks:
| Model | AIME2024 | MATH500 | GPQA-D | LiveCodeBench V2 |
|----------------------|----------|---------|--------|------------------|
| DeepSeek-R1-Distill-7B | 57.3 | 89.6 | 47.3 | 48.4 |
| **DistilQwen-ThoughtX-7B** | **56.7** | **90.2** | **50.0** | **56.8** |
| DeepSeek-R1-Distill-32B | 74.7 | 90.0 | 62.4 | 72.3 |
| **DistilQwen-ThoughtX-32B** | **80.0** | **92.6** | **64.0** | **73.4** |
## Reference
For more detailed information about the model, we encourage you to refer to our paper:
- **Reasoning with OmniThought: A Large CoT Dataset with Verbosity and Cognitive Difficulty Annotations**
Wenrui Cai, Chengyu Wang, Junbing Yan, Jun Huang, Xiangzhong Fang
[arXiv:2505.10937](https://arxiv.org/abs/2505.10937)
You can cite the paper using the following citation format:
```bibtex
@misc{cai2025reasoningomnithoughtlargecot,
title={Reasoning with OmniThought: A Large CoT Dataset with Verbosity and Cognitive Difficulty Annotations},
author={Wenrui Cai and Chengyu Wang and Junbing Yan and Jun Huang and Xiangzhong Fang},
year={2025},
eprint={2505.10937},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2505.10937}
}
```

View File

@@ -0,0 +1,24 @@
{
"job_type": "kd_black_box_api",
"dataset": {
"labeled_path": "distil_qwen_thoughtX.json",
"template" : "chat_template_kd.jinja",
"seed": 42
},
"models": {
"student": "model/Qwen/Qwen2.5-1.5B-Instruct/"
},
"training": {
"output_dir": "result/",
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 8,
"max_length":4096,
"save_steps": 1000,
"logging_steps": 1,
"learning_rate": 2e-5,
"weight_decay": 0.05,
"warmup_ratio": 0.1,
"lr_scheduler_type": "cosine"
}
}

View File

@@ -0,0 +1,135 @@
# DistilQwen2.5: Combining Black-Box and White Box KD
## Brief Introduction
The DistilQwen2.5 distilled language model series is built upon the Qwen2.5 model. This series leverages innovative distillation techniques to enhance instruction-following capabilities. As a result, these distilled models retain the excellent performance of the original models while requiring fewer computational resources.
The distillation process involves carefully selecting, rewriting, and optimizing instruction-response pairs conducive to student model learning, thus improving model comprehension and execution abilities. Following standard fine-tuning, we employ white-box distillation techniques to enable the student models to better acquire fine-grained knowledge from teacher models. Experimental evaluations demonstrate the significant improvement in capabilities of the DistilQwen2.5 models.
## Detailed Steps
### Processing of Instructional Dataset
DistilQwen2.5 begins with collecting diverse, high-quality instructional data from sources like Magpie, Openhermes, and Mammoth 2, along with proprietary datasets. This data includes Chinese and English instructions, scoring them for difficulty and task relevance. This process is very similar to the recipe of DistilQwen2.
In addition, we have open-sourced part of the dataset used for model training, totaling 100K entries. This dataset includes mathematical problems, code tasks, Q&A, instruction following, and creative generation. Users can incorporate the DistilQwen_100K dataset, or its subsets, during model fine-tuning to enhance downstream task performance while maintaining generalization ability. The dataset is in JSON format, exemplified by entries such as:
```json
[
{
"instruction": "The ratio of the number of molar teeth in the human upper jaw at the age of 6 is 2:1 compared to number of incisors teeth. There are total 8 incisors in the human mouth...",
"output": "## Step 1: Determine the total number of incisors in the upper jaw...\n\nThe final answer is: \\boxed{8}"
},
{
"instruction": "This is the text of a lecture I am giving tomorrow. Can you go over it and make recommendations to improve clarity and flow?",
"output": "I'd be happy to help you review your lecture text..."
}
]
```
The dataset is available on ModelScope and Hugging Face. Users can download it using ModelScope's scripts and command-line tools.
```python
# Validate SDK token
from modelscope.hub.api import HubApi
api = HubApi()
api.login('your_token_id')
# Dataset download
from modelscope.msdatasets import MsDataset
ds = MsDataset.load('PAI/DistilQwen_100k')
```
### Black-Box KD
The black-box KD process follows a supervised learning paradigm, utilizing enhanced instruction-response pairs as training samples. Through this approach, the student model can effectively absorb and understand the knowledge imparted by the larger model, even with a limited number of parameters. This method not only boosts the student model's ability to tackle tasks but also enables it to perform better in multi-task scenarios. Because we have already obtained the teacher's responses in the dataset, we need to run the training job only:
```bash
python easydistill/kd/train.py --config=distilqwen2.5_stage1.json
```
Plese refer to the config file `distilqwen2.5_stage1.json` in the current folder. If you need to run the job in a distributed mode, use `accelerate` to run the job.
### White-Box KD
Unlike black-box KD, which relies solely on the highest probability token output by the teacher model, white-box KD focuses on the distribution of logits produced by the teacher model. This approach provides the student model with richer information. By mimicking the teacher model's logits distribution, white-box KD can transfer knowledge more effectively, thereby enhancing the performance of the student model. As an example, we take `Qwen2.5-72B-Instruct` as the white-box teacher model, and generate the logits by:
```bash
python easydistill/kd/infer.py --config=distilqwen2.5_stage2.json
```
Next, we run the training job by:
```bash
python easydistill/kd/train.py --config=distilqwen2.5_stage2.json
```
Again, please refer to the config file `distilqwen2.5_stage2.json` in the current folder. Remember to change the configurations when needed.
## Model Download
We have open-sourced our distilled models on both HuggingFace and ModelScope. The available models are named `alibaba-pai/DistilQwen2.5-0.5B-Instruct`, `alibaba-pai/DistilQwen2.5-1.5B-Instruct`, `alibaba-pai/DistilQwen2.5-3B-Instruct`, and `alibaba-pai/DistilQwen2.5-7B-Instruct`.
For example, users can download these models from HuggingFace using the following code:
```python
from huggingface_hub import snapshot_download
# Download the 0.5B model
model_name = "alibaba-pai/DistilQwen2.5-0.5B-Instruct"
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2.5-0.5B/")
# Download the 1.5B model
model_name = "alibaba-pai/DistilQwen2.5-1.5B-Instruct"
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2.5-1.5B/")
# Download the 3B model
model_name = "alibaba-pai/DistilQwen2.5-3B-Instruct"
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2.5-3B/")
# Download the 7B model
model_name = "alibaba-pai/DistilQwen2.5-7B-Instruct"
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2.5-7B/")
```
## Performance
The table below compares the performance of the original Qwen2.5 models with the distilled DistilQwen2.5 models across different parameter sizes: 0.5B, 1.5B, 3B, and 7B. The evaluation metrics include AlpacaEval 2.0, MT-Bench, and IFEval scores. The distilled models demonstrate improved performance in instruction-following abilities over their respective original versions.
| Model | AlpacaEval 2.0 (length control) | MT-Bench | MT-Bench (single) | IFEval (instruct-loose) | IFEval (strict-prompt) |
|-------------------------------|---------------------------------|------------------|-------------------|-------------------------|------------------------|
| Qwen2.5-0.5B-Instruct | 2.46 | 5.49 | 6.26 | 42.81 | 30.31 |
| **DistilQwen2.5-0.5B-Instruct** | **4.89** | **5.78** | **6.83** | **52.61** | **37.82** |
| Qwen2.5-1.5B-Instruct | 6.69 | 7.09 | 7.66 | 55.40 | 40.11 |
| **DistilQwen2.5-1.5B-Instruct** | **13.69** | **7.35** | **7.99** | **61.10** | **74.49** |
| Qwen2.5-3B-Instruct | 17.98 | 7.92 | 8.40 | 61.18 | 74.58 |
| **DistilQwen2.5-3B-Instruct** | **20.91** | **8.37** | **8.97** | **67.03** | **77.36** |
| Qwen2.5-7B-Instruct | 31.43 | 8.52 | 8.83 | 81.53 | 72.10 |
| **DistilQwen2.5-7B-Instruct** | **34.86** | **8.76** | **9.22** | **83.48** | **73.27** |
For evaluation details, please refer to our paper.
## Reference
For more detailed information about the DistilQwen2.5 model series and the methodologies employed, we encourage you to refer to our paper:
- **DistilQwen2.5: Industrial Practices of Training Distilled Open Lightweight Language Models**
Chengyu Wang, Junbing Yan, Yuanhao Yue, Jun Huang
[arXiv:2504.15027](https://arxiv.org/abs/2504.15027)
You can cite the paper using the following citation format:
```bibtex
@misc{wang2025distilqwen25industrialpracticestraining,
title={DistilQwen2.5: Industrial Practices of Training Distilled Open Lightweight Language Models},
author={Chengyu Wang and Junbing Yan and Yuanhao Yue and Jun Huang},
year={2025},
eprint={2504.15027},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2504.15027},
}
```

View File

@@ -0,0 +1,23 @@
{
"job_type": "kd_black_box_api",
"dataset": {
"labeled_path": "distil_qwen_100k.json",
"template" : "chat_template_kd.jinja",
"seed": 42
},
"models": {
"student": "model/Qwen/Qwen2.5-0.5B-Instruct/"
},
"training": {
"output_dir": "result_stage1/",
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 8,
"save_steps": 1000,
"logging_steps": 1,
"learning_rate": 2e-5,
"weight_decay": 0.05,
"warmup_ratio": 0.1,
"lr_scheduler_type": "cosine"
}
}

View File

@@ -0,0 +1,40 @@
{
"job_type": "kd_white_box",
"dataset": {
"labeled_path": "distil_qwen_100k.json",
"logits_path": "logits.json",
"template" : "chat_template_kd.jinja",
"seed": 42
},
"inference":{
"enable_chunked_prefill": true,
"seed": 777,
"gpu_memory_utilization": 0.9,
"temperature": 0.8,
"trust_remote_code": true,
"enforce_eager": false,
"max_model_len": 4096,
"max_new_tokens": 512
},
"distillation": {
"kd_ratio": 0.5,
"max_seq_length": 512,
"distillation_type": "forward_kld"
},
"models": {
"teacher": "teacher/Qwen/Qwen2.5-72B-Instruct/",
"student": "result_stage1/"
},
"training": {
"output_dir": "result_stage2/",
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 8,
"save_steps": 1000,
"logging_steps": 1,
"learning_rate": 2e-5,
"weight_decay": 0.05,
"warmup_ratio": 0.1,
"lr_scheduler_type": "cosine"
}
}

View File

@@ -0,0 +1,165 @@
# DistilQwen2: Refining Instructional Data for Black-Box KD
## Brief Introduction
Knowledge distillation offers an effective solution by transferring knowledge from larger models to smaller ones, ensuring performance while significantly reducing computational resources and inference time. We introduce DistilQwen2, a lightweight LLM based on the Qwen2 series, optimized through enhanced instruction following and diverse distillation techniques. This enables more agile and efficient deployment in resource-constrained environments like mobile devices and edge computing. For ease of use by developers and enterprises, DistilQwen2's checkpoints are open-sourced on HuggingFace and ModelScope, empowering more stakeholders to innovate and realize value through advanced NLP applications.
## Instructional Data Processing Guidelines
For the training of DistilQwen2, we collected data from well-known open-source datasets like Magpie, Openhermes, and Mammoth 2, along with proprietary synthetic datasets to initiate the distillation process. The focus is on providing diverse instructional data, predominantly in Chinese and English. We also leverage prompt templates to conduct instructional data augmentation. Here, we provide several commonly used operations to re-sample and augement the dataset.
### Instruction Set Expansion
The instruction expansion operator is employed generate a diverse set of instruction variations, ensuring that student models are exposed to a comprehensive range of instructions. After instruction expansion, we can also call the teacher model to generate responses for new instructions. An example is calling this operator is as follows:
```bash
python easydistill/synthesis/synthesis_main.py --config=configs/instruction_expansion_api.json
```
If you need to run the job using batch inference, please refer to the config example `configs/instruction_expansion_batch.json`.
### Instruction Refinement
The instruction refinement operator further enhances the quality and diversity of the training data, which also preserves the semantic integrity of the tasks expressed in instructions, ensuring that the rewritten content remains faithful to the original intent and task category. After instruction refinement, we can also call the teacher model to generate responses for new instructions. An example is calling this operator is as follows:
```bash
python easydistill/synthesis/synthesis_main.py --config=configs/instruction_refinement_api.json
```
If you need to run the job using batch inference, please refer to the config example `configs/instruction_refinement_batch.json`.
### Instruction Resampling
We also consider task balance when selecting useful instructional data pairs. The task distrubutions are defined based on our paper in the reference. You can run the job by:
```bash
python task_resampling.py --input-file input.json --output-file output.json --api-key <your_api_key> --base-url <base_url>
```
The dataset is in JSON format, exemplified by entries such as:
```json
[
{
"instruction": "The ratio of the number of molar teeth in the human upper jaw at the age of 6 is 2:1 compared to number of incisors teeth. There are total 8 incisors in the human mouth..."
},
{
"instruction": "This is the text of a lecture I am giving tomorrow. Can you go over it and make recommendations to improve clarity and flow?"
}
]
```
After the processing of intructions, you can generate the responses of the teacher model.
### Open-Source Dataset
In addition, we have open-sourced part of the dataset used for model training, totaling 100K entries. This dataset includes mathematical problems, code tasks, Q&A, instruction following, and creative generation. Users can incorporate the DistilQwen_100K dataset, or its subsets, during model fine-tuning to enhance downstream task performance while maintaining generalization ability. The dataset is in JSON format, exemplified by entries such as:
```json
[
{
"instruction": "The ratio of the number of molar teeth in the human upper jaw at the age of 6 is 2:1 compared to number of incisors teeth. There are total 8 incisors in the human mouth...",
"output": "## Step 1: Determine the total number of incisors in the upper jaw...\n\nThe final answer is: \\boxed{8}"
},
{
"instruction": "This is the text of a lecture I am giving tomorrow. Can you go over it and make recommendations to improve clarity and flow?",
"output": "I'd be happy to help you review your lecture text..."
}
]
```
The dataset is available on ModelScope and Hugging Face. Users can download it using ModelScope's scripts and command-line tools.
```python
# Validate SDK token
from modelscope.hub.api import HubApi
api = HubApi()
api.login('your_token_id')
# Dataset download
from modelscope.msdatasets import MsDataset
ds = MsDataset.load('PAI/DistilQwen_100k')
```
## Model Training Guidelines
### Black-Box KD
The black-box KD process follows a supervised learning paradigm, utilizing enhanced instruction-response pairs as training samples. Through this approach, the student model can effectively absorb and understand the knowledge imparted by the larger model, even with a limited number of parameters. This method not only boosts the student model's ability to tackle tasks but also enables it to perform better in multi-task scenarios. For simplicity, we use the `DistilQwen_100k` dataset as a tutorial, we need to run the training job only:
```bash
python easydistill/kd/train.py --config=distilqwen2_stage1.json
```
Plese refer to the config file `distilqwen2_stage1.json` in the current folder. If you need to run the job in a distributed mode, use `accelerate` to run the job.
### Preference Rank Optimization
For more challenging instruction tasks, SFT alone may not yield optimal results. To address this, we further refine the model using Direct Preference Optimization (DPO), enabling more granular fine-tuning and improved performance. Firstly, we generate the student outputs as rejected response. The contents in the SFT datasets are regarded as prompt and chosen responses. Please refer to the following script:
```bash
python dpo_student_infer_only.py --config=distilqwen2_stage2.json
```
Next, we run the training job by:
```bash
python easydistill/kd/train.py --config=distilqwen2_stage2.json
```
Again, please refer to the config file `distilqwen2_stage2.json` in the current folder. Remember to change the configurations when needed. If you need to run the job in a distributed mode, use `accelerate` to run the job.
## Model Download
We have open-sourced our distilled models on both HuggingFace and ModelScope. The available models are named `alibaba-pai/DistilQwen2-1.5B-Instruct` and `alibaba-pai/DistilQwen2-7B-Instruct`.
For example, users can download these models from HuggingFace using the following code:
```python
from huggingface_hub import snapshot_download
model_name = "alibaba-pai/DistilQwen2-1.5B-Instruct"
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2-1.5B/")
model_name = "alibaba-pai/DistilQwen2-7B-Instruct"
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2-7B/")
```
## Performance
The table below compares the performance of the original Qwen2 models with the distilled DistilQwen2 models across different parameter sizes: 1.5B and 7B. The evaluation metrics include AlpacaEval 2.0, MT-Bench, and IFEval scores. The distilled models demonstrate improved performance in instruction-following abilities over their respective original versions.
| Model | AlpacaEval 2.0 (length control) | MT-Bench | MT-Bench (single) | IFEval (instruct-loose) | IFEval (strict-prompt) |
|-------------------------------|---------------------------------|------------------|-------------------|-------------------------|------------------------|
| Qwen2-1.5B-Instruct | 5.22 | 5.85 | 6.45 | 41.37 | 28.10 |
| **DistilQwen2-1.5B-Instruct** | **8.28** | **6.42** | **7.12** | **49.76** | **36.04** |
| Qwen2-7B-Instruct | 24.33 | 8.27 | 8.68 | 66.67 | 52.31 |
| **DistilQwen2-7B-Instruct** | **25.35** | **8.40** | **9.03** | **71.46** | **60.26** |
## Reference
For more detailed information about the DistilQwen2 model series and the methodologies employed, we encourage you to refer to our paper:
- **Distilling Instruction-following Abilities of Large Language Models with Task-aware Curriculum Planning**
Yuanhao Yue, Chengyu Wang, Jun Huang, Peng Wang
You can cite the paper using the following citation format:
```bibtex
@inproceedings{emnlp2024,
author = {Yuanhao Yue and
Chengyu Wang and
Jun Huang and
Peng Wang},
title = {Distilling Instruction-following Abilities of Large Language Models with Task-aware Curriculum Planning},
booktitle = {Findings of the Association for Computational Linguistics: {EMNLP} 2024},
pages = {6030--6054},
publisher = {Association for Computational Linguistics},
year = {2024},
url = {https://aclanthology.org/2024.findings-emnlp.350}
}

View File

@@ -0,0 +1,23 @@
{
"job_type": "kd_black_box_api",
"dataset": {
"labeled_path": "distil_qwen_100k.json",
"template" : "chat_template_kd.jinja",
"seed": 42
},
"models": {
"student": "model/Qwen/Qwen2-0.5B-Instruct/"
},
"training": {
"output_dir": "result_stage1/",
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 8,
"save_steps": 1000,
"logging_steps": 1,
"learning_rate": 2e-5,
"weight_decay": 0.05,
"warmup_ratio": 0.1,
"lr_scheduler_type": "cosine"
}
}

View File

@@ -0,0 +1,25 @@
{
"job_type": "rank_dpo_api",
"dataset": {
"instruction_path": "distil_qwen_100k.json",
"labeled_path": "distil_qwen_100k_dpo.json",
"template" : "chat_template_kd.jinja",
"seed": 42
},
"models": {
"student": "result_stage1/"
},
"training": {
"output_dir": "result_stage2/",
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 8,
"save_steps": 1000,
"logging_steps": 1,
"beta": 0.1,
"learning_rate": 2e-5,
"weight_decay": 0.05,
"warmup_ratio": 0.1,
"lr_scheduler_type": "cosine"
}
}

View File

@@ -0,0 +1,105 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json
import argparse
import logging
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def read_json_field(filename):
try:
with open(filename, 'r') as file:
data = json.load(file)
output = []
for item in data:
instruction = item["instruction"]
output = item["output"]
output.append({"prompt": instruction, "chosen": output})
return output
except FileNotFoundError:
logging.error("The file was not found.")
except json.JSONDecodeError:
logging.error("There was an error decoding the JSON file.")
except Exception as e:
logging.error(f"An error occurred: {e}")
def write_data_to_json_file(data, file_path):
try:
with open(file_path, 'w') as file:
json.dump(data, file, ensure_ascii=False, indent=4)
logging.info(f"Data successfully written to {file_path}")
except Exception as e:
logging.error(f"An error occurred: {e}")
def generate_student_response(data_list, config):
# load student model
student_tokenizer = AutoTokenizer.from_pretrained(
config["models"]["student"],
trust_remote_code=True
)
student_model = AutoModelForCausalLM.from_pretrained(
config["models"]["student"],
device_map="auto",
trust_remote_code=True
)
outcomes = []
for sample in tqdm(data_list, desc="Call remote model and generating responses"):
prompt = sample["prompt"]
chosen = sample["chosen"]
# for student model
messages = [
{"role": "user", "content": prompt}
]
text = student_tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = student_tokenizer([text], return_tensors="pt").to(student_model.device)
generated_ids = student_model.generate(
**model_inputs,
max_new_tokens=config["inference"]["max_new_tokens"]
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
rejected = student_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
gen_data = {'prompt': prompt, 'chosen': chosen, 'rejected': rejected}
outcomes.append(gen_data)
write_data_to_json_file(outcomes, config["dataset"]["labeled_path"])
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
args = parser.parse_args()
config = json.load(open(args.config))
data_list = read_json_field(config["dataset"]["instruction_path"])
generate_student_response(data_list, config)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,156 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json
import re
import logging
from openai import OpenAI
from collections import Counter
import random
import argparse
predefined_distribution = {
'Math': 0.167,
'Code Generation': 0.083,
'Writing': 0.017,
'Computer Science': 0.017,
'Reasoning': 0.167,
'Complex Format': 0.017,
'Code Debug': 0.083,
'Common-Sense': 0.017,
'Counterfactual': 0.017,
'Multilingual': 0.017,
'Roleplay': 0.017,
'Biology': 0.017,
'Technology': 0.017,
'Ethics': 0.017,
'Sport': 0.017,
'Law': 0.017,
'Medicine': 0.017,
'Literature': 0.017,
'Entertainment': 0.017,
'Art': 0.017,
'Music': 0.017,
'Toxicity': 0.017,
'Economy': 0.017,
'Physics': 0.017,
'History': 0.017,
'Chemistry': 0.017,
'Philosophy': 0.017,
'Health': 0.017,
'Ecology': 0.017,
'Grammar': 0.017,
'Paraphrase': 0.017,
'Others': 0.041
}
predefined_prompt = """
You are a data annotation expert. Please classify the task type or domain of #Given Instruction.
The task type or domain should be in the list: [Math, Code Generation, Writing, Computer Science, Reasoning, Complex Format, Code Debug, Common-Sense, Counterfactual, Multilingual, Roleplay,Biology, Technology, Ethics, Sport, Law, Medicine, Literature, Entertainment, Art, Music, Toxicity, Economy, Physics, History, Chemistry, Philosophy,Health,Ecology,Grammar,Paraphrase, Others]. You should place your answer enclosed within <answer></answer> tags, such as <answer>Math</answer>. Do not return anything else.
#Given Instruction#:
"""
def extract_answer(content):
pattern = r'<answer>(.*?)</answer>'
match = re.search(pattern, content, re.DOTALL)
if match:
return match.group(1)
else:
return None
def classify_instruction(instruction, client, model):
message = [
{"role": "user", "content": predefined_prompt + "\n" + instruction}
]
completion = client.chat.completions.create(
messages = message,
model = model,
max_completion_tokens = 1024
)
result = completion.choices[0].message.content.strip()
print(result)
result = extract_answer(result)
if result is None or result not in predefined_distribution.keys():
result = 'Others'
print(result)
return result
def main(args):
# Load dataset
with open(args.input_file, 'r') as file:
data = json.load(file)
# Initialize OpenAI client
client = OpenAI(
api_key=args.api_key,
base_url=args.base_url
)
models = client.models.list()
model = models.data[0].id
logging.info(model)
# Classify each instruction
classified_data = []
count = 0
for item in data:
category = classify_instruction(item['instruction'], client, model)
classified_data.append({'instruction': item['instruction'], 'category': category})
count += 1
print(count)
# Count occurrences per category
category_counts = Counter(item['category'] for item in classified_data)
total_samples = len(classified_data)
# Resample according to predefined distribution
resampled_data = []
for category, target_ratio in predefined_distribution.items():
target_count = int(total_samples * target_ratio)
category_samples = [item for item in classified_data if item['category'] == category]
if len(category_samples) == 0:
logging.warning("No instructions are provided for the category: " + category)
continue
if len(category_samples) > target_count:
print(category)
print(len(category_samples))
print(target_count)
# Randomly sample the required number of instructions
resampled_category_samples = random.sample(category_samples, target_count)
else:
# If not enough samples, repeat the existing ones
resampled_category_samples = category_samples * (target_count // len(category_samples)) + random.sample(category_samples, target_count % len(category_samples))
resampled_data.extend(resampled_category_samples)
# Save final dataset
with open(args.output_file, 'w') as file:
json.dump(resampled_data, file, indent=4)
print("Resampling complete. Final output saved to '{}'.".format(args.output_file))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Task and Domain Classification')
parser.add_argument('--input-file', type=str, required=True, help='Input JSON file containing instructions.')
parser.add_argument('--output-file', type=str, required=True, help='Output JSON file to store resampled instructions.')
parser.add_argument('--api-key', type=str, required=True, help='API key.')
parser.add_argument('--base-url', type=str, required=True, help='Base URL.')
args = parser.parse_args()
main(args)

View File

@@ -0,0 +1,67 @@
# Code Generation: Combining Black-Box and White Box KD
## Brief Introduction
The code generation distillation process involves carefully selecting and optimizing code generation instruction-response pairs conducive to student model learning. Following standard fine-tuning, we employ white-box distillation techniques to enable the student models to better acquire fine-grained knowledge from teacher models.
## Detailed Steps
### Processing of Code Generation Dataset
The process begins with collecting diverse, high-quality code generation data from various programming languages and tasks. This data includes code completion, code generation from natural language descriptions, bug fixing, and code explanation tasks. We score examples for difficulty and relevance to code generation tasks.
The dataset is in JSON format, exemplified by entries such as:
```json
[
{
"instruction": "Write a Python function that calculates the Fibonacci sequence up to n elements using dynamic programming.",
"output": "def fibonacci(n):\n if n <= 0:\n return []\n elif n == 1:\n return [0]\n elif n == 2:\n return [0, 1]\n \n fib = [0, 1]\n for i in range(2, n):\n fib.append(fib[i-1] + fib[i-2])\n \n return fib\n\n# Example usage\nprint(fibonacci(10)) # Output: [0, 1, 1, 2, 3, 5, 8, 13, 21, 34]"
},
{
"instruction": "Create a JavaScript function that takes an array of numbers and returns the sum of all even numbers in the array.",
"output": "function sumEvenNumbers(numbers) {\n return numbers\n .filter(num => num % 2 === 0)\n .reduce((sum, num) => sum + num, 0);\n}\n\n// Example usage\nconst numbers = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];\nconsole.log(sumEvenNumbers(numbers)); // Output: 30"
}
]
```
### Black-Box KD
The black-box KD process follows a supervised learning paradigm, utilizing enhanced code instruction-response pairs as training samples. Through this approach, the student model can effectively absorb and understand the code generation capabilities of the larger model, even with a limited number of parameters. This method not only boosts the student model's ability to tackle programming tasks but also enables it to perform better across multiple programming languages and paradigms.
To run the black-box KD training:
```bash
python easydistill/kd/train.py --config=code_generation_stage1.json
```
Please refer to the config file `code_generation_stage1.json` in the current folder. If you need to run the job in a distributed mode, use `accelerate` to run the job.
### White-Box KD
Unlike black-box KD, which relies solely on the highest probability token output by the teacher model, white-box KD focuses on the distribution of logits produced by the teacher model. This approach provides the student model with richer information about code structure and syntax. By mimicking the teacher model's logits distribution, white-box KD can transfer programming knowledge more effectively, thereby enhancing the performance of the student model.
To generate the logits with the teacher model:
```bash
python easydistill/kd/infer.py --config=code_generation_stage2.json
```
Next, run the training job:
```bash
python easydistill/kd/train.py --config=code_generation_stage2.json
```
Please refer to the config file `code_generation_stage2.json` in the current folder. Remember to change the configurations when needed.
## Performance
We trained the model using data from nvidia/OpenCodeReasoning, and the final model performance is as follows:
| Model | LiveCodeBench V2 | speed |
|---------------------------|------------------|--------|
| Qwen2.5-3B-Instruct | 11.35 | 2.3x |
| Qwen2.5-3B-Code-Optimize | 16.62 | 2.3x |
| Qwen2.5-7B-Instruct | 30.72 | 1x |
| Qwen2.5-7B-Code-Optimize | 35.32 | 1x |

View File

@@ -0,0 +1,23 @@
{
"job_type": "kd_black_box_api",
"dataset": {
"labeled_path": "code_generation_dataset.json",
"template" : "chat_template_kd.jinja",
"seed": 42
},
"models": {
"student": "model/Qwen/Qwen2.5-0.5B-Instruct/"
},
"training": {
"output_dir": "result_stage1/",
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 8,
"save_steps": 1000,
"logging_steps": 1,
"learning_rate": 2e-5,
"weight_decay": 0.05,
"warmup_ratio": 0.1,
"lr_scheduler_type": "cosine"
}
}

View File

@@ -0,0 +1,40 @@
{
"job_type": "kd_white_box",
"dataset": {
"labeled_path": "code_generation_dataset.json",
"logits_path": "logits.json",
"template" : "chat_template_kd.jinja",
"seed": 42
},
"inference":{
"enable_chunked_prefill": true,
"seed": 777,
"gpu_memory_utilization": 0.9,
"temperature": 0.8,
"trust_remote_code": true,
"enforce_eager": false,
"max_model_len": 4096,
"max_new_tokens": 1024
},
"distillation": {
"kd_ratio": 0.1,
"max_seq_length": 1024,
"distillation_type": "forward_kld"
},
"models": {
"teacher": "teacher/Qwen/Qwen2.5-7B-Instruct/",
"student": "result_stage1/"
},
"training": {
"output_dir": "result_stage2/",
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 8,
"save_steps": 1000,
"logging_steps": 1,
"learning_rate": 2e-5,
"weight_decay": 0.05,
"warmup_ratio": 0.1,
"lr_scheduler_type": "cosine"
}
}

View File

@@ -0,0 +1,50 @@
# DistilQwen-100k/DistilQwen-1M: High-Quality Instruction-Tuning Datasets
## Overview
To empower community developers in enhancing the **instruction-following capabilities** of large language models (LLMs), we open-source **`DistilQwen-100k`** and **`DistilQwen-1M`**, subsets of the training data used for the **DistilQwen model series**. The datasets provide diverse, high-quality samples to improve model performance in key areas.
## Dataset Features
- **Scale**: **100 thousand**/**1 million** meticulously distilled entries.
- **Coverage**: Balanced mix of:
- **Mathematics**
- **Code generation & understanding**
- **Knowledge-based QA**
- **Instruction following**
- **Creative generation**
- **Purpose**: Optimized for **instruction tuning**, helping models retain generalization while adapting to downstream tasks.
## Use Cases
- **Fine-tuning LLMs**: Mitigate *catastrophic forgetting* by combining with custom datasets.
- **Multi-task learning**: Improve coherence in mathematical reasoning, coding, and creative tasks.
- **Research**: Study distillation techniques or instruction-tuning efficacy.
## Use the Datasets
```python
from datasets import load_dataset
# Login using e.g. `huggingface-cli login` to access this dataset
ds = load_dataset("alibaba-pai/DistilQwen_100k")
ds = load_dataset("alibaba-pai/DistilQwen_1M")
```
## Reference
For more detailed information about the dataset construction process, we encourage you to refer to our paper:
- **DistilQwen2.5: Industrial Practices of Training Distilled Open Lightweight Language Models**
Chengyu Wang, Junbing Yan, Yuanhao Yue, Jun Huang
[arXiv:2504.15027](https://arxiv.org/abs/2504.15027)
You can cite the paper using the following citation format:
```bibtex
@misc{wang2025distilqwen25industrialpracticestraining,
title={DistilQwen2.5: Industrial Practices of Training Distilled Open Lightweight Language Models},
author={Chengyu Wang and Junbing Yan and Yuanhao Yue and Jun Huang},
year={2025},
eprint={2504.15027},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2504.15027}
}
```

View File

@@ -0,0 +1,58 @@
# OmniThought: A Large-Scale Chain-of-Thought Dataset for Advancing Large Reasoning Models
## Overview
The rise of **Large Reasoning Models (LRMs)** has revolutionized **Natural Language Processing (NLP)**, enabling breakthroughs in complex tasks like **mathematical problem-solving** and **code generation**. These models rely on **Chain-of-Thought (CoT)** processes to mimic human-like reasoning. However, progress in LRMs is limited by the scarcity of **high-quality, large-scale CoT datasets**—existing resources often lack:
- **Diverse reasoning problems** with well-structured CoT processes.
- **Multi-teacher distillation** to ensure reasoning quality.
- **Fine-grained annotations** describing CoT properties.
To bridge this gap, we introduce **`OmniThought`**, a **2-million-scale CoT dataset** generated and validated by **two powerful LRMs**. Each CoT process is annotated with:
- **Reasoning Verbosity (RV)**: Measures the optimal verbosity of reasoning steps.
- **Cognitive Difficulty (CD)**: Assesses the complexity of reasoning for model comprehension.
We also propose a **self-reliant pipeline** for dataset curation, ensuring high-quality reasoning traces.
## Key Features
**2 million high-quality CoT processes** covering diverse reasoning tasks.
**RV-CD scores** to guide model training for better reasoning performance.
**Multi-teacher distillation** for robust and coherent reasoning paths.
**Optimized for LRM training**—improves reasoning ability and output quality.
## Experiments & Results
Extensive experiments with **Qwen2.5 models** (various sizes) confirm that:
- Training with **RV-CD scores** enhances **LRM reasoning effectiveness**.
- Models trained on `OmniThought` achieve **stronger reasoning abilities** with **optimal CoT length and difficulty**.
Based on this dataset, we release **a series of high-performance LRMs** with superior reasoning capabilities.
## Use the Datasets
```python
from datasets import load_dataset
# Login using e.g. `huggingface-cli login` to access this dataset
ds = load_dataset("alibaba-pai/OmniThought")
```
## Reference
For more detailed information about the dataset construction process, we encourage you to refer to our paper:
- **Reasoning with OmniThought: A Large CoT Dataset with Verbosity and Cognitive Difficulty Annotations**
Wenrui Cai, Chengyu Wang, Junbing Yan, Jun Huang, Xiangzhong Fang
[arXiv:2505.10937](https://arxiv.org/abs/2505.10937)
You can cite the paper using the following citation format:
```bibtex
@misc{cai2025reasoningomnithoughtlargecot,
title={Reasoning with OmniThought: A Large CoT Dataset with Verbosity and Cognitive Difficulty Annotations},
author={Wenrui Cai and Chengyu Wang and Junbing Yan and Jun Huang and Xiangzhong Fang},
year={2025},
eprint={2505.10937},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2505.10937}
}
```

7
requirements.txt Normal file
View File

@@ -0,0 +1,7 @@
transformers==4.51.0
transformers-stream-generator==0.0.5
trl==0.17.0
tokenizers==0.21.1
vllm==0.8.5
openai
jinja2

BIN
resources/framework.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 156 KiB

25
setup.py Normal file
View File

@@ -0,0 +1,25 @@
#!/usr/bin/env python
from setuptools import find_packages, setup
with open('README.md') as f:
readme = f.read()
with open('requirements.txt') as f:
requirements = f.read()
setup(
# Metadata
name='easydistill',
version='0.0.1',
python_requires='>=3.6',
author='PAI',
description='PAI EasyDistill Toolkit',
long_description=readme,
entry_points={'console_scripts': ['easydistill=easydistill.cli:main']},
long_description_content_type='text/markdown',
packages=find_packages(),
license='Apache-2.0',
#Package info
install_requires=requirements)