File size: 6,070 Bytes
93c4f75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
import json

# Model ID for a smaller model suitable for Spaces
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
FALLBACK_MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.2"

# Initialize with None - will be loaded on first use
tokenizer = None
text_generation_pipeline = None

def get_text_pipeline():
    """
    Initialize or return the text generation pipeline.
    Uses smaller models that work well on Spaces.
    """
    global tokenizer, text_generation_pipeline
    
    if text_generation_pipeline is None:
        try:
            # Try to load primary model
            tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
            
            # Use 8-bit quantization to reduce memory usage
            model = AutoModelForCausalLM.from_pretrained(
                MODEL_ID, 
                device_map="auto",
                torch_dtype=torch.float16,
                load_in_8bit=True
            )
            
            # Create the pipeline
            text_generation_pipeline = pipeline(
                "text-generation",
                model=model,
                tokenizer=tokenizer,
                max_new_tokens=1024,
                do_sample=True,
                temperature=0.3,
                top_p=0.95,
                repetition_penalty=1.15
            )
            
        except Exception as e:
            print(f"Error loading primary model: {str(e)}")
            print(f"Falling back to {FALLBACK_MODEL_ID}")
            
            try:
                # Fall back to Mistral model which is more widely available
                tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL_ID)
                model = AutoModelForCausalLM.from_pretrained(
                    FALLBACK_MODEL_ID,
                    device_map="auto",
                    torch_dtype=torch.float16,
                    load_in_8bit=True
                )
                
                text_generation_pipeline = pipeline(
                    "text-generation",
                    model=model,
                    tokenizer=tokenizer,
                    max_new_tokens=1024,
                    do_sample=True,
                    temperature=0.3,
                    top_p=0.95,
                    repetition_penalty=1.15
                )
            except Exception as e2:
                print(f"Error loading fallback model: {str(e2)}")
                return None
    
    return text_generation_pipeline

def process_menu_text(raw_text):
    """
    Process raw OCR text using LLM to improve structure and readability.
    
    Args:
        raw_text: Raw text extracted from menu image
        
    Returns:
        Processed and structured menu text
    """
    # Get the pipeline
    pipeline = get_text_pipeline()
    
    if pipeline is None:
        # Fallback to simple processing if model not available
        return {
            'structured_text': raw_text,
            'menu_sections': [],
            'success': False,
            'error': "LLM model not available"
        }
    
    # Construct prompt for the LLM
    prompt = f"""<|system|>
You are an AI assistant that helps structure menu text from OCR.
Your task is to clean up the text, correct obvious OCR errors, and structure it properly.
Identify menu sections, items, and prices.
Format your response as JSON with menu sections, items, and prices.
<|user|>
Here is the raw text extracted from a menu image:

{raw_text}

Please clean and structure this menu text. Format your response as JSON with the following structure:
{{
    "menu_sections": [
        {{
            "section_name": "Section name (e.g., Appetizers, Main Course, etc.)",
            "items": [
                {{
                    "name": "Item name",
                    "description": "Item description if available",
                    "price": "Price if available"
                }}
            ]
        }}
    ]
}}
<|assistant|>
"""
    
    try:
        # Generate response from LLM
        response = pipeline(prompt, return_full_text=False)[0]['generated_text']
        
        # Extract JSON from response
        response_text = response.strip()
        
        # Find JSON in the response
        json_start = response_text.find('{')
        json_end = response_text.rfind('}') + 1
        
        if json_start >= 0 and json_end > json_start:
            json_str = response_text[json_start:json_end]
            menu_data = json.loads(json_str)
            
            # Reconstruct structured text
            structured_text = ""
            for section in menu_data.get('menu_sections', []):
                structured_text += f"{section.get('section_name', 'Menu Items')}\n"
                structured_text += "-" * len(section.get('section_name', 'Menu Items')) + "\n\n"
                
                for item in section.get('items', []):
                    structured_text += f"{item.get('name', '')}"
                    if item.get('price'):
                        structured_text += f" - {item.get('price')}"
                    structured_text += "\n"
                    
                    if item.get('description'):
                        structured_text += f"  {item.get('description')}\n"
                    
                    structured_text += "\n"
                
                structured_text += "\n"
            
            return {
                'structured_text': structured_text,
                'menu_data': menu_data,
                'success': True
            }
        else:
            # Fallback to simple processing
            return {
                'structured_text': raw_text,
                'menu_sections': [],
                'success': False,
                'error': "Failed to parse LLM response as JSON"
            }
            
    except Exception as e:
        return {
            'structured_text': raw_text,
            'menu_sections': [],
            'success': False,
            'error': str(e)
        }