Spaces:
Running
Running
File size: 6,070 Bytes
93c4f75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
import json
# Model ID for a smaller model suitable for Spaces
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
FALLBACK_MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.2"
# Initialize with None - will be loaded on first use
tokenizer = None
text_generation_pipeline = None
def get_text_pipeline():
"""
Initialize or return the text generation pipeline.
Uses smaller models that work well on Spaces.
"""
global tokenizer, text_generation_pipeline
if text_generation_pipeline is None:
try:
# Try to load primary model
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# Use 8-bit quantization to reduce memory usage
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
torch_dtype=torch.float16,
load_in_8bit=True
)
# Create the pipeline
text_generation_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=1024,
do_sample=True,
temperature=0.3,
top_p=0.95,
repetition_penalty=1.15
)
except Exception as e:
print(f"Error loading primary model: {str(e)}")
print(f"Falling back to {FALLBACK_MODEL_ID}")
try:
# Fall back to Mistral model which is more widely available
tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
FALLBACK_MODEL_ID,
device_map="auto",
torch_dtype=torch.float16,
load_in_8bit=True
)
text_generation_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=1024,
do_sample=True,
temperature=0.3,
top_p=0.95,
repetition_penalty=1.15
)
except Exception as e2:
print(f"Error loading fallback model: {str(e2)}")
return None
return text_generation_pipeline
def process_menu_text(raw_text):
"""
Process raw OCR text using LLM to improve structure and readability.
Args:
raw_text: Raw text extracted from menu image
Returns:
Processed and structured menu text
"""
# Get the pipeline
pipeline = get_text_pipeline()
if pipeline is None:
# Fallback to simple processing if model not available
return {
'structured_text': raw_text,
'menu_sections': [],
'success': False,
'error': "LLM model not available"
}
# Construct prompt for the LLM
prompt = f"""<|system|>
You are an AI assistant that helps structure menu text from OCR.
Your task is to clean up the text, correct obvious OCR errors, and structure it properly.
Identify menu sections, items, and prices.
Format your response as JSON with menu sections, items, and prices.
<|user|>
Here is the raw text extracted from a menu image:
{raw_text}
Please clean and structure this menu text. Format your response as JSON with the following structure:
{{
"menu_sections": [
{{
"section_name": "Section name (e.g., Appetizers, Main Course, etc.)",
"items": [
{{
"name": "Item name",
"description": "Item description if available",
"price": "Price if available"
}}
]
}}
]
}}
<|assistant|>
"""
try:
# Generate response from LLM
response = pipeline(prompt, return_full_text=False)[0]['generated_text']
# Extract JSON from response
response_text = response.strip()
# Find JSON in the response
json_start = response_text.find('{')
json_end = response_text.rfind('}') + 1
if json_start >= 0 and json_end > json_start:
json_str = response_text[json_start:json_end]
menu_data = json.loads(json_str)
# Reconstruct structured text
structured_text = ""
for section in menu_data.get('menu_sections', []):
structured_text += f"{section.get('section_name', 'Menu Items')}\n"
structured_text += "-" * len(section.get('section_name', 'Menu Items')) + "\n\n"
for item in section.get('items', []):
structured_text += f"{item.get('name', '')}"
if item.get('price'):
structured_text += f" - {item.get('price')}"
structured_text += "\n"
if item.get('description'):
structured_text += f" {item.get('description')}\n"
structured_text += "\n"
structured_text += "\n"
return {
'structured_text': structured_text,
'menu_data': menu_data,
'success': True
}
else:
# Fallback to simple processing
return {
'structured_text': raw_text,
'menu_sections': [],
'success': False,
'error': "Failed to parse LLM response as JSON"
}
except Exception as e:
return {
'structured_text': raw_text,
'menu_sections': [],
'success': False,
'error': str(e)
}
|