diff --git a/easydistill/mmkd/convert_conversation_json.py b/easydistill/mmkd/convert_conversation_json.py new file mode 100644 index 0000000..da5704b --- /dev/null +++ b/easydistill/mmkd/convert_conversation_json.py @@ -0,0 +1,159 @@ +import json +import uuid +from typing import List, Dict +import argparse +import shutil + +# --------------------------- +# Helper functions +# --------------------------- + +def load_json_data(filepath: str): + """Load JSON file from disk.""" + with open(filepath, 'r', encoding='utf-8') as f: + return json.load(f) + +def get_img_path(data: List[dict]): + """Extract image paths from a single record.""" + return [item["image"] for item in data[1].get("content", []) if item.get("type") == "image"] + +# def get_conversations(data: List[dict]): +# """Extract conversations in desired format.""" +# conversation_data = [] +# for item in data: +# if item.get("role") == "system": +# conversation_data.append({"from": "system", "value": item.get("content")}) +# elif item.get("role") == "user": +# texts = [x["text"] for x in item.get("content", []) if x.get("type") == "text"][0] +# conversation_data.append({"from": "human", "value": texts}) +# elif item.get("role") == "assistant" or item.get("role") == "assistant_gt": +# #texts = [x["text"] for x in item.get("content", []) if x.get("type") == "text"][0] +# #conversation_data.append({"from": "gpt", "value": texts}) +# conversation_data.append({"from": "gpt", "value": item.get("content")}) +# return conversation_data + +def get_conversations_v2(data: List[dict]) -> List[Dict[str, str]]: + """Extract conversations in desired format, handling multiple text items.""" + conversation_data = [] + + for item in data: + role = item.get("role") + + if role == "system": + conversation_data.append({"from": "system", "value": item.get("content")}) + + elif role == "user": + texts = [x["text"] for x in item.get("content", []) if x.get("type") == "text"] + + if texts: + conversation_data.append({"from": "human", "value": texts[0]}) + + elif role in ["assistant", "assistant_gt"]: + content = item.get("content") + + if isinstance(content, list): # list of dicts + texts = [x["text"] for x in content if x.get("type") == "text"] + if texts: + conversation_data.append({"from": "gpt", "value": texts[0]}) + + elif isinstance(content, str): # single string + conversation_data.append({"from": "gpt", "value": content}) + + else: # raw content + conversation_data.append({"from": "gpt", "value": str(content)}) + return conversation_data + +def convert(images: List[str] = [], conversation: List[Dict[str,str]] = []): + """Convert raw data into docai_mgp_facture_data instance.""" + new_data = docai_mgp_facture_data() + new_data.id = str(uuid.uuid4()) + new_data.images["images"] = images + new_data.conversations["conversations"] = conversation + return new_data + +# --------------------------- +# Data class +# --------------------------- + +class docai_mgp_facture_data: + id: str + images: Dict[str, List[str]] + conversations: Dict[str, List[Dict[str,str]]] + + def __init__(self): + self.id = "" + self.images = {"images": []} + self.conversations = {"conversations": [{"from": "", "value": ""}]} + + def display_data(self): + print("Current data in instance:") + print(f"ID: {self.id}") + print("Images:") + for img in self.images.get("images", []): + print(f" - {img}") + print("Conversations:") + for conv in self.conversations.get("conversations", []): + print(f" - from: {conv.get('from')}, value: {conv.get('value')}") + + def write_to_json(self, filename: str): + """Write the current instance data to a JSON file (overwrite).""" + data_dict = { + "id": self.id, + "images": self.images["images"], + "conversations": self.conversations["conversations"] + } + with open(filename, "w", encoding="utf-8") as f: + json.dump(data_dict, f, ensure_ascii=False, indent=4) + print(f"Data written to {filename}") + + +def main() -> None: + ''' + Input: one or more JSON files path + Output: one JSON file under conversation format + + Ex: python3 ../convert_conversation_json.py \ + --source_path data1.json data2.json ... \ + --destination_path dest_path.json + ''' + parser = argparse.ArgumentParser(description="Convert one or more JSON files to conversation-form JSON.") + parser.add_argument( + "--source_path", + type=str, + nargs='+', # allow multiple files + required=True, + help="Path(s) to the source JSON file." + ) + parser.add_argument( + "--destination_path", + type=str, + required=True, + help="Path to the destination JSON file." + ) + args = parser.parse_args() + + all_data = [] + + for source_path in args.source_path: # match the argument name + source_data = load_json_data(source_path) + for record_data in source_data: + images = get_img_path(record_data) + conversations = get_conversations_v2(record_data) + record = convert(images=images, conversation=conversations) + all_data.append({ + "id": record.id, + "images": record.images["images"], + "conversations": record.conversations["conversations"] + }) + + with open(args.destination_path, "w", encoding="utf-8") as f: + json.dump(all_data, f, ensure_ascii=False, indent=4) + + print(f"✅ All data from {len(args.source_path)} file(s) saved to {args.destination_path}") + +# --------------------------- +# Main script +# --------------------------- + +if __name__ == "__main__": + main() \ No newline at end of file