init commit
This commit is contained in:
201
LICENSE
Normal file
201
LICENSE
Normal file
@@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
13
NOTICE
Normal file
13
NOTICE
Normal file
@@ -0,0 +1,13 @@
|
||||
=============================================================
|
||||
EasyDistill is a open-source tool developed by Alibaba PAI Team
|
||||
Licensed under the Apache License, Version 2.0
|
||||
|
||||
=============================================================
|
||||
This toolkit implements some modules referring to some repositories under
|
||||
the same/different open source licenses.
|
||||
|
||||
-----------------------------
|
||||
Apache License, Version 2.0
|
||||
The HuggingFace Inc. team
|
||||
The OpenAI Team
|
||||
The vLLM Team
|
188
README.md
188
README.md
@@ -1 +1,187 @@
|
||||
# easydistill
|
||||
# EasyDistill: Easy Knowledge Distillation for Large Language Models
|
||||
|
||||
Introducing **EasyDistill**, a pioneering toolkit on knowledge distillation (KD) for large language models (LLMs). With the growing complexity and size of LLMs, **EasyDistill** offers a versatile and user-friendly platform to streamline the KD process, supporting both black-box and white-box methodologies. It facilitates efficient model training, enabling smaller models to emulate the performance of larger ones without compromising accuracy. **EasyDistill** boasts an extensive range of features, including data synthesis, supervised fine-tuning, ranking optimization, and reinforcement learning, all tailored for various KD scenarios. Designed to accommodate both System 1 (fast, intuitive) and System 2 (slow, analytical) cognitive models, the toolkit is modular and easy to use, with a simple command-line interface guiding users. Beyond academic exploration, **EasyDistill** anchors practical industrial solutions, offering robust distilled models and open-source datasets, while also showcasing seamless integration with Alibaba Cloud’s AI platform, PAI. Committed to bridging theoretical advancements with practical needs, **EasyDistill** empowers the NLP community, making state-of-the-art KD strategies accessible to researchers and industry practitioners alike.
|
||||
|
||||
# Technical Articles
|
||||
|
||||
We have a series of technical articles on the functionalities of EasyDistill.
|
||||
|
||||
- [人工智能平台 PAI DistilQwen2.5-DS3-0324发布:知识蒸馏+快思考=更高效解决推理难题](https://developer.aliyun.com/article/1661734)
|
||||
- [DistilQwen2.5-R1发布:知识蒸馏助推小模型深度思考](https://developer.aliyun.com/article/1659288)
|
||||
- [DistilQwen2.5发布:通义千问蒸馏小模型再升级](https://developer.aliyun.com/article/1653842)
|
||||
- [DistilQwen2:通义千问大模型的知识蒸馏实践](https://developer.aliyun.com/article/1633882)
|
||||
- [基于多轮课程学习的大语言模型蒸馏算法TAPIR](https://developer.aliyun.com/article/1635146)
|
||||
|
||||
|
||||
|
||||
## Overview
|
||||
|
||||

|
||||
|
||||
- **Toolkit Features**: EasyDistill provides versatile functionalities, including data synthesis, supervised fine-tuning, logits distillation, ranking optimization, and reinforcement learning techniques tailored for KD scenarios.
|
||||
- **Compatibility**: It supports both System 1 (fast, intuitive) and System 2 (slow, analytical) models.
|
||||
- **User-Friendly**: With its modular design and simple command-line interface, EasyDistill makes experimentation and implementation of KD strategies straightforward.
|
||||
- **Industrial Integration**: Incorporates KD-based solutions and supports integration with platforms such as Alibaba Cloud’s Platform for AI (PAI).
|
||||
|
||||
|
||||
## Getting Started
|
||||
|
||||
1. Clone the repository:
|
||||
```bash
|
||||
git clone <repository-url>
|
||||
cd EasyDistill
|
||||
```
|
||||
|
||||
2. Install the required dependencies:
|
||||
```bash
|
||||
python setup.py install
|
||||
```
|
||||
|
||||
3. Explore the usage of EasyDistill through the command-line interface:
|
||||
```bash
|
||||
easydistill --config <config-file-path>
|
||||
```
|
||||
|
||||
The config file expresses the detailed settings of any knowledge distillation jobs that **EasyDistill** supports. A sample of black-box distillation config can be shown below:
|
||||
```json
|
||||
{
|
||||
"job_type": "kd_black_box_local",
|
||||
"dataset": {
|
||||
"instruction_path": "train.json",
|
||||
"labeled_path": "train_labeled.json",
|
||||
"template" : "chat_template/chat_template_kd.jinja",
|
||||
"seed": 42
|
||||
},
|
||||
"inference":{
|
||||
"enable_chunked_prefill": true,
|
||||
"seed": 777,
|
||||
"gpu_memory_utilization": 0.9,
|
||||
"temperature": 0.8,
|
||||
"trust_remote_code": true,
|
||||
"enforce_eager": false,
|
||||
"max_model_len": 4096,
|
||||
"max_new_tokens": 512
|
||||
},
|
||||
"models": {
|
||||
"teacher": "teacher/Qwen/Qwen2.5-7B-Instruct/",
|
||||
"student": "student/Qwen/Qwen2.5-0.5B-Instruct/"
|
||||
},
|
||||
"training": {
|
||||
"output_dir": "./result/",
|
||||
"num_train_epochs": 3,
|
||||
"per_device_train_batch_size": 1,
|
||||
"gradient_accumulation_steps": 8,
|
||||
"save_steps": 1000,
|
||||
"logging_steps": 1,
|
||||
"learning_rate": 2e-5,
|
||||
"weight_decay": 0.05,
|
||||
"warmup_ratio": 0.1,
|
||||
"lr_scheduler_type": "cosine"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## DistilQWen Series
|
||||
|
||||
The **DistilQwen** models represent a robust suite of distilled language models derived from the **EasyDistill** toolkit. Designed to capitalize on the principles of knowledge distillation, DistilQwen models offer a significant reduction in model size while maintaining high performance, making them ideal for resource-constrained environments. Whether you're aiming for efficient deployment in industrial scenarios or seeking to explore advanced KD methodologies, **DistilQwen** models are poised to meet diverse application needs with agility and precision.
|
||||
|
||||
|
||||
### What's New: Adaptive Thinking Models
|
||||
|
||||
The most recent **DistilQwen** series is **DistilQwen-ThoughtX**, which exhibits improved reasoning abilities and generates CoTs with more optimal lengths compared to its predecessors. This model series is developed from the innovative **OmniThought** dataset by utilizing the novel Reasoning Verbosity (RV) and Cognitive Difficulty (CD) scores, which ensure that models receive rich, high-quality training data reflecting optimal CoT output length and difficulty. **DistilQwen-ThoughtX** outperforms other KD models in the open-source community. The performance of **DistilQwen-ThoughtX** is shown below.
|
||||
|
||||
|
||||
| **Model** | **AIME2024** | **MATH500** | **GPQA-D** | **LCB V2** | **Avg.** |
|
||||
|-----------------------------------------------|--------------|-------------|------------|------------|-----------|
|
||||
| OpenThinker-7B | 31.3 | 83.0 | 42.4 | 39.9 | 49.1 |
|
||||
| DeepSeek-R1-Distill-Qwen-7B | **57.3** | _89.6_ | 47.3 | 48.4 | 60.6 |
|
||||
| OpenThinker2-7B | 50.0 | 88.4 | _49.3_ | _55.6_ | _60.8_ |
|
||||
| **DistilQwen-ThoughtX-7B** | _56.7_ | **90.2** | **50.0** | **56.8** | **63.4** |
|
||||
| LIMO-32B | 56.7 | 86.6 | 58.1 | 60.0 | 65.3 |
|
||||
| OpenThinker-32B | 66.0 | 90.6 | 61.6 | 68.9 | 71.7 |
|
||||
| DeepSeek-R1-Distill-Qwen-32B | 74.7 | 90.0 | 62.4 | 72.3 | 74.8 |
|
||||
| OpenThinker2-32B | _76.7_ | _90.8_ | **64.1** | _72.5_ | _76.0_ |
|
||||
| Light-R1-32B | 74.7 | 90.4 | 62.0 | 56.0 | 70.7 |
|
||||
| s1.1-32B | 59.3 | 87.4 | 62.0 | 58.7 | 66.8 |
|
||||
| **DistilQwen-ThoughtX-32B** | **80.0** | **92.6** | _64.0_ | **73.4** | **77.5** |
|
||||
|
||||
The **OmniThought** datasets are also publicly available. Refer to the Datasets section.
|
||||
|
||||
### System 1 Models
|
||||
|
||||
**DistilQwen2** is an enhanced version of the Qwen2 models, equipped with improved instruction-following capabilities for various NLP tasks. We employ GPT-4 and Qwen-max as teacher models to generate high-quality responses, with the balance on the task distributions of input instructions. Following SFT, a rank optimization process is performed using the DPO algorithm to enhance alignment between the student models and the teacher models. **DistilQwen2.5** models are trained using a combination of black-box and white-box KD algorithms. We adhere to the same instruction data processing and black-box SFT procedure as employed in the production of **DistilQwen2**. Subsequently, white-box training is applied to refine the students' acquisition of intricate knowledge from the teacher models, specifically utilizing Qwen2.5-72B-Instruct as open-source teacher models. The performance of **DistilQwen2** and **DistilQwen2.5** is shown below.
|
||||
|
||||
| **Model** | **AlpacaEval 2.0 (length control)** | **MT-Bench** | **MT-Bench (single)** | **IFEval (instruct-loose)** | **IFEval (strict-prompt)** |
|
||||
|------------------------------------|-------------------------------------|--------------|-----------------------|-----------------------------|----------------------------|
|
||||
| Qwen2.5-0.5B-Instruct | 2.46 | 5.49 | 6.26 | 42.81 | 30.31 |
|
||||
| **DistilQwen2.5-0.5B-Instruct** | **4.89** | **5.78** | **6.83** | **52.61** | **37.82** |
|
||||
| Qwen2-1.5B-Instruct | 5.22 | 5.85 | 6.45 | 41.37 | 28.10 |
|
||||
| **DistilQwen2-1.5B-Instruct** | **8.28** | **6.42** | **7.12** | **49.76** | **36.04** |
|
||||
| Qwen2.5-1.5B-Instruct | 6.69 | 7.09 | 7.66 | 55.40 | 40.11 |
|
||||
| **DistilQwen2.5-1.5B-Instruct** | **13.69** | **7.35** | **7.99** | **61.10** | **74.49** |
|
||||
| Qwen2.5-3B-Instruct | 17.98 | 7.92 | 8.40 | 61.18 | 74.58 |
|
||||
| **DistilQwen2.5-3B-Instruct** | **20.91** | **8.37** | **8.97** | **67.03** | **77.36** |
|
||||
| Qwen2-7B-Instruct | 24.33 | 8.27 | 8.68 | 66.67 | 52.31 |
|
||||
| **DistilQwen2-7B-Instruct** | **25.35** | **8.40** | **9.03** | **71.46** | **60.26** |
|
||||
| Qwen2.5-7B-Instruct | 31.43 | 8.52 | 8.83 | 81.53 | 72.10 |
|
||||
| **DistilQwen2.5-7B-Instruct** | **34.86** | **8.76** | **9.22** | **83.48** | **73.27** |
|
||||
|
||||
|
||||
We have released two instruction following datasets to public. Refer to the Datasets section.
|
||||
|
||||
|
||||
### System 2 Models
|
||||
|
||||
The **DistilQwen2.5-R1** model series utilizes DeepSeek-R1 as the teacher model. To align the reasoning abilities of smaller distilled models with their intrinsic cognitive capacities, the models are further refined using our CogPO algorithm, which outperforms other training methods. Additionally, we transfer the fast-thinking reasoning capabilities from DeepSeek-V3-0324 to the **DistilQwen2.5-DS3-0324** models. To shorten the reasoning process, the CoT simplification operator are employed to reduce the number of tokens in the training data for **DistilQwen2.5-R1**. Combined with a rewritten dataset comprising DeepSeek-V3-0324's CoT distillation data, we develop the **DistilQwen2.5-DS3-0324** models. The performance of **DistilQwen2.5-R1** and **DistilQwen2.5-DS3-0324** is shown below.
|
||||
|
||||
| **Model** | **AIME2024** | **MATH-500** | **GPQA Diamond** | **LiveCodeBench V2** |
|
||||
|---------------------------------------|--------------|--------------|------------------|----------------------|
|
||||
| Qwen2.5-3B-Instruct | 6.67 | 62.6 | 32.83 | 11.35 |
|
||||
| **DistilQwen2.5-DS3-0324-3B** | **16.67** | **70.0** | **34.34** | **18.00** |
|
||||
| Qwen2.5-7B-Instruct | 10.0 | 73.6 | 33.30 | 30.72 |
|
||||
| **DistilQwen2.5-7B-R1** | **23.33** | **77.8** | **37.88** | **36.40** |
|
||||
| **DistilQwen2.5-DS3-0324-7B** | **43.33** | **88.4** | **42.93** | **46.38** |
|
||||
| Qwen2.5-14B-Instruct | 16.7 | 78.2 | 43.43 | 37.38 |
|
||||
| **DistilQwen2.5-14B-R1** | **26.67** | **82.6** | **45.45** | **41.49** |
|
||||
| **DistilQwen2.5-DS3-0324-14B** | **46.67** | **90.8** | **51.52** | **54.40** |
|
||||
| Qwen2.5-32B-Instruct | 16.67 | 81.4 | 45.50 | 47.36 |
|
||||
| **DistilQwen2.5-32B-R1** | **46.67** | **87.0** | **48.99** | **55.97** |
|
||||
| **DistilQwen2.5-DS3-0324-32B** | **70.00** | **93.8** | **62.12** | **65.95** |
|
||||
|
||||
All the **DistilQwen** models are publicly available in HuggingFace and ModelScope.
|
||||
|
||||
## Released Datasets
|
||||
|
||||
We have also released several datasets based on the **EasyDistill** framework.
|
||||
|
||||
### Instruction Following Datasets
|
||||
|
||||
To assist community developers in avoiding catastrophic forgetting when fine-tuning the **DistilQwen** model, we have open-sourced two datasets: **DistilQwen_100K** and **DistilQwen_1M**. These datasets are intended to provide a solid foundation for model fine-tuning, enhancing adaptability to new tasks while retaining performance on previous tasks. Additionally, it can be utilized to improve instruction-following capabilities when fine-tuning other similar large language models. These datasets cover a range of contents, including mathematics, code, knowledge-based Q&A, instruction following, and creative generation, with a total dataset size of 100K and 1M entries. Users can integrate **DistilQwen_100K** and **DistilQwen_1M**, or its subsets, with their own data during model fine-tuning to ensure excellent downstream task performance while maintaining the model's general capabilities, thus preserving its ability to generalize.
|
||||
|
||||
|
||||
### Chain-of-Thought Reasoning Datasets
|
||||
|
||||
**OmniThought** is a large-scale dataset featuring **2 million** Chain-of-Thought (CoT) processes generated and validated by DeepSeek-R1 and QwQ-32B. Each CoT process in **OmniThought** is annotated with novel Reasoning Verbosity (RV) and Cognitive Difficulty (CD) scores, which describe the appropriateness of CoT verbosity and cognitive difficulty level for models to comprehend these reasoning processes. Based on our **OmniThought** dataset, we further train and release a series of high-performing models (**DistilQwen-ThoughtX-7B** and **DistilQwen-ThoughtX-32B**), specifically equipped with stronger reasoning abilities and optimal CoT output length and difficulty level. Refer to `recipes/open_datasets` for details.
|
||||
|
||||
All the datasets are publicly available in HuggingFace and ModelScope.
|
||||
|
||||
|
||||
## Reference
|
||||
|
||||
We have [an arxiv paper](TBD) for you to cite for the EasyDistill library. Below are other papers related to our project.
|
||||
|
||||
- Wenrui Cai, Chengyu Wang, Junbing Yan, Jun Huang, Xiangzhong Fang. Reasoning with OmniThought: A Large CoT Dataset with Verbosity and Cognitive Difficulty Annotations. arXiv preprint
|
||||
- Wenrui Cai, Chengyu Wang, Junbing Yan, Jun Huang, Xiangzhong Fang. Training Small Reasoning LLMs with Cognitive Preference Alignment. arXiv preprint
|
||||
- Chengyu Wang, Junbing Yan, Yuanhao Yue, Jun Huang. DistilQwen2.5: Industrial Practices of Training Distilled Open Lightweight Language Models. **ACL 2025**
|
||||
- Yuanhao Yue, Chengyu Wang, Jun Huang, Peng Wang. Building a Family of Data Augmentation Models for Low-cost LLM Fine-tuning on the Cloud. **COLING 2025**
|
||||
- Yuanhao Yue, Chengyu Wang, Jun Huang, Peng Wang. Distilling Instruction-following Abilities of Large Language Models with Task-aware Curriculum Planning. **EMNLP 2024**
|
||||
|
||||
|
||||
## License
|
||||
|
||||
This project is licensed under the [Apache License (Version 2.0)](LICENSE). This toolkit also contains some code modified from other repos under other open-source licenses. See the [NOTICE](NOTICE) file for more information.
|
||||
|
||||
|
||||
## Join in the Discussion
|
||||
|
||||
We welcome community partners to collaborate and contribute to the development, and welcome to join the DingTalk group: 117440002081 to participate in the discussion.
|
19
configs/accelerate_config/muti_gpu.yaml
Normal file
19
configs/accelerate_config/muti_gpu.yaml
Normal file
@@ -0,0 +1,19 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
gradient_clipping: 1.0
|
||||
offload_optimizer_device: cpu
|
||||
offload_param_device: cpu
|
||||
zero_stage: 2
|
||||
distributed_type: DEEPSPEED
|
||||
gpu_ids: all
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
8
configs/chat_template/chat_template_kd.jinja
Normal file
8
configs/chat_template/chat_template_kd.jinja
Normal file
@@ -0,0 +1,8 @@
|
||||
{{'<|im_start|>system\nYou are a helpful assistant.<|im_end|>'}}
|
||||
{{'<|im_start|>user\n' + message['content'] + '<|im_end|>'-}}
|
||||
{% if add_generation_prompt %}
|
||||
{{'<|im_start|>assistant'-}}
|
||||
{% endif %}
|
||||
{% if add_output %}
|
||||
{{'<|im_start|>assistant\n' + message['output'] + '<|im_end|>-'}}
|
||||
{% endif %}
|
14
configs/cot_generation_api.json
Normal file
14
configs/cot_generation_api.json
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"job_type": "cot_generation_api",
|
||||
"dataset": {
|
||||
"input_path": "./cot_question.json",
|
||||
"output_path": "./cot_question_with_answer.json"
|
||||
},
|
||||
"inference":{
|
||||
"base_url": "ENDPOINT",
|
||||
"api_key": "TOKEN",
|
||||
"stream": true,
|
||||
"prompt" : "Your role as an assistant involves thoroughly exploring questions through a systematic long thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution. In the Thought section, detail your reasoning process using the specified format: <|begin_of_thought|> {thought with steps separated with '\n\n'} <|end_of_thought|> Each step should include detailed considerations such as analisying questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The solution should remain a logical, accurate, concise expression style and detail necessary step needed to reach the conclusion, formatted as follows: <|begin_of_solution|> {final formatted, precise, and clear solution} <|end_of_solution|> Now, try to solve the following question through the above guidelines:",
|
||||
"max_new_tokens": 1024
|
||||
}
|
||||
}
|
22
configs/cot_generation_batch.json
Normal file
22
configs/cot_generation_batch.json
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"job_type": "cot_generation_batch",
|
||||
"dataset": {
|
||||
"input_path": "./cot_question.json",
|
||||
"output_path": "./cot_question_with_answer.json",
|
||||
"template" : "./chat_template/chat_template_kd.jinja"
|
||||
},
|
||||
"models": {
|
||||
"teacher": "teacher/Qwen/Qwen2.5-7B-Instruct/"
|
||||
},
|
||||
"inference":{
|
||||
"prompt" : "Your role as an assistant involves thoroughly exploring questions through a systematic long thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution. In the Thought section, detail your reasoning process using the specified format: <|begin_of_thought|> {thought with steps separated with '\n\n'} <|end_of_thought|> Each step should include detailed considerations such as analisying questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The solution should remain a logical, accurate, concise expression style and detail necessary step needed to reach the conclusion, formatted as follows: <|begin_of_solution|> {final formatted, precise, and clear solution} <|end_of_solution|> Now, try to solve the following question through the above guidelines:",
|
||||
"enable_chunked_prefill": true,
|
||||
"seed": 777,
|
||||
"gpu_memory_utilization": 0.9,
|
||||
"temperature": 0.8,
|
||||
"trust_remote_code": true,
|
||||
"enforce_eager": false,
|
||||
"max_model_len": 4096,
|
||||
"max_new_tokens": 512
|
||||
}
|
||||
}
|
14
configs/cot_long2short_api.json
Normal file
14
configs/cot_long2short_api.json
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"job_type": "cot_long2short_api",
|
||||
"dataset": {
|
||||
"input_path": "./raw.json",
|
||||
"output_path": "./raw_simplified.json"
|
||||
},
|
||||
"inference":{
|
||||
"base_url": "ENDPOINT",
|
||||
"api_key": "TOKEN",
|
||||
"stream": true,
|
||||
"prompt" : "You are a helpful assistant who is highly skilled at simplifying reasoning processes. Given a problem, its answer and its reasoning process, your task is to simplify the reasoning process so that a small language model (e.g., a 7B model) can reliably follow the steps to solve the problem. If the original reasoning process is divided into multiple steps separated by two newline characters (\n\n), your output must preserve this formatting. You must output ONLY the simplified reasoning process with no additional explanation or commentary.",
|
||||
"max_new_tokens": 1024
|
||||
}
|
||||
}
|
22
configs/cot_long2short_batch.json
Normal file
22
configs/cot_long2short_batch.json
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"job_type": "cot_long2short_batch",
|
||||
"dataset": {
|
||||
"input_path": "./train.json",
|
||||
"output_path": "./train_simplified.json",
|
||||
"template" : "./chat_template/chat_template_kd.jinja"
|
||||
},
|
||||
"models": {
|
||||
"teacher": "teacher/Qwen/Qwen2.5-7B-Instruct/"
|
||||
},
|
||||
"inference":{
|
||||
"prompt" : "You are a helpful assistant who is highly skilled at simplifying reasoning processes. Given a problem, its answer and its reasoning process, your task is to simplify the reasoning process so that a small language model (e.g., a 7B model) can reliably follow the steps to solve the problem. If the original reasoning process is divided into multiple steps separated by two newline characters (\n\n), your output must preserve this formatting. You must output ONLY the simplified reasoning process with no additional explanation or commentary.",
|
||||
"enable_chunked_prefill": true,
|
||||
"seed": 777,
|
||||
"gpu_memory_utilization": 0.9,
|
||||
"temperature": 0.8,
|
||||
"trust_remote_code": true,
|
||||
"enforce_eager": false,
|
||||
"max_model_len": 4096,
|
||||
"max_new_tokens": 512
|
||||
}
|
||||
}
|
14
configs/cot_short2long_api.json
Normal file
14
configs/cot_short2long_api.json
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"job_type": "cot_short2long_api",
|
||||
"dataset": {
|
||||
"input_path": "./raw.json",
|
||||
"output_path": "./raw_extended.json"
|
||||
},
|
||||
"inference":{
|
||||
"base_url": "ENDPOINT",
|
||||
"api_key": "TOKEN",
|
||||
"stream": true,
|
||||
"prompt" : "You are a helpful assistant who is highly skilled at extending reasoning processes. Given a problem ,its answer and its reasoning process, your task is to extend the reasoning process by adding necessary details and intermediate steps, so that a small language model (e.g., a 7B model) can follow the extended reasoning process to solve the problem. If the original reasoning process is divided into multiple steps separated by two newline characters (\\n\\n), your output must preserve this formatting. You must output ONLY the extended reasoning process with no additional explanation or commentary.",
|
||||
"max_new_tokens": 1024
|
||||
}
|
||||
}
|
22
configs/cot_short2long_batch.json
Normal file
22
configs/cot_short2long_batch.json
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"job_type": "cot_short2long_batch",
|
||||
"dataset": {
|
||||
"input_path": "./train.json",
|
||||
"output_path": "./train_extended.json",
|
||||
"template" : "./chat_template/chat_template_kd.jinja"
|
||||
},
|
||||
"models": {
|
||||
"teacher": "teacher/Qwen/Qwen2.5-7B-Instruct/"
|
||||
},
|
||||
"inference":{
|
||||
"prompt" : "You are a helpful assistant who is highly skilled at extending reasoning processes. Given a problem ,its answer and its reasoning process, your task is to extend the reasoning process by adding necessary details and intermediate steps, so that a small language model (e.g., a 7B model) can follow the extended reasoning process to solve the problem. If the original reasoning process is divided into multiple steps separated by two newline characters (\\n\\n), your output must preserve this formatting. You must output ONLY the extended reasoning process with no additional explanation or commentary.",
|
||||
"enable_chunked_prefill": true,
|
||||
"seed": 777,
|
||||
"gpu_memory_utilization": 0.9,
|
||||
"temperature": 0.8,
|
||||
"trust_remote_code": true,
|
||||
"enforce_eager": false,
|
||||
"max_model_len": 4096,
|
||||
"max_new_tokens": 512
|
||||
}
|
||||
}
|
16
configs/instruction_expansion_api.json
Normal file
16
configs/instruction_expansion_api.json
Normal file
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"job_type": "instruction_expansion_api",
|
||||
"dataset": {
|
||||
"input_path": "./train.json",
|
||||
"output_path": "./train_extended.json",
|
||||
"num_in_context_samples": 3,
|
||||
"num_output_samples": 10
|
||||
},
|
||||
"inference":{
|
||||
"base_url": "ENDPOINT",
|
||||
"api_key": "TOKEN",
|
||||
"stream": true,
|
||||
"prompt" : "Assume you are a data synthesis expert. Given a few instructions as in-context examples, you should generate a new instruction similar to the examples to support the training of large language models. You should place your answer enclosed within <answer></answer> tags. The examples are as follows:",
|
||||
"max_new_tokens": 512
|
||||
}
|
||||
}
|
24
configs/instruction_expansion_batch.json
Normal file
24
configs/instruction_expansion_batch.json
Normal file
@@ -0,0 +1,24 @@
|
||||
{
|
||||
"job_type": "instruction_expansion_batch",
|
||||
"dataset": {
|
||||
"input_path": "./train.json",
|
||||
"output_path": "./train_extended.json",
|
||||
"template" : "./chat_template/chat_template_kd.jinja",
|
||||
"num_in_context_samples": 3,
|
||||
"num_output_samples": 10
|
||||
},
|
||||
"models": {
|
||||
"teacher": "teacher/Qwen/Qwen2.5-7B-Instruct/"
|
||||
},
|
||||
"inference":{
|
||||
"prompt" : "Assume you are a data synthesis expert. Given a few instructions as in-context examples, you should generate a new instruction similar to the examples to support the training of large language models. You should place your answer enclosed within <answer></answer> tags. The examples are as follows:",
|
||||
"enable_chunked_prefill": true,
|
||||
"seed": 777,
|
||||
"gpu_memory_utilization": 0.9,
|
||||
"temperature": 0.8,
|
||||
"trust_remote_code": true,
|
||||
"enforce_eager": false,
|
||||
"max_model_len": 4096,
|
||||
"max_new_tokens": 512
|
||||
}
|
||||
}
|
14
configs/instruction_refinement_api.json
Normal file
14
configs/instruction_refinement_api.json
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"job_type": "instruction_refinement_api",
|
||||
"dataset": {
|
||||
"input_path": "./train.json",
|
||||
"output_path": "./train_refined.json"
|
||||
},
|
||||
"inference":{
|
||||
"base_url": "ENDPOINT",
|
||||
"api_key": "TOKEN",
|
||||
"stream": true,
|
||||
"prompt" : "Assume you are a prompt re-writing expert. Given an instruction as input, you should generate a new instruction semantically similar to the input to support the training of large language models. Transform the input raw prompt into a detailed prompt that comprehensively captures the user’s request. Make sure to maintain the original intent while significantly enhancing clarity and depth. You should place your answer enclosed within <answer></answer> tags. The input prompt is as follows:",
|
||||
"max_new_tokens": 512
|
||||
}
|
||||
}
|
22
configs/instruction_refinement_batch.json
Normal file
22
configs/instruction_refinement_batch.json
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"job_type": "instruction_refinement_batch",
|
||||
"dataset": {
|
||||
"input_path": "./train.json",
|
||||
"output_path": "./train_refined.json",
|
||||
"template" : "./chat_template/chat_template_kd.jinja"
|
||||
},
|
||||
"models": {
|
||||
"teacher": "teacher/Qwen/Qwen2.5-7B-Instruct/"
|
||||
},
|
||||
"inference": {
|
||||
"prompt" : "Assume you are a prompt re-writing expert. Given an instruction as input, you should generate a new instruction semantically similar to the input to support the training of large language models. Transform the input raw prompt into a detailed prompt that comprehensively captures the user’s request. Make sure to maintain the original intent while significantly enhancing clarity and depth. You should place your answer enclosed within <answer></answer> tags. The input prompt is as follows:",
|
||||
"enable_chunked_prefill": true,
|
||||
"seed": 777,
|
||||
"gpu_memory_utilization": 0.9,
|
||||
"temperature": 0.8,
|
||||
"trust_remote_code": true,
|
||||
"enforce_eager": false,
|
||||
"max_model_len": 4096,
|
||||
"max_new_tokens": 512
|
||||
}
|
||||
}
|
14
configs/instruction_response_extraction_api.json
Normal file
14
configs/instruction_response_extraction_api.json
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"job_type": "instruction_response_extraction_api",
|
||||
"dataset": {
|
||||
"input_path": "./raw.json",
|
||||
"output_path": "./raw_extracted.json"
|
||||
},
|
||||
"inference":{
|
||||
"base_url": "ENDPOINT",
|
||||
"api_key": "TOKEN",
|
||||
"stream": true,
|
||||
"prompt" : "Assume you are a data synthesis expert. Given plain text as input, you should generate an instruction-response pair where the instruction and the response are derived from the knowledge of the plain text to support the training of large language models. The response should properly answer the instruction. You should place your instruction enclosed within <instruction></instruction> tags, and place your response enclosed within <response></response> tags. The input plain text is as follows:",
|
||||
"max_new_tokens": 1024
|
||||
}
|
||||
}
|
22
configs/instruction_response_extraction_batch.json
Normal file
22
configs/instruction_response_extraction_batch.json
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"job_type": "instruction_response_extraction_batch",
|
||||
"dataset": {
|
||||
"input_path": "./train.json",
|
||||
"output_path": "./train_extended.json",
|
||||
"template" : "./chat_template/chat_template_kd.jinja"
|
||||
},
|
||||
"models": {
|
||||
"teacher": "teacher/Qwen/Qwen2.5-7B-Instruct/"
|
||||
},
|
||||
"inference":{
|
||||
"prompt" : "Assume you are a data synthesis expert. Given plain text as input, you should generate an instruction-response pair where the instruction and the response are derived from the knowledge of the plain text to support the training of large language models. The response should properly answer the instruction. You should place your instruction enclosed within <instruction></instruction> tags, and place your response enclosed within <response></response> tags. The input plain text is as follows:",
|
||||
"enable_chunked_prefill": true,
|
||||
"seed": 777,
|
||||
"gpu_memory_utilization": 0.9,
|
||||
"temperature": 0.8,
|
||||
"trust_remote_code": true,
|
||||
"enforce_eager": false,
|
||||
"max_model_len": 4096,
|
||||
"max_new_tokens": 512
|
||||
}
|
||||
}
|
32
configs/kd_black_box_api.json
Normal file
32
configs/kd_black_box_api.json
Normal file
@@ -0,0 +1,32 @@
|
||||
{
|
||||
"job_type": "kd_black_box_api",
|
||||
"dataset": {
|
||||
"instruction_path": "train.json",
|
||||
"labeled_path": "train_labeled.json",
|
||||
"template" : "./chat_template/chat_template_kd.jinja",
|
||||
"seed": 42
|
||||
},
|
||||
"inference":{
|
||||
"base_url": "ENDPOINT",
|
||||
"api_key": "TOKEN",
|
||||
"stream": true,
|
||||
"system_prompt" : "You are a helpful assistant.",
|
||||
"max_new_tokens": 512
|
||||
},
|
||||
"models": {
|
||||
"student": "student/Qwen/Qwen2.5-0.5B-Instruct/"
|
||||
},
|
||||
"training": {
|
||||
"output_dir": "./result/",
|
||||
"num_train_epochs": 3,
|
||||
"per_device_train_batch_size": 1,
|
||||
"gradient_accumulation_steps": 8,
|
||||
"max_length":512,
|
||||
"save_steps": 1000,
|
||||
"logging_steps": 1,
|
||||
"learning_rate": 2e-5,
|
||||
"weight_decay": 0.05,
|
||||
"warmup_ratio": 0.1,
|
||||
"lr_scheduler_type": "cosine"
|
||||
}
|
||||
}
|
36
configs/kd_black_box_local.json
Normal file
36
configs/kd_black_box_local.json
Normal file
@@ -0,0 +1,36 @@
|
||||
{
|
||||
"job_type": "kd_black_box_local",
|
||||
"dataset": {
|
||||
"instruction_path": "train.json",
|
||||
"labeled_path": "train_labeled.json",
|
||||
"template" : "./chat_template/chat_template_kd.jinja",
|
||||
"seed": 42
|
||||
},
|
||||
"inference":{
|
||||
"enable_chunked_prefill": true,
|
||||
"seed": 777,
|
||||
"gpu_memory_utilization": 0.9,
|
||||
"temperature": 0.8,
|
||||
"trust_remote_code": true,
|
||||
"enforce_eager": false,
|
||||
"max_model_len": 4096,
|
||||
"max_new_tokens": 512
|
||||
},
|
||||
"models": {
|
||||
"teacher": "teacher/Qwen/Qwen2.5-7B-Instruct/",
|
||||
"student": "student/Qwen/Qwen2.5-0.5B-Instruct/"
|
||||
},
|
||||
"training": {
|
||||
"output_dir": "./result/",
|
||||
"num_train_epochs": 3,
|
||||
"per_device_train_batch_size": 1,
|
||||
"gradient_accumulation_steps": 8,
|
||||
"max_length":512,
|
||||
"save_steps": 1000,
|
||||
"logging_steps": 1,
|
||||
"learning_rate": 2e-5,
|
||||
"weight_decay": 0.05,
|
||||
"warmup_ratio": 0.1,
|
||||
"lr_scheduler_type": "cosine"
|
||||
}
|
||||
}
|
42
configs/kd_white_box.json
Normal file
42
configs/kd_white_box.json
Normal file
@@ -0,0 +1,42 @@
|
||||
{
|
||||
"job_type": "kd_white_box",
|
||||
"dataset": {
|
||||
"instruction_path": "train.json",
|
||||
"labeled_path": "train_labeled.json",
|
||||
"logits_path": "./logits.json",
|
||||
"template" : "./chat_template/chat_template_kd.jinja",
|
||||
"seed": 42
|
||||
},
|
||||
"inference":{
|
||||
"enable_chunked_prefill": true,
|
||||
"seed": 777,
|
||||
"gpu_memory_utilization": 0.9,
|
||||
"temperature": 0.8,
|
||||
"trust_remote_code": true,
|
||||
"enforce_eager": false,
|
||||
"max_model_len": 4096,
|
||||
"max_new_tokens": 512,
|
||||
"top_logits_num": 10
|
||||
},
|
||||
"distillation": {
|
||||
"kd_ratio": 0.5,
|
||||
"max_seq_length": 512,
|
||||
"distillation_type": "forward_kld"
|
||||
},
|
||||
"models": {
|
||||
"teacher": "teacher/Qwen/Qwen2.5-7B-Instruct/",
|
||||
"student": "student/Qwen/Qwen2.5-0.5B-Instruct/"
|
||||
},
|
||||
"training": {
|
||||
"output_dir": "./result/",
|
||||
"num_train_epochs": 3,
|
||||
"per_device_train_batch_size": 1,
|
||||
"gradient_accumulation_steps": 8,
|
||||
"save_steps": 1000,
|
||||
"logging_steps": 1,
|
||||
"learning_rate": 2e-5,
|
||||
"weight_decay": 0.05,
|
||||
"warmup_ratio": 0.1,
|
||||
"lr_scheduler_type": "cosine"
|
||||
}
|
||||
}
|
32
configs/rank_dpo_api.json
Normal file
32
configs/rank_dpo_api.json
Normal file
@@ -0,0 +1,32 @@
|
||||
{
|
||||
"job_type": "rank_dpo_api",
|
||||
"dataset": {
|
||||
"instruction_path": "train.json",
|
||||
"labeled_path": "train_labeled.json",
|
||||
"template" : "chat_template/chat_template_kd.jinja",
|
||||
"seed": 42
|
||||
},
|
||||
"inference":{
|
||||
"base_url": "ENDPOINT",
|
||||
"api_key": "TOKEN",
|
||||
"stream": true,
|
||||
"system_prompt" : "You are a helpful assistant.",
|
||||
"max_new_tokens": 512
|
||||
},
|
||||
"models": {
|
||||
"student": "student/Qwen/Qwen2.5-0.5B-Instruct/"
|
||||
},
|
||||
"training": {
|
||||
"output_dir": "./result/",
|
||||
"num_train_epochs": 3,
|
||||
"per_device_train_batch_size": 1,
|
||||
"gradient_accumulation_steps": 8,
|
||||
"save_steps": 1000,
|
||||
"logging_steps": 1,
|
||||
"beta": 0.1,
|
||||
"learning_rate": 2e-5,
|
||||
"weight_decay": 0.05,
|
||||
"warmup_ratio": 0.1,
|
||||
"lr_scheduler_type": "cosine"
|
||||
}
|
||||
}
|
37
configs/rank_dpo_local.json
Normal file
37
configs/rank_dpo_local.json
Normal file
@@ -0,0 +1,37 @@
|
||||
{
|
||||
"job_type": "rank_dpo_api",
|
||||
"dataset": {
|
||||
"instruction_path": "train.json",
|
||||
"labeled_path": "train_labeled.json",
|
||||
"template" : "chat_template/chat_template_kd.jinja",
|
||||
"seed": 42
|
||||
},
|
||||
"inference":{
|
||||
"system_prompt" : "You are a helpful assistant.",
|
||||
"enable_chunked_prefill": true,
|
||||
"seed": 777,
|
||||
"gpu_memory_utilization": 0.9,
|
||||
"temperature": 0.8,
|
||||
"trust_remote_code": true,
|
||||
"enforce_eager": false,
|
||||
"max_model_len": 4096,
|
||||
"max_new_tokens": 512
|
||||
},
|
||||
"models": {
|
||||
"teacher": "teacher/Qwen/Qwen2.5-7B-Instruct/",
|
||||
"student": "student/Qwen/Qwen2.5-0.5B-Instruct/"
|
||||
},
|
||||
"training": {
|
||||
"output_dir": "./result/",
|
||||
"num_train_epochs": 3,
|
||||
"per_device_train_batch_size": 1,
|
||||
"gradient_accumulation_steps": 8,
|
||||
"save_steps": 1000,
|
||||
"logging_steps": 1,
|
||||
"beta": 0.1,
|
||||
"learning_rate": 2e-5,
|
||||
"weight_decay": 0.05,
|
||||
"warmup_ratio": 0.1,
|
||||
"lr_scheduler_type": "cosine"
|
||||
}
|
||||
}
|
25
configs/rl_grpo.json
Normal file
25
configs/rl_grpo.json
Normal file
@@ -0,0 +1,25 @@
|
||||
{
|
||||
"job_type": "rl_grpo",
|
||||
"dataset": {
|
||||
"instruction_path": "sample.json",
|
||||
"template" : "chat_template_kd.jinja",
|
||||
"train_ratio": 0.7,
|
||||
"seed": 42
|
||||
},
|
||||
"models": {
|
||||
"reward": "reward/",
|
||||
"student": "Qwen/Qwen2.5-0.5B-Instruct"
|
||||
},
|
||||
"training": {
|
||||
"output_dir": "./result/",
|
||||
"per_device_train_batch_size": 1,
|
||||
"gradient_accumulation_steps": 8,
|
||||
"num_train_epochs": 3,
|
||||
"save_steps": 100,
|
||||
"logging_steps": 1,
|
||||
"learning_rate": 2e-5,
|
||||
"weight_decay": 0.05,
|
||||
"warmup_ratio": 0.1,
|
||||
"lr_scheduler_type": "cosine"
|
||||
}
|
||||
}
|
28
configs/rl_ppo.json
Normal file
28
configs/rl_ppo.json
Normal file
@@ -0,0 +1,28 @@
|
||||
{
|
||||
"job_type": "rl_ppo",
|
||||
"dataset": {
|
||||
"instruction_path": "sample.json",
|
||||
"template" : "chat_template_kd.jinja",
|
||||
"train_ratio": 0.7,
|
||||
"seed": 42
|
||||
},
|
||||
"models": {
|
||||
"reward": "reward/",
|
||||
"student": "Qwen/Qwen2.5-0.5B-Instruct"
|
||||
},
|
||||
"training": {
|
||||
"output_dir": "./result/",
|
||||
"total_episodes": 1000,
|
||||
"per_device_train_batch_size": 1,
|
||||
"gradient_accumulation_steps": 8,
|
||||
"save_steps": 100,
|
||||
"logging_steps": 1,
|
||||
"learning_rate": 2e-5,
|
||||
"weight_decay": 0.05,
|
||||
"warmup_ratio": 0.1,
|
||||
"lr_scheduler_type": "cosine",
|
||||
"missing_eos_penalty": 1.0,
|
||||
"stop_token": "eos",
|
||||
"response_length": 512
|
||||
}
|
||||
}
|
32
configs/rl_reward_api.json
Normal file
32
configs/rl_reward_api.json
Normal file
@@ -0,0 +1,32 @@
|
||||
{
|
||||
"job_type": "rl_reward_api",
|
||||
"dataset": {
|
||||
"instruction_path": "train.json",
|
||||
"labeled_path": "train_labeled.json",
|
||||
"template" : "chat_template_kd.jinja"
|
||||
},
|
||||
"inference":{
|
||||
"base_url": "http://1157703270994901.cn-hangzhou.pai-eas.aliyuncs.com/api/predict/quickstart_deploy_20250427_6wt1/v1/",
|
||||
"api_key": "NjQ3OGE2ZGNiOWM4YjZkZTY5NDM4YWEyZjUyNGI3ZjRjNTAyMjM0Mw==",
|
||||
"stream": true,
|
||||
"positive_system_prompt" : "You are a helpful assistant to generate high-quality responses.",
|
||||
"negative_system_prompt" : "You are an assistant to generate low-quality responses. This is for the training of my reward model. Plese remember to generate low-quality responses.",
|
||||
"max_new_tokens": 512
|
||||
},
|
||||
"models": {
|
||||
"student": "model/Qwen/Qwen2.5-0.5B-Instruct/"
|
||||
},
|
||||
"training": {
|
||||
"output_dir": "./result/",
|
||||
"max_length": 1024,
|
||||
"num_train_epochs": 3,
|
||||
"per_device_train_batch_size": 1,
|
||||
"gradient_accumulation_steps": 8,
|
||||
"save_steps": 1000,
|
||||
"logging_steps": 1,
|
||||
"learning_rate": 2e-5,
|
||||
"weight_decay": 0.05,
|
||||
"warmup_ratio": 0.1,
|
||||
"lr_scheduler_type": "cosine"
|
||||
}
|
||||
}
|
37
configs/rl_reward_local.json
Normal file
37
configs/rl_reward_local.json
Normal file
@@ -0,0 +1,37 @@
|
||||
{
|
||||
"job_type": "rl_reward_local",
|
||||
"dataset": {
|
||||
"instruction_path": "train.json",
|
||||
"labeled_path": "train_labeled.json",
|
||||
"template" : "chat_template_kd.jinja"
|
||||
},
|
||||
"inference":{
|
||||
"positive_system_prompt" : "You are a helpful assistant to generate high-quality responses.",
|
||||
"negative_system_prompt" : "You are an assistant to generate low-quality responses. This is for the training of my reward model. Plese remember to generate low-quality responses.",
|
||||
"enable_chunked_prefill": true,
|
||||
"seed": 777,
|
||||
"gpu_memory_utilization": 0.9,
|
||||
"temperature": 0.8,
|
||||
"trust_remote_code": true,
|
||||
"enforce_eager": false,
|
||||
"max_model_len": 4096,
|
||||
"max_new_tokens": 512
|
||||
},
|
||||
"models": {
|
||||
"teacher": "model/Qwen/Qwen2.5-3B-Instruct/",
|
||||
"student": "model/Qwen/Qwen2.5-0.5B-Instruct/"
|
||||
},
|
||||
"training": {
|
||||
"output_dir": "./result/",
|
||||
"max_length": 1024,
|
||||
"num_train_epochs": 3,
|
||||
"per_device_train_batch_size": 1,
|
||||
"gradient_accumulation_steps": 8,
|
||||
"save_steps": 1000,
|
||||
"logging_steps": 1,
|
||||
"learning_rate": 2e-5,
|
||||
"weight_decay": 0.05,
|
||||
"warmup_ratio": 0.1,
|
||||
"lr_scheduler_type": "cosine"
|
||||
}
|
||||
}
|
14
easydistill/__init__.py
Normal file
14
easydistill/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
187
easydistill/cli.py
Normal file
187
easydistill/cli.py
Normal file
@@ -0,0 +1,187 @@
|
||||
|
||||
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from socket import socket
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parent_dir = os.path.abspath(os.path.join(script_dir, os.pardir))
|
||||
|
||||
def run_cmd(cmd):
|
||||
try:
|
||||
p = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT, # Merge stderr into stdout
|
||||
shell=True,
|
||||
universal_newlines=True # Ensure output is in text mode
|
||||
)
|
||||
|
||||
error_detected = False
|
||||
error_keywords = [
|
||||
"ERROR",
|
||||
"Error",
|
||||
"error"
|
||||
"Unrecognized model",
|
||||
"failed",
|
||||
"exception",
|
||||
"Traceback"
|
||||
]
|
||||
|
||||
# Read output in real-time and detect errors
|
||||
while True:
|
||||
line = p.stdout.readline()
|
||||
if not line:
|
||||
break
|
||||
logging.info(line.rstrip()) # Log normally
|
||||
|
||||
# Check if any error keywords are present
|
||||
if any(keyword.lower() in line.lower() for keyword in error_keywords):
|
||||
error_detected = True
|
||||
logging.error(f"Detected error in output: {line.strip()}")
|
||||
|
||||
# Wait for process to finish
|
||||
returncode = p.wait()
|
||||
|
||||
# If errors were detected or return code is non-zero, return False
|
||||
if error_detected or returncode != 0:
|
||||
logging.error(f"Command failed (returncode={returncode}, errors detected)")
|
||||
return False
|
||||
|
||||
return True # Return True indicates success
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Unexpected error running command: {e}")
|
||||
return False
|
||||
|
||||
def process(job_type, config):
|
||||
if not os.path.isabs(config):
|
||||
config = os.path.join(script_dir, config)
|
||||
|
||||
# Knowledge Distillation tasks
|
||||
if job_type in ['kd_black_box_train_only', 'kd_white_box_train_only']:
|
||||
cmd_train = [
|
||||
'accelerate', 'launch',
|
||||
'--config_file', os.path.join(parent_dir, 'configs/accelerate_config/muti_gpu.yaml'),
|
||||
os.path.join(script_dir, 'kd/train.py'),
|
||||
'--config', config
|
||||
]
|
||||
cmd_train = ' '.join(cmd_train)
|
||||
logging.info(f"Running command: {cmd_train}")
|
||||
run_cmd(cmd_train)
|
||||
|
||||
elif job_type in ['kd_black_box_api', 'kd_black_box_local', 'kd_white_box']:
|
||||
cmd_infer = [
|
||||
'python', os.path.join(script_dir, 'kd/infer.py'),
|
||||
'--config', config
|
||||
]
|
||||
cmd_infer = ' '.join(cmd_infer)
|
||||
logging.info(f"Running command: {cmd_infer}")
|
||||
infer_success = run_cmd(cmd_infer)
|
||||
if infer_success:
|
||||
cmd_train = [
|
||||
'accelerate', 'launch',
|
||||
'--config_file', os.path.join(parent_dir, 'configs/accelerate_config/muti_gpu.yaml'),
|
||||
os.path.join(script_dir, 'kd/train.py'),
|
||||
'--config', config
|
||||
]
|
||||
cmd_train = ' '.join(cmd_train)
|
||||
logging.info(f"Running command: {cmd_train}")
|
||||
run_cmd(cmd_train)
|
||||
else:
|
||||
logging.error("Infer failed, skipping training")
|
||||
|
||||
# Reinforcement Learning tasks
|
||||
elif job_type in ['rl_ppo', 'rl_grpo']:
|
||||
cmd = [
|
||||
'accelerate', 'launch',
|
||||
'--config_file', os.path.join(parent_dir, 'configs/accelerate_config/muti_gpu.yaml'),
|
||||
os.path.join(script_dir, f'rl/{job_type.split("_")[1]}.py'),
|
||||
'--config', config
|
||||
]
|
||||
cmd = ' '.join(cmd)
|
||||
logging.info(f"Running command: {cmd}")
|
||||
run_cmd(cmd)
|
||||
|
||||
elif job_type in ['rl_reward_api', 'rl_reward_local']:
|
||||
cmd = [
|
||||
'python',
|
||||
os.path.join(script_dir, 'rl/reward.py'),
|
||||
'--config', config
|
||||
]
|
||||
cmd = ' '.join(cmd)
|
||||
logging.info(f"Running command: {cmd}")
|
||||
run_cmd(cmd)
|
||||
|
||||
# Instruction Processing tasks
|
||||
elif job_type.startswith('instruction_'):
|
||||
task_type = job_type.replace('instruction_', '')
|
||||
cmd = [
|
||||
'python',
|
||||
os.path.join(script_dir, f'synthesis/synthesis_main.py'),
|
||||
'--config', config
|
||||
]
|
||||
cmd = ' '.join(cmd)
|
||||
logging.info(f"Running command: {cmd}")
|
||||
run_cmd(cmd)
|
||||
|
||||
# Chain of Thought tasks
|
||||
elif job_type.startswith('cot_'):
|
||||
task_type = job_type.replace('cot_', '')
|
||||
cmd = [
|
||||
'python',
|
||||
os.path.join(script_dir, f'synthesis/synthesis_main.py'),
|
||||
'--config', config
|
||||
]
|
||||
cmd = ' '.join(cmd)
|
||||
logging.info(f"Running command: {cmd}")
|
||||
run_cmd(cmd)
|
||||
|
||||
# Ranking and DPO tasks
|
||||
elif job_type.startswith('rank_'):
|
||||
task_type = job_type.replace('rank_', '')
|
||||
cmd = [
|
||||
'python',
|
||||
os.path.join(script_dir, f'rank/{task_type}.py'),
|
||||
'--config', config
|
||||
]
|
||||
cmd = ' '.join(cmd)
|
||||
logging.info(f"Running command: {cmd}")
|
||||
run_cmd(cmd)
|
||||
|
||||
else:
|
||||
logging.error(f"Unknown job type: {job_type}")
|
||||
sys.exit(1)
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
|
||||
args = parser.parse_args()
|
||||
config_path = args.config
|
||||
config = json.load(open(config_path))
|
||||
job_type = config["job_type"]
|
||||
process(job_type, config_path)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
247
easydistill/kd/infer.py
Normal file
247
easydistill/kd/infer.py
Normal file
@@ -0,0 +1,247 @@
|
||||
|
||||
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import json, jsonlines
|
||||
import argparse
|
||||
import torch
|
||||
import logging
|
||||
import os
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from vllm import LLM, SamplingParams
|
||||
from tqdm import tqdm
|
||||
from openai import OpenAI
|
||||
import math
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
|
||||
def read_json_field(filename, field_name='instruction'):
|
||||
try:
|
||||
with open(filename, 'r') as file:
|
||||
data = json.load(file)
|
||||
output_fields = []
|
||||
for item in data:
|
||||
if field_name in item:
|
||||
output_fields.append(item[field_name])
|
||||
return output_fields
|
||||
except FileNotFoundError:
|
||||
logging.error("The file was not found.")
|
||||
except json.JSONDecodeError:
|
||||
logging.error("There was an error decoding the JSON file.")
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred: {e}")
|
||||
|
||||
|
||||
def write_data_to_json_file(data, file_path):
|
||||
try:
|
||||
with open(file_path, 'w') as file:
|
||||
json.dump(data, file, ensure_ascii=False, indent=4)
|
||||
logging.info(f"Data successfully written to {file_path}")
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred: {e}")
|
||||
|
||||
|
||||
def load_tokenizer_and_vllm(config, eos_token=None):
|
||||
teacher_model_path = config["models"]["teacher"]
|
||||
logging.info(f"Loading ckpt and tokenizer: {teacher_model_path}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(teacher_model_path, trust_remote_code=True)
|
||||
tokenizer.padding_side = "left"
|
||||
if eos_token:
|
||||
eos_token_id = tokenizer.convert_tokens_to_ids(eos_token)
|
||||
logging.info(f"eos_token {eos_token} from user input")
|
||||
elif hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id:
|
||||
logging.info(f"Initial eos_token_id {tokenizer.eos_token_id} from tokenizer")
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
eos_token = tokenizer.convert_ids_to_tokens(eos_token_id)
|
||||
else:
|
||||
raise ValueError("No available eos_token or eos_token_id.")
|
||||
try:
|
||||
tokenizer.eos_token = eos_token
|
||||
tokenizer.eos_token_id = eos_token_id
|
||||
tokenizer.pad_token = eos_token
|
||||
tokenizer.pad_token_id = eos_token_id
|
||||
except:
|
||||
logging.info(f"[WARNING] Cannot set tokenizer.eos_token")
|
||||
logging.info(f"tokenizer's eos_token: {tokenizer.eos_token}, pad_token: {tokenizer.pad_token}")
|
||||
logging.info(f"tokenizer's eos_token_id: {tokenizer.eos_token_id}, pad_token_id: {tokenizer.pad_token_id}")
|
||||
num_gpus = torch.cuda.device_count()
|
||||
llm = LLM(
|
||||
model=teacher_model_path,
|
||||
tensor_parallel_size=num_gpus,
|
||||
enable_chunked_prefill=config["inference"]["enable_chunked_prefill"],
|
||||
gpu_memory_utilization=config["inference"]["gpu_memory_utilization"],
|
||||
trust_remote_code=config["inference"]["trust_remote_code"],
|
||||
dtype=torch.bfloat16,
|
||||
enforce_eager=config["inference"]["enforce_eager"],
|
||||
max_model_len=config["inference"]["max_model_len"],
|
||||
)
|
||||
logging.info("vLLM model loaded successfully")
|
||||
return tokenizer, llm
|
||||
|
||||
|
||||
def generate_teacher_response_batch(tokenizer, llm, data_list, config, batch_size=32):
|
||||
full_path = config["dataset"]["template"]
|
||||
template_dir = os.path.dirname(full_path)
|
||||
template_file = os.path.basename(full_path)
|
||||
env = Environment(loader=FileSystemLoader(template_dir))
|
||||
template = env.get_template(template_file)
|
||||
outcomes = []
|
||||
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
|
||||
for batch in tqdm(batches, desc="Generating responses"):
|
||||
new_batch = []
|
||||
for sample in batch:
|
||||
message = {"role": "user", "content": sample}
|
||||
full_text = template.render(
|
||||
message = message,
|
||||
add_generation_prompt = True,
|
||||
add_output = False
|
||||
)
|
||||
new_batch.append(full_text)
|
||||
outputs = llm.generate(
|
||||
new_batch,
|
||||
SamplingParams(
|
||||
n = 1,
|
||||
top_k = 1,
|
||||
temperature = config["inference"]["temperature"],
|
||||
seed = config["inference"]["seed"],
|
||||
skip_special_tokens = False,
|
||||
ignore_eos = False,
|
||||
max_tokens = config["inference"]["max_new_tokens"]
|
||||
)
|
||||
)
|
||||
responses = [output.outputs[0].text for output in outputs]
|
||||
gen_data = [{'instruction': batch[i], 'output': responses[i]} for i in range(len(batch))]
|
||||
outcomes = outcomes + gen_data
|
||||
write_data_to_json_file(outcomes, config["dataset"]["labeled_path"])
|
||||
|
||||
|
||||
def generate_teacher_logits_batch(tokenizer, llm, data_list, config, batch_size=32):
|
||||
full_path = config["dataset"]["template"]
|
||||
template_dir = os.path.dirname(full_path)
|
||||
template_file = os.path.basename(full_path)
|
||||
env = Environment(loader=FileSystemLoader(template_dir))
|
||||
template = env.get_template(template_file)
|
||||
|
||||
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
|
||||
for batch in tqdm(batches, desc="Generating responses"):
|
||||
new_batch = []
|
||||
for sample in batch:
|
||||
message={"role": "user", "content": sample}
|
||||
full_text = template.render(
|
||||
message=message,
|
||||
add_generation_prompt=True,
|
||||
add_output=False
|
||||
)
|
||||
new_batch.append(full_text)
|
||||
|
||||
outputs = llm.generate(
|
||||
new_batch, # Pass the raw text directly
|
||||
SamplingParams(
|
||||
n=1,
|
||||
top_k=1,
|
||||
temperature=config["inference"]["temperature"],
|
||||
seed=config["inference"]["seed"],
|
||||
skip_special_tokens=False,
|
||||
ignore_eos=True,
|
||||
max_tokens=config["inference"]["max_new_tokens"],
|
||||
logprobs=config["inference"]["top_logits_num"],
|
||||
)
|
||||
)
|
||||
# Extract the generated logits
|
||||
responses = [output.outputs[0].text for output in outputs]
|
||||
logits=[output.outputs[0].logprobs for output in outputs]
|
||||
for logit in logits:
|
||||
for pos in logit:
|
||||
for k,v in pos.items():
|
||||
pos[k]=math.exp(v.logprob)
|
||||
|
||||
with jsonlines.open(config["dataset"]["logits_path"], mode='a') as writer:
|
||||
for row in logits:
|
||||
#for item in row:
|
||||
writer.write(row)
|
||||
|
||||
|
||||
def generate_teacher_response_api(data_list, config):
|
||||
client = OpenAI(
|
||||
api_key = config["inference"]["api_key"],
|
||||
base_url = config["inference"]["base_url"]
|
||||
)
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
logging.info(model)
|
||||
system_prompt = config["inference"]["system_prompt"]
|
||||
stream = config["inference"]["stream"]
|
||||
outcomes = []
|
||||
for sample in tqdm(data_list, desc="Call remote model and generating responses"):
|
||||
if system_prompt == "":
|
||||
message = [
|
||||
{'role': 'user', 'content': sample}
|
||||
]
|
||||
else:
|
||||
message = [
|
||||
{'role': 'system', 'content': system_prompt},
|
||||
{'role': 'user', 'content': sample}
|
||||
]
|
||||
completion = client.chat.completions.create(
|
||||
messages = message,
|
||||
model = model,
|
||||
max_completion_tokens = config["inference"]["max_new_tokens"],
|
||||
stream = stream
|
||||
)
|
||||
if stream:
|
||||
result = ""
|
||||
for chunk in completion:
|
||||
result += chunk.choices[0].delta.content
|
||||
else:
|
||||
result = completion.choices[0].message.content
|
||||
|
||||
outcomes.append({'instruction': sample, 'output': result})
|
||||
write_data_to_json_file(outcomes, config["dataset"]["labeled_path"])
|
||||
|
||||
|
||||
def infer_with_teacher_model(config):
|
||||
logging.info('Generating distillation data from the teacher model!')
|
||||
data_list = read_json_field(config["dataset"]["instruction_path"])
|
||||
try:
|
||||
job_type = config["job_type"]
|
||||
if job_type == "kd_black_box_api":
|
||||
generate_teacher_response_api(data_list, config)
|
||||
elif job_type == "kd_black_box_local":
|
||||
tokenizer, llm = load_tokenizer_and_vllm(config)
|
||||
generate_teacher_response_batch(tokenizer, llm, data_list, config)
|
||||
elif job_type == "kd_white_box":
|
||||
tokenizer, llm = load_tokenizer_and_vllm(config)
|
||||
generate_teacher_logits_batch(tokenizer, llm, data_list, config)
|
||||
else:
|
||||
logging.error(f"Invalid job type: {job_type}")
|
||||
raise ValueError(f"Invalid job type: {job_type}")
|
||||
except ValueError as e:
|
||||
logging.error(f"Training job terminated: {e}")
|
||||
return
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
|
||||
args = parser.parse_args()
|
||||
config = json.load(open(args.config))
|
||||
infer_with_teacher_model(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
218
easydistill/kd/train.py
Normal file
218
easydistill/kd/train.py
Normal file
@@ -0,0 +1,218 @@
|
||||
|
||||
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import json
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from jinja2 import Environment, BaseLoader, FileSystemLoader
|
||||
from datasets import load_dataset,Dataset
|
||||
from typing import Optional, Dict, Union, List
|
||||
from datasets import Dataset
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase,AutoModelForCausalLM, AutoTokenizer, TrainingArguments
|
||||
from trl import SFTTrainer,SFTConfig
|
||||
import torch
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class DistillSFTTrainer(SFTTrainer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logits_dir: str = None,
|
||||
teacher_vocab_size = None,
|
||||
kd_ratio: float = 0.5,
|
||||
max_seq_length : int = 1024,
|
||||
distillation_type: str = "forward_kld",
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.logits_dir = logits_dir
|
||||
self.teacher_vocab_size = teacher_vocab_size
|
||||
self.kd_ratio = kd_ratio
|
||||
self.max_seq_length = max_seq_length
|
||||
self.distillation_type = distillation_type
|
||||
self.teacher_logits = []
|
||||
with jsonlines.open(self.logits_dir) as reader:
|
||||
for obj in reader:
|
||||
self.teacher_logits.append(obj)
|
||||
|
||||
|
||||
def _load_teacher_logits(self, batch_size: int, it: int, dp_rank: int, device: torch.device, no_model_batch: Dict):
|
||||
start_idx = dp_rank * batch_size + batch_size * it
|
||||
end_idx = dp_rank * batch_size + batch_size * (it + 1)
|
||||
loaded_data = self.teacher_logits[start_idx:end_idx]
|
||||
arr = np.zeros((batch_size, self.max_seq_length, self.teacher_vocab_size))
|
||||
for i in range(len(loaded_data)):
|
||||
for j in range(len(loaded_data[i])):
|
||||
keys = np.array(list(loaded_data[i][j].keys()), dtype=int)
|
||||
values = np.array(list(loaded_data[i][j].values()))
|
||||
arr[i, j, keys] = values
|
||||
|
||||
logits_tensor = torch.tensor(arr, dtype=torch.bfloat16, device=device)
|
||||
return self._shift_tensor_right(logits_tensor, no_model_batch['label'], pad_value=0)
|
||||
|
||||
|
||||
def _compute_white_box_distillation_loss(self, student_logits: torch.Tensor, teacher_logits: torch.Tensor, labels: Optional[torch.Tensor]):
|
||||
student_logits = student_logits[:, :self.max_seq_length, :]
|
||||
teacher_probs = teacher_logits[:, :student_logits.size(1), :student_logits.size(-1)]
|
||||
mask = (labels != -100).float() if labels is not None else torch.ones_like(student_logits[:, :, 0])
|
||||
|
||||
if self.distillation_type == "forward_kld":
|
||||
# Forward KLD: student learns from teacher (original implementation)
|
||||
loss = F.kl_div(
|
||||
F.log_softmax(student_logits, dim=-1),
|
||||
teacher_probs,
|
||||
reduction='none',
|
||||
log_target=False
|
||||
).sum(dim=-1)/torch.sum(mask.view(-1), dim=0)
|
||||
elif self.distillation_type == "reverse_kld":
|
||||
# Reverse KLD: teacher provides certainty to student
|
||||
loss = F.kl_div(
|
||||
torch.log(teacher_probs.clamp(min=1e-10)), # avoid log(0)
|
||||
F.softmax(student_logits, dim=-1),
|
||||
reduction='none',
|
||||
log_target=False
|
||||
).sum(dim=-1)/torch.sum(mask.view(-1), dim=0)
|
||||
else:
|
||||
raise ValueError(f"Unsupported distillation type: {self.distillation_type}. Use 'forward_kld' or 'reverse_kld'")
|
||||
|
||||
return (loss * mask).sum() / mask.sum()
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _shift_tensor_right(inputs: torch.Tensor, labels: torch.Tensor, pad_value: float = 0.0):
|
||||
batch_size, seqlen, vocab_size = inputs.shape
|
||||
device = inputs.device
|
||||
labels_ne = labels != -100
|
||||
shift_distances = torch.argmax(labels_ne.int(), dim=1)
|
||||
idx = torch.arange(seqlen, device=device).unsqueeze(0).expand(batch_size, seqlen)
|
||||
shifted_idx = idx - shift_distances.unsqueeze(1)
|
||||
mask = shifted_idx >= 0
|
||||
shifted_idx = shifted_idx.clamp(min=0)
|
||||
inputs_flat = inputs.view(batch_size, seqlen, vocab_size)
|
||||
shifted_idx = shifted_idx.unsqueeze(2).expand(-1, -1, vocab_size)
|
||||
gathered = torch.gather(inputs_flat, 1, shifted_idx)
|
||||
mask = mask.unsqueeze(2).expand(-1, -1, vocab_size)
|
||||
return torch.where(mask, gathered, torch.full_like(gathered, pad_value))
|
||||
|
||||
|
||||
def compute_loss(self, model: PreTrainedModel, inputs: Dict[str, torch.Tensor], return_outputs=False, num_items_in_batch=None):
|
||||
outputs = model(**inputs)
|
||||
lm_loss = outputs.loss
|
||||
if self.logits_dir:
|
||||
teacher_logits = self._load_teacher_logits(
|
||||
batch_size=inputs['input_ids'].size(0),
|
||||
it=self.state.global_step,
|
||||
dp_rank=torch.distributed.get_rank() if torch.distributed.is_initialized() else 0,
|
||||
device=model.device,
|
||||
no_model_batch={'label': inputs.get('labels', None)}
|
||||
)
|
||||
distil_loss = self._compute_white_box_distillation_loss(
|
||||
student_logits=outputs.logits,
|
||||
teacher_logits=teacher_logits,
|
||||
labels=inputs.get('labels', None)
|
||||
)
|
||||
total_loss = (1 - self.kd_ratio) * lm_loss + self.kd_ratio * distil_loss
|
||||
else:
|
||||
total_loss = lm_loss
|
||||
return (total_loss, outputs) if return_outputs else total_loss
|
||||
|
||||
|
||||
def formatting_func(examples):
|
||||
env = Environment(loader=BaseLoader())
|
||||
try:
|
||||
message = {"content": examples["instruction"],"output":examples["output"]}
|
||||
full_text = template.render(
|
||||
message=message,
|
||||
add_generation_prompt=False,
|
||||
add_output=True
|
||||
)
|
||||
return full_text
|
||||
except Exception as e:
|
||||
logging.warning(f"Error processing sample: {str(e)}")
|
||||
return ""
|
||||
|
||||
|
||||
def train(config):
|
||||
dataset = load_dataset("json", data_files=config["dataset"]["labeled_path"])
|
||||
|
||||
student_tokenizer = AutoTokenizer.from_pretrained(
|
||||
config["models"]["student"],
|
||||
trust_remote_code=True
|
||||
)
|
||||
student_model = AutoModelForCausalLM.from_pretrained(
|
||||
config["models"]["student"],
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
global template
|
||||
full_path = config["dataset"]["template"]
|
||||
template_dir = os.path.dirname(full_path)
|
||||
template_file = os.path.basename(full_path)
|
||||
env = Environment(loader=FileSystemLoader(template_dir))
|
||||
template = env.get_template(template_file)
|
||||
training_arguments = SFTConfig(**config["training"])
|
||||
|
||||
try:
|
||||
job_type = config["job_type"]
|
||||
if "kd_black_box" in job_type:
|
||||
dataset = dataset.shuffle(seed=config["dataset"]["seed"])
|
||||
trainer = SFTTrainer(
|
||||
model=student_model,
|
||||
processing_class=student_tokenizer,
|
||||
args=training_arguments,
|
||||
train_dataset=dataset["train"],
|
||||
formatting_func=formatting_func
|
||||
)
|
||||
elif "kd_white_box" in job_type:
|
||||
teacher_vocab_size=json.load(open(os.path.join(config["models"]["teacher"], 'config.json')))['vocab_size']
|
||||
trainer = DistillSFTTrainer(
|
||||
logits_dir=config["dataset"]["logits_path"],
|
||||
teacher_vocab_size=teacher_vocab_size,
|
||||
kd_ratio=config["distillation"]["kd_ratio"],
|
||||
max_seq_length=config["distillation"]["max_seq_length"],
|
||||
distillation_type=config["distillation"].get("distillation_type", "forward_kld"),
|
||||
model=student_model,
|
||||
processing_class=student_tokenizer,
|
||||
args=training_arguments,
|
||||
train_dataset=dataset["train"],
|
||||
formatting_func=formatting_func
|
||||
)
|
||||
else:
|
||||
logging.error(f"Invalid job type: {job_type}")
|
||||
raise ValueError(f"Invalid job type: {job_type}")
|
||||
except ValueError as e:
|
||||
logging.error(f"Training job terminated: {e}")
|
||||
return
|
||||
|
||||
trainer.train()
|
||||
trainer.save_model(config["training"]["output_dir"])
|
||||
student_tokenizer.save_pretrained(config["training"]["output_dir"])
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
|
||||
args = parser.parse_args()
|
||||
config = json.load(open(args.config))
|
||||
train(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
262
easydistill/rank/infer.py
Normal file
262
easydistill/rank/infer.py
Normal file
@@ -0,0 +1,262 @@
|
||||
|
||||
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import json
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from vllm import LLM, SamplingParams
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from tqdm import tqdm
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
|
||||
def read_json_field(filename, field_name='prompt'):
|
||||
try:
|
||||
with open(filename, 'r') as file:
|
||||
data = json.load(file)
|
||||
output_fields = []
|
||||
for item in data:
|
||||
if field_name in item:
|
||||
output_fields.append(item[field_name])
|
||||
return output_fields
|
||||
except FileNotFoundError:
|
||||
logging.error("The file was not found.")
|
||||
except json.JSONDecodeError:
|
||||
logging.error("There was an error decoding the JSON file.")
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred: {e}")
|
||||
|
||||
|
||||
def write_data_to_json_file(data, file_path):
|
||||
try:
|
||||
with open(file_path, 'w') as file:
|
||||
json.dump(data, file, ensure_ascii=False, indent=4)
|
||||
logging.info(f"Data successfully written to {file_path}")
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred: {e}")
|
||||
|
||||
|
||||
def load_tokenizer_and_vllm(config, eos_token=None, is_teacher_model=True):
|
||||
if is_teacher_model:
|
||||
model_path = config["models"]["teacher"]
|
||||
else:
|
||||
model_path = config["models"]["student"]
|
||||
logging.info(f"Loading ckpt and tokenizer: {model_path}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
tokenizer.padding_side = "left"
|
||||
if eos_token:
|
||||
eos_token_id = tokenizer.convert_tokens_to_ids(eos_token)
|
||||
logging.info(f"eos_token {eos_token} from user input")
|
||||
elif hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id:
|
||||
logging.info(f"Initial eos_token_id {tokenizer.eos_token_id} from tokenizer")
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
eos_token = tokenizer.convert_ids_to_tokens(eos_token_id)
|
||||
else:
|
||||
raise ValueError("No available eos_token or eos_token_id.")
|
||||
try:
|
||||
tokenizer.eos_token = eos_token
|
||||
tokenizer.eos_token_id = eos_token_id
|
||||
tokenizer.pad_token = eos_token
|
||||
tokenizer.pad_token_id = eos_token_id
|
||||
except:
|
||||
logging.info(f"[WARNING] Cannot set tokenizer.eos_token")
|
||||
logging.info(f"tokenizer's eos_token: {tokenizer.eos_token}, pad_token: {tokenizer.pad_token}")
|
||||
logging.info(f"tokenizer's eos_token_id: {tokenizer.eos_token_id}, pad_token_id: {tokenizer.pad_token_id}")
|
||||
num_gpus = torch.cuda.device_count()
|
||||
llm = LLM(
|
||||
model=model_path,
|
||||
tensor_parallel_size=num_gpus,
|
||||
enable_chunked_prefill=config["inference"]["enable_chunked_prefill"],
|
||||
gpu_memory_utilization=config["inference"]["gpu_memory_utilization"],
|
||||
trust_remote_code=config["inference"]["trust_remote_code"],
|
||||
dtype=torch.bfloat16,
|
||||
enforce_eager=config["inference"]["enforce_eager"],
|
||||
max_model_len=config["inference"]["max_model_len"],
|
||||
)
|
||||
logging.info("vLLM model loaded successfully")
|
||||
return tokenizer, llm
|
||||
|
||||
|
||||
def generate_teacher_student_response_api(data_list, config):
|
||||
client = OpenAI(
|
||||
api_key=config["inference"]["api_key"],
|
||||
base_url=config["inference"]["base_url"]
|
||||
)
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
logging.info(model)
|
||||
system_prompt = config["inference"]["system_prompt"]
|
||||
stream = config["inference"]["stream"]
|
||||
|
||||
# load student model
|
||||
student_tokenizer = AutoTokenizer.from_pretrained(
|
||||
config["models"]["student"],
|
||||
trust_remote_code=True
|
||||
)
|
||||
student_model = AutoModelForCausalLM.from_pretrained(
|
||||
config["models"]["student"],
|
||||
device_map="auto",
|
||||
trust_remote_code=True
|
||||
)
|
||||
outcomes = []
|
||||
for sample in tqdm(data_list, desc="Call remote model and generating responses"):
|
||||
# for teacher model
|
||||
if system_prompt == "":
|
||||
message=[
|
||||
{'role': 'user', 'content': sample}
|
||||
]
|
||||
else:
|
||||
message=[
|
||||
{'role': 'system', 'content': system_prompt},
|
||||
{'role': 'user', 'content': sample}
|
||||
]
|
||||
completion = client.chat.completions.create(
|
||||
messages=message,
|
||||
model=model,
|
||||
max_completion_tokens=config["inference"]["max_new_tokens"],
|
||||
stream=stream,
|
||||
)
|
||||
if stream:
|
||||
result = ""
|
||||
for chunk in completion:
|
||||
result += chunk.choices[0].delta.content
|
||||
else:
|
||||
result = completion.choices[0].message.content
|
||||
|
||||
# for student model
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": sample}
|
||||
]
|
||||
text = student_tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
)
|
||||
model_inputs = student_tokenizer([text], return_tensors="pt").to(student_model.device)
|
||||
|
||||
generated_ids = student_model.generate(
|
||||
**model_inputs,
|
||||
max_new_tokens=config["inference"]["max_new_tokens"]
|
||||
)
|
||||
generated_ids = [
|
||||
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
||||
]
|
||||
|
||||
rejected = student_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
gen_data = {'prompt': sample, 'chosen': result, 'rejected': rejected}
|
||||
outcomes.append(gen_data)
|
||||
write_data_to_json_file(outcomes, config["dataset"]["labeled_path"])
|
||||
|
||||
|
||||
def generate_model_response_batch(tokenizer, llm, data_list, config, batch_size=32, is_teacher_model=True):
|
||||
full_path = config["dataset"]["template"]
|
||||
template_dir = os.path.dirname(full_path)
|
||||
template_file = os.path.basename(full_path)
|
||||
env = Environment(loader=FileSystemLoader(template_dir))
|
||||
template = env.get_template(template_file)
|
||||
outcomes = []
|
||||
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
|
||||
for batch in tqdm(batches, desc="Generating responses"):
|
||||
new_batch = []
|
||||
for sample in batch:
|
||||
message={"role": "user", "content": sample}
|
||||
full_text = template.render(
|
||||
message=message,
|
||||
add_generation_prompt=True,
|
||||
add_output=False
|
||||
)
|
||||
new_batch.append(full_text)
|
||||
model_outputs = llm.generate(
|
||||
new_batch,
|
||||
SamplingParams(
|
||||
n=1,
|
||||
top_k=1,
|
||||
temperature=config["inference"]["temperature"],
|
||||
seed=config["inference"]["seed"],
|
||||
skip_special_tokens=False,
|
||||
ignore_eos=False,
|
||||
max_tokens=config["inference"]["max_new_tokens"]
|
||||
)
|
||||
)
|
||||
model_responses = [output.outputs[0].text for output in model_outputs]
|
||||
if is_teacher_model:
|
||||
gen_data = [{'prompt': batch[i], 'chosen': model_responses[i]} for i in range(len(batch))]
|
||||
else:
|
||||
gen_data = [{'prompt': batch[i], 'rejected': model_responses[i]} for i in range(len(batch))]
|
||||
outcomes = outcomes + gen_data
|
||||
return outcomes
|
||||
|
||||
|
||||
|
||||
def merge_outcomes(teacher_outcomes, student_outcomes, config):
|
||||
try:
|
||||
student_dict = {item['prompt']: item['rejected'] for item in student_outcomes}
|
||||
merged_outcomes = []
|
||||
for teacher_item in teacher_outcomes:
|
||||
prompt = teacher_item['prompt']
|
||||
if prompt in student_dict:
|
||||
merged_outcome = {
|
||||
'prompt': prompt,
|
||||
'chosen': teacher_item['chosen'],
|
||||
'rejected': student_dict[prompt]
|
||||
}
|
||||
merged_outcomes.append(merged_outcome)
|
||||
with open(config["dataset"]["labeled_path"], 'w') as file:
|
||||
json.dump(merged_outcomes, file, ensure_ascii=False, indent=4)
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
|
||||
|
||||
def infer_with_teacher_model(config):
|
||||
logging.info('Generating distillation data from the teacher model!')
|
||||
data_list = read_json_field(config["dataset"]["instruction_path"])
|
||||
try:
|
||||
job_type = config["job_type"]
|
||||
if job_type == "rank_dpo_api":
|
||||
generate_teacher_student_response_api(data_list, config)
|
||||
elif job_type == "rank_dpo_local":
|
||||
teacher_tokenizer, teacher_llm = load_tokenizer_and_vllm(config, is_teacher_model=True)
|
||||
teacher_outcomes = generate_model_response_batch(teacher_tokenizer, teacher_llm, data_list, config, is_teacher_model=True)
|
||||
del teacher_llm
|
||||
student_tokenizer, student_llm = load_tokenizer_and_vllm(config, is_teacher_model=False)
|
||||
student_outcomes = generate_model_response_batch(student_tokenizer, student_llm, data_list, config, is_teacher_model=False)
|
||||
del student_llm
|
||||
merge_outcomes(teacher_outcomes, student_outcomes, config)
|
||||
else:
|
||||
logging.error(f"Invalid job type: {job_type}")
|
||||
raise ValueError(f"Invalid job type: {job_type}")
|
||||
except ValueError as e:
|
||||
logging.error(f"Training job terminated: {e}")
|
||||
return
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
|
||||
args = parser.parse_args()
|
||||
config = json.load(open(args.config))
|
||||
infer_with_teacher_model(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
105
easydistill/rank/train.py
Normal file
105
easydistill/rank/train.py
Normal file
@@ -0,0 +1,105 @@
|
||||
|
||||
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import json
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from jinja2 import Environment, BaseLoader, FileSystemLoader
|
||||
from datasets import load_dataset, Dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import DPOTrainer, DPOConfig
|
||||
import copy
|
||||
|
||||
|
||||
def process_dataset(dataset_path, dataset_seed, env, template):
|
||||
examples = []
|
||||
with open(dataset_path, 'r') as file:
|
||||
examples = json.load(file)
|
||||
output_text = {
|
||||
"prompt": [],
|
||||
"chosen": [],
|
||||
"rejected": []
|
||||
}
|
||||
# use chat template
|
||||
for i in range(len(examples)):
|
||||
try:
|
||||
prompt_message = {"content": examples[i]["prompt"]}
|
||||
prompt = template.render(message=prompt_message, add_generation_prompt=False, add_output=False)
|
||||
|
||||
chosen_message = {"content": examples[i]["prompt"], "output": examples[i]["chosen"]}
|
||||
chosen = template.render(message=chosen_message, add_generation_prompt=False, add_output=True)
|
||||
chosen = chosen[len(prompt):]
|
||||
|
||||
rejected_message = {"content": examples[i]["prompt"], "output": examples[i]["rejected"]}
|
||||
rejected = template.render(message=rejected_message, add_generation_prompt=False, add_output=True)
|
||||
rejected = rejected[len(prompt):]
|
||||
|
||||
output_text["prompt"].append(prompt)
|
||||
output_text["chosen"].append(chosen)
|
||||
output_text["rejected"].append(rejected)
|
||||
except:
|
||||
logging.warning(f"Error processing sample.")
|
||||
|
||||
dataset = Dataset.from_dict(output_text)
|
||||
dataset = dataset.shuffle(seed=dataset_seed)
|
||||
return dataset
|
||||
|
||||
|
||||
def train(config):
|
||||
dataset_path = config["dataset"]["labeled_path"]
|
||||
dataset_seed = config["dataset"]["seed"]
|
||||
|
||||
full_path = config["dataset"]["template"]
|
||||
template_dir = os.path.dirname(full_path)
|
||||
template_file = os.path.basename(full_path)
|
||||
env = Environment(loader=FileSystemLoader(template_dir))
|
||||
template = env.get_template(template_file)
|
||||
dataset = process_dataset(dataset_path, dataset_seed, env, template)
|
||||
|
||||
student_tokenizer = AutoTokenizer.from_pretrained(
|
||||
config["models"]["student"],
|
||||
trust_remote_code=True
|
||||
)
|
||||
student_model = AutoModelForCausalLM.from_pretrained(
|
||||
config["models"]["student"],
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
training_arguments = DPOConfig(**config["training"])
|
||||
trainer = DPOTrainer(
|
||||
student_model,
|
||||
ref_model=copy.deepcopy(student_model),
|
||||
args=training_arguments,
|
||||
train_dataset=dataset,
|
||||
processing_class=student_tokenizer
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
trainer.save_model(config["training"]["output_dir"])
|
||||
student_tokenizer.save_pretrained(config["training"]["output_dir"])
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
|
||||
args = parser.parse_args()
|
||||
config = json.load(open(args.config))
|
||||
train(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
111
easydistill/rl/grpo_train.py
Normal file
111
easydistill/rl/grpo_train.py
Normal file
@@ -0,0 +1,111 @@
|
||||
|
||||
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import json
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from jinja2 import Environment, BaseLoader, FileSystemLoader
|
||||
from datasets import load_dataset, Dataset
|
||||
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
|
||||
from trl import GRPOConfig, GRPOTrainer
|
||||
|
||||
|
||||
def process_dataset(dataset_path, dataset_seed, env, template, train_ratio):
|
||||
examples = []
|
||||
try:
|
||||
with open(dataset_path, 'r') as file:
|
||||
examples = json.load(file)
|
||||
except FileNotFoundError:
|
||||
print(f"Error: The file '{dataset_path}' was not found.")
|
||||
except json.JSONDecodeError:
|
||||
print(f"Error: The file '{dataset_path}' is not a valid JSON file.")
|
||||
except Exception as e:
|
||||
print(f"An unexpected error occurred: {e}")
|
||||
|
||||
output_dataset = []
|
||||
# use chat template
|
||||
for i in range(len(examples)):
|
||||
try:
|
||||
message = {"content": examples[i]["prompt"]}
|
||||
rendered = template.render(message=message, add_generation_prompt=True, add_output=False)
|
||||
sample = {"prompt": rendered}
|
||||
output_dataset.append(sample)
|
||||
except:
|
||||
logging.warning(f"Error processing sample.")
|
||||
|
||||
random.shuffle(output_dataset)
|
||||
random.seed(dataset_seed)
|
||||
split_index = int(len(output_dataset) * train_ratio)
|
||||
train_list = output_dataset[:split_index]
|
||||
eval_list = output_dataset[split_index:]
|
||||
|
||||
return Dataset.from_list(train_list), Dataset.from_list(eval_list)
|
||||
|
||||
|
||||
def train(config):
|
||||
dataset_path = config["dataset"]["instruction_path"]
|
||||
dataset_seed = config["dataset"]["seed"]
|
||||
train_ratio = config["dataset"]["train_ratio"]
|
||||
|
||||
full_path = config["dataset"]["template"]
|
||||
template_dir = os.path.dirname(full_path)
|
||||
template_file = os.path.basename(full_path)
|
||||
env = Environment(loader=FileSystemLoader(template_dir))
|
||||
template = env.get_template(template_file)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
config["models"]["student"],
|
||||
trust_remote_code=True
|
||||
)
|
||||
train_dataset, eval_dataset = process_dataset(dataset_path, dataset_seed, env, template, train_ratio)
|
||||
print(train_dataset)
|
||||
print(eval_dataset)
|
||||
|
||||
reward_model_path = config["models"]["reward"]
|
||||
sft_model_path = config["models"]["student"]
|
||||
reward_model = AutoModelForSequenceClassification.from_pretrained(
|
||||
reward_model_path, trust_remote_code=True, num_labels=1
|
||||
)
|
||||
sft_model = AutoModelForCausalLM.from_pretrained(
|
||||
sft_model_path, trust_remote_code=True
|
||||
)
|
||||
|
||||
training_arguments = GRPOConfig(**config["training"])
|
||||
trainer = GRPOTrainer(
|
||||
args=training_arguments,
|
||||
processing_class=tokenizer,
|
||||
model=sft_model,
|
||||
reward_funcs=reward_model,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model(config["training"]["output_dir"])
|
||||
tokenizer.save_pretrained(config["training"]["output_dir"])
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
|
||||
args = parser.parse_args()
|
||||
config = json.load(open(args.config))
|
||||
train(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
122
easydistill/rl/ppo_train.py
Normal file
122
easydistill/rl/ppo_train.py
Normal file
@@ -0,0 +1,122 @@
|
||||
|
||||
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import json
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from jinja2 import Environment, BaseLoader, FileSystemLoader
|
||||
from datasets import load_dataset, Dataset
|
||||
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
|
||||
from trl import PPOConfig, PPOTrainer
|
||||
|
||||
|
||||
def process_dataset(dataset_path, dataset_seed, env, template, tokenizer, train_ratio):
|
||||
examples = []
|
||||
try:
|
||||
with open(dataset_path, 'r') as file:
|
||||
examples = json.load(file)
|
||||
except FileNotFoundError:
|
||||
print(f"Error: The file '{dataset_path}' was not found.")
|
||||
except json.JSONDecodeError:
|
||||
print(f"Error: The file '{dataset_path}' is not a valid JSON file.")
|
||||
except Exception as e:
|
||||
print(f"An unexpected error occurred: {e}")
|
||||
|
||||
output_dataset = []
|
||||
# use chat template
|
||||
for i in range(len(examples)):
|
||||
try:
|
||||
message = {"content": examples[i]["instruction"]}
|
||||
rendered = template.render(message=message, add_generation_prompt=True, add_output=False)
|
||||
tokens = tokenizer.encode(rendered)
|
||||
sample = {"input_ids": tokens}
|
||||
output_dataset.append(sample)
|
||||
except:
|
||||
logging.warning(f"Error processing sample.")
|
||||
|
||||
random.shuffle(output_dataset)
|
||||
random.seed(dataset_seed)
|
||||
split_index = int(len(output_dataset) * train_ratio)
|
||||
train_list = output_dataset[:split_index]
|
||||
eval_list = output_dataset[split_index:]
|
||||
|
||||
return Dataset.from_list(train_list), Dataset.from_list(eval_list)
|
||||
|
||||
|
||||
def train(config):
|
||||
dataset_path = config["dataset"]["instruction_path"]
|
||||
dataset_seed = config["dataset"]["seed"]
|
||||
train_ratio = config["dataset"]["train_ratio"]
|
||||
|
||||
full_path = config["dataset"]["template"]
|
||||
template_dir = os.path.dirname(full_path)
|
||||
template_file = os.path.basename(full_path)
|
||||
env = Environment(loader=FileSystemLoader(template_dir))
|
||||
template = env.get_template(template_file)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
config["models"]["student"],
|
||||
trust_remote_code=True
|
||||
)
|
||||
train_dataset, eval_dataset = process_dataset(dataset_path, dataset_seed, env, template, tokenizer, train_ratio)
|
||||
assert train_dataset[0]["input_ids"][-1] != tokenizer.eos_token_id, "The last token should not be an EOS token"
|
||||
|
||||
print(train_dataset)
|
||||
print(eval_dataset)
|
||||
|
||||
reward_model_path = config["models"]["reward"]
|
||||
sft_model_path = config["models"]["student"]
|
||||
value_model = AutoModelForSequenceClassification.from_pretrained(
|
||||
reward_model_path, trust_remote_code=True, num_labels=1
|
||||
)
|
||||
reward_model = AutoModelForSequenceClassification.from_pretrained(
|
||||
reward_model_path, trust_remote_code=True, num_labels=1
|
||||
)
|
||||
ref_policy = AutoModelForCausalLM.from_pretrained(
|
||||
sft_model_path, trust_remote_code=True
|
||||
)
|
||||
policy = AutoModelForCausalLM.from_pretrained(
|
||||
sft_model_path, trust_remote_code=True
|
||||
)
|
||||
|
||||
training_arguments = PPOConfig(**config["training"])
|
||||
trainer = PPOTrainer(
|
||||
config=training_arguments,
|
||||
processing_class=tokenizer,
|
||||
policy=policy,
|
||||
ref_policy=ref_policy,
|
||||
reward_model=reward_model,
|
||||
value_model=value_model,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model(config["training"]["output_dir"])
|
||||
tokenizer.save_pretrained(config["training"]["output_dir"])
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
|
||||
args = parser.parse_args()
|
||||
config = json.load(open(args.config))
|
||||
train(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
258
easydistill/rl/reward_infer.py
Normal file
258
easydistill/rl/reward_infer.py
Normal file
@@ -0,0 +1,258 @@
|
||||
|
||||
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import json
|
||||
import argparse
|
||||
import torch
|
||||
import logging
|
||||
import os
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from transformers import AutoTokenizer
|
||||
from vllm import LLM, SamplingParams
|
||||
from tqdm import tqdm
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
|
||||
def read_json_field(filename, field_name='prompt'):
|
||||
try:
|
||||
with open(filename, 'r') as file:
|
||||
data = json.load(file)
|
||||
output_fields = []
|
||||
for item in data:
|
||||
if field_name in item:
|
||||
output_fields.append(item[field_name])
|
||||
return output_fields
|
||||
except FileNotFoundError:
|
||||
logging.error("The file was not found.")
|
||||
except json.JSONDecodeError:
|
||||
logging.error("There was an error decoding the JSON file.")
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred: {e}")
|
||||
|
||||
|
||||
def write_data_to_json_file(data, file_path):
|
||||
try:
|
||||
with open(file_path, 'w') as file:
|
||||
json.dump(data, file, ensure_ascii=False, indent=4)
|
||||
logging.info(f"Data successfully written to {file_path}")
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred: {e}")
|
||||
|
||||
|
||||
def load_tokenizer_and_vllm(config, eos_token=None):
|
||||
teacher_model_path = config["models"]["teacher"]
|
||||
logging.info(f"Loading ckpt and tokenizer: {teacher_model_path}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(teacher_model_path, trust_remote_code=True)
|
||||
tokenizer.padding_side = "left"
|
||||
if eos_token:
|
||||
eos_token_id = tokenizer.convert_tokens_to_ids(eos_token)
|
||||
logging.info(f"eos_token {eos_token} from user input")
|
||||
elif hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id:
|
||||
logging.info(f"Initial eos_token_id {tokenizer.eos_token_id} from tokenizer")
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
eos_token = tokenizer.convert_ids_to_tokens(eos_token_id)
|
||||
else:
|
||||
raise ValueError("No available eos_token or eos_token_id.")
|
||||
try:
|
||||
tokenizer.eos_token = eos_token
|
||||
tokenizer.eos_token_id = eos_token_id
|
||||
tokenizer.pad_token = eos_token
|
||||
tokenizer.pad_token_id = eos_token_id
|
||||
except:
|
||||
logging.info(f"[WARNING] Cannot set tokenizer.eos_token")
|
||||
logging.info(f"tokenizer's eos_token: {tokenizer.eos_token}, pad_token: {tokenizer.pad_token}")
|
||||
logging.info(f"tokenizer's eos_token_id: {tokenizer.eos_token_id}, pad_token_id: {tokenizer.pad_token_id}")
|
||||
num_gpus = torch.cuda.device_count()
|
||||
llm = LLM(
|
||||
model=teacher_model_path,
|
||||
tensor_parallel_size=num_gpus,
|
||||
enable_chunked_prefill=config["inference"]["enable_chunked_prefill"],
|
||||
gpu_memory_utilization=config["inference"]["gpu_memory_utilization"],
|
||||
trust_remote_code=config["inference"]["trust_remote_code"],
|
||||
dtype=torch.bfloat16,
|
||||
enforce_eager=config["inference"]["enforce_eager"],
|
||||
max_model_len=config["inference"]["max_model_len"],
|
||||
)
|
||||
logging.info("vLLM model loaded successfully")
|
||||
return tokenizer, llm
|
||||
|
||||
|
||||
def generate_teacher_response_for_reward_model_local(tokenizer, llm, data_list, config, batch_size=32):
|
||||
full_path = config["dataset"]["template"]
|
||||
template_dir = os.path.dirname(full_path)
|
||||
template_file = os.path.basename(full_path)
|
||||
env = Environment(loader=FileSystemLoader(template_dir))
|
||||
template = env.get_template(template_file)
|
||||
positive_system_prompt = config["inference"]["positive_system_prompt"]
|
||||
negative_system_prompt = config["inference"]["negative_system_prompt"]
|
||||
outcomes = []
|
||||
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
|
||||
for batch in tqdm(batches, desc="Generating responses"):
|
||||
positive_new_batch = []
|
||||
negative_new_batch = []
|
||||
for sample in batch:
|
||||
positive_message = [
|
||||
{'role': 'system', 'content': positive_system_prompt},
|
||||
{'role': 'user', 'content': sample}
|
||||
]
|
||||
positive_full_text = template.render(
|
||||
message = positive_message,
|
||||
add_generation_prompt = True,
|
||||
add_output = False
|
||||
)
|
||||
positive_new_batch.append(positive_full_text)
|
||||
negative_message = [
|
||||
{'role': 'system', 'content': negative_system_prompt},
|
||||
{'role': 'user', 'content': sample}
|
||||
]
|
||||
negative_full_text = template.render(
|
||||
message = negative_message,
|
||||
add_generation_prompt = True,
|
||||
add_output = False
|
||||
)
|
||||
negative_new_batch.append(negative_full_text)
|
||||
|
||||
positive_outputs = llm.generate(
|
||||
positive_new_batch,
|
||||
SamplingParams(
|
||||
n = 1,
|
||||
top_k = 1,
|
||||
temperature = config["inference"]["temperature"],
|
||||
seed = config["inference"]["seed"],
|
||||
skip_special_tokens = False,
|
||||
ignore_eos = False,
|
||||
max_tokens = config["inference"]["max_new_tokens"]
|
||||
)
|
||||
)
|
||||
positve_responses = [output.outputs[0].text for output in positive_outputs]
|
||||
positive_gen_data = [{'prompt': batch[i], 'chosen': positve_responses[i]} for i in range(len(batch))]
|
||||
|
||||
negative_outputs = llm.generate(
|
||||
negative_new_batch,
|
||||
SamplingParams(
|
||||
n = 1,
|
||||
top_k = 1,
|
||||
temperature = config["inference"]["temperature"],
|
||||
seed = config["inference"]["seed"],
|
||||
skip_special_tokens = False,
|
||||
ignore_eos = False,
|
||||
max_tokens = config["inference"]["max_new_tokens"]
|
||||
)
|
||||
)
|
||||
negative_responses = [output.outputs[0].text for output in negative_outputs]
|
||||
negative_gen_data = [{'prompt': batch[i], 'rejected': negative_responses[i]} for i in range(len(batch))]
|
||||
|
||||
merged_data = merge_outcomes(positive_gen_data, negative_gen_data)
|
||||
outcomes = outcomes + merged_data
|
||||
write_data_to_json_file(outcomes, config["dataset"]["labeled_path"])
|
||||
|
||||
|
||||
def merge_outcomes(positive_gen_data, negative_gen_data):
|
||||
negative_dict = {item['prompt']: item['rejected'] for item in negative_gen_data}
|
||||
merged_outcomes = []
|
||||
for positive_item in positive_gen_data:
|
||||
prompt = positive_item['prompt']
|
||||
if prompt in negative_dict:
|
||||
merged_outcome = {
|
||||
'prompt': prompt,
|
||||
'chosen': positive_item['chosen'],
|
||||
'rejected': negative_dict[prompt]
|
||||
}
|
||||
merged_outcomes.append(merged_outcome)
|
||||
return merged_outcomes
|
||||
|
||||
|
||||
def generate_teacher_response_for_reward_model_api(data_list, config):
|
||||
client = OpenAI(
|
||||
api_key = config["inference"]["api_key"],
|
||||
base_url = config["inference"]["base_url"]
|
||||
)
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
logging.info(model)
|
||||
positive_system_prompt = config["inference"]["positive_system_prompt"]
|
||||
negative_system_prompt = config["inference"]["negative_system_prompt"]
|
||||
stream = config["inference"]["stream"]
|
||||
outcomes = []
|
||||
for sample in tqdm(data_list, desc="Call remote model and generating responses"):
|
||||
positive_message = [
|
||||
{'role': 'system', 'content': positive_system_prompt},
|
||||
{'role': 'user', 'content': sample}
|
||||
]
|
||||
positive_completion = client.chat.completions.create(
|
||||
messages = positive_message,
|
||||
model = model,
|
||||
max_completion_tokens = config["inference"]["max_new_tokens"],
|
||||
stream = stream
|
||||
)
|
||||
if stream:
|
||||
positive_result = ""
|
||||
for chunk in positive_completion:
|
||||
positive_result += chunk.choices[0].delta.content
|
||||
else:
|
||||
positive_result = positive_completion.choices[0].message.content
|
||||
|
||||
negative_message = [
|
||||
{'role': 'system', 'content': negative_system_prompt},
|
||||
{'role': 'user', 'content': sample}
|
||||
]
|
||||
negative_completion = client.chat.completions.create(
|
||||
messages = negative_message,
|
||||
model = model,
|
||||
max_completion_tokens = config["inference"]["max_new_tokens"],
|
||||
stream = stream
|
||||
)
|
||||
if stream:
|
||||
negative_result = ""
|
||||
for chunk in negative_completion:
|
||||
negative_result += chunk.choices[0].delta.content
|
||||
else:
|
||||
negative_result = negative_completion.choices[0].message.content
|
||||
outcomes.append({'prompt': sample, 'chosen': positive_result, 'rejected': negative_result})
|
||||
write_data_to_json_file(outcomes, config["dataset"]["labeled_path"])
|
||||
|
||||
|
||||
def infer_with_teacher_model(config):
|
||||
logging.info('Generating distillation data from the teacher model!')
|
||||
data_list = read_json_field(config["dataset"]["instruction_path"])
|
||||
try:
|
||||
job_type = config["job_type"]
|
||||
if job_type == "rl_reward_api":
|
||||
generate_teacher_response_for_reward_model_api(data_list, config)
|
||||
elif job_type == "rl_reward_local":
|
||||
tokenizer, llm = load_tokenizer_and_vllm(config)
|
||||
generate_teacher_response_for_reward_model_local(tokenizer, llm, data_list, config)
|
||||
else:
|
||||
logging.error(f"Invalid job type: {job_type}")
|
||||
raise ValueError(f"Invalid job type: {job_type}")
|
||||
except ValueError as e:
|
||||
logging.error(f"Training job terminated: {e}")
|
||||
return
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
|
||||
args = parser.parse_args()
|
||||
config = json.load(open(args.config))
|
||||
infer_with_teacher_model(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
107
easydistill/rl/reward_train.py
Normal file
107
easydistill/rl/reward_train.py
Normal file
@@ -0,0 +1,107 @@
|
||||
|
||||
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import json
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
from trl import RewardTrainer, RewardConfig
|
||||
from datasets import Dataset
|
||||
|
||||
|
||||
def process_dataset(dataset_path, tokenizer, config, template):
|
||||
kwargs = {"padding": "max_length", "truncation": True, "max_length": config["training"]["max_length"], "return_tensors": "pt"}
|
||||
examples = []
|
||||
try:
|
||||
with open(dataset_path, 'r') as file:
|
||||
examples = json.load(file)
|
||||
except FileNotFoundError:
|
||||
print(f"Error: The file '{dataset_path}' was not found.")
|
||||
except json.JSONDecodeError:
|
||||
print(f"Error: The file '{dataset_path}' is not a valid JSON file.")
|
||||
except Exception as e:
|
||||
print(f"An unexpected error occurred: {e}")
|
||||
|
||||
print(examples)
|
||||
output_dataset = []
|
||||
# use chat template
|
||||
for i in range(len(examples)):
|
||||
try:
|
||||
chosen_message = {"content": examples[i]["prompt"], "output": examples[i]["chosen"]}
|
||||
prompt_plus_chosen_response = template.render(message=chosen_message, add_generation_prompt=False, add_output=True)
|
||||
|
||||
rejected_message = {"content": examples[i]["prompt"], "output": examples[i]["rejected"]}
|
||||
prompt_plus_rejected_response = template.render(message=rejected_message, add_generation_prompt=False, add_output=True)
|
||||
|
||||
tokens_chosen = tokenizer.encode_plus(prompt_plus_chosen_response, **kwargs)
|
||||
tokens_rejected = tokenizer.encode_plus(prompt_plus_rejected_response, **kwargs)
|
||||
sample = {
|
||||
"input_ids_chosen": tokens_chosen["input_ids"][0], "attention_mask_chosen": tokens_chosen["attention_mask"][0],
|
||||
"input_ids_rejected": tokens_rejected["input_ids"][0], "attention_mask_rejected": tokens_rejected["attention_mask"][0]
|
||||
}
|
||||
output_dataset.append(sample)
|
||||
except:
|
||||
logging.warning(f"Error processing sample.")
|
||||
dataset = Dataset.from_list(output_dataset)
|
||||
return dataset
|
||||
|
||||
|
||||
def train(config):
|
||||
dataset_path = config["dataset"]["labeled_path"]
|
||||
student_tokenizer = AutoTokenizer.from_pretrained(
|
||||
config["models"]["student"],
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
full_path = config["dataset"]["template"]
|
||||
template_dir = os.path.dirname(full_path)
|
||||
template_file = os.path.basename(full_path)
|
||||
env = Environment(loader=FileSystemLoader(template_dir))
|
||||
template = env.get_template(template_file)
|
||||
dataset = process_dataset(dataset_path, student_tokenizer, config, template)
|
||||
|
||||
student_model = AutoModelForSequenceClassification.from_pretrained(
|
||||
config["models"]["student"],
|
||||
num_labels=1,
|
||||
trust_remote_code=True
|
||||
)
|
||||
student_model.config.pad_token_id = student_tokenizer.pad_token_id
|
||||
|
||||
training_arguments = RewardConfig(**config["training"])
|
||||
trainer = RewardTrainer(
|
||||
model=student_model,
|
||||
processing_class=student_tokenizer,
|
||||
args=training_arguments,
|
||||
train_dataset=dataset
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
trainer.save_model(config["training"]["output_dir"])
|
||||
student_tokenizer.save_pretrained(config["training"]["output_dir"])
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
|
||||
args = parser.parse_args()
|
||||
config = json.load(open(args.config))
|
||||
train(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
274
easydistill/synthesis/cot_synthesis.py
Normal file
274
easydistill/synthesis/cot_synthesis.py
Normal file
@@ -0,0 +1,274 @@
|
||||
|
||||
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import jsonlines
|
||||
import logging
|
||||
import os
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from vllm import LLM, SamplingParams
|
||||
from tqdm import tqdm
|
||||
from openai import OpenAI
|
||||
|
||||
from utils import write_data_to_json_file
|
||||
|
||||
|
||||
# I have checked this function.
|
||||
def cot_generate_api(data_list, config):
|
||||
client = OpenAI(
|
||||
api_key = config["inference"]["api_key"],
|
||||
base_url = config["inference"]["base_url"]
|
||||
)
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
prompt = config["inference"]["prompt"]
|
||||
stream = config["inference"]["stream"]
|
||||
logging.info(model)
|
||||
outcomes = []
|
||||
for sample in tqdm(data_list, desc="Calling remote model and generating responses"):
|
||||
sample = prompt + "\n" + sample
|
||||
message = [
|
||||
{"role": "user", "content": sample}
|
||||
]
|
||||
completion = client.chat.completions.create(
|
||||
messages = message,
|
||||
model = model,
|
||||
max_completion_tokens = config["inference"]["max_new_tokens"],
|
||||
stream = stream
|
||||
)
|
||||
if stream:
|
||||
result = ""
|
||||
for chunk in completion:
|
||||
result += chunk.choices[0].delta.content
|
||||
else:
|
||||
result = completion.choices[0].message.content
|
||||
if result is not None:
|
||||
outcomes.append({"instruction": sample, "output": result})
|
||||
write_data_to_json_file(outcomes, config["dataset"]["output_path"])
|
||||
|
||||
|
||||
def cot_generate_batch(tokenizer, llm, data_list, config, batch_size=32):
|
||||
full_path = config["dataset"]["template"]
|
||||
template_dir = os.path.dirname(full_path)
|
||||
template_file = os.path.basename(full_path)
|
||||
env = Environment(loader=FileSystemLoader(template_dir))
|
||||
template = env.get_template(template_file)
|
||||
prompt = config["inference"]["prompt"]
|
||||
|
||||
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
|
||||
for batch in tqdm(batches, desc="Generating responses"):
|
||||
new_batch = []
|
||||
for sample in batch:
|
||||
sample = prompt + "\n" + sample
|
||||
logging.info(sample)
|
||||
message={"role": "user", "content": sample}
|
||||
full_text = template.render(
|
||||
message=message,
|
||||
add_generation_prompt=True,
|
||||
add_output=False
|
||||
)
|
||||
new_batch.append(full_text)
|
||||
outputs = llm.generate(
|
||||
new_batch,
|
||||
SamplingParams(
|
||||
n=1,
|
||||
top_k=1,
|
||||
temperature=config["inference"]["temperature"],
|
||||
seed=config["inference"]["seed"],
|
||||
skip_special_tokens=False,
|
||||
ignore_eos=False,
|
||||
max_tokens=config["inference"]["max_new_tokens"],
|
||||
)
|
||||
)
|
||||
responses = [output.outputs[0].text for output in outputs]
|
||||
outcomes = []
|
||||
for i in range(len(batch)):
|
||||
if responses[i] is not None:
|
||||
outcomes.append((sample,responses[i]))
|
||||
|
||||
with jsonlines.open(config["dataset"]["output_path"], mode='a') as writer:
|
||||
for ins,result in outcomes:
|
||||
gen_data = {"instruction": ins, "output": result}
|
||||
writer.write(gen_data)
|
||||
|
||||
|
||||
def cot_long2short_api(data_list_ins, data_list_out, config):
|
||||
client = OpenAI(
|
||||
api_key = config["inference"]["api_key"],
|
||||
base_url = config["inference"]["base_url"],
|
||||
)
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
prompt = config["inference"]["prompt"]
|
||||
stream = config["inference"]["stream"]
|
||||
logging.info(model)
|
||||
outcomes = []
|
||||
data_list=[(ins,out) for ins,out in zip(data_list_ins,data_list_out)]
|
||||
for ins,out in tqdm(data_list, desc="Calling remote model and generating responses"):
|
||||
sample = f"{prompt} Simplify the reasoning process for the problem below.\n\nProblem:\n{ins}\n\nAnswer:\n{out}\n\nSimplified Reasoning Process:"
|
||||
logging.info(sample)
|
||||
message = [
|
||||
{"role": "user", "content": sample}
|
||||
]
|
||||
completion = client.chat.completions.create(
|
||||
messages = message,
|
||||
model = model,
|
||||
max_completion_tokens = config["inference"]["max_new_tokens"],
|
||||
stream = stream,
|
||||
)
|
||||
if stream:
|
||||
result = ""
|
||||
for chunk in completion:
|
||||
result += chunk.choices[0].delta.content
|
||||
else:
|
||||
result = completion.choices[0].message.content
|
||||
|
||||
if result is not None:
|
||||
outcomes.append((sample,result))
|
||||
|
||||
with jsonlines.open(config["dataset"]["output_path"], mode='a') as writer:
|
||||
for ins,result in outcomes:
|
||||
gen_data = {"instruction": ins, "output": result}
|
||||
writer.write(gen_data)
|
||||
|
||||
|
||||
def cot_long2short_batch(tokenizer, llm, data_list_ins, data_list_out, config, batch_size=32):
|
||||
full_path = config["dataset"]["template"]
|
||||
template_dir = os.path.dirname(full_path)
|
||||
template_file = os.path.basename(full_path)
|
||||
env = Environment(loader=FileSystemLoader(template_dir))
|
||||
template = env.get_template(template_file)
|
||||
prompt = config["inference"]["prompt"]
|
||||
data_list=[(ins,out) for ins,out in zip(data_list_ins,data_list_out)]
|
||||
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
|
||||
for batch in tqdm(batches, desc="Generating responses"):
|
||||
new_batch = []
|
||||
for ins,out in batch:
|
||||
sample = f"{prompt} Simplify the reasoning process for the problem below.\n\nProblem:\n{ins}\n\nAnswer:\n{out}\n\nSimplified Reasoning Process:"
|
||||
logging.info(sample)
|
||||
message={"role": "user", "content": sample}
|
||||
full_text = template.render(
|
||||
message=message,
|
||||
add_generation_prompt=True,
|
||||
add_output=False
|
||||
)
|
||||
new_batch.append(full_text)
|
||||
outputs = llm.generate(
|
||||
new_batch,
|
||||
SamplingParams(
|
||||
n=1,
|
||||
top_k=1,
|
||||
temperature=config["inference"]["temperature"],
|
||||
seed=config["inference"]["seed"],
|
||||
skip_special_tokens=False,
|
||||
ignore_eos=False,
|
||||
max_tokens=config["inference"]["max_new_tokens"],
|
||||
)
|
||||
)
|
||||
responses = [output.outputs[0].text for output in outputs]
|
||||
outcomes = []
|
||||
for i in range(len(batch)):
|
||||
if responses[i] is not None:
|
||||
outcomes.append((sample,responses[i]))
|
||||
|
||||
with jsonlines.open(config["dataset"]["output_path"], mode='a') as writer:
|
||||
for ins,result in outcomes:
|
||||
gen_data = {"instruction": ins, "output": result}
|
||||
writer.write(gen_data)
|
||||
|
||||
|
||||
def cot_short2long_api(data_list_ins, data_list_out, config):
|
||||
client = OpenAI(
|
||||
api_key = config["inference"]["api_key"],
|
||||
base_url = config["inference"]["base_url"],
|
||||
)
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
prompt = config["inference"]["prompt"]
|
||||
stream = config["inference"]["stream"]
|
||||
logging.info(model)
|
||||
outcomes = []
|
||||
data_list=[(ins,out) for ins,out in zip(data_list_ins,data_list_out)]
|
||||
for ins,out in tqdm(data_list, desc="Calling remote model and generating responses"):
|
||||
sample = f"{prompt} Extend the reasoning process for the problem below.\n\nProblem:\n{ins}\n\nAnswer:\n{out}\n\nExtended Reasoning Process:"
|
||||
logging.info(sample)
|
||||
message = [
|
||||
{"role": "user", "content": sample}
|
||||
]
|
||||
completion = client.chat.completions.create(
|
||||
messages = message,
|
||||
model = model,
|
||||
max_completion_tokens = config["inference"]["max_new_tokens"],
|
||||
stream = stream,
|
||||
)
|
||||
if stream:
|
||||
result = ""
|
||||
for chunk in completion:
|
||||
result += chunk.choices[0].delta.content
|
||||
else:
|
||||
result = completion.choices[0].message.content
|
||||
|
||||
if result is not None:
|
||||
outcomes.append((sample,result))
|
||||
|
||||
with jsonlines.open(config["dataset"]["output_path"], mode='a') as writer:
|
||||
for ins,result in outcomes:
|
||||
gen_data = {"instruction": ins, "output": result}
|
||||
writer.write(gen_data)
|
||||
|
||||
|
||||
def cot_short2long_batch(tokenizer, llm, data_list_ins, data_list_out, config, batch_size=32):
|
||||
full_path = config["dataset"]["template"]
|
||||
template_dir = os.path.dirname(full_path)
|
||||
template_file = os.path.basename(full_path)
|
||||
env = Environment(loader=FileSystemLoader(template_dir))
|
||||
template = env.get_template(template_file)
|
||||
prompt = config["inference"]["prompt"]
|
||||
data_list=[(ins,out) for ins,out in zip(data_list_ins,data_list_out)]
|
||||
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
|
||||
for batch in tqdm(batches, desc="Generating responses"):
|
||||
new_batch = []
|
||||
for ins,out in batch:
|
||||
sample = f"{prompt} Extend the reasoning process for the problem below.\n\nProblem:\n{ins}\n\nAnswer:\n{out}\n\nExtended Reasoning Process:"
|
||||
logging.info(sample)
|
||||
message={"role": "user", "content": sample}
|
||||
full_text = template.render(
|
||||
message=message,
|
||||
add_generation_prompt=True,
|
||||
add_output=False
|
||||
)
|
||||
new_batch.append(full_text)
|
||||
outputs = llm.generate(
|
||||
new_batch,
|
||||
SamplingParams(
|
||||
n=1,
|
||||
top_k=1,
|
||||
temperature=config["inference"]["temperature"],
|
||||
seed=config["inference"]["seed"],
|
||||
skip_special_tokens=False,
|
||||
ignore_eos=False,
|
||||
max_tokens=config["inference"]["max_new_tokens"],
|
||||
)
|
||||
)
|
||||
responses = [output.outputs[0].text for output in outputs]
|
||||
outcomes = []
|
||||
for i in range(len(batch)):
|
||||
if responses[i] is not None:
|
||||
outcomes.append((sample,responses[i]))
|
||||
|
||||
with jsonlines.open(config["dataset"]["output_path"], mode='a') as writer:
|
||||
for ins,result in outcomes:
|
||||
gen_data = {"instruction": ins, "output": result}
|
||||
writer.write(gen_data)
|
293
easydistill/synthesis/instruct_synthesis.py
Normal file
293
easydistill/synthesis/instruct_synthesis.py
Normal file
@@ -0,0 +1,293 @@
|
||||
|
||||
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import logging
|
||||
import os
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from vllm import LLM, SamplingParams
|
||||
from tqdm import tqdm
|
||||
from openai import OpenAI
|
||||
import random
|
||||
import re
|
||||
|
||||
from utils import read_json_field, write_data_to_json_file, load_tokenizer_and_vllm
|
||||
|
||||
|
||||
def extract_answer(content):
|
||||
pattern = r'<answer>(.*?)</answer>'
|
||||
match = re.search(pattern, content, re.DOTALL)
|
||||
if match:
|
||||
return match.group(1)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def extract_instruction_response(content):
|
||||
instruction_pattern = r'<instruction>(.*?)</instruction>'
|
||||
instruction_match = re.search(instruction_pattern, content, re.DOTALL)
|
||||
response_pattern = r'<response>(.*?)</response>'
|
||||
response_match = re.search(response_pattern, content, re.DOTALL)
|
||||
if instruction_match and response_match:
|
||||
return instruction_match.group(1), response_match.group(1)
|
||||
else:
|
||||
return None, None
|
||||
|
||||
|
||||
def generate_prompt_list(data_list, prompt, num_in_context_samples, num_output_samples):
|
||||
if num_in_context_samples > len(data_list):
|
||||
raise ValueError("num_in_context_samples cannot be larger than the length of data_list")
|
||||
output_list = []
|
||||
for _ in range(num_output_samples):
|
||||
selected_samples = random.sample(data_list, num_in_context_samples)
|
||||
combined_prompts = prompt + "\n" + "".join([sample + "\n" for sample in selected_samples])
|
||||
output_list.append(combined_prompts)
|
||||
return output_list
|
||||
|
||||
|
||||
def expand_instruction_api(data_list, config):
|
||||
client = OpenAI(
|
||||
api_key = config["inference"]["api_key"],
|
||||
base_url = config["inference"]["base_url"],
|
||||
)
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
num_output_samples = config["dataset"]["num_output_samples"]
|
||||
num_in_context_samples = config["dataset"]["num_in_context_samples"]
|
||||
prompt = config["inference"]["prompt"]
|
||||
stream = config["inference"]["stream"]
|
||||
logging.info(model)
|
||||
prompt_list = generate_prompt_list(data_list, prompt, num_in_context_samples, num_output_samples)
|
||||
outcomes = []
|
||||
for sample in tqdm(prompt_list, desc="Calling remote model and generating responses"):
|
||||
logging.info(sample)
|
||||
message = [
|
||||
{"role": "user", "content": sample}
|
||||
]
|
||||
completion = client.chat.completions.create(
|
||||
messages = message,
|
||||
model = model,
|
||||
max_completion_tokens = config["inference"]["max_new_tokens"],
|
||||
stream = stream,
|
||||
)
|
||||
if stream:
|
||||
result = ""
|
||||
for chunk in completion:
|
||||
result += chunk.choices[0].delta.content
|
||||
else:
|
||||
result = completion.choices[0].message.content
|
||||
result = extract_answer(result)
|
||||
if result is not None:
|
||||
outcomes.append({"instruction": result})
|
||||
write_data_to_json_file(outcomes, config["dataset"]["output_path"])
|
||||
|
||||
|
||||
def expand_instruction_batch(tokenizer, llm, data_list, config, batch_size=32):
|
||||
full_path = config["dataset"]["template"]
|
||||
template_dir = os.path.dirname(full_path)
|
||||
template_file = os.path.basename(full_path)
|
||||
env = Environment(loader=FileSystemLoader(template_dir))
|
||||
template = env.get_template(template_file)
|
||||
|
||||
num_output_samples = config["dataset"]["num_output_samples"]
|
||||
num_in_context_samples = config["dataset"]["num_in_context_samples"]
|
||||
prompt = config["inference"]["prompt"]
|
||||
prompt_list = generate_prompt_list(data_list, prompt, num_in_context_samples, num_output_samples)
|
||||
|
||||
outcomes = []
|
||||
batches = [prompt_list[i:i + batch_size] for i in range(0, len(prompt_list), batch_size)]
|
||||
for batch in tqdm(batches, desc="Generating responses"):
|
||||
new_batch = []
|
||||
for sample in batch:
|
||||
logging.info(sample)
|
||||
message={"role": "user", "content": sample}
|
||||
full_text = template.render(
|
||||
message=message,
|
||||
add_generation_prompt=True,
|
||||
add_output=False
|
||||
)
|
||||
new_batch.append(full_text)
|
||||
outputs = llm.generate(
|
||||
new_batch,
|
||||
SamplingParams(
|
||||
n=1,
|
||||
top_k=1,
|
||||
temperature=config["inference"]["temperature"],
|
||||
seed=config["inference"]["seed"],
|
||||
skip_special_tokens=False,
|
||||
ignore_eos=False,
|
||||
max_tokens=config["inference"]["max_new_tokens"]
|
||||
)
|
||||
)
|
||||
responses = [output.outputs[0].text for output in outputs]
|
||||
for i in range(len(batch)):
|
||||
result = extract_answer(responses[i])
|
||||
if result is not None:
|
||||
outcomes.append({"instruction": result})
|
||||
write_data_to_json_file(outcomes, config["dataset"]["output_path"])
|
||||
|
||||
|
||||
def refine_instruction_api(data_list, config):
|
||||
client = OpenAI(
|
||||
api_key = config["inference"]["api_key"],
|
||||
base_url = config["inference"]["base_url"],
|
||||
)
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
prompt = config["inference"]["prompt"]
|
||||
stream = config["inference"]["stream"]
|
||||
logging.info(model)
|
||||
outcomes = []
|
||||
for sample in tqdm(data_list, desc="Calling remote model and generating responses"):
|
||||
sample = prompt + "\n" + sample
|
||||
logging.info(sample)
|
||||
message = [
|
||||
{"role": "user", "content": sample}
|
||||
]
|
||||
completion = client.chat.completions.create(
|
||||
messages = message,
|
||||
model = model,
|
||||
max_completion_tokens = config["inference"]["max_new_tokens"],
|
||||
stream = stream
|
||||
)
|
||||
if stream:
|
||||
result = ""
|
||||
for chunk in completion:
|
||||
result += chunk.choices[0].delta.content
|
||||
else:
|
||||
result = completion.choices[0].message.content
|
||||
result = extract_answer(result)
|
||||
if result is not None:
|
||||
outcomes.append({"instruction": result})
|
||||
write_data_to_json_file(outcomes, config["dataset"]["output_path"])
|
||||
|
||||
|
||||
def refine_instruction_batch(tokenizer, llm, data_list, config, batch_size=32):
|
||||
full_path = config["dataset"]["template"]
|
||||
template_dir = os.path.dirname(full_path)
|
||||
template_file = os.path.basename(full_path)
|
||||
env = Environment(loader=FileSystemLoader(template_dir))
|
||||
template = env.get_template(template_file)
|
||||
prompt = config["inference"]["prompt"]
|
||||
|
||||
outcomes = []
|
||||
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
|
||||
for batch in tqdm(batches, desc="Generating responses"):
|
||||
new_batch = []
|
||||
for sample in batch:
|
||||
sample = prompt + "\n" + sample
|
||||
logging.info(sample)
|
||||
message={"role": "user", "content": sample}
|
||||
full_text = template.render(
|
||||
message=message,
|
||||
add_generation_prompt=True,
|
||||
add_output=False
|
||||
)
|
||||
new_batch.append(full_text)
|
||||
outputs = llm.generate(
|
||||
new_batch,
|
||||
SamplingParams(
|
||||
n=1,
|
||||
top_k=1,
|
||||
temperature=config["inference"]["temperature"],
|
||||
seed=config["inference"]["seed"],
|
||||
skip_special_tokens=False,
|
||||
ignore_eos=False,
|
||||
max_tokens=config["inference"]["max_new_tokens"],
|
||||
)
|
||||
)
|
||||
responses = [output.outputs[0].text for output in outputs]
|
||||
for i in range(len(batch)):
|
||||
result = extract_answer(responses[i])
|
||||
if result is not None:
|
||||
outcomes.append({"instruction": result})
|
||||
write_data_to_json_file(outcomes, config["dataset"]["output_path"])
|
||||
|
||||
|
||||
def instruction_response_extraction_api(data_list, config):
|
||||
client = OpenAI(
|
||||
api_key = config["inference"]["api_key"],
|
||||
base_url = config["inference"]["base_url"],
|
||||
)
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
prompt = config["inference"]["prompt"]
|
||||
stream = config["inference"]["stream"]
|
||||
logging.info(model)
|
||||
outcomes = []
|
||||
for sample in tqdm(data_list, desc="Calling remote model and generating responses"):
|
||||
sample = prompt + "\n" + sample
|
||||
logging.info(sample)
|
||||
message = [
|
||||
{"role": "user", "content": sample}
|
||||
]
|
||||
completion = client.chat.completions.create(
|
||||
messages = message,
|
||||
model = model,
|
||||
max_completion_tokens = config["inference"]["max_new_tokens"],
|
||||
stream= stream,
|
||||
)
|
||||
if stream:
|
||||
result = ""
|
||||
for chunk in completion:
|
||||
result += chunk.choices[0].delta.content
|
||||
else:
|
||||
result = completion.choices[0].message.content
|
||||
new_instruction, new_response = extract_instruction_response(result)
|
||||
if new_instruction is not None and new_response is not None:
|
||||
outcomes.append({"instruction": new_instruction, "output": new_response})
|
||||
write_data_to_json_file(outcomes, config["dataset"]["output_path"])
|
||||
|
||||
|
||||
def instruction_response_extraction_batch(tokenizer, llm, data_list, config, batch_size=32):
|
||||
full_path = config["dataset"]["template"]
|
||||
template_dir = os.path.dirname(full_path)
|
||||
template_file = os.path.basename(full_path)
|
||||
env = Environment(loader=FileSystemLoader(template_dir))
|
||||
template = env.get_template(template_file)
|
||||
prompt = config["inference"]["prompt"]
|
||||
|
||||
outcomes = []
|
||||
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
|
||||
for batch in tqdm(batches, desc="Generating responses"):
|
||||
new_batch = []
|
||||
for sample in batch:
|
||||
logging.info(sample)
|
||||
sample = prompt + "\n" + sample
|
||||
message={"role": "user", "content": sample}
|
||||
full_text = template.render(
|
||||
message=message,
|
||||
add_generation_prompt=True,
|
||||
add_output=False
|
||||
)
|
||||
new_batch.append(full_text)
|
||||
outputs = llm.generate(
|
||||
new_batch,
|
||||
SamplingParams(
|
||||
n=1,
|
||||
top_k=1,
|
||||
temperature=config["inference"]["temperature"],
|
||||
seed=config["inference"]["seed"],
|
||||
skip_special_tokens=False,
|
||||
ignore_eos=False,
|
||||
max_tokens=config["inference"]["max_new_tokens"],
|
||||
)
|
||||
)
|
||||
responses = [output.outputs[0].text for output in outputs]
|
||||
for i in range(len(batch)):
|
||||
new_instruction, new_response = extract_instruction_response(responses[i])
|
||||
if new_instruction is not None and new_response is not None:
|
||||
outcomes.append({"instruction": new_instruction, "output": new_response})
|
||||
write_data_to_json_file(outcomes, config["dataset"]["output_path"])
|
107
easydistill/synthesis/synthesis_main.py
Normal file
107
easydistill/synthesis/synthesis_main.py
Normal file
@@ -0,0 +1,107 @@
|
||||
|
||||
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import json
|
||||
|
||||
from instruct_synthesis import (
|
||||
expand_instruction_api,
|
||||
expand_instruction_batch,
|
||||
refine_instruction_api,
|
||||
refine_instruction_batch,
|
||||
instruction_response_extraction_api,
|
||||
instruction_response_extraction_batch
|
||||
)
|
||||
from cot_synthesis import (
|
||||
cot_generate_api,
|
||||
cot_generate_batch,
|
||||
cot_long2short_api,
|
||||
cot_long2short_batch,
|
||||
cot_short2long_api,
|
||||
cot_short2long_batch
|
||||
)
|
||||
from utils import read_json_field, load_tokenizer_and_vllm
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
|
||||
def data_synthesis_with_teacher_model(config):
|
||||
logging.info('Generating distillation data from the teacher model!')
|
||||
job_type = config["job_type"]
|
||||
if job_type == "instruction_response_extraction_api":
|
||||
data_list = read_json_field(config["dataset"]["input_path"], field_name="data")
|
||||
elif job_type in ["cot_long2short_api","cot_long2short_batch","cot_short2long_api","cot_short2long_batch"]:
|
||||
data_list_ins = read_json_field(config["dataset"]["input_path"])
|
||||
data_list_out = read_json_field(config["dataset"]["input_path"], field_name="output")
|
||||
else:
|
||||
data_list = read_json_field(config["dataset"]["input_path"])
|
||||
|
||||
try:
|
||||
if job_type == "instruction_expansion_api":
|
||||
expand_instruction_api(data_list, config)
|
||||
elif job_type == "instruction_expansion_batch":
|
||||
tokenizer, llm = load_tokenizer_and_vllm(config)
|
||||
expand_instruction_batch(tokenizer, llm, data_list, config)
|
||||
|
||||
elif job_type == "instruction_refinement_api":
|
||||
refine_instruction_api(data_list, config)
|
||||
elif job_type == "instruction_refinement_batch":
|
||||
tokenizer, llm = load_tokenizer_and_vllm(config)
|
||||
refine_instruction_batch(tokenizer, llm, data_list, config)
|
||||
|
||||
elif job_type == "instruction_response_extraction_api":
|
||||
instruction_response_extraction_api(data_list, config)
|
||||
elif job_type == "instruction_response_extraction_batch":
|
||||
tokenizer, llm = load_tokenizer_and_vllm(config)
|
||||
instruction_response_extraction_batch(tokenizer, llm, data_list, config)
|
||||
|
||||
elif job_type == "cot_generation_api":
|
||||
cot_generate_api(data_list, config)
|
||||
elif job_type == "cot_generation_batch":
|
||||
tokenizer, llm = load_tokenizer_and_vllm(config)
|
||||
cot_generate_batch(tokenizer, llm, data_list, config)
|
||||
|
||||
elif job_type == "cot_long2short_api":
|
||||
cot_long2short_api(data_list_ins, data_list_out, config)
|
||||
elif job_type == "cot_long2short_batch":
|
||||
tokenizer, llm = load_tokenizer_and_vllm(config)
|
||||
cot_long2short_batch(tokenizer, llm, data_list_ins, data_list_out, config)
|
||||
|
||||
elif job_type == "cot_short2long_api":
|
||||
cot_short2long_api(data_list_ins, data_list_out, config)
|
||||
elif job_type == "cot_short2long_batch":
|
||||
tokenizer, llm = load_tokenizer_and_vllm(config)
|
||||
cot_short2long_batch(tokenizer, llm, data_list_ins, data_list_out, config)
|
||||
else:
|
||||
logging.error(f"Invalid job type: {job_type}")
|
||||
raise ValueError(f"Invalid job type: {job_type}")
|
||||
except ValueError as e:
|
||||
logging.error(f"Training job terminated: {e}")
|
||||
return
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
|
||||
args = parser.parse_args()
|
||||
config = json.load(open(args.config))
|
||||
data_synthesis_with_teacher_model(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
85
easydistill/synthesis/utils.py
Normal file
85
easydistill/synthesis/utils.py
Normal file
@@ -0,0 +1,85 @@
|
||||
|
||||
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import json
|
||||
import torch
|
||||
import logging
|
||||
from vllm import LLM
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
def read_json_field(filename, field_name='instruction'):
|
||||
try:
|
||||
with open(filename, 'r') as file:
|
||||
data = json.load(file)
|
||||
output_fields = []
|
||||
for item in data:
|
||||
if field_name in item:
|
||||
output_fields.append(item[field_name])
|
||||
return output_fields
|
||||
except FileNotFoundError:
|
||||
logging.error("The file was not found.")
|
||||
except json.JSONDecodeError:
|
||||
logging.error("There was an error decoding the JSON file.")
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred: {e}")
|
||||
|
||||
|
||||
def write_data_to_json_file(data, file_path):
|
||||
try:
|
||||
with open(file_path, 'w') as file:
|
||||
json.dump(data, file, ensure_ascii=False, indent=4)
|
||||
logging.info(f"Data successfully written to {file_path}")
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred: {e}")
|
||||
|
||||
|
||||
def load_tokenizer_and_vllm(config, eos_token=None):
|
||||
teacher_model_path = config["models"]["teacher"]
|
||||
logging.info(f"Loading ckpt and tokenizer: {teacher_model_path}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(teacher_model_path, trust_remote_code=True)
|
||||
tokenizer.padding_side = "left"
|
||||
if eos_token:
|
||||
eos_token_id = tokenizer.convert_tokens_to_ids(eos_token)
|
||||
logging.info(f"eos_token {eos_token} from user input")
|
||||
elif hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id:
|
||||
logging.info(f"Initial eos_token_id {tokenizer.eos_token_id} from tokenizer")
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
eos_token = tokenizer.convert_ids_to_tokens(eos_token_id)
|
||||
else:
|
||||
raise ValueError("No available eos_token or eos_token_id.")
|
||||
try:
|
||||
tokenizer.eos_token = eos_token
|
||||
tokenizer.eos_token_id = eos_token_id
|
||||
tokenizer.pad_token = eos_token
|
||||
tokenizer.pad_token_id = eos_token_id
|
||||
except:
|
||||
logging.info(f"[WARNING] Cannot set tokenizer.eos_token")
|
||||
logging.info(f"tokenizer's eos_token: {tokenizer.eos_token}, pad_token: {tokenizer.pad_token}")
|
||||
logging.info(f"tokenizer's eos_token_id: {tokenizer.eos_token_id}, pad_token_id: {tokenizer.pad_token_id}")
|
||||
num_gpus = torch.cuda.device_count()
|
||||
llm = LLM(
|
||||
model=teacher_model_path,
|
||||
tensor_parallel_size=num_gpus,
|
||||
enable_chunked_prefill=config["inference"]["enable_chunked_prefill"],
|
||||
gpu_memory_utilization=config["inference"]["gpu_memory_utilization"],
|
||||
trust_remote_code=config["inference"]["trust_remote_code"],
|
||||
dtype=torch.bfloat16,
|
||||
enforce_eager=config["inference"]["enforce_eager"],
|
||||
max_model_len=config["inference"]["max_model_len"],
|
||||
)
|
||||
logging.info("vLLM model loaded successfully")
|
||||
return tokenizer, llm
|
76
recipes/distilqwen_series/distillqwen2.5-0324/README.md
Normal file
76
recipes/distilqwen_series/distillqwen2.5-0324/README.md
Normal file
@@ -0,0 +1,76 @@
|
||||
# DistilQwen2.5-0324: training fast-thinking models
|
||||
|
||||
## Brief Introduction
|
||||
|
||||
In the rapid advancement of large language models, effectively balancing the trade-off between efficient inference and model thinking capabilities has been a key focus in both academia and industry. DeepSeekV3-0324, by default, does not employ deep thinking mode, which accelerates model inference while maintaining a balance between swift reasoning and handling complex tasks. The DistilQwen2.5-0324 series not only inherits the essence of the original model's chain-of-thought distillation but also introduces fast-thinking strategies, significantly boosting inference speed. This enables these models to efficiently execute complex tasks on resource-constrained devices and in edge computing scenarios.
|
||||
|
||||
## Detailed Steps
|
||||
|
||||
### Processing of Instructional Dataset
|
||||
|
||||
DistilQwen2.5-0324 was trained using data distilled from Deepseek-V3-0324 as well as data rewritten with long2short after distillation from Deepseek-R1. For Deepseek-V3-0324, the official recommendation is not to use a system prompt; for the long2short scenario, the following prompt was used. You can employ this method to reduce the output of Deepseek-R1 and distill your own model.
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "You are a helpful assistant who is highly skilled at simplifying reasoning processes. Given a problem, its answer and its reasoning process, your task is to simplify the reasoning process so that a small language model (e.g., a 7B model) can reliably follow the steps to solve the problem. If the original reasoning process is divided into multiple steps separated by two newline characters (\n\n), your output must preserve this formatting. You must output ONLY the simplified reasoning process with no additional explanation or commentary."
|
||||
}
|
||||
```
|
||||
|
||||
```bash
|
||||
python easydistill/kd/infer.py --config=distilqwen2.5-0324_stage1.json
|
||||
```
|
||||
|
||||
The training dataset is in JSON format, exemplified by entries such as:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"instruction": "The ratio of the number of molar teeth in the human upper jaw at the age of 6 is 2:1 compared to number of incisors teeth. There are total 8 incisors in the human mouth...",
|
||||
"output": "Step 1: Determine the total number of incisors in the upper jaw...The final answer is: \\boxed{8}"
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
### Black-Box KD
|
||||
|
||||
The black-box KD process follows a supervised learning paradigm, utilizing enhanced instruction-response pairs as training samples. Through this approach, the student model can effectively absorb and understand the knowledge imparted by the larger model, even with a limited number of parameters. This method not only boosts the student model's ability to tackle tasks but also enables it to perform better in multi-task scenarios. Because we have already obtained the teacher's responses in the dataset, we can run the training job:
|
||||
|
||||
```bash
|
||||
python easydistill/kd/train.py --config=distilqwen2.5-0324_stage2.json
|
||||
```
|
||||
|
||||
Plese refer to the config file `distilqwen2.5-0324_stage2.json` in the current folder. If you need to run the job in a distributed mode, use `accelerate` to run the job.
|
||||
|
||||
## Model Download
|
||||
|
||||
We have open-sourced our distilled models on both HuggingFace and ModelScope. The available models are named `alibaba-pai/DistilQwen2.5-DS3-0324-7B`, `alibaba-pai/DistilQwen2.5-DS3-0324-14B`, and `alibaba-pai/DistilQwen2.5-DS3-0324-32B`.
|
||||
|
||||
For example, users can download these models from HuggingFace using the following code:
|
||||
|
||||
|
||||
```python
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# Download the 1.5B model
|
||||
model_name = "alibaba-pai/DistilQwen2.5-DS3-0324-7B"
|
||||
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2.5-DS3-0324-7B/")
|
||||
|
||||
# Download the 3B model
|
||||
model_name = "alibaba-pai/DistilQwen2.5-DS3-0324-14B"
|
||||
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2.5-DS3-0324-14B/")
|
||||
|
||||
# Download the 7B model
|
||||
model_name = "alibaba-pai/DistilQwen2.5-DS3-0324-32B"
|
||||
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2.5-DS3-0324-32B/")
|
||||
```
|
||||
|
||||
|
||||
## Performance
|
||||
|
||||
- **32B Model** approaches the performance of closed-source models with 10x the parameters on the GPQA Diamond benchmark
|
||||
- **Significant Improvement in Reasoning Efficiency** (see comparison table below)
|
||||
|
||||
| Model | MMLU_PRO Tokens | AIME2024 Tokens | Speed Gain |
|
||||
|--------------------------------|-----------------|-----------------|------------|
|
||||
| DistilQwen2.5-R1-32B (Slow-Thinking) | 4198 | 12178 | 1x |
|
||||
| DistilQwen2.5-DS3-0324-32B | 690 | 4177 | 5-8x |
|
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"job_type": "cot_long2short_api",
|
||||
"dataset": {
|
||||
"input_path": "./raw.json",
|
||||
"output_path": "./raw_simplified.json"
|
||||
},
|
||||
"inference":{
|
||||
"base_url": "ENDPOINT",
|
||||
"api_key": "TOKEN",
|
||||
"stream": true,
|
||||
"prompt" : "You are a helpful assistant who is highly skilled at simplifying reasoning processes. Given a problem, its answer and its reasoning process, your task is to simplify the reasoning process so that a small language model (e.g., a 7B model) can reliably follow the steps to solve the problem. If the original reasoning process is divided into multiple steps separated by two newline characters (\n\n), your output must preserve this formatting. You must output ONLY the simplified reasoning process with no additional explanation or commentary.",
|
||||
"max_new_tokens": 1024
|
||||
}
|
||||
}
|
@@ -0,0 +1,23 @@
|
||||
{
|
||||
"job_type": "kd_black_box_api",
|
||||
"dataset": {
|
||||
"labeled_path": "distil_qwen_0324.json",
|
||||
"template" : "chat_template_kd.jinja",
|
||||
"seed": 42
|
||||
},
|
||||
"models": {
|
||||
"student": "model/Qwen/Qwen2.5-1.5B-Instruct/"
|
||||
},
|
||||
"training": {
|
||||
"output_dir": "result_stage2/",
|
||||
"num_train_epochs": 3,
|
||||
"per_device_train_batch_size": 1,
|
||||
"gradient_accumulation_steps": 8,
|
||||
"save_steps": 1000,
|
||||
"logging_steps": 1,
|
||||
"learning_rate": 2e-5,
|
||||
"weight_decay": 0.05,
|
||||
"warmup_ratio": 0.1,
|
||||
"lr_scheduler_type": "cosine"
|
||||
}
|
||||
}
|
142
recipes/distilqwen_series/distillqwen2.5-r1/README.md
Normal file
142
recipes/distilqwen_series/distillqwen2.5-r1/README.md
Normal file
@@ -0,0 +1,142 @@
|
||||
# DistilQwen2.5-R1: training distilled reasonin models based on CoTs generated by Deepseek-R1
|
||||
|
||||
## Brief Introduction
|
||||
|
||||
As large language models (LLMs) evolve toward deep reasoning capabilities, deploying them in resource-constrained environments (e.g., mobile devices, edge computing) remains challenging. The DistilQwen2.5-R1 series addresses this by transferring reasoning capabilities from ultra-large models (e.g., DeepSeek-R1) to compact models through innovative distillation techniques, achieving high performance while reducing computational costs.
|
||||
|
||||
## Data Generation Detailed Steps
|
||||
|
||||
### 1. Generate Thinking Dataset
|
||||
|
||||
Distillqwen-r1 is trained using chain-of-thought data distilled from deepseek-r1. We provide the system prompts used for distilling the R1 data and the system prompts used for training qwen2.5. You can use the current system prompts to call Deepseek-R1 to generate your own data and train the model.
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "Your role as an assistant involves thoroughly exploring questions through a systematic long thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution. In the Thought section, detail your reasoning process using the specified format: <|begin_of_thought|> {thought with steps separated with '\n\n'} <|end_of_thought|> Each step should include detailed considerations such as analisying questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The solution should remain a logical, accurate, concise expression style and detail necessary step needed to reach the conclusion, formatted as follows: <|begin_of_solution|> {final formatted, precise, and clear solution} <|end_of_solution|> Now, try to solve the following question through the above guidelines:"
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Determine the Difficulty Level
|
||||
|
||||
Critiquing the CoT qualities according to the cognitive capabilities of smaller models. You can use the current system prompts using QwQ-32B to determine the difficulty level of the CoTs.
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "You are a highly capable evaluator. Your task is to assess the given reasoning process from the perspective of a small language model (e.g., 7B). Specifically, determine whether the reasoning process provides sufficient detail for a small model to solve the problem, or whether it is too simplistic (i.e., lacking critical details) or too complex (i.e., containing unnecessary or confusing steps). Difficulty Definitions (from the perspective of a small model): - Easy: The reasoning process is overly simplistic relative to the problem's difficulty; it omits essential details that a small model needs to solve the problem. - Medium: The reasoning process is appropriately balanced, offering enough detailed guidance. - Hard: The reasoning process is overly complex, with extraneous or convoluted steps that could hinder a small model's ability to follow it. Output Format: You must output exactly one word: easy, medium, or hard. Do NOT provide any additional text, explanation."
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Rethinking and Refining these CoTs
|
||||
|
||||
Rethinking and refining these CoTs based on the critiques using following prompts:
|
||||
|
||||
#### easy
|
||||
```json
|
||||
{
|
||||
"system": "You are a helpful assistant who is highly skilled at extending reasoning processes. Given a problem, its answer, and its reasoning process, your task is to extend the reasoning process by adding necessary details and intermediate steps so that a small language model (e.g., a 7B model) can follow the extended reasoning process to solve the problem. If the original reasoning process is divided into multiple steps separated by two newline characters, your output must preserve this formatting. You must output ONLY the extended reasoning process with no additional explanation or commentary."
|
||||
}
|
||||
```
|
||||
|
||||
#### hard
|
||||
```json
|
||||
{
|
||||
"system": "You are a helpful assistant who is highly skilled at simplifying reasoning processes. Given a problem, its answer, and its reasoning process, your task is to simplify the reasoning process so that a small language model (e.g., a 7B model) can reliably follow the steps to solve the problem. If the original reasoning process is divided into multiple steps separated by two newline characters, your output must preserve this formatting. You must output ONLY the simplified reasoning process with no additional explanation or commentary."
|
||||
}
|
||||
```
|
||||
|
||||
The training dataset is in JSON format, exemplified by entries such as:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"instruction": "The ratio of the number of molar teeth in the human upper jaw at the age of 6 is 2:1 compared to number of incisors teeth. There are total 8 incisors in the human mouth...",
|
||||
"output": "<|begin_of_thought|>## Step 1: Determine the total number of incisors in the upper jaw...\n<|end_of_thought|>\n<|begin_of_solution|>The final answer is: \\boxed{8}<|end_of_solution|>"
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
## Model Training Guidelines
|
||||
|
||||
### 1. Black-Box KD
|
||||
|
||||
The black-box KD process follows a supervised learning paradigm, utilizing enhanced instruction-response pairs as training samples. Through this approach, the student model can effectively absorb and understand the knowledge imparted by the larger model, even with a limited number of parameters. This method not only boosts the student model's ability to tackle tasks but also enables it to perform better in multi-task scenarios. Because we have already obtained the teacher's responses in the dataset, we need to run the training job only:
|
||||
|
||||
```bash
|
||||
python easydistill/kd/train.py --config=distilqwen2.5-r1_stage1.json
|
||||
```
|
||||
|
||||
Plese refer to the config file `distilqwen2.5-r1.json` in the current folder. If you need to run the job in a distributed mode, use `accelerate` to run the job.
|
||||
|
||||
### 2. CogPO
|
||||
|
||||
CogPO (Cognitive Preference Optimization) is a novel algorithm designed to enhance the reasoning abilities of small language models (LLMs) by aligning their reasoning processes with their inherent cognitive capacities.
|
||||
|
||||
Key aspects of CogPO:
|
||||
- Extends Direct Preference Optimization (DPO) with cognitive alignment
|
||||
- Introduces three specialized "mini-tasks" with different preference gaps
|
||||
- Dynamically adjusts optimization strength (β values) based on reasoning complexity
|
||||
- Works synergistically with the CRV (Critique-Rethink-Verify) system
|
||||
|
||||
You can run the CogPO by:
|
||||
|
||||
```bash
|
||||
accelerate launch --num_processes n --config_file multi_gpu.yaml cogpo.py --config distilqwen2.5-r1_stage2.json
|
||||
```
|
||||
|
||||
The dataset is in JSON format, exemplified by entries such as:
|
||||
```json
|
||||
{
|
||||
"prompt": "Ellie has 8 pairs of shoes. Riley has 3 fewer. How many pairs of shoes do they have in all?",
|
||||
"chosen": "<think>Identify the number of pairs of shoes Ellie has. According to the problem statement, Ellie has 8 pairs of shoes.\n Next, determine the number of pairs of shoes Riley has. The problem states that Riley has 3 fewer pairs than Ellie. To find out how many pairs Riley has, subtract 3 from the number of pairs Ellie has: 8 - 3 = 5. So, Riley has 5 pairs of shoes.\n Now, calculate the total number of pairs of shoes both Ellie and Riley have together. To do this, add the number of pairs Ellie has to the number of pairs Riley has: 8 (Ellie's pairs) + 5 (Riley's pairs) = 13 pairs. This step is crucial because it combines the information about both individuals to give the overall total.\n The total number of pairs of shoes they have in all is 13. Thus, the final answer is 13. Each step in the reasoning process is designed to help understand and solve the problem effectively, showing how the information about each individual's shoe count leads to finding the combined total.</think>\boxed{13}",
|
||||
"rejected": "<think>Identify the number of pairs of shoes Ellie has. Ellie has 8 pairs of shoes as stated in the problem. Determine how many pairs of shoes Riley has. Since Riley has 3 fewer pairs than Ellie, we mistakenly add 3 to Ellie's pairs instead of subtracting, giving us 8 + 3 = 11 pairs of shoes for Riley. Calculate the total number of pairs of shoes they both have. Add Ellie's and Riley's pairs together: 8 + 11. The total pairs of shoes is 19. The final answer is thus \boxed{19}.</think>\boxed{13}",
|
||||
"beta": 0.5
|
||||
}
|
||||
```
|
||||
|
||||
## Model Download
|
||||
|
||||
We have open-sourced our distilled models on both HuggingFace and ModelScope. The available models are named `alibaba-pai/DistilQwen2.5-R1-3B`, `alibaba-pai/DistilQwen2.5-R1-7B`, `alibaba-pai/DistilQwen2.5-R1-14B`, and `alibaba-pai/DistilQwen2.5-R1-32B`.
|
||||
|
||||
For example, users can download these models from HuggingFace using the following code:
|
||||
|
||||
|
||||
```python
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# Download the 3B model
|
||||
model_name = "alibaba-pai/DistilQwen2.5-R1-3B"
|
||||
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2.5-R1-3B/")
|
||||
|
||||
# Download the 7B model
|
||||
model_name = "alibaba-pai/DistilQwen2.5-R1-7B"
|
||||
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2.5-R1-7B/")
|
||||
|
||||
# Download the 14B model
|
||||
model_name = "alibaba-pai/DistilQwen2.5-R1-14B"
|
||||
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2.5-R1-14B/")
|
||||
|
||||
# Download the 32B model
|
||||
model_name = "alibaba-pai/DistilQwen2.5-R1-32B"
|
||||
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2.5-R1-32B/")
|
||||
```
|
||||
|
||||
|
||||
## Performance
|
||||
|
||||
We compared DistilQwen2.5-R1 series with leading reasoning models across four benchmarks:
|
||||
|
||||
### 7B Model Comparison
|
||||
| Model | Training Data Size | AIME2024 | MATH-500 | GPQA Diamond | LiveCodeBench V2 |
|
||||
|--------------------------------|--------------------|----------|----------|--------------|------------------|
|
||||
| DeepSeek-R1-Distill-Qwen-7B | 800k | 55.5 | 92.8 | 49.1 | - |
|
||||
| Bespoke-Stratos-7B | 17k | 20.0 | 82.0 | 37.8 | 36.1 |
|
||||
| OpenThinker-7B | 114k | 31.3 | 83.0 | 42.4 | 39.9 |
|
||||
| **DistilQwen2.5-R1-7B** | 105k | 43.33 | 88.4 | 42.93 | 46.38 |
|
||||
|
||||
### 32B Model Comparison
|
||||
| Model | Training Data Size | AIME2024 | MATH-500 | GPQA Diamond | LiveCodeBench V2 |
|
||||
|--------------------------------|--------------------|----------|----------|--------------|------------------|
|
||||
| DeepSeek-R1-Distill-Qwen-32B | 800k | 72.6 | 94.3 | 62.1 | - |
|
||||
| Sky-T1-32B-Preview | 17k | 43.3 | 86.4 | 56.8 | - |
|
||||
| OpenThinker-32B | 114k | 66.0 | 90.6 | 61.6 | 68.9 |
|
||||
| **DistilQwen2.5-R1-32B** | 105k | 70.0 | 93.8 | 62.12 | 65.95 |
|
194
recipes/distilqwen_series/distillqwen2.5-r1/cogpo.py
Normal file
194
recipes/distilqwen_series/distillqwen2.5-r1/cogpo.py
Normal file
@@ -0,0 +1,194 @@
|
||||
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import trl
|
||||
from trl.trainer.dpo_trainer import DataCollatorForPreference
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Literal, Optional, Union
|
||||
import torch, torch.nn.functional as F
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import DPOConfig, DPOTrainer, FDivergenceConstants, FDivergenceType
|
||||
from trl.trainer.utils import cap_exp
|
||||
import json
|
||||
import argparse
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForPreferenceWithBeta(DataCollatorForPreference):
|
||||
def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]:
|
||||
betas = torch.tensor([float(ex["beta"]) for ex in examples], dtype=torch.float32)
|
||||
for ex in examples:
|
||||
ex.pop("beta")
|
||||
batch = super().torch_call(examples)
|
||||
batch["betas"] = betas
|
||||
return batch
|
||||
|
||||
|
||||
class CogPOTrainer(DPOTrainer):
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model,
|
||||
batch,
|
||||
train_eval: str = "train",
|
||||
):
|
||||
metrics = {}
|
||||
|
||||
betas = batch.pop("betas").to(self.accelerator.device)
|
||||
model_output = self.concatenated_forward(model, batch)
|
||||
|
||||
if "ref_chosen_logps" in batch and "ref_rejected_logps" in batch:
|
||||
ref_chosen_logps = batch["ref_chosen_logps"]
|
||||
ref_rejected_logps = batch["ref_rejected_logps"]
|
||||
else:
|
||||
ref_chosen_logps, ref_rejected_logps = self.compute_ref_log_probs(batch)
|
||||
|
||||
losses, chosen_rewards, rejected_rewards = self._dpo_sigmoid_loss(
|
||||
model_output["chosen_logps"],
|
||||
model_output["rejected_logps"],
|
||||
ref_chosen_logps,
|
||||
ref_rejected_logps,
|
||||
betas,
|
||||
)
|
||||
|
||||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||
|
||||
if self.args.rpo_alpha is not None:
|
||||
losses = losses + self.args.rpo_alpha * model_output["nll_loss"]
|
||||
|
||||
if self.use_weighting:
|
||||
losses = losses * model_output["policy_weights"]
|
||||
|
||||
if self.aux_loss_enabled:
|
||||
losses = losses + self.aux_loss_coef * model_output["aux_loss"]
|
||||
|
||||
prefix = "eval_" if train_eval == "eval" else ""
|
||||
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
|
||||
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
|
||||
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
|
||||
metrics[f"{prefix}rewards/margins"] = (
|
||||
self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item()
|
||||
)
|
||||
metrics[f"{prefix}logps/chosen"] = (
|
||||
self.accelerator.gather_for_metrics(model_output["chosen_logps"]).detach().mean().item()
|
||||
)
|
||||
metrics[f"{prefix}logps/rejected"] = (
|
||||
self.accelerator.gather_for_metrics(model_output["rejected_logps"]).detach().mean().item()
|
||||
)
|
||||
metrics[f"{prefix}logits/chosen"] = (
|
||||
self.accelerator.gather_for_metrics(model_output["mean_chosen_logits"]).detach().mean().item()
|
||||
)
|
||||
metrics[f"{prefix}logits/rejected"] = (
|
||||
self.accelerator.gather_for_metrics(model_output["mean_rejected_logits"]).detach().mean().item()
|
||||
)
|
||||
if self.args.rpo_alpha is not None:
|
||||
metrics[f"{prefix}nll_loss"] = (
|
||||
self.accelerator.gather_for_metrics(model_output["nll_loss"]).detach().mean().item()
|
||||
)
|
||||
if self.aux_loss_enabled:
|
||||
metrics[f"{prefix}aux_loss"] = (
|
||||
self.accelerator.gather_for_metrics(model_output["aux_loss"]).detach().mean().item()
|
||||
)
|
||||
|
||||
return losses.mean(), metrics
|
||||
|
||||
def _dpo_sigmoid_loss(
|
||||
self,
|
||||
chosen_logps: torch.FloatTensor,
|
||||
rejected_logps: torch.FloatTensor,
|
||||
ref_chosen_logps: torch.FloatTensor,
|
||||
ref_rejected_logps: torch.FloatTensor,
|
||||
betas: torch.FloatTensor,
|
||||
):
|
||||
|
||||
device = self.accelerator.device
|
||||
chosen_logratios = chosen_logps.to(device) - (not self.reference_free) * ref_chosen_logps.to(device)
|
||||
rejected_logratios = rejected_logps.to(device) - (not self.reference_free) * ref_rejected_logps.to(device)
|
||||
|
||||
# 2) Δ = (log p_c - log p_r) - (log p̂_c - log p̂_r)
|
||||
if self.f_divergence_type == FDivergenceType.ALPHA_DIVERGENCE.value:
|
||||
alpha_coef = FDivergenceConstants.ALPHA_DIVERGENCE_COEF_DEFAULT
|
||||
if self.f_divergence_params and FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY in self.f_divergence_params:
|
||||
alpha_coef = float(self.f_divergence_params[FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY])
|
||||
logits = (cap_exp(rejected_logratios * -alpha_coef) - cap_exp(chosen_logratios * -alpha_coef)) / alpha_coef
|
||||
else:
|
||||
logratios = chosen_logps - rejected_logps
|
||||
if self.reference_free:
|
||||
ref_logratios = torch.tensor([0], dtype=logratios.dtype, device=logratios.device)
|
||||
else:
|
||||
ref_logratios = ref_chosen_logps - ref_rejected_logps
|
||||
logratios = logratios.to(self.accelerator.device)
|
||||
ref_logratios = ref_logratios.to(self.accelerator.device)
|
||||
logits = logratios - ref_logratios
|
||||
|
||||
if self.f_divergence_type == FDivergenceType.JS_DIVERGENCE.value:
|
||||
logits -= F.softplus(chosen_logratios) - F.softplus(rejected_logratios)
|
||||
|
||||
|
||||
losses = (
|
||||
-F.logsigmoid(betas * logits) * (1 - self.label_smoothing)
|
||||
- F.logsigmoid(-betas * logits) * self.label_smoothing
|
||||
)
|
||||
|
||||
chosen_rewards = betas * (chosen_logps.to(device) - ref_chosen_logps.to(device)).detach()
|
||||
rejected_rewards = betas * (rejected_logps.to(device) - ref_rejected_logps.to(device)).detach()
|
||||
|
||||
|
||||
return losses, chosen_rewards, rejected_rewards
|
||||
|
||||
|
||||
def train(config):
|
||||
model_name = config["models"]["student"]
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
dataset = load_dataset("json", data_files=config["dataset"]["labeled_path"], split='train')
|
||||
|
||||
dpo_args = DPOConfig(
|
||||
output_dir=config["training"]["output_dir"],
|
||||
num_train_epochs=config["training"]["num_train_epochs"],
|
||||
loss_type=config["training"]["loss_type"],
|
||||
beta=config["training"]["beta"],
|
||||
per_device_train_batch_size=config["training"]["per_device_train_batch_size"],
|
||||
remove_unused_columns=False,
|
||||
)
|
||||
|
||||
collator = DataCollatorForPreferenceWithBeta(
|
||||
pad_token_id=tokenizer.pad_token_id
|
||||
)
|
||||
|
||||
trainer = CogPOTrainer(
|
||||
model=model,
|
||||
args=dpo_args,
|
||||
train_dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=collator,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
trainer.save_model(config["training"]["output_dir"])
|
||||
tokenizer.save_pretrained(config["training"]["output_dir"])
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
|
||||
args = parser.parse_args()
|
||||
config = json.load(open(args.config))
|
||||
train(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@@ -0,0 +1,23 @@
|
||||
{
|
||||
"job_type": "kd_black_box_api",
|
||||
"dataset": {
|
||||
"labeled_path": "distil_qwen_r1.json",
|
||||
"template" : "chat_template_kd.jinja",
|
||||
"seed": 42
|
||||
},
|
||||
"models": {
|
||||
"student": "model/Qwen/Qwen2.5-1.5B-Instruct/"
|
||||
},
|
||||
"training": {
|
||||
"output_dir": "result_stage1/",
|
||||
"num_train_epochs": 3,
|
||||
"per_device_train_batch_size": 1,
|
||||
"gradient_accumulation_steps": 8,
|
||||
"save_steps": 1000,
|
||||
"logging_steps": 1,
|
||||
"learning_rate": 2e-5,
|
||||
"weight_decay": 0.05,
|
||||
"warmup_ratio": 0.1,
|
||||
"lr_scheduler_type": "cosine"
|
||||
}
|
||||
}
|
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"models": {
|
||||
"student": "models/Qwen2.5-0.5B-Instruct"
|
||||
},
|
||||
"dataset": {
|
||||
"labeled_path": "cogpo/test500.jsonl"
|
||||
},
|
||||
"training": {
|
||||
"output_dir": "save/Qwen2.5-0.5B-CogPO",
|
||||
"num_train_epochs": 1.0,
|
||||
"loss_type": "sigmoid",
|
||||
"beta": 1.0,
|
||||
"per_device_train_batch_size": 2
|
||||
}
|
||||
}
|
101
recipes/distilqwen_series/distillqwen2.5-thoughtX/README.md
Normal file
101
recipes/distilqwen_series/distillqwen2.5-thoughtX/README.md
Normal file
@@ -0,0 +1,101 @@
|
||||
# DistilQwen-ThoughtX: Optimized Reasoning Models with OmniThought
|
||||
|
||||
## Brief Introduction
|
||||
|
||||
DistilQwen-ThoughtX is a series of high-performance reasoning models trained on the [OmniThought](https://huggingface.co/datasets/alibaba-pai/OmniThought) dataset. These models are optimized for chain-of-thought (CoT) reasoning with balanced verbosity and cognitive difficulty, achieving state-of-the-art results on mathematical, coding, and logical reasoning benchmarks.
|
||||
|
||||
## Detailed Steps
|
||||
|
||||
### Direct Training
|
||||
|
||||
DistilQwen-ThoughtX was trained using data from the OmniThought dataset, which includes 2 million CoT processes with RV (Reasoning Verbosity) and CD (Cognitive Difficulty) annotations. The dataset covers mathematics, coding, and logical reasoning tasks, validated by multiple teacher models (DeepSeek-R1, QwQ-32B).
|
||||
|
||||
The training system prompt is:
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
|
||||
}
|
||||
```
|
||||
|
||||
Using the OmniThought dataset, we can run the training job:
|
||||
|
||||
```bash
|
||||
python easydistill/kd/train.py --config=distilqwen2.5-thoughtx-train.json
|
||||
```
|
||||
|
||||
Remember to filter the RV and CD annotations to ensure they are within the desired range to train your own model.
|
||||
|
||||
| Model Name | Parameters | Base Model |
|
||||
|--------------------------------------|------------|---------------------|
|
||||
| `DistilQwen-ThoughtX-7B` | 7B | Qwen2.5-7B-Instruct |
|
||||
| `DistilQwen-ThoughtX-32B` | 32B | Qwen2.5-32B-Instruct|
|
||||
|
||||
### Process Your Own Data
|
||||
|
||||
To obtain the RV and CD values of your own data, you can use the following prompt to call QwQ-32B/Deepseek-R1, score your own data, and filter it.
|
||||
|
||||
Prompt Template to Calculate the RV Score:
|
||||
```json
|
||||
{
|
||||
"prompt": "You are an expert judge tasked with evaluating the Reasoning Verbosity of a Chain-of-Thought (CoT) for a given problem and its answer. Reasoning Verbosity Evaluation Focus: Assess how well the CoT’s length and step complexity match the problem’s inherent difficulty. An optimal chain is neither missing essential steps nor padded with needless digressions. A simple question should be solved with a brief, direct chain; a challenging one may justifiably require a longer path with reflection and error-checking. Scoring Guidelines (0-9): 0-1 Minimal verbosity, straightforward expression with little to no elaboration. 2-3 Clear and concise reasoning with necessary explanations. 4-5 Moderate verbosity with detailed explanations and thorough reasoning. 6-7 Extensive verbosity with comprehensive justification and exploration of complex connections. 8-9 High verbosity with deep, exhaustive exploration of reasoning; involves extensive elaboration, nested justifications, and consideration of counterarguments or alternative perspectives. Given Problem, Chain-of-Thought and Answer, you will: 1. Analyze the Reasoning Verbosity 2. Determine score using the above criteria 3. Output ONLY the integer score (0-9) Problem: {problem} Chain-of-Thought: {thought} Answer: {solution}"
|
||||
}
|
||||
```
|
||||
|
||||
Prompt Template to Calculate the CD Score:
|
||||
```json
|
||||
{
|
||||
"prompt": "You are an expert judge assessing the Cognitive Difficulty of a Chain-of-Thought (CoT) for a given problem and its answer. Cognitive Difficulty Evaluation Focus: The level of reasoning competence required for a model to follow and reproduce the chain faithfully. Judge the reasoning approach, techniques, and overall difficulty. Higher scores correspond to more advanced concepts, abstractions, or multi-layer reasoning patterns. Scoring Guidelines (0-9): 0-1 Elementary facts or a single trivial operation. 2-3 Multi-step arithmetic, explicit enumeration, basic rule chaining. 4-5 Early-undergraduate logic/algebra; one non-obvious insight. 6-7 Advanced undergraduate techniques (determinants, dynamic programming, layered code reasoning, etc). 8-9 Graduate-level abstraction, nested proofs, intricate algorithmic analysis. Given Problem, Chain-of-Thought and Answer, you will: 1. Analyze the Cognitive Difficulty 2. Determine score using the above criteria 3. Output ONLY the integer score (0-9) Problem: {problem} Chain-of-Thought: {thought} Answer: {solution}"
|
||||
}
|
||||
```
|
||||
|
||||
## Model Download
|
||||
|
||||
We have open-sourced our distilled models on HuggingFace. The available models are named `alibaba-pai/DistilQwen-ThoughtX-7B` and `alibaba-pai/DistilQwen-ThoughtX-32B`.
|
||||
|
||||
Users can download these models from HuggingFace using the following code:
|
||||
|
||||
```python
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# Download the 7B model
|
||||
model_name = "alibaba-pai/DistilQwen-ThoughtX-7B"
|
||||
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen-ThoughtX-7B/")
|
||||
|
||||
# Download the 32B model
|
||||
model_name = "alibaba-pai/DistilQwen-ThoughtX-32B"
|
||||
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen-ThoughtX-32B/")
|
||||
```
|
||||
|
||||
## Performance
|
||||
|
||||
The models achieve state-of-the-art performance on various reasoning benchmarks:
|
||||
|
||||
| Model | AIME2024 | MATH500 | GPQA-D | LiveCodeBench V2 |
|
||||
|----------------------|----------|---------|--------|------------------|
|
||||
| DeepSeek-R1-Distill-7B | 57.3 | 89.6 | 47.3 | 48.4 |
|
||||
| **DistilQwen-ThoughtX-7B** | **56.7** | **90.2** | **50.0** | **56.8** |
|
||||
| DeepSeek-R1-Distill-32B | 74.7 | 90.0 | 62.4 | 72.3 |
|
||||
| **DistilQwen-ThoughtX-32B** | **80.0** | **92.6** | **64.0** | **73.4** |
|
||||
|
||||
## Reference
|
||||
|
||||
For more detailed information about the model, we encourage you to refer to our paper:
|
||||
|
||||
- **Reasoning with OmniThought: A Large CoT Dataset with Verbosity and Cognitive Difficulty Annotations**
|
||||
Wenrui Cai, Chengyu Wang, Junbing Yan, Jun Huang, Xiangzhong Fang
|
||||
[arXiv:2505.10937](https://arxiv.org/abs/2505.10937)
|
||||
|
||||
You can cite the paper using the following citation format:
|
||||
|
||||
```bibtex
|
||||
@misc{cai2025reasoningomnithoughtlargecot,
|
||||
title={Reasoning with OmniThought: A Large CoT Dataset with Verbosity and Cognitive Difficulty Annotations},
|
||||
author={Wenrui Cai and Chengyu Wang and Junbing Yan and Jun Huang and Xiangzhong Fang},
|
||||
year={2025},
|
||||
eprint={2505.10937},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CL},
|
||||
url={https://arxiv.org/abs/2505.10937}
|
||||
}
|
||||
```
|
@@ -0,0 +1,24 @@
|
||||
{
|
||||
"job_type": "kd_black_box_api",
|
||||
"dataset": {
|
||||
"labeled_path": "distil_qwen_thoughtX.json",
|
||||
"template" : "chat_template_kd.jinja",
|
||||
"seed": 42
|
||||
},
|
||||
"models": {
|
||||
"student": "model/Qwen/Qwen2.5-1.5B-Instruct/"
|
||||
},
|
||||
"training": {
|
||||
"output_dir": "result/",
|
||||
"num_train_epochs": 3,
|
||||
"per_device_train_batch_size": 1,
|
||||
"gradient_accumulation_steps": 8,
|
||||
"max_length":4096,
|
||||
"save_steps": 1000,
|
||||
"logging_steps": 1,
|
||||
"learning_rate": 2e-5,
|
||||
"weight_decay": 0.05,
|
||||
"warmup_ratio": 0.1,
|
||||
"lr_scheduler_type": "cosine"
|
||||
}
|
||||
}
|
135
recipes/distilqwen_series/distillqwen2.5/README.md
Normal file
135
recipes/distilqwen_series/distillqwen2.5/README.md
Normal file
@@ -0,0 +1,135 @@
|
||||
# DistilQwen2.5: Combining Black-Box and White Box KD
|
||||
|
||||
## Brief Introduction
|
||||
|
||||
The DistilQwen2.5 distilled language model series is built upon the Qwen2.5 model. This series leverages innovative distillation techniques to enhance instruction-following capabilities. As a result, these distilled models retain the excellent performance of the original models while requiring fewer computational resources.
|
||||
|
||||
The distillation process involves carefully selecting, rewriting, and optimizing instruction-response pairs conducive to student model learning, thus improving model comprehension and execution abilities. Following standard fine-tuning, we employ white-box distillation techniques to enable the student models to better acquire fine-grained knowledge from teacher models. Experimental evaluations demonstrate the significant improvement in capabilities of the DistilQwen2.5 models.
|
||||
|
||||
## Detailed Steps
|
||||
|
||||
### Processing of Instructional Dataset
|
||||
|
||||
DistilQwen2.5 begins with collecting diverse, high-quality instructional data from sources like Magpie, Openhermes, and Mammoth 2, along with proprietary datasets. This data includes Chinese and English instructions, scoring them for difficulty and task relevance. This process is very similar to the recipe of DistilQwen2.
|
||||
|
||||
In addition, we have open-sourced part of the dataset used for model training, totaling 100K entries. This dataset includes mathematical problems, code tasks, Q&A, instruction following, and creative generation. Users can incorporate the DistilQwen_100K dataset, or its subsets, during model fine-tuning to enhance downstream task performance while maintaining generalization ability. The dataset is in JSON format, exemplified by entries such as:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"instruction": "The ratio of the number of molar teeth in the human upper jaw at the age of 6 is 2:1 compared to number of incisors teeth. There are total 8 incisors in the human mouth...",
|
||||
"output": "## Step 1: Determine the total number of incisors in the upper jaw...\n\nThe final answer is: \\boxed{8}"
|
||||
},
|
||||
{
|
||||
"instruction": "This is the text of a lecture I am giving tomorrow. Can you go over it and make recommendations to improve clarity and flow?",
|
||||
"output": "I'd be happy to help you review your lecture text..."
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
The dataset is available on ModelScope and Hugging Face. Users can download it using ModelScope's scripts and command-line tools.
|
||||
|
||||
```python
|
||||
# Validate SDK token
|
||||
from modelscope.hub.api import HubApi
|
||||
api = HubApi()
|
||||
api.login('your_token_id')
|
||||
|
||||
# Dataset download
|
||||
from modelscope.msdatasets import MsDataset
|
||||
ds = MsDataset.load('PAI/DistilQwen_100k')
|
||||
```
|
||||
|
||||
### Black-Box KD
|
||||
|
||||
The black-box KD process follows a supervised learning paradigm, utilizing enhanced instruction-response pairs as training samples. Through this approach, the student model can effectively absorb and understand the knowledge imparted by the larger model, even with a limited number of parameters. This method not only boosts the student model's ability to tackle tasks but also enables it to perform better in multi-task scenarios. Because we have already obtained the teacher's responses in the dataset, we need to run the training job only:
|
||||
|
||||
```bash
|
||||
python easydistill/kd/train.py --config=distilqwen2.5_stage1.json
|
||||
```
|
||||
|
||||
Plese refer to the config file `distilqwen2.5_stage1.json` in the current folder. If you need to run the job in a distributed mode, use `accelerate` to run the job.
|
||||
|
||||
### White-Box KD
|
||||
|
||||
Unlike black-box KD, which relies solely on the highest probability token output by the teacher model, white-box KD focuses on the distribution of logits produced by the teacher model. This approach provides the student model with richer information. By mimicking the teacher model's logits distribution, white-box KD can transfer knowledge more effectively, thereby enhancing the performance of the student model. As an example, we take `Qwen2.5-72B-Instruct` as the white-box teacher model, and generate the logits by:
|
||||
|
||||
```bash
|
||||
python easydistill/kd/infer.py --config=distilqwen2.5_stage2.json
|
||||
```
|
||||
|
||||
Next, we run the training job by:
|
||||
|
||||
```bash
|
||||
python easydistill/kd/train.py --config=distilqwen2.5_stage2.json
|
||||
```
|
||||
|
||||
Again, please refer to the config file `distilqwen2.5_stage2.json` in the current folder. Remember to change the configurations when needed.
|
||||
|
||||
## Model Download
|
||||
|
||||
We have open-sourced our distilled models on both HuggingFace and ModelScope. The available models are named `alibaba-pai/DistilQwen2.5-0.5B-Instruct`, `alibaba-pai/DistilQwen2.5-1.5B-Instruct`, `alibaba-pai/DistilQwen2.5-3B-Instruct`, and `alibaba-pai/DistilQwen2.5-7B-Instruct`.
|
||||
|
||||
For example, users can download these models from HuggingFace using the following code:
|
||||
|
||||
|
||||
```python
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# Download the 0.5B model
|
||||
model_name = "alibaba-pai/DistilQwen2.5-0.5B-Instruct"
|
||||
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2.5-0.5B/")
|
||||
|
||||
# Download the 1.5B model
|
||||
model_name = "alibaba-pai/DistilQwen2.5-1.5B-Instruct"
|
||||
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2.5-1.5B/")
|
||||
|
||||
# Download the 3B model
|
||||
model_name = "alibaba-pai/DistilQwen2.5-3B-Instruct"
|
||||
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2.5-3B/")
|
||||
|
||||
# Download the 7B model
|
||||
model_name = "alibaba-pai/DistilQwen2.5-7B-Instruct"
|
||||
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2.5-7B/")
|
||||
```
|
||||
|
||||
|
||||
## Performance
|
||||
|
||||
The table below compares the performance of the original Qwen2.5 models with the distilled DistilQwen2.5 models across different parameter sizes: 0.5B, 1.5B, 3B, and 7B. The evaluation metrics include AlpacaEval 2.0, MT-Bench, and IFEval scores. The distilled models demonstrate improved performance in instruction-following abilities over their respective original versions.
|
||||
|
||||
| Model | AlpacaEval 2.0 (length control) | MT-Bench | MT-Bench (single) | IFEval (instruct-loose) | IFEval (strict-prompt) |
|
||||
|-------------------------------|---------------------------------|------------------|-------------------|-------------------------|------------------------|
|
||||
| Qwen2.5-0.5B-Instruct | 2.46 | 5.49 | 6.26 | 42.81 | 30.31 |
|
||||
| **DistilQwen2.5-0.5B-Instruct** | **4.89** | **5.78** | **6.83** | **52.61** | **37.82** |
|
||||
| Qwen2.5-1.5B-Instruct | 6.69 | 7.09 | 7.66 | 55.40 | 40.11 |
|
||||
| **DistilQwen2.5-1.5B-Instruct** | **13.69** | **7.35** | **7.99** | **61.10** | **74.49** |
|
||||
| Qwen2.5-3B-Instruct | 17.98 | 7.92 | 8.40 | 61.18 | 74.58 |
|
||||
| **DistilQwen2.5-3B-Instruct** | **20.91** | **8.37** | **8.97** | **67.03** | **77.36** |
|
||||
| Qwen2.5-7B-Instruct | 31.43 | 8.52 | 8.83 | 81.53 | 72.10 |
|
||||
| **DistilQwen2.5-7B-Instruct** | **34.86** | **8.76** | **9.22** | **83.48** | **73.27** |
|
||||
|
||||
|
||||
For evaluation details, please refer to our paper.
|
||||
|
||||
## Reference
|
||||
|
||||
For more detailed information about the DistilQwen2.5 model series and the methodologies employed, we encourage you to refer to our paper:
|
||||
|
||||
- **DistilQwen2.5: Industrial Practices of Training Distilled Open Lightweight Language Models**
|
||||
Chengyu Wang, Junbing Yan, Yuanhao Yue, Jun Huang
|
||||
[arXiv:2504.15027](https://arxiv.org/abs/2504.15027)
|
||||
|
||||
You can cite the paper using the following citation format:
|
||||
|
||||
```bibtex
|
||||
@misc{wang2025distilqwen25industrialpracticestraining,
|
||||
title={DistilQwen2.5: Industrial Practices of Training Distilled Open Lightweight Language Models},
|
||||
author={Chengyu Wang and Junbing Yan and Yuanhao Yue and Jun Huang},
|
||||
year={2025},
|
||||
eprint={2504.15027},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CL},
|
||||
url={https://arxiv.org/abs/2504.15027},
|
||||
}
|
||||
```
|
@@ -0,0 +1,23 @@
|
||||
{
|
||||
"job_type": "kd_black_box_api",
|
||||
"dataset": {
|
||||
"labeled_path": "distil_qwen_100k.json",
|
||||
"template" : "chat_template_kd.jinja",
|
||||
"seed": 42
|
||||
},
|
||||
"models": {
|
||||
"student": "model/Qwen/Qwen2.5-0.5B-Instruct/"
|
||||
},
|
||||
"training": {
|
||||
"output_dir": "result_stage1/",
|
||||
"num_train_epochs": 3,
|
||||
"per_device_train_batch_size": 1,
|
||||
"gradient_accumulation_steps": 8,
|
||||
"save_steps": 1000,
|
||||
"logging_steps": 1,
|
||||
"learning_rate": 2e-5,
|
||||
"weight_decay": 0.05,
|
||||
"warmup_ratio": 0.1,
|
||||
"lr_scheduler_type": "cosine"
|
||||
}
|
||||
}
|
@@ -0,0 +1,40 @@
|
||||
{
|
||||
"job_type": "kd_white_box",
|
||||
"dataset": {
|
||||
"labeled_path": "distil_qwen_100k.json",
|
||||
"logits_path": "logits.json",
|
||||
"template" : "chat_template_kd.jinja",
|
||||
"seed": 42
|
||||
},
|
||||
"inference":{
|
||||
"enable_chunked_prefill": true,
|
||||
"seed": 777,
|
||||
"gpu_memory_utilization": 0.9,
|
||||
"temperature": 0.8,
|
||||
"trust_remote_code": true,
|
||||
"enforce_eager": false,
|
||||
"max_model_len": 4096,
|
||||
"max_new_tokens": 512
|
||||
},
|
||||
"distillation": {
|
||||
"kd_ratio": 0.5,
|
||||
"max_seq_length": 512,
|
||||
"distillation_type": "forward_kld"
|
||||
},
|
||||
"models": {
|
||||
"teacher": "teacher/Qwen/Qwen2.5-72B-Instruct/",
|
||||
"student": "result_stage1/"
|
||||
},
|
||||
"training": {
|
||||
"output_dir": "result_stage2/",
|
||||
"num_train_epochs": 3,
|
||||
"per_device_train_batch_size": 1,
|
||||
"gradient_accumulation_steps": 8,
|
||||
"save_steps": 1000,
|
||||
"logging_steps": 1,
|
||||
"learning_rate": 2e-5,
|
||||
"weight_decay": 0.05,
|
||||
"warmup_ratio": 0.1,
|
||||
"lr_scheduler_type": "cosine"
|
||||
}
|
||||
}
|
165
recipes/distilqwen_series/distillqwen2/README.md
Normal file
165
recipes/distilqwen_series/distillqwen2/README.md
Normal file
@@ -0,0 +1,165 @@
|
||||
# DistilQwen2: Refining Instructional Data for Black-Box KD
|
||||
|
||||
## Brief Introduction
|
||||
|
||||
Knowledge distillation offers an effective solution by transferring knowledge from larger models to smaller ones, ensuring performance while significantly reducing computational resources and inference time. We introduce DistilQwen2, a lightweight LLM based on the Qwen2 series, optimized through enhanced instruction following and diverse distillation techniques. This enables more agile and efficient deployment in resource-constrained environments like mobile devices and edge computing. For ease of use by developers and enterprises, DistilQwen2's checkpoints are open-sourced on HuggingFace and ModelScope, empowering more stakeholders to innovate and realize value through advanced NLP applications.
|
||||
|
||||
## Instructional Data Processing Guidelines
|
||||
|
||||
For the training of DistilQwen2, we collected data from well-known open-source datasets like Magpie, Openhermes, and Mammoth 2, along with proprietary synthetic datasets to initiate the distillation process. The focus is on providing diverse instructional data, predominantly in Chinese and English. We also leverage prompt templates to conduct instructional data augmentation. Here, we provide several commonly used operations to re-sample and augement the dataset.
|
||||
|
||||
### Instruction Set Expansion
|
||||
|
||||
The instruction expansion operator is employed generate a diverse set of instruction variations, ensuring that student models are exposed to a comprehensive range of instructions. After instruction expansion, we can also call the teacher model to generate responses for new instructions. An example is calling this operator is as follows:
|
||||
|
||||
```bash
|
||||
python easydistill/synthesis/synthesis_main.py --config=configs/instruction_expansion_api.json
|
||||
```
|
||||
|
||||
If you need to run the job using batch inference, please refer to the config example `configs/instruction_expansion_batch.json`.
|
||||
|
||||
### Instruction Refinement
|
||||
|
||||
The instruction refinement operator further enhances the quality and diversity of the training data, which also preserves the semantic integrity of the tasks expressed in instructions, ensuring that the rewritten content remains faithful to the original intent and task category. After instruction refinement, we can also call the teacher model to generate responses for new instructions. An example is calling this operator is as follows:
|
||||
|
||||
```bash
|
||||
python easydistill/synthesis/synthesis_main.py --config=configs/instruction_refinement_api.json
|
||||
```
|
||||
|
||||
If you need to run the job using batch inference, please refer to the config example `configs/instruction_refinement_batch.json`.
|
||||
|
||||
### Instruction Resampling
|
||||
|
||||
We also consider task balance when selecting useful instructional data pairs. The task distrubutions are defined based on our paper in the reference. You can run the job by:
|
||||
|
||||
```bash
|
||||
python task_resampling.py --input-file input.json --output-file output.json --api-key <your_api_key> --base-url <base_url>
|
||||
```
|
||||
|
||||
The dataset is in JSON format, exemplified by entries such as:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"instruction": "The ratio of the number of molar teeth in the human upper jaw at the age of 6 is 2:1 compared to number of incisors teeth. There are total 8 incisors in the human mouth..."
|
||||
},
|
||||
{
|
||||
"instruction": "This is the text of a lecture I am giving tomorrow. Can you go over it and make recommendations to improve clarity and flow?"
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
After the processing of intructions, you can generate the responses of the teacher model.
|
||||
|
||||
|
||||
### Open-Source Dataset
|
||||
|
||||
In addition, we have open-sourced part of the dataset used for model training, totaling 100K entries. This dataset includes mathematical problems, code tasks, Q&A, instruction following, and creative generation. Users can incorporate the DistilQwen_100K dataset, or its subsets, during model fine-tuning to enhance downstream task performance while maintaining generalization ability. The dataset is in JSON format, exemplified by entries such as:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"instruction": "The ratio of the number of molar teeth in the human upper jaw at the age of 6 is 2:1 compared to number of incisors teeth. There are total 8 incisors in the human mouth...",
|
||||
"output": "## Step 1: Determine the total number of incisors in the upper jaw...\n\nThe final answer is: \\boxed{8}"
|
||||
},
|
||||
{
|
||||
"instruction": "This is the text of a lecture I am giving tomorrow. Can you go over it and make recommendations to improve clarity and flow?",
|
||||
"output": "I'd be happy to help you review your lecture text..."
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
The dataset is available on ModelScope and Hugging Face. Users can download it using ModelScope's scripts and command-line tools.
|
||||
|
||||
```python
|
||||
# Validate SDK token
|
||||
from modelscope.hub.api import HubApi
|
||||
api = HubApi()
|
||||
api.login('your_token_id')
|
||||
|
||||
# Dataset download
|
||||
from modelscope.msdatasets import MsDataset
|
||||
ds = MsDataset.load('PAI/DistilQwen_100k')
|
||||
```
|
||||
|
||||
## Model Training Guidelines
|
||||
|
||||
### Black-Box KD
|
||||
|
||||
The black-box KD process follows a supervised learning paradigm, utilizing enhanced instruction-response pairs as training samples. Through this approach, the student model can effectively absorb and understand the knowledge imparted by the larger model, even with a limited number of parameters. This method not only boosts the student model's ability to tackle tasks but also enables it to perform better in multi-task scenarios. For simplicity, we use the `DistilQwen_100k` dataset as a tutorial, we need to run the training job only:
|
||||
|
||||
```bash
|
||||
python easydistill/kd/train.py --config=distilqwen2_stage1.json
|
||||
```
|
||||
|
||||
Plese refer to the config file `distilqwen2_stage1.json` in the current folder. If you need to run the job in a distributed mode, use `accelerate` to run the job.
|
||||
|
||||
### Preference Rank Optimization
|
||||
|
||||
For more challenging instruction tasks, SFT alone may not yield optimal results. To address this, we further refine the model using Direct Preference Optimization (DPO), enabling more granular fine-tuning and improved performance. Firstly, we generate the student outputs as rejected response. The contents in the SFT datasets are regarded as prompt and chosen responses. Please refer to the following script:
|
||||
|
||||
```bash
|
||||
python dpo_student_infer_only.py --config=distilqwen2_stage2.json
|
||||
```
|
||||
|
||||
Next, we run the training job by:
|
||||
|
||||
```bash
|
||||
python easydistill/kd/train.py --config=distilqwen2_stage2.json
|
||||
```
|
||||
|
||||
Again, please refer to the config file `distilqwen2_stage2.json` in the current folder. Remember to change the configurations when needed. If you need to run the job in a distributed mode, use `accelerate` to run the job.
|
||||
|
||||
## Model Download
|
||||
|
||||
We have open-sourced our distilled models on both HuggingFace and ModelScope. The available models are named `alibaba-pai/DistilQwen2-1.5B-Instruct` and `alibaba-pai/DistilQwen2-7B-Instruct`.
|
||||
|
||||
For example, users can download these models from HuggingFace using the following code:
|
||||
|
||||
|
||||
```python
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
model_name = "alibaba-pai/DistilQwen2-1.5B-Instruct"
|
||||
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2-1.5B/")
|
||||
|
||||
model_name = "alibaba-pai/DistilQwen2-7B-Instruct"
|
||||
snapshot_download(repo_id=model_name, cache_dir="./DistilQwen2-7B/")
|
||||
```
|
||||
|
||||
|
||||
## Performance
|
||||
|
||||
The table below compares the performance of the original Qwen2 models with the distilled DistilQwen2 models across different parameter sizes: 1.5B and 7B. The evaluation metrics include AlpacaEval 2.0, MT-Bench, and IFEval scores. The distilled models demonstrate improved performance in instruction-following abilities over their respective original versions.
|
||||
|
||||
| Model | AlpacaEval 2.0 (length control) | MT-Bench | MT-Bench (single) | IFEval (instruct-loose) | IFEval (strict-prompt) |
|
||||
|-------------------------------|---------------------------------|------------------|-------------------|-------------------------|------------------------|
|
||||
| Qwen2-1.5B-Instruct | 5.22 | 5.85 | 6.45 | 41.37 | 28.10 |
|
||||
| **DistilQwen2-1.5B-Instruct** | **8.28** | **6.42** | **7.12** | **49.76** | **36.04** |
|
||||
| Qwen2-7B-Instruct | 24.33 | 8.27 | 8.68 | 66.67 | 52.31 |
|
||||
| **DistilQwen2-7B-Instruct** | **25.35** | **8.40** | **9.03** | **71.46** | **60.26** |
|
||||
|
||||
|
||||
|
||||
## Reference
|
||||
|
||||
For more detailed information about the DistilQwen2 model series and the methodologies employed, we encourage you to refer to our paper:
|
||||
|
||||
- **Distilling Instruction-following Abilities of Large Language Models with Task-aware Curriculum Planning**
|
||||
Yuanhao Yue, Chengyu Wang, Jun Huang, Peng Wang
|
||||
|
||||
You can cite the paper using the following citation format:
|
||||
|
||||
```bibtex
|
||||
@inproceedings{emnlp2024,
|
||||
author = {Yuanhao Yue and
|
||||
Chengyu Wang and
|
||||
Jun Huang and
|
||||
Peng Wang},
|
||||
title = {Distilling Instruction-following Abilities of Large Language Models with Task-aware Curriculum Planning},
|
||||
booktitle = {Findings of the Association for Computational Linguistics: {EMNLP} 2024},
|
||||
pages = {6030--6054},
|
||||
publisher = {Association for Computational Linguistics},
|
||||
year = {2024},
|
||||
url = {https://aclanthology.org/2024.findings-emnlp.350}
|
||||
}
|
@@ -0,0 +1,23 @@
|
||||
{
|
||||
"job_type": "kd_black_box_api",
|
||||
"dataset": {
|
||||
"labeled_path": "distil_qwen_100k.json",
|
||||
"template" : "chat_template_kd.jinja",
|
||||
"seed": 42
|
||||
},
|
||||
"models": {
|
||||
"student": "model/Qwen/Qwen2-0.5B-Instruct/"
|
||||
},
|
||||
"training": {
|
||||
"output_dir": "result_stage1/",
|
||||
"num_train_epochs": 3,
|
||||
"per_device_train_batch_size": 1,
|
||||
"gradient_accumulation_steps": 8,
|
||||
"save_steps": 1000,
|
||||
"logging_steps": 1,
|
||||
"learning_rate": 2e-5,
|
||||
"weight_decay": 0.05,
|
||||
"warmup_ratio": 0.1,
|
||||
"lr_scheduler_type": "cosine"
|
||||
}
|
||||
}
|
@@ -0,0 +1,25 @@
|
||||
{
|
||||
"job_type": "rank_dpo_api",
|
||||
"dataset": {
|
||||
"instruction_path": "distil_qwen_100k.json",
|
||||
"labeled_path": "distil_qwen_100k_dpo.json",
|
||||
"template" : "chat_template_kd.jinja",
|
||||
"seed": 42
|
||||
},
|
||||
"models": {
|
||||
"student": "result_stage1/"
|
||||
},
|
||||
"training": {
|
||||
"output_dir": "result_stage2/",
|
||||
"num_train_epochs": 3,
|
||||
"per_device_train_batch_size": 1,
|
||||
"gradient_accumulation_steps": 8,
|
||||
"save_steps": 1000,
|
||||
"logging_steps": 1,
|
||||
"beta": 0.1,
|
||||
"learning_rate": 2e-5,
|
||||
"weight_decay": 0.05,
|
||||
"warmup_ratio": 0.1,
|
||||
"lr_scheduler_type": "cosine"
|
||||
}
|
||||
}
|
105
recipes/distilqwen_series/distillqwen2/dpo_student_infer_only.py
Normal file
105
recipes/distilqwen_series/distillqwen2/dpo_student_infer_only.py
Normal file
@@ -0,0 +1,105 @@
|
||||
|
||||
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import json
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
|
||||
def read_json_field(filename):
|
||||
try:
|
||||
with open(filename, 'r') as file:
|
||||
data = json.load(file)
|
||||
output = []
|
||||
for item in data:
|
||||
instruction = item["instruction"]
|
||||
output = item["output"]
|
||||
output.append({"prompt": instruction, "chosen": output})
|
||||
return output
|
||||
except FileNotFoundError:
|
||||
logging.error("The file was not found.")
|
||||
except json.JSONDecodeError:
|
||||
logging.error("There was an error decoding the JSON file.")
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred: {e}")
|
||||
|
||||
|
||||
def write_data_to_json_file(data, file_path):
|
||||
try:
|
||||
with open(file_path, 'w') as file:
|
||||
json.dump(data, file, ensure_ascii=False, indent=4)
|
||||
logging.info(f"Data successfully written to {file_path}")
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred: {e}")
|
||||
|
||||
|
||||
def generate_student_response(data_list, config):
|
||||
# load student model
|
||||
student_tokenizer = AutoTokenizer.from_pretrained(
|
||||
config["models"]["student"],
|
||||
trust_remote_code=True
|
||||
)
|
||||
student_model = AutoModelForCausalLM.from_pretrained(
|
||||
config["models"]["student"],
|
||||
device_map="auto",
|
||||
trust_remote_code=True
|
||||
)
|
||||
outcomes = []
|
||||
for sample in tqdm(data_list, desc="Call remote model and generating responses"):
|
||||
prompt = sample["prompt"]
|
||||
chosen = sample["chosen"]
|
||||
# for student model
|
||||
messages = [
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
text = student_tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
)
|
||||
model_inputs = student_tokenizer([text], return_tensors="pt").to(student_model.device)
|
||||
generated_ids = student_model.generate(
|
||||
**model_inputs,
|
||||
max_new_tokens=config["inference"]["max_new_tokens"]
|
||||
)
|
||||
generated_ids = [
|
||||
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
||||
]
|
||||
|
||||
rejected = student_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
gen_data = {'prompt': prompt, 'chosen': chosen, 'rejected': rejected}
|
||||
outcomes.append(gen_data)
|
||||
write_data_to_json_file(outcomes, config["dataset"]["labeled_path"])
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--config', type=str, required=True, help='path to the json config file')
|
||||
args = parser.parse_args()
|
||||
config = json.load(open(args.config))
|
||||
data_list = read_json_field(config["dataset"]["instruction_path"])
|
||||
generate_student_response(data_list, config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
156
recipes/distilqwen_series/distillqwen2/task_resampling.py
Normal file
156
recipes/distilqwen_series/distillqwen2/task_resampling.py
Normal file
@@ -0,0 +1,156 @@
|
||||
|
||||
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import json
|
||||
import re
|
||||
import logging
|
||||
from openai import OpenAI
|
||||
from collections import Counter
|
||||
import random
|
||||
import argparse
|
||||
|
||||
|
||||
predefined_distribution = {
|
||||
'Math': 0.167,
|
||||
'Code Generation': 0.083,
|
||||
'Writing': 0.017,
|
||||
'Computer Science': 0.017,
|
||||
'Reasoning': 0.167,
|
||||
'Complex Format': 0.017,
|
||||
'Code Debug': 0.083,
|
||||
'Common-Sense': 0.017,
|
||||
'Counterfactual': 0.017,
|
||||
'Multilingual': 0.017,
|
||||
'Roleplay': 0.017,
|
||||
'Biology': 0.017,
|
||||
'Technology': 0.017,
|
||||
'Ethics': 0.017,
|
||||
'Sport': 0.017,
|
||||
'Law': 0.017,
|
||||
'Medicine': 0.017,
|
||||
'Literature': 0.017,
|
||||
'Entertainment': 0.017,
|
||||
'Art': 0.017,
|
||||
'Music': 0.017,
|
||||
'Toxicity': 0.017,
|
||||
'Economy': 0.017,
|
||||
'Physics': 0.017,
|
||||
'History': 0.017,
|
||||
'Chemistry': 0.017,
|
||||
'Philosophy': 0.017,
|
||||
'Health': 0.017,
|
||||
'Ecology': 0.017,
|
||||
'Grammar': 0.017,
|
||||
'Paraphrase': 0.017,
|
||||
'Others': 0.041
|
||||
}
|
||||
|
||||
predefined_prompt = """
|
||||
You are a data annotation expert. Please classify the task type or domain of #Given Instruction.
|
||||
The task type or domain should be in the list: [’Math’, ’Code Generation’, ’Writing’, ’Computer Science’, ’Reasoning’, ’Complex Format’, ’Code Debug’, ’Common-Sense’, ’Counterfactual’, ’Multilingual’, ’Roleplay’,’Biology’, ’Technology’, ’Ethics’, ’Sport’, ’Law’, ’Medicine’, ’Literature’, ’Entertainment’, ’Art’, ’Music’, ’Toxicity’, ’Economy’, ’Physics’, ’History’, ’Chemistry’, ’Philosophy’,’Health’,’Ecology’,’Grammar’,’Paraphrase’, ’Others’]. You should place your answer enclosed within <answer></answer> tags, such as <answer>Math</answer>. Do not return anything else.
|
||||
#Given Instruction#:
|
||||
"""
|
||||
|
||||
|
||||
def extract_answer(content):
|
||||
pattern = r'<answer>(.*?)</answer>'
|
||||
match = re.search(pattern, content, re.DOTALL)
|
||||
if match:
|
||||
return match.group(1)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def classify_instruction(instruction, client, model):
|
||||
message = [
|
||||
{"role": "user", "content": predefined_prompt + "\n" + instruction}
|
||||
]
|
||||
completion = client.chat.completions.create(
|
||||
messages = message,
|
||||
model = model,
|
||||
max_completion_tokens = 1024
|
||||
)
|
||||
result = completion.choices[0].message.content.strip()
|
||||
print(result)
|
||||
result = extract_answer(result)
|
||||
if result is None or result not in predefined_distribution.keys():
|
||||
result = 'Others'
|
||||
print(result)
|
||||
return result
|
||||
|
||||
|
||||
def main(args):
|
||||
# Load dataset
|
||||
with open(args.input_file, 'r') as file:
|
||||
data = json.load(file)
|
||||
|
||||
# Initialize OpenAI client
|
||||
client = OpenAI(
|
||||
api_key=args.api_key,
|
||||
base_url=args.base_url
|
||||
)
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
logging.info(model)
|
||||
|
||||
# Classify each instruction
|
||||
classified_data = []
|
||||
count = 0
|
||||
for item in data:
|
||||
category = classify_instruction(item['instruction'], client, model)
|
||||
classified_data.append({'instruction': item['instruction'], 'category': category})
|
||||
count += 1
|
||||
print(count)
|
||||
|
||||
# Count occurrences per category
|
||||
category_counts = Counter(item['category'] for item in classified_data)
|
||||
total_samples = len(classified_data)
|
||||
|
||||
# Resample according to predefined distribution
|
||||
resampled_data = []
|
||||
for category, target_ratio in predefined_distribution.items():
|
||||
target_count = int(total_samples * target_ratio)
|
||||
category_samples = [item for item in classified_data if item['category'] == category]
|
||||
if len(category_samples) == 0:
|
||||
logging.warning("No instructions are provided for the category: " + category)
|
||||
continue
|
||||
if len(category_samples) > target_count:
|
||||
print(category)
|
||||
print(len(category_samples))
|
||||
print(target_count)
|
||||
# Randomly sample the required number of instructions
|
||||
resampled_category_samples = random.sample(category_samples, target_count)
|
||||
else:
|
||||
# If not enough samples, repeat the existing ones
|
||||
resampled_category_samples = category_samples * (target_count // len(category_samples)) + random.sample(category_samples, target_count % len(category_samples))
|
||||
resampled_data.extend(resampled_category_samples)
|
||||
|
||||
# Save final dataset
|
||||
with open(args.output_file, 'w') as file:
|
||||
json.dump(resampled_data, file, indent=4)
|
||||
|
||||
print("Resampling complete. Final output saved to '{}'.".format(args.output_file))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Task and Domain Classification')
|
||||
parser.add_argument('--input-file', type=str, required=True, help='Input JSON file containing instructions.')
|
||||
parser.add_argument('--output-file', type=str, required=True, help='Output JSON file to store resampled instructions.')
|
||||
parser.add_argument('--api-key', type=str, required=True, help='API key.')
|
||||
parser.add_argument('--base-url', type=str, required=True, help='Base URL.')
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
67
recipes/domain_specific/code_generation/README.md
Normal file
67
recipes/domain_specific/code_generation/README.md
Normal file
@@ -0,0 +1,67 @@
|
||||
# Code Generation: Combining Black-Box and White Box KD
|
||||
|
||||
## Brief Introduction
|
||||
|
||||
The code generation distillation process involves carefully selecting and optimizing code generation instruction-response pairs conducive to student model learning. Following standard fine-tuning, we employ white-box distillation techniques to enable the student models to better acquire fine-grained knowledge from teacher models.
|
||||
|
||||
## Detailed Steps
|
||||
|
||||
### Processing of Code Generation Dataset
|
||||
|
||||
The process begins with collecting diverse, high-quality code generation data from various programming languages and tasks. This data includes code completion, code generation from natural language descriptions, bug fixing, and code explanation tasks. We score examples for difficulty and relevance to code generation tasks.
|
||||
|
||||
The dataset is in JSON format, exemplified by entries such as:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"instruction": "Write a Python function that calculates the Fibonacci sequence up to n elements using dynamic programming.",
|
||||
"output": "def fibonacci(n):\n if n <= 0:\n return []\n elif n == 1:\n return [0]\n elif n == 2:\n return [0, 1]\n \n fib = [0, 1]\n for i in range(2, n):\n fib.append(fib[i-1] + fib[i-2])\n \n return fib\n\n# Example usage\nprint(fibonacci(10)) # Output: [0, 1, 1, 2, 3, 5, 8, 13, 21, 34]"
|
||||
},
|
||||
{
|
||||
"instruction": "Create a JavaScript function that takes an array of numbers and returns the sum of all even numbers in the array.",
|
||||
"output": "function sumEvenNumbers(numbers) {\n return numbers\n .filter(num => num % 2 === 0)\n .reduce((sum, num) => sum + num, 0);\n}\n\n// Example usage\nconst numbers = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];\nconsole.log(sumEvenNumbers(numbers)); // Output: 30"
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
### Black-Box KD
|
||||
|
||||
The black-box KD process follows a supervised learning paradigm, utilizing enhanced code instruction-response pairs as training samples. Through this approach, the student model can effectively absorb and understand the code generation capabilities of the larger model, even with a limited number of parameters. This method not only boosts the student model's ability to tackle programming tasks but also enables it to perform better across multiple programming languages and paradigms.
|
||||
|
||||
To run the black-box KD training:
|
||||
|
||||
```bash
|
||||
python easydistill/kd/train.py --config=code_generation_stage1.json
|
||||
```
|
||||
|
||||
Please refer to the config file `code_generation_stage1.json` in the current folder. If you need to run the job in a distributed mode, use `accelerate` to run the job.
|
||||
|
||||
### White-Box KD
|
||||
|
||||
Unlike black-box KD, which relies solely on the highest probability token output by the teacher model, white-box KD focuses on the distribution of logits produced by the teacher model. This approach provides the student model with richer information about code structure and syntax. By mimicking the teacher model's logits distribution, white-box KD can transfer programming knowledge more effectively, thereby enhancing the performance of the student model.
|
||||
|
||||
To generate the logits with the teacher model:
|
||||
|
||||
```bash
|
||||
python easydistill/kd/infer.py --config=code_generation_stage2.json
|
||||
```
|
||||
|
||||
Next, run the training job:
|
||||
|
||||
```bash
|
||||
python easydistill/kd/train.py --config=code_generation_stage2.json
|
||||
```
|
||||
|
||||
Please refer to the config file `code_generation_stage2.json` in the current folder. Remember to change the configurations when needed.
|
||||
|
||||
## Performance
|
||||
|
||||
We trained the model using data from nvidia/OpenCodeReasoning, and the final model performance is as follows:
|
||||
|
||||
| Model | LiveCodeBench V2 | speed |
|
||||
|---------------------------|------------------|--------|
|
||||
| Qwen2.5-3B-Instruct | 11.35 | 2.3x |
|
||||
| Qwen2.5-3B-Code-Optimize | 16.62 | 2.3x |
|
||||
| Qwen2.5-7B-Instruct | 30.72 | 1x |
|
||||
| Qwen2.5-7B-Code-Optimize | 35.32 | 1x |
|
@@ -0,0 +1,23 @@
|
||||
{
|
||||
"job_type": "kd_black_box_api",
|
||||
"dataset": {
|
||||
"labeled_path": "code_generation_dataset.json",
|
||||
"template" : "chat_template_kd.jinja",
|
||||
"seed": 42
|
||||
},
|
||||
"models": {
|
||||
"student": "model/Qwen/Qwen2.5-0.5B-Instruct/"
|
||||
},
|
||||
"training": {
|
||||
"output_dir": "result_stage1/",
|
||||
"num_train_epochs": 3,
|
||||
"per_device_train_batch_size": 1,
|
||||
"gradient_accumulation_steps": 8,
|
||||
"save_steps": 1000,
|
||||
"logging_steps": 1,
|
||||
"learning_rate": 2e-5,
|
||||
"weight_decay": 0.05,
|
||||
"warmup_ratio": 0.1,
|
||||
"lr_scheduler_type": "cosine"
|
||||
}
|
||||
}
|
@@ -0,0 +1,40 @@
|
||||
{
|
||||
"job_type": "kd_white_box",
|
||||
"dataset": {
|
||||
"labeled_path": "code_generation_dataset.json",
|
||||
"logits_path": "logits.json",
|
||||
"template" : "chat_template_kd.jinja",
|
||||
"seed": 42
|
||||
},
|
||||
"inference":{
|
||||
"enable_chunked_prefill": true,
|
||||
"seed": 777,
|
||||
"gpu_memory_utilization": 0.9,
|
||||
"temperature": 0.8,
|
||||
"trust_remote_code": true,
|
||||
"enforce_eager": false,
|
||||
"max_model_len": 4096,
|
||||
"max_new_tokens": 1024
|
||||
},
|
||||
"distillation": {
|
||||
"kd_ratio": 0.1,
|
||||
"max_seq_length": 1024,
|
||||
"distillation_type": "forward_kld"
|
||||
},
|
||||
"models": {
|
||||
"teacher": "teacher/Qwen/Qwen2.5-7B-Instruct/",
|
||||
"student": "result_stage1/"
|
||||
},
|
||||
"training": {
|
||||
"output_dir": "result_stage2/",
|
||||
"num_train_epochs": 3,
|
||||
"per_device_train_batch_size": 1,
|
||||
"gradient_accumulation_steps": 8,
|
||||
"save_steps": 1000,
|
||||
"logging_steps": 1,
|
||||
"learning_rate": 2e-5,
|
||||
"weight_decay": 0.05,
|
||||
"warmup_ratio": 0.1,
|
||||
"lr_scheduler_type": "cosine"
|
||||
}
|
||||
}
|
50
recipes/open_datasets/distilqwen_datasets.md
Normal file
50
recipes/open_datasets/distilqwen_datasets.md
Normal file
@@ -0,0 +1,50 @@
|
||||
# DistilQwen-100k/DistilQwen-1M: High-Quality Instruction-Tuning Datasets
|
||||
|
||||
## Overview
|
||||
To empower community developers in enhancing the **instruction-following capabilities** of large language models (LLMs), we open-source **`DistilQwen-100k`** and **`DistilQwen-1M`**, subsets of the training data used for the **DistilQwen model series**. The datasets provide diverse, high-quality samples to improve model performance in key areas.
|
||||
|
||||
## Dataset Features
|
||||
- **Scale**: **100 thousand**/**1 million** meticulously distilled entries.
|
||||
- **Coverage**: Balanced mix of:
|
||||
- **Mathematics**
|
||||
- **Code generation & understanding**
|
||||
- **Knowledge-based QA**
|
||||
- **Instruction following**
|
||||
- **Creative generation**
|
||||
- **Purpose**: Optimized for **instruction tuning**, helping models retain generalization while adapting to downstream tasks.
|
||||
|
||||
## Use Cases
|
||||
- **Fine-tuning LLMs**: Mitigate *catastrophic forgetting* by combining with custom datasets.
|
||||
- **Multi-task learning**: Improve coherence in mathematical reasoning, coding, and creative tasks.
|
||||
- **Research**: Study distillation techniques or instruction-tuning efficacy.
|
||||
|
||||
## Use the Datasets
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
# Login using e.g. `huggingface-cli login` to access this dataset
|
||||
ds = load_dataset("alibaba-pai/DistilQwen_100k")
|
||||
ds = load_dataset("alibaba-pai/DistilQwen_1M")
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
For more detailed information about the dataset construction process, we encourage you to refer to our paper:
|
||||
|
||||
- **DistilQwen2.5: Industrial Practices of Training Distilled Open Lightweight Language Models**
|
||||
Chengyu Wang, Junbing Yan, Yuanhao Yue, Jun Huang
|
||||
[arXiv:2504.15027](https://arxiv.org/abs/2504.15027)
|
||||
|
||||
You can cite the paper using the following citation format:
|
||||
|
||||
```bibtex
|
||||
@misc{wang2025distilqwen25industrialpracticestraining,
|
||||
title={DistilQwen2.5: Industrial Practices of Training Distilled Open Lightweight Language Models},
|
||||
author={Chengyu Wang and Junbing Yan and Yuanhao Yue and Jun Huang},
|
||||
year={2025},
|
||||
eprint={2504.15027},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CL},
|
||||
url={https://arxiv.org/abs/2504.15027}
|
||||
}
|
||||
```
|
58
recipes/open_datasets/omni_thought.md
Normal file
58
recipes/open_datasets/omni_thought.md
Normal file
@@ -0,0 +1,58 @@
|
||||
# OmniThought: A Large-Scale Chain-of-Thought Dataset for Advancing Large Reasoning Models
|
||||
|
||||
## Overview
|
||||
The rise of **Large Reasoning Models (LRMs)** has revolutionized **Natural Language Processing (NLP)**, enabling breakthroughs in complex tasks like **mathematical problem-solving** and **code generation**. These models rely on **Chain-of-Thought (CoT)** processes to mimic human-like reasoning. However, progress in LRMs is limited by the scarcity of **high-quality, large-scale CoT datasets**—existing resources often lack:
|
||||
- **Diverse reasoning problems** with well-structured CoT processes.
|
||||
- **Multi-teacher distillation** to ensure reasoning quality.
|
||||
- **Fine-grained annotations** describing CoT properties.
|
||||
|
||||
To bridge this gap, we introduce **`OmniThought`**, a **2-million-scale CoT dataset** generated and validated by **two powerful LRMs**. Each CoT process is annotated with:
|
||||
- **Reasoning Verbosity (RV)**: Measures the optimal verbosity of reasoning steps.
|
||||
- **Cognitive Difficulty (CD)**: Assesses the complexity of reasoning for model comprehension.
|
||||
|
||||
We also propose a **self-reliant pipeline** for dataset curation, ensuring high-quality reasoning traces.
|
||||
|
||||
## Key Features
|
||||
✅ **2 million high-quality CoT processes** covering diverse reasoning tasks.
|
||||
✅ **RV-CD scores** to guide model training for better reasoning performance.
|
||||
✅ **Multi-teacher distillation** for robust and coherent reasoning paths.
|
||||
✅ **Optimized for LRM training**—improves reasoning ability and output quality.
|
||||
|
||||
## Experiments & Results
|
||||
Extensive experiments with **Qwen2.5 models** (various sizes) confirm that:
|
||||
- Training with **RV-CD scores** enhances **LRM reasoning effectiveness**.
|
||||
- Models trained on `OmniThought` achieve **stronger reasoning abilities** with **optimal CoT length and difficulty**.
|
||||
|
||||
Based on this dataset, we release **a series of high-performance LRMs** with superior reasoning capabilities.
|
||||
|
||||
## Use the Datasets
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
# Login using e.g. `huggingface-cli login` to access this dataset
|
||||
ds = load_dataset("alibaba-pai/OmniThought")
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Reference
|
||||
|
||||
For more detailed information about the dataset construction process, we encourage you to refer to our paper:
|
||||
|
||||
- **Reasoning with OmniThought: A Large CoT Dataset with Verbosity and Cognitive Difficulty Annotations**
|
||||
Wenrui Cai, Chengyu Wang, Junbing Yan, Jun Huang, Xiangzhong Fang
|
||||
[arXiv:2505.10937](https://arxiv.org/abs/2505.10937)
|
||||
|
||||
You can cite the paper using the following citation format:
|
||||
|
||||
```bibtex
|
||||
@misc{cai2025reasoningomnithoughtlargecot,
|
||||
title={Reasoning with OmniThought: A Large CoT Dataset with Verbosity and Cognitive Difficulty Annotations},
|
||||
author={Wenrui Cai and Chengyu Wang and Junbing Yan and Jun Huang and Xiangzhong Fang},
|
||||
year={2025},
|
||||
eprint={2505.10937},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CL},
|
||||
url={https://arxiv.org/abs/2505.10937}
|
||||
}
|
||||
```
|
7
requirements.txt
Normal file
7
requirements.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
transformers==4.51.0
|
||||
transformers-stream-generator==0.0.5
|
||||
trl==0.17.0
|
||||
tokenizers==0.21.1
|
||||
vllm==0.8.5
|
||||
openai
|
||||
jinja2
|
BIN
resources/framework.png
Normal file
BIN
resources/framework.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 156 KiB |
25
setup.py
Normal file
25
setup.py
Normal file
@@ -0,0 +1,25 @@
|
||||
#!/usr/bin/env python
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
with open('README.md') as f:
|
||||
readme = f.read()
|
||||
|
||||
with open('requirements.txt') as f:
|
||||
requirements = f.read()
|
||||
|
||||
setup(
|
||||
# Metadata
|
||||
name='easydistill',
|
||||
version='0.0.1',
|
||||
python_requires='>=3.6',
|
||||
author='PAI',
|
||||
description='PAI EasyDistill Toolkit',
|
||||
long_description=readme,
|
||||
entry_points={'console_scripts': ['easydistill=easydistill.cli:main']},
|
||||
long_description_content_type='text/markdown',
|
||||
packages=find_packages(),
|
||||
license='Apache-2.0',
|
||||
|
||||
#Package info
|
||||
install_requires=requirements)
|
Reference in New Issue
Block a user