Files
embedding-clustering/extract/extract_donut.py

202 lines
7.1 KiB
Python

import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel
from PIL import Image
import os
import numpy as np
import json
from tqdm import tqdm
# --- Configuration ---
MODEL_NAME = "naver-clova-ix/donut-base-finetuned-docvqa" # Donut model for document VQA
IMAGE_DIR = "/home/nguyendc/model-factory/Finetuning-Automation/etc/data/media/docai_mgp_facture_v2_0/"
BATCH_SIZE = 4 # Smaller batch size for Donut as it's memory intensive
# --- End Configuration ---
# Check for GPU availability
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load the Donut model and processor
print("Loading Donut model and processor...")
processor = DonutProcessor.from_pretrained(MODEL_NAME)
model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME)
model.to(device)
model.eval()
# Set model to half precision for efficiency if using GPU
if device == "cuda":
model = model.half()
def get_document_embeddings(image_paths):
"""
Processes a batch of document images and extracts their embeddings using Donut.
Uses the encoder part of the VisionEncoderDecoder model to get visual representations.
"""
images_pil = []
valid_paths = []
for path in image_paths:
if path.lower().endswith(('.png', '.jpg', '.jpeg')):
try:
# Load and convert image to RGB
image = Image.open(path).convert("RGB")
images_pil.append(image)
valid_paths.append(path)
except Exception as e:
print(f"Warning: Could not load image {path}. Skipping. Error: {e}")
if not images_pil:
return np.array([]), []
embeddings_list = []
# Process images one by one to avoid memory issues
for image in images_pil:
try:
# Preprocess the image
pixel_values = processor(image, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device)
if device == "cuda":
pixel_values = pixel_values.half()
with torch.no_grad():
# Get encoder outputs (visual features)
encoder_outputs = model.encoder(pixel_values=pixel_values)
# Use the last hidden state and apply global average pooling
# to get a fixed-size representation
last_hidden_state = encoder_outputs.last_hidden_state # [batch_size, seq_len, hidden_size]
# Global average pooling across the sequence dimension
embedding = torch.mean(last_hidden_state, dim=1) # [batch_size, hidden_size]
embeddings_list.append(embedding.squeeze().cpu().float().numpy())
except Exception as e:
print(f"Warning: Could not process image. Error: {e}")
# Add zero embedding for failed images to maintain consistency
embeddings_list.append(np.zeros(model.config.encoder.hidden_size))
return np.array(embeddings_list), valid_paths
def extract_document_info(image_path, question="What information is in this document?"):
"""
Extract specific information from a document using Donut's text generation capability.
This function demonstrates how to use Donut for document understanding tasks.
"""
try:
image = Image.open(image_path).convert("RGB")
# Prepare the task prompt for document VQA
task_prompt = f"<s_docvqa><s_question>{question}</s_question><s_answer>"
# Process the image and prompt
inputs = processor(image, task_prompt, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
if device == "cuda":
inputs["pixel_values"] = inputs["pixel_values"].half()
with torch.no_grad():
# Generate answer
generated_ids = model.generate(
**inputs,
max_length=512,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
# Decode the generated answer
decoded_text = processor.batch_decode(generated_ids.sequences)[0]
# Extract the answer part
answer = decoded_text.split("<s_answer>")[-1].replace("</s_answer>", "").strip()
return answer
except Exception as e:
print(f"Error extracting info from {image_path}: {e}")
return ""
# --- Process all images in the directory ---
print("Scanning for image files...")
image_files = [os.path.join(IMAGE_DIR, f) for f in os.listdir(IMAGE_DIR)
if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
print(f"Found {len(image_files)} image files")
all_embeddings = []
filepaths = []
# Extract embeddings and save to JSON
print("Extracting embeddings using Donut...")
with open("embeddings_factures_donut.json", "w") as f:
f.write("[\n")
first = True
for i in tqdm(range(0, len(image_files), BATCH_SIZE), desc="Processing batches"):
batch_paths = image_files[i:i+BATCH_SIZE]
batch_embeddings, valid_paths = get_document_embeddings(batch_paths)
if len(batch_embeddings) > 0:
embeddings_list = [emb.tolist() for emb in batch_embeddings]
for path, emb in zip(valid_paths, embeddings_list):
if not first:
f.write(",\n")
entry = {
"filepath": path,
"embedding": emb,
"model": "donut-base-finetuned-docvqa",
"embedding_size": len(emb)
}
json.dump(entry, f)
first = False
f.write("\n]\n")
print("Embeddings extracted and saved to 'embeddings_factures_donut.json'")
# Optional: Extract some sample document information
print("\nExtracting sample document information...")
sample_images = image_files[:3] # Process first 3 images as samples
sample_info = []
for img_path in sample_images:
print(f"Processing: {os.path.basename(img_path)}")
# Extract different types of information
questions = [
"What is the total amount?",
"What is the invoice number?",
"What is the date?",
"Who is the vendor?",
"What are the main items?"
]
info = {"filepath": img_path, "extracted_info": {}}
for question in questions:
answer = extract_document_info(img_path, question)
info["extracted_info"][question] = answer
print(f" {question}: {answer}")
sample_info.append(info)
# Save sample extraction results
with open("donut_sample_extractions.json", "w") as f:
json.dump(sample_info, f, indent=2, ensure_ascii=False)
print("Sample document information extracted and saved to 'donut_sample_extractions.json'")
print("Processing completed!")