feat: add function to convert training data into conversation format
This commit is contained in:
159
easydistill/mmkd/convert_conversation_json.py
Normal file
159
easydistill/mmkd/convert_conversation_json.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import json
|
||||
import uuid
|
||||
from typing import List, Dict
|
||||
import argparse
|
||||
import shutil
|
||||
|
||||
# ---------------------------
|
||||
# Helper functions
|
||||
# ---------------------------
|
||||
|
||||
def load_json_data(filepath: str):
|
||||
"""Load JSON file from disk."""
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
|
||||
def get_img_path(data: List[dict]):
|
||||
"""Extract image paths from a single record."""
|
||||
return [item["image"] for item in data[1].get("content", []) if item.get("type") == "image"]
|
||||
|
||||
# def get_conversations(data: List[dict]):
|
||||
# """Extract conversations in desired format."""
|
||||
# conversation_data = []
|
||||
# for item in data:
|
||||
# if item.get("role") == "system":
|
||||
# conversation_data.append({"from": "system", "value": item.get("content")})
|
||||
# elif item.get("role") == "user":
|
||||
# texts = [x["text"] for x in item.get("content", []) if x.get("type") == "text"][0]
|
||||
# conversation_data.append({"from": "human", "value": texts})
|
||||
# elif item.get("role") == "assistant" or item.get("role") == "assistant_gt":
|
||||
# #texts = [x["text"] for x in item.get("content", []) if x.get("type") == "text"][0]
|
||||
# #conversation_data.append({"from": "gpt", "value": texts})
|
||||
# conversation_data.append({"from": "gpt", "value": item.get("content")})
|
||||
# return conversation_data
|
||||
|
||||
def get_conversations_v2(data: List[dict]) -> List[Dict[str, str]]:
|
||||
"""Extract conversations in desired format, handling multiple text items."""
|
||||
conversation_data = []
|
||||
|
||||
for item in data:
|
||||
role = item.get("role")
|
||||
|
||||
if role == "system":
|
||||
conversation_data.append({"from": "system", "value": item.get("content")})
|
||||
|
||||
elif role == "user":
|
||||
texts = [x["text"] for x in item.get("content", []) if x.get("type") == "text"]
|
||||
|
||||
if texts:
|
||||
conversation_data.append({"from": "human", "value": texts[0]})
|
||||
|
||||
elif role in ["assistant", "assistant_gt"]:
|
||||
content = item.get("content")
|
||||
|
||||
if isinstance(content, list): # list of dicts
|
||||
texts = [x["text"] for x in content if x.get("type") == "text"]
|
||||
if texts:
|
||||
conversation_data.append({"from": "gpt", "value": texts[0]})
|
||||
|
||||
elif isinstance(content, str): # single string
|
||||
conversation_data.append({"from": "gpt", "value": content})
|
||||
|
||||
else: # raw content
|
||||
conversation_data.append({"from": "gpt", "value": str(content)})
|
||||
return conversation_data
|
||||
|
||||
def convert(images: List[str] = [], conversation: List[Dict[str,str]] = []):
|
||||
"""Convert raw data into docai_mgp_facture_data instance."""
|
||||
new_data = docai_mgp_facture_data()
|
||||
new_data.id = str(uuid.uuid4())
|
||||
new_data.images["images"] = images
|
||||
new_data.conversations["conversations"] = conversation
|
||||
return new_data
|
||||
|
||||
# ---------------------------
|
||||
# Data class
|
||||
# ---------------------------
|
||||
|
||||
class docai_mgp_facture_data:
|
||||
id: str
|
||||
images: Dict[str, List[str]]
|
||||
conversations: Dict[str, List[Dict[str,str]]]
|
||||
|
||||
def __init__(self):
|
||||
self.id = ""
|
||||
self.images = {"images": []}
|
||||
self.conversations = {"conversations": [{"from": "", "value": ""}]}
|
||||
|
||||
def display_data(self):
|
||||
print("Current data in instance:")
|
||||
print(f"ID: {self.id}")
|
||||
print("Images:")
|
||||
for img in self.images.get("images", []):
|
||||
print(f" - {img}")
|
||||
print("Conversations:")
|
||||
for conv in self.conversations.get("conversations", []):
|
||||
print(f" - from: {conv.get('from')}, value: {conv.get('value')}")
|
||||
|
||||
def write_to_json(self, filename: str):
|
||||
"""Write the current instance data to a JSON file (overwrite)."""
|
||||
data_dict = {
|
||||
"id": self.id,
|
||||
"images": self.images["images"],
|
||||
"conversations": self.conversations["conversations"]
|
||||
}
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
json.dump(data_dict, f, ensure_ascii=False, indent=4)
|
||||
print(f"Data written to {filename}")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
'''
|
||||
Input: one or more JSON files path
|
||||
Output: one JSON file under conversation format
|
||||
|
||||
Ex: python3 ../convert_conversation_json.py \
|
||||
--source_path data1.json data2.json ... \
|
||||
--destination_path dest_path.json
|
||||
'''
|
||||
parser = argparse.ArgumentParser(description="Convert one or more JSON files to conversation-form JSON.")
|
||||
parser.add_argument(
|
||||
"--source_path",
|
||||
type=str,
|
||||
nargs='+', # allow multiple files
|
||||
required=True,
|
||||
help="Path(s) to the source JSON file."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--destination_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the destination JSON file."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
all_data = []
|
||||
|
||||
for source_path in args.source_path: # match the argument name
|
||||
source_data = load_json_data(source_path)
|
||||
for record_data in source_data:
|
||||
images = get_img_path(record_data)
|
||||
conversations = get_conversations_v2(record_data)
|
||||
record = convert(images=images, conversation=conversations)
|
||||
all_data.append({
|
||||
"id": record.id,
|
||||
"images": record.images["images"],
|
||||
"conversations": record.conversations["conversations"]
|
||||
})
|
||||
|
||||
with open(args.destination_path, "w", encoding="utf-8") as f:
|
||||
json.dump(all_data, f, ensure_ascii=False, indent=4)
|
||||
|
||||
print(f"✅ All data from {len(args.source_path)} file(s) saved to {args.destination_path}")
|
||||
|
||||
# ---------------------------
|
||||
# Main script
|
||||
# ---------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user