check visison extract model
This commit is contained in:
201
extract/extract_donut.py
Normal file
201
extract/extract_donut.py
Normal file
@@ -0,0 +1,201 @@
|
||||
import torch
|
||||
from transformers import DonutProcessor, VisionEncoderDecoderModel
|
||||
from PIL import Image
|
||||
import os
|
||||
import numpy as np
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
# --- Configuration ---
|
||||
MODEL_NAME = "naver-clova-ix/donut-base-finetuned-docvqa" # Donut model for document VQA
|
||||
IMAGE_DIR = "/home/nguyendc/model-factory/Finetuning-Automation/etc/data/media/docai_mgp_facture_v2_0/"
|
||||
BATCH_SIZE = 4 # Smaller batch size for Donut as it's memory intensive
|
||||
# --- End Configuration ---
|
||||
|
||||
# Check for GPU availability
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Load the Donut model and processor
|
||||
print("Loading Donut model and processor...")
|
||||
processor = DonutProcessor.from_pretrained(MODEL_NAME)
|
||||
model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
# Set model to half precision for efficiency if using GPU
|
||||
if device == "cuda":
|
||||
model = model.half()
|
||||
|
||||
|
||||
def get_document_embeddings(image_paths):
|
||||
"""
|
||||
Processes a batch of document images and extracts their embeddings using Donut.
|
||||
Uses the encoder part of the VisionEncoderDecoder model to get visual representations.
|
||||
"""
|
||||
images_pil = []
|
||||
valid_paths = []
|
||||
|
||||
for path in image_paths:
|
||||
if path.lower().endswith(('.png', '.jpg', '.jpeg')):
|
||||
try:
|
||||
# Load and convert image to RGB
|
||||
image = Image.open(path).convert("RGB")
|
||||
images_pil.append(image)
|
||||
valid_paths.append(path)
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not load image {path}. Skipping. Error: {e}")
|
||||
|
||||
if not images_pil:
|
||||
return np.array([]), []
|
||||
|
||||
embeddings_list = []
|
||||
|
||||
# Process images one by one to avoid memory issues
|
||||
for image in images_pil:
|
||||
try:
|
||||
# Preprocess the image
|
||||
pixel_values = processor(image, return_tensors="pt").pixel_values
|
||||
pixel_values = pixel_values.to(device)
|
||||
|
||||
if device == "cuda":
|
||||
pixel_values = pixel_values.half()
|
||||
|
||||
with torch.no_grad():
|
||||
# Get encoder outputs (visual features)
|
||||
encoder_outputs = model.encoder(pixel_values=pixel_values)
|
||||
|
||||
# Use the last hidden state and apply global average pooling
|
||||
# to get a fixed-size representation
|
||||
last_hidden_state = encoder_outputs.last_hidden_state # [batch_size, seq_len, hidden_size]
|
||||
|
||||
# Global average pooling across the sequence dimension
|
||||
embedding = torch.mean(last_hidden_state, dim=1) # [batch_size, hidden_size]
|
||||
|
||||
embeddings_list.append(embedding.squeeze().cpu().float().numpy())
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not process image. Error: {e}")
|
||||
# Add zero embedding for failed images to maintain consistency
|
||||
embeddings_list.append(np.zeros(model.config.encoder.hidden_size))
|
||||
|
||||
return np.array(embeddings_list), valid_paths
|
||||
|
||||
|
||||
def extract_document_info(image_path, question="What information is in this document?"):
|
||||
"""
|
||||
Extract specific information from a document using Donut's text generation capability.
|
||||
This function demonstrates how to use Donut for document understanding tasks.
|
||||
"""
|
||||
try:
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
|
||||
# Prepare the task prompt for document VQA
|
||||
task_prompt = f"<s_docvqa><s_question>{question}</s_question><s_answer>"
|
||||
|
||||
# Process the image and prompt
|
||||
inputs = processor(image, task_prompt, return_tensors="pt")
|
||||
inputs = {k: v.to(device) for k, v in inputs.items()}
|
||||
|
||||
if device == "cuda":
|
||||
inputs["pixel_values"] = inputs["pixel_values"].half()
|
||||
|
||||
with torch.no_grad():
|
||||
# Generate answer
|
||||
generated_ids = model.generate(
|
||||
**inputs,
|
||||
max_length=512,
|
||||
early_stopping=True,
|
||||
pad_token_id=processor.tokenizer.pad_token_id,
|
||||
eos_token_id=processor.tokenizer.eos_token_id,
|
||||
use_cache=True,
|
||||
num_beams=1,
|
||||
bad_words_ids=[[processor.tokenizer.unk_token_id]],
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
# Decode the generated answer
|
||||
decoded_text = processor.batch_decode(generated_ids.sequences)[0]
|
||||
# Extract the answer part
|
||||
answer = decoded_text.split("<s_answer>")[-1].replace("</s_answer>", "").strip()
|
||||
|
||||
return answer
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error extracting info from {image_path}: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
# --- Process all images in the directory ---
|
||||
print("Scanning for image files...")
|
||||
image_files = [os.path.join(IMAGE_DIR, f) for f in os.listdir(IMAGE_DIR)
|
||||
if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
|
||||
print(f"Found {len(image_files)} image files")
|
||||
|
||||
all_embeddings = []
|
||||
filepaths = []
|
||||
|
||||
# Extract embeddings and save to JSON
|
||||
print("Extracting embeddings using Donut...")
|
||||
with open("embeddings_factures_donut.json", "w") as f:
|
||||
f.write("[\n")
|
||||
first = True
|
||||
|
||||
for i in tqdm(range(0, len(image_files), BATCH_SIZE), desc="Processing batches"):
|
||||
batch_paths = image_files[i:i+BATCH_SIZE]
|
||||
batch_embeddings, valid_paths = get_document_embeddings(batch_paths)
|
||||
|
||||
if len(batch_embeddings) > 0:
|
||||
embeddings_list = [emb.tolist() for emb in batch_embeddings]
|
||||
|
||||
for path, emb in zip(valid_paths, embeddings_list):
|
||||
if not first:
|
||||
f.write(",\n")
|
||||
|
||||
entry = {
|
||||
"filepath": path,
|
||||
"embedding": emb,
|
||||
"model": "donut-base-finetuned-docvqa",
|
||||
"embedding_size": len(emb)
|
||||
}
|
||||
|
||||
json.dump(entry, f)
|
||||
first = False
|
||||
|
||||
f.write("\n]\n")
|
||||
|
||||
print("Embeddings extracted and saved to 'embeddings_factures_donut.json'")
|
||||
|
||||
# Optional: Extract some sample document information
|
||||
print("\nExtracting sample document information...")
|
||||
sample_images = image_files[:3] # Process first 3 images as samples
|
||||
|
||||
sample_info = []
|
||||
for img_path in sample_images:
|
||||
print(f"Processing: {os.path.basename(img_path)}")
|
||||
|
||||
# Extract different types of information
|
||||
questions = [
|
||||
"What is the total amount?",
|
||||
"What is the invoice number?",
|
||||
"What is the date?",
|
||||
"Who is the vendor?",
|
||||
"What are the main items?"
|
||||
]
|
||||
|
||||
info = {"filepath": img_path, "extracted_info": {}}
|
||||
|
||||
for question in questions:
|
||||
answer = extract_document_info(img_path, question)
|
||||
info["extracted_info"][question] = answer
|
||||
print(f" {question}: {answer}")
|
||||
|
||||
sample_info.append(info)
|
||||
|
||||
# Save sample extraction results
|
||||
with open("donut_sample_extractions.json", "w") as f:
|
||||
json.dump(sample_info, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print("Sample document information extracted and saved to 'donut_sample_extractions.json'")
|
||||
print("Processing completed!")
|
Reference in New Issue
Block a user