202 lines
7.1 KiB
Python
202 lines
7.1 KiB
Python
|
import torch
|
||
|
from transformers import DonutProcessor, VisionEncoderDecoderModel
|
||
|
from PIL import Image
|
||
|
import os
|
||
|
import numpy as np
|
||
|
import json
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
|
||
|
# --- Configuration ---
|
||
|
MODEL_NAME = "naver-clova-ix/donut-base-finetuned-docvqa" # Donut model for document VQA
|
||
|
IMAGE_DIR = "/home/nguyendc/model-factory/Finetuning-Automation/etc/data/media/docai_mgp_facture_v2_0/"
|
||
|
BATCH_SIZE = 4 # Smaller batch size for Donut as it's memory intensive
|
||
|
# --- End Configuration ---
|
||
|
|
||
|
# Check for GPU availability
|
||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
|
print(f"Using device: {device}")
|
||
|
|
||
|
# Load the Donut model and processor
|
||
|
print("Loading Donut model and processor...")
|
||
|
processor = DonutProcessor.from_pretrained(MODEL_NAME)
|
||
|
model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME)
|
||
|
model.to(device)
|
||
|
model.eval()
|
||
|
|
||
|
# Set model to half precision for efficiency if using GPU
|
||
|
if device == "cuda":
|
||
|
model = model.half()
|
||
|
|
||
|
|
||
|
def get_document_embeddings(image_paths):
|
||
|
"""
|
||
|
Processes a batch of document images and extracts their embeddings using Donut.
|
||
|
Uses the encoder part of the VisionEncoderDecoder model to get visual representations.
|
||
|
"""
|
||
|
images_pil = []
|
||
|
valid_paths = []
|
||
|
|
||
|
for path in image_paths:
|
||
|
if path.lower().endswith(('.png', '.jpg', '.jpeg')):
|
||
|
try:
|
||
|
# Load and convert image to RGB
|
||
|
image = Image.open(path).convert("RGB")
|
||
|
images_pil.append(image)
|
||
|
valid_paths.append(path)
|
||
|
except Exception as e:
|
||
|
print(f"Warning: Could not load image {path}. Skipping. Error: {e}")
|
||
|
|
||
|
if not images_pil:
|
||
|
return np.array([]), []
|
||
|
|
||
|
embeddings_list = []
|
||
|
|
||
|
# Process images one by one to avoid memory issues
|
||
|
for image in images_pil:
|
||
|
try:
|
||
|
# Preprocess the image
|
||
|
pixel_values = processor(image, return_tensors="pt").pixel_values
|
||
|
pixel_values = pixel_values.to(device)
|
||
|
|
||
|
if device == "cuda":
|
||
|
pixel_values = pixel_values.half()
|
||
|
|
||
|
with torch.no_grad():
|
||
|
# Get encoder outputs (visual features)
|
||
|
encoder_outputs = model.encoder(pixel_values=pixel_values)
|
||
|
|
||
|
# Use the last hidden state and apply global average pooling
|
||
|
# to get a fixed-size representation
|
||
|
last_hidden_state = encoder_outputs.last_hidden_state # [batch_size, seq_len, hidden_size]
|
||
|
|
||
|
# Global average pooling across the sequence dimension
|
||
|
embedding = torch.mean(last_hidden_state, dim=1) # [batch_size, hidden_size]
|
||
|
|
||
|
embeddings_list.append(embedding.squeeze().cpu().float().numpy())
|
||
|
|
||
|
except Exception as e:
|
||
|
print(f"Warning: Could not process image. Error: {e}")
|
||
|
# Add zero embedding for failed images to maintain consistency
|
||
|
embeddings_list.append(np.zeros(model.config.encoder.hidden_size))
|
||
|
|
||
|
return np.array(embeddings_list), valid_paths
|
||
|
|
||
|
|
||
|
def extract_document_info(image_path, question="What information is in this document?"):
|
||
|
"""
|
||
|
Extract specific information from a document using Donut's text generation capability.
|
||
|
This function demonstrates how to use Donut for document understanding tasks.
|
||
|
"""
|
||
|
try:
|
||
|
image = Image.open(image_path).convert("RGB")
|
||
|
|
||
|
# Prepare the task prompt for document VQA
|
||
|
task_prompt = f"<s_docvqa><s_question>{question}</s_question><s_answer>"
|
||
|
|
||
|
# Process the image and prompt
|
||
|
inputs = processor(image, task_prompt, return_tensors="pt")
|
||
|
inputs = {k: v.to(device) for k, v in inputs.items()}
|
||
|
|
||
|
if device == "cuda":
|
||
|
inputs["pixel_values"] = inputs["pixel_values"].half()
|
||
|
|
||
|
with torch.no_grad():
|
||
|
# Generate answer
|
||
|
generated_ids = model.generate(
|
||
|
**inputs,
|
||
|
max_length=512,
|
||
|
early_stopping=True,
|
||
|
pad_token_id=processor.tokenizer.pad_token_id,
|
||
|
eos_token_id=processor.tokenizer.eos_token_id,
|
||
|
use_cache=True,
|
||
|
num_beams=1,
|
||
|
bad_words_ids=[[processor.tokenizer.unk_token_id]],
|
||
|
return_dict_in_generate=True,
|
||
|
)
|
||
|
|
||
|
# Decode the generated answer
|
||
|
decoded_text = processor.batch_decode(generated_ids.sequences)[0]
|
||
|
# Extract the answer part
|
||
|
answer = decoded_text.split("<s_answer>")[-1].replace("</s_answer>", "").strip()
|
||
|
|
||
|
return answer
|
||
|
|
||
|
except Exception as e:
|
||
|
print(f"Error extracting info from {image_path}: {e}")
|
||
|
return ""
|
||
|
|
||
|
|
||
|
# --- Process all images in the directory ---
|
||
|
print("Scanning for image files...")
|
||
|
image_files = [os.path.join(IMAGE_DIR, f) for f in os.listdir(IMAGE_DIR)
|
||
|
if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
|
||
|
print(f"Found {len(image_files)} image files")
|
||
|
|
||
|
all_embeddings = []
|
||
|
filepaths = []
|
||
|
|
||
|
# Extract embeddings and save to JSON
|
||
|
print("Extracting embeddings using Donut...")
|
||
|
with open("embeddings_factures_donut.json", "w") as f:
|
||
|
f.write("[\n")
|
||
|
first = True
|
||
|
|
||
|
for i in tqdm(range(0, len(image_files), BATCH_SIZE), desc="Processing batches"):
|
||
|
batch_paths = image_files[i:i+BATCH_SIZE]
|
||
|
batch_embeddings, valid_paths = get_document_embeddings(batch_paths)
|
||
|
|
||
|
if len(batch_embeddings) > 0:
|
||
|
embeddings_list = [emb.tolist() for emb in batch_embeddings]
|
||
|
|
||
|
for path, emb in zip(valid_paths, embeddings_list):
|
||
|
if not first:
|
||
|
f.write(",\n")
|
||
|
|
||
|
entry = {
|
||
|
"filepath": path,
|
||
|
"embedding": emb,
|
||
|
"model": "donut-base-finetuned-docvqa",
|
||
|
"embedding_size": len(emb)
|
||
|
}
|
||
|
|
||
|
json.dump(entry, f)
|
||
|
first = False
|
||
|
|
||
|
f.write("\n]\n")
|
||
|
|
||
|
print("Embeddings extracted and saved to 'embeddings_factures_donut.json'")
|
||
|
|
||
|
# Optional: Extract some sample document information
|
||
|
print("\nExtracting sample document information...")
|
||
|
sample_images = image_files[:3] # Process first 3 images as samples
|
||
|
|
||
|
sample_info = []
|
||
|
for img_path in sample_images:
|
||
|
print(f"Processing: {os.path.basename(img_path)}")
|
||
|
|
||
|
# Extract different types of information
|
||
|
questions = [
|
||
|
"What is the total amount?",
|
||
|
"What is the invoice number?",
|
||
|
"What is the date?",
|
||
|
"Who is the vendor?",
|
||
|
"What are the main items?"
|
||
|
]
|
||
|
|
||
|
info = {"filepath": img_path, "extracted_info": {}}
|
||
|
|
||
|
for question in questions:
|
||
|
answer = extract_document_info(img_path, question)
|
||
|
info["extracted_info"][question] = answer
|
||
|
print(f" {question}: {answer}")
|
||
|
|
||
|
sample_info.append(info)
|
||
|
|
||
|
# Save sample extraction results
|
||
|
with open("donut_sample_extractions.json", "w") as f:
|
||
|
json.dump(sample_info, f, indent=2, ensure_ascii=False)
|
||
|
|
||
|
print("Sample document information extracted and saved to 'donut_sample_extractions.json'")
|
||
|
print("Processing completed!")
|