from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline import torch import json # Model ID for a smaller model suitable for Spaces MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" FALLBACK_MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.2" # Initialize with None - will be loaded on first use tokenizer = None text_generation_pipeline = None def get_text_pipeline(): """ Initialize or return the text generation pipeline. Uses smaller models that work well on Spaces. """ global tokenizer, text_generation_pipeline if text_generation_pipeline is None: try: # Try to load primary model tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) # Use 8-bit quantization to reduce memory usage model = AutoModelForCausalLM.from_pretrained( MODEL_ID, device_map="auto", torch_dtype=torch.float16, load_in_8bit=True ) # Create the pipeline text_generation_pipeline = pipeline( "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=1024, do_sample=True, temperature=0.3, top_p=0.95, repetition_penalty=1.15 ) except Exception as e: print(f"Error loading primary model: {str(e)}") print(f"Falling back to {FALLBACK_MODEL_ID}") try: # Fall back to Mistral model which is more widely available tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL_ID) model = AutoModelForCausalLM.from_pretrained( FALLBACK_MODEL_ID, device_map="auto", torch_dtype=torch.float16, load_in_8bit=True ) text_generation_pipeline = pipeline( "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=1024, do_sample=True, temperature=0.3, top_p=0.95, repetition_penalty=1.15 ) except Exception as e2: print(f"Error loading fallback model: {str(e2)}") return None return text_generation_pipeline def process_menu_text(raw_text): """ Process raw OCR text using LLM to improve structure and readability. Args: raw_text: Raw text extracted from menu image Returns: Processed and structured menu text """ # Get the pipeline pipeline = get_text_pipeline() if pipeline is None: # Fallback to simple processing if model not available return { 'structured_text': raw_text, 'menu_sections': [], 'success': False, 'error': "LLM model not available" } # Construct prompt for the LLM prompt = f"""<|system|> You are an AI assistant that helps structure menu text from OCR. Your task is to clean up the text, correct obvious OCR errors, and structure it properly. Identify menu sections, items, and prices. Format your response as JSON with menu sections, items, and prices. <|user|> Here is the raw text extracted from a menu image: {raw_text} Please clean and structure this menu text. Format your response as JSON with the following structure: {{ "menu_sections": [ {{ "section_name": "Section name (e.g., Appetizers, Main Course, etc.)", "items": [ {{ "name": "Item name", "description": "Item description if available", "price": "Price if available" }} ] }} ] }} <|assistant|> """ try: # Generate response from LLM response = pipeline(prompt, return_full_text=False)[0]['generated_text'] # Extract JSON from response response_text = response.strip() # Find JSON in the response json_start = response_text.find('{') json_end = response_text.rfind('}') + 1 if json_start >= 0 and json_end > json_start: json_str = response_text[json_start:json_end] menu_data = json.loads(json_str) # Reconstruct structured text structured_text = "" for section in menu_data.get('menu_sections', []): structured_text += f"{section.get('section_name', 'Menu Items')}\n" structured_text += "-" * len(section.get('section_name', 'Menu Items')) + "\n\n" for item in section.get('items', []): structured_text += f"{item.get('name', '')}" if item.get('price'): structured_text += f" - {item.get('price')}" structured_text += "\n" if item.get('description'): structured_text += f" {item.get('description')}\n" structured_text += "\n" structured_text += "\n" return { 'structured_text': structured_text, 'menu_data': menu_data, 'success': True } else: # Fallback to simple processing return { 'structured_text': raw_text, 'menu_sections': [], 'success': False, 'error': "Failed to parse LLM response as JSON" } except Exception as e: return { 'structured_text': raw_text, 'menu_sections': [], 'success': False, 'error': str(e) }