BrailleMenuGenV2 / models /text_processor.py
Chamin09's picture
initial commit
93c4f75 verified
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
import json
# Model ID for a smaller model suitable for Spaces
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
FALLBACK_MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.2"
# Initialize with None - will be loaded on first use
tokenizer = None
text_generation_pipeline = None
def get_text_pipeline():
"""
Initialize or return the text generation pipeline.
Uses smaller models that work well on Spaces.
"""
global tokenizer, text_generation_pipeline
if text_generation_pipeline is None:
try:
# Try to load primary model
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# Use 8-bit quantization to reduce memory usage
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
torch_dtype=torch.float16,
load_in_8bit=True
)
# Create the pipeline
text_generation_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=1024,
do_sample=True,
temperature=0.3,
top_p=0.95,
repetition_penalty=1.15
)
except Exception as e:
print(f"Error loading primary model: {str(e)}")
print(f"Falling back to {FALLBACK_MODEL_ID}")
try:
# Fall back to Mistral model which is more widely available
tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
FALLBACK_MODEL_ID,
device_map="auto",
torch_dtype=torch.float16,
load_in_8bit=True
)
text_generation_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=1024,
do_sample=True,
temperature=0.3,
top_p=0.95,
repetition_penalty=1.15
)
except Exception as e2:
print(f"Error loading fallback model: {str(e2)}")
return None
return text_generation_pipeline
def process_menu_text(raw_text):
"""
Process raw OCR text using LLM to improve structure and readability.
Args:
raw_text: Raw text extracted from menu image
Returns:
Processed and structured menu text
"""
# Get the pipeline
pipeline = get_text_pipeline()
if pipeline is None:
# Fallback to simple processing if model not available
return {
'structured_text': raw_text,
'menu_sections': [],
'success': False,
'error': "LLM model not available"
}
# Construct prompt for the LLM
prompt = f"""<|system|>
You are an AI assistant that helps structure menu text from OCR.
Your task is to clean up the text, correct obvious OCR errors, and structure it properly.
Identify menu sections, items, and prices.
Format your response as JSON with menu sections, items, and prices.
<|user|>
Here is the raw text extracted from a menu image:
{raw_text}
Please clean and structure this menu text. Format your response as JSON with the following structure:
{{
"menu_sections": [
{{
"section_name": "Section name (e.g., Appetizers, Main Course, etc.)",
"items": [
{{
"name": "Item name",
"description": "Item description if available",
"price": "Price if available"
}}
]
}}
]
}}
<|assistant|>
"""
try:
# Generate response from LLM
response = pipeline(prompt, return_full_text=False)[0]['generated_text']
# Extract JSON from response
response_text = response.strip()
# Find JSON in the response
json_start = response_text.find('{')
json_end = response_text.rfind('}') + 1
if json_start >= 0 and json_end > json_start:
json_str = response_text[json_start:json_end]
menu_data = json.loads(json_str)
# Reconstruct structured text
structured_text = ""
for section in menu_data.get('menu_sections', []):
structured_text += f"{section.get('section_name', 'Menu Items')}\n"
structured_text += "-" * len(section.get('section_name', 'Menu Items')) + "\n\n"
for item in section.get('items', []):
structured_text += f"{item.get('name', '')}"
if item.get('price'):
structured_text += f" - {item.get('price')}"
structured_text += "\n"
if item.get('description'):
structured_text += f" {item.get('description')}\n"
structured_text += "\n"
structured_text += "\n"
return {
'structured_text': structured_text,
'menu_data': menu_data,
'success': True
}
else:
# Fallback to simple processing
return {
'structured_text': raw_text,
'menu_sections': [],
'success': False,
'error': "Failed to parse LLM response as JSON"
}
except Exception as e:
return {
'structured_text': raw_text,
'menu_sections': [],
'success': False,
'error': str(e)
}