import torch from transformers import DonutProcessor, VisionEncoderDecoderModel from PIL import Image import os import numpy as np import json from tqdm import tqdm # --- Configuration --- MODEL_NAME = "naver-clova-ix/donut-base-finetuned-docvqa" # Donut model for document VQA IMAGE_DIR = "/home/nguyendc/model-factory/Finetuning-Automation/etc/data/media/docai_mgp_facture_v2_0/" BATCH_SIZE = 4 # Smaller batch size for Donut as it's memory intensive # --- End Configuration --- # Check for GPU availability device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Load the Donut model and processor print("Loading Donut model and processor...") processor = DonutProcessor.from_pretrained(MODEL_NAME) model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME) model.to(device) model.eval() # Set model to half precision for efficiency if using GPU if device == "cuda": model = model.half() def get_document_embeddings(image_paths): """ Processes a batch of document images and extracts their embeddings using Donut. Uses the encoder part of the VisionEncoderDecoder model to get visual representations. """ images_pil = [] valid_paths = [] for path in image_paths: if path.lower().endswith(('.png', '.jpg', '.jpeg')): try: # Load and convert image to RGB image = Image.open(path).convert("RGB") images_pil.append(image) valid_paths.append(path) except Exception as e: print(f"Warning: Could not load image {path}. Skipping. Error: {e}") if not images_pil: return np.array([]), [] embeddings_list = [] # Process images one by one to avoid memory issues for image in images_pil: try: # Preprocess the image pixel_values = processor(image, return_tensors="pt").pixel_values pixel_values = pixel_values.to(device) if device == "cuda": pixel_values = pixel_values.half() with torch.no_grad(): # Get encoder outputs (visual features) encoder_outputs = model.encoder(pixel_values=pixel_values) # Use the last hidden state and apply global average pooling # to get a fixed-size representation last_hidden_state = encoder_outputs.last_hidden_state # [batch_size, seq_len, hidden_size] # Global average pooling across the sequence dimension embedding = torch.mean(last_hidden_state, dim=1) # [batch_size, hidden_size] embeddings_list.append(embedding.squeeze().cpu().float().numpy()) except Exception as e: print(f"Warning: Could not process image. Error: {e}") # Add zero embedding for failed images to maintain consistency embeddings_list.append(np.zeros(model.config.encoder.hidden_size)) return np.array(embeddings_list), valid_paths def extract_document_info(image_path, question="What information is in this document?"): """ Extract specific information from a document using Donut's text generation capability. This function demonstrates how to use Donut for document understanding tasks. """ try: image = Image.open(image_path).convert("RGB") # Prepare the task prompt for document VQA task_prompt = f"{question}" # Process the image and prompt inputs = processor(image, task_prompt, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} if device == "cuda": inputs["pixel_values"] = inputs["pixel_values"].half() with torch.no_grad(): # Generate answer generated_ids = model.generate( **inputs, max_length=512, early_stopping=True, pad_token_id=processor.tokenizer.pad_token_id, eos_token_id=processor.tokenizer.eos_token_id, use_cache=True, num_beams=1, bad_words_ids=[[processor.tokenizer.unk_token_id]], return_dict_in_generate=True, ) # Decode the generated answer decoded_text = processor.batch_decode(generated_ids.sequences)[0] # Extract the answer part answer = decoded_text.split("")[-1].replace("", "").strip() return answer except Exception as e: print(f"Error extracting info from {image_path}: {e}") return "" # --- Process all images in the directory --- print("Scanning for image files...") image_files = [os.path.join(IMAGE_DIR, f) for f in os.listdir(IMAGE_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg'))] print(f"Found {len(image_files)} image files") all_embeddings = [] filepaths = [] # Extract embeddings and save to JSON print("Extracting embeddings using Donut...") with open("embeddings_factures_donut.json", "w") as f: f.write("[\n") first = True for i in tqdm(range(0, len(image_files), BATCH_SIZE), desc="Processing batches"): batch_paths = image_files[i:i+BATCH_SIZE] batch_embeddings, valid_paths = get_document_embeddings(batch_paths) if len(batch_embeddings) > 0: embeddings_list = [emb.tolist() for emb in batch_embeddings] for path, emb in zip(valid_paths, embeddings_list): if not first: f.write(",\n") entry = { "filepath": path, "embedding": emb, "model": "donut-base-finetuned-docvqa", "embedding_size": len(emb) } json.dump(entry, f) first = False f.write("\n]\n") print("Embeddings extracted and saved to 'embeddings_factures_donut.json'") # Optional: Extract some sample document information print("\nExtracting sample document information...") sample_images = image_files[:3] # Process first 3 images as samples sample_info = [] for img_path in sample_images: print(f"Processing: {os.path.basename(img_path)}") # Extract different types of information questions = [ "What is the total amount?", "What is the invoice number?", "What is the date?", "Who is the vendor?", "What are the main items?" ] info = {"filepath": img_path, "extracted_info": {}} for question in questions: answer = extract_document_info(img_path, question) info["extracted_info"][question] = answer print(f" {question}: {answer}") sample_info.append(info) # Save sample extraction results with open("donut_sample_extractions.json", "w") as f: json.dump(sample_info, f, indent=2, ensure_ascii=False) print("Sample document information extracted and saved to 'donut_sample_extractions.json'") print("Processing completed!")