Spaces:
Sleeping
Sleeping
File size: 2,685 Bytes
6bb168e 8b34af2 6bb168e 9919fac 6bb168e beb4859 6bb168e 9919fac 6bb168e b0416c1 9919fac 6bb168e b0416c1 6bb168e b0416c1 9919fac 6bb168e 9919fac 6bb168e 9919fac b0416c1 6bb168e 9919fac 6bb168e 9919fac 6bb168e b0416c1 6bb168e 9919fac 8b34af2 9919fac 8b34af2 9919fac 8b34af2 9919fac 8b34af2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import os
import torch
from transformers import AutoModel, AutoTokenizer
# Load model and tokenizer
model_name = "srimanth-d/GOT_CPU" # Using GOT model on CPU
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, return_tensors='pt')
# Load the model
model = AutoModel.from_pretrained(
model_name,
trust_remote_code=True,
low_cpu_mem_usage=True,
use_safetensors=True,
pad_token_id=tokenizer.eos_token_id,
)
# Ensure the model is in evaluation mode and loaded on CPU
device = torch.device("cpu")
model = model.eval()
# OCR function to extract text
def extract_text_got(uploaded_file):
"""Use GOT-OCR2.0 model to extract text from the uploaded image."""
temp_file_path = 'temp_image.jpg'
try:
# Save the uploaded file temporarily
with open(temp_file_path, 'wb') as temp_file:
temp_file.write(uploaded_file.read())
print(f"Processing image from path: {temp_file_path}")
ocr_types = ['ocr', 'format']
results = []
# Run OCR on the image
for ocr_type in ocr_types:
with torch.no_grad():
print(f"Running OCR with type: {ocr_type}")
outputs = model.chat(tokenizer, temp_file_path, ocr_type=ocr_type)
if isinstance(outputs, list) and outputs[0].strip():
return outputs[0].strip() # Return the result if successful
results.append(outputs[0].strip() if outputs else "No result")
# Combine results or return no text found message
return results[0] if results else "No text extracted."
except Exception as e:
return f"Error during text extraction: {str(e)}"
finally:
# Clean up temporary file
if os.path.exists(temp_file_path):
os.remove(temp_file_path)
print(f"Temporary file {temp_file_path} removed.")
# Function to clean extracted text using AI
def clean_text_with_ai(extracted_text):
"""
Cleans extracted text by leveraging an AI model to intelligently remove extra spaces.
"""
try:
# Prepare the input for the AI model
inputs = tokenizer(extracted_text, return_tensors="pt").to(device)
# Generate cleaned text using the AI model
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=100) # Adjust max_new_tokens as needed
# Decode the generated output
cleaned_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return cleaned_text.strip() # Return the cleaned text
except Exception as e:
return f"Error during AI text cleaning: {str(e)}"
|