DHEIVER commited on
Commit
aace76b
·
verified ·
1 Parent(s): bc85663

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +219 -150
app.py CHANGED
@@ -1,180 +1,249 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import (
4
- Blip2Processor, Blip2ForConditionalGeneration,
5
- AutoProcessor, AutoModelForCausalLM, AutoModelForVision2Seq
6
- )
7
  from PIL import Image
8
  import numpy as np
 
 
 
9
 
10
- class ModelManager:
 
 
 
 
11
  def __init__(self):
12
- self.current_model = None
13
- self.current_processor = None
14
- self.model_name = None
15
 
16
- def load_blip2(self):
17
- """Carrega modelo BLIP-2"""
18
- self.model_name = "Salesforce/blip2-opt-2.7b"
19
- self.current_processor = Blip2Processor.from_pretrained(self.model_name)
20
- self.current_model = Blip2ForConditionalGeneration.from_pretrained(
21
- self.model_name,
22
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
23
- device_map="auto"
24
- )
25
- return "BLIP-2 carregado com sucesso!"
26
-
27
- def load_llava(self):
28
- """Carrega modelo LLaVA"""
29
- self.model_name = "llava-hf/llava-1.5-7b-hf"
30
- self.current_processor = AutoProcessor.from_pretrained(self.model_name)
31
- self.current_model = AutoModelForVision2Seq.from_pretrained(
32
- self.model_name,
33
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
34
- device_map="auto"
35
- )
36
- return "LLaVA carregado com sucesso!"
37
-
38
- def load_git(self):
39
- """Carrega modelo GIT"""
40
- self.model_name = "microsoft/git-base-coco"
41
- self.current_processor = AutoProcessor.from_pretrained(self.model_name)
42
- self.current_model = AutoModelForCausalLM.from_pretrained(
43
- self.model_name,
44
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
45
- device_map="auto"
46
- )
47
- return "GIT carregado com sucesso!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  def analyze_image(self, image, question, model_choice):
50
- """Analisa imagem com foco nutricional"""
51
  try:
52
- # Carrega o modelo apropriado se necessário
53
- if model_choice == "BLIP-2" and (self.model_name != "Salesforce/blip2-opt-2.7b"):
54
- status = self.load_blip2()
55
- elif model_choice == "LLaVA" and (self.model_name != "llava-hf/llava-1.5-7b-hf"):
56
- status = self.load_llava()
57
- elif model_choice == "GIT" and (self.model_name != "microsoft/git-base-coco"):
58
- status = self.load_git()
59
 
60
- # Adiciona contexto nutricional à pergunta
61
- nutritional_prompt = (
62
- "Como nutricionista, analise este prato considerando: "
63
- "1. Lista de ingredientes principais\n"
64
- "2. Estimativa calórica total\n"
65
- "3. Sugestões para uma versão mais saudável\n"
66
- "4. Análise de grupos alimentares\n"
67
- f"Pergunta do usuário: {question}"
68
- "\nPor favor, responda em português com detalhes nutricionais."
69
- )
70
-
71
- # Prepara a imagem
72
- if isinstance(image, str):
73
- image = Image.open(image)
74
- elif isinstance(image, np.ndarray):
75
- image = Image.fromarray(image)
76
 
77
- # Processa a entrada
78
- inputs = self.current_processor(
79
- images=image,
 
 
 
 
80
  text=nutritional_prompt,
81
  return_tensors="pt"
82
- ).to(self.current_model.device)
83
 
84
- # Gera a resposta
85
- outputs = self.current_model.generate(
86
- **inputs,
87
- max_new_tokens=200, # Aumentado para respostas mais completas
88
- num_beams=5,
89
- temperature=0.7,
90
- top_p=0.9
91
- )
 
 
 
 
 
 
 
 
 
92
 
93
- # Decodifica e formata a resposta
94
- response = self.current_processor.decode(outputs[0], skip_special_tokens=True)
95
- formatted_response = response.replace(". ", ".\n").replace("; ", ";\n")
96
- return f"**Análise Nutricional:**\n{formatted_response}"
97
-
98
  except Exception as e:
99
- return f"Erro na análise: {str(e)}"
 
100
 
101
- # Cria instância do gerenciador de modelos
102
- model_manager = ModelManager()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- # Interface Gradio
105
- with gr.Blocks(theme=gr.themes.Soft()) as iface:
106
- gr.Markdown("""
107
- # 🥗 Analisador Nutricional Inteligente
108
- Escolha o modelo que deseja usar para analisar seu prato e obter recomendações nutricionais.
109
- """)
110
 
111
- with gr.Row():
112
- with gr.Column():
113
- # Inputs
114
- model_choice = gr.Radio(
115
- choices=["BLIP-2", "LLaVA", "GIT"],
116
- label="Escolha o Modelo",
117
- value="BLIP-2"
118
- )
119
-
120
- # Substitui gr.Box() por gr.Group() para compatibilidade
121
- with gr.Group():
122
- gr.Markdown("""
123
- ### 📝 Características dos Modelos:
124
 
125
- **BLIP-2:**
126
- - Análise detalhada de ingredientes
127
- - Estimativas calóricas mais precisas
128
- - Recomendações técnicas
 
129
 
130
- **LLaVA:**
131
- - Explicações mais conversacionais
132
- - Sugestões práticas para o dia a dia
133
- - Foco em hábitos alimentares
 
134
 
135
- **GIT:**
136
- - Respostas rápidas e diretas
137
- - Ideal para análises simples
138
- - Menor consumo de recursos
139
- """)
140
-
141
- image_input = gr.Image(
142
- type="pil",
143
- label="Foto do Prato"
144
- )
145
-
146
- question_input = gr.Textbox(
147
- label="Sua Pergunta",
148
- placeholder="Ex: Quantas calorias tem este prato? Como posso torná-lo mais saudável?"
149
- )
150
 
151
- analyze_btn = gr.Button("🔍 Analisar", variant="primary")
 
152
 
153
- with gr.Column():
154
- # Output
155
- with gr.Group(): # Substitui gr.Box() por gr.Group()
156
- gr.Markdown("### 📊 Resultado da Análise")
157
- output_text = gr.Markdown()
 
 
 
158
 
159
- with gr.Accordion("💡 Sugestões de Perguntas", open=False):
160
- gr.Markdown("""
161
- 1. Quantas calorias tem este prato?
162
- 2. Quais são os ingredientes principais?
163
- 3. Como posso tornar este prato mais saudável?
164
- 4. Este prato é adequado para uma dieta low-carb?
165
- 5. Quais nutrientes estão presentes neste prato?
166
- 6. Este prato é rico em proteínas?
167
- 7. Como posso substituir ingredientes para reduzir calorias?
168
- 8. Este prato é indicado para quem tem restrição a glúten/lactose?
169
- """)
 
170
 
171
- # Eventos
172
- analyze_btn.click(
173
- fn=model_manager.analyze_image,
174
- inputs=[image_input, question_input, model_choice],
175
- outputs=output_text
176
- )
177
 
178
  if __name__ == "__main__":
179
- print(f"Dispositivo: {'CUDA' if torch.cuda.is_available() else 'CPU'}")
180
  iface.launch()
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import pipeline, AutoProcessor, AutoModelForVision2Seq
 
 
 
4
  from PIL import Image
5
  import numpy as np
6
+ import os
7
+ from huggingface_hub import snapshot_download
8
+ import logging
9
 
10
+ # Configure logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ class NutritionalAnalyzer:
15
  def __init__(self):
16
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ self.models = {}
18
+ self.processors = {}
19
 
20
+ def initialize_model(self, model_name):
21
+ """Initialize a specific model with error handling and caching"""
22
+ try:
23
+ if model_name not in self.models:
24
+ logger.info(f"Initializing {model_name}...")
25
+
26
+ # Model-specific configurations
27
+ model_configs = {
28
+ "llava": {
29
+ "repo": "llava-hf/llava-1.5-7b-hf",
30
+ "local_cache": "models/llava"
31
+ },
32
+ "git": {
33
+ "repo": "microsoft/git-base-coco",
34
+ "local_cache": "models/git"
35
+ }
36
+ }
37
+
38
+ config = model_configs.get(model_name)
39
+ if not config:
40
+ raise ValueError(f"Unsupported model: {model_name}")
41
+
42
+ # Ensure cache directory exists
43
+ os.makedirs(config["local_cache"], exist_ok=True)
44
+
45
+ # Download model if not cached
46
+ if not os.path.exists(os.path.join(config["local_cache"], "model.safetensors")):
47
+ snapshot_download(
48
+ repo_id=config["repo"],
49
+ local_dir=config["local_cache"],
50
+ ignore_patterns=["*.md", "*.txt"]
51
+ )
52
+
53
+ # Load processor and model
54
+ self.processors[model_name] = AutoProcessor.from_pretrained(
55
+ config["local_cache"],
56
+ local_files_only=True
57
+ )
58
+
59
+ self.models[model_name] = AutoModelForVision2Seq.from_pretrained(
60
+ config["local_cache"],
61
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
62
+ device_map="auto",
63
+ local_files_only=True
64
+ )
65
+
66
+ logger.info(f"{model_name} initialized successfully")
67
+ return True
68
+
69
+ return True
70
+
71
+ except Exception as e:
72
+ logger.error(f"Error initializing {model_name}: {str(e)}")
73
+ return False
74
+
75
+ def prepare_image(self, image):
76
+ """Prepare image for model input"""
77
+ if isinstance(image, str):
78
+ image = Image.open(image)
79
+ elif isinstance(image, np.ndarray):
80
+ image = Image.fromarray(image)
81
+
82
+ # Ensure image is in RGB mode
83
+ if image.mode != "RGB":
84
+ image = image.convert("RGB")
85
+
86
+ return image
87
+
88
+ def generate_nutritional_prompt(self, user_question):
89
+ """Generate a comprehensive nutritional analysis prompt"""
90
+ return f"""Como nutricionista especializado, analise esta refeição detalhadamente:
91
+
92
+ 1. Composição do Prato:
93
+ - Ingredientes principais
94
+ - Proporções aproximadas
95
+ - Método de preparo aparente
96
+
97
+ 2. Análise Nutricional:
98
+ - Estimativa calórica
99
+ - Macronutrientes (proteínas, carboidratos, gorduras)
100
+ - Principais micronutrientes
101
+
102
+ 3. Recomendações:
103
+ - Sugestões para versão mais saudável
104
+ - Porção recomendada
105
+ - Adequação para dietas específicas
106
+
107
+ Pergunta específica do usuário: {user_question}
108
+
109
+ Por favor, forneça uma análise detalhada em português."""
110
 
111
  def analyze_image(self, image, question, model_choice):
112
+ """Analyze image with nutritional focus"""
113
  try:
114
+ # Convert model choice to internal name
115
+ model_name = model_choice.lower().replace("-", "")
 
 
 
 
 
116
 
117
+ # Initialize model if needed
118
+ if not self.initialize_model(model_name):
119
+ return "Erro: Não foi possível inicializar o modelo. Por favor, tente novamente."
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
+ # Prepare image and prompt
122
+ processed_image = self.prepare_image(image)
123
+ nutritional_prompt = self.generate_nutritional_prompt(question)
124
+
125
+ # Process input
126
+ inputs = self.processors[model_name](
127
+ images=processed_image,
128
  text=nutritional_prompt,
129
  return_tensors="pt"
130
+ ).to(self.device)
131
 
132
+ # Generate response with enhanced parameters
133
+ with torch.no_grad():
134
+ outputs = self.models[model_name].generate(
135
+ **inputs,
136
+ max_new_tokens=300,
137
+ num_beams=5,
138
+ temperature=0.7,
139
+ top_p=0.9,
140
+ repetition_penalty=1.2,
141
+ length_penalty=1.0
142
+ )
143
+
144
+ # Decode and format response
145
+ response = self.processors[model_name].decode(outputs[0], skip_special_tokens=True)
146
+ formatted_response = self.format_response(response)
147
+
148
+ return formatted_response
149
 
 
 
 
 
 
150
  except Exception as e:
151
+ logger.error(f"Analysis error: {str(e)}")
152
+ return f"Erro na análise: {str(e)}\nPor favor, tente novamente ou escolha outro modelo."
153
 
154
+ def format_response(self, response):
155
+ """Format the response for better readability"""
156
+ sections = [
157
+ "Composição do Prato",
158
+ "Análise Nutricional",
159
+ "Recomendações"
160
+ ]
161
+
162
+ formatted = "# 📊 Análise Nutricional\n\n"
163
+
164
+ # Split response into paragraphs
165
+ paragraphs = response.split("\n")
166
+
167
+ current_section = ""
168
+ for paragraph in paragraphs:
169
+ # Check if paragraph starts a new section
170
+ for section in sections:
171
+ if section.lower() in paragraph.lower():
172
+ current_section = f"\n## {section}\n"
173
+ formatted += current_section
174
+ break
175
+
176
+ # Add paragraph to current section
177
+ if paragraph.strip() and current_section:
178
+ formatted += f"- {paragraph.strip()}\n"
179
+ elif paragraph.strip():
180
+ formatted += f"{paragraph.strip()}\n"
181
+
182
+ return formatted
183
 
184
+ # Create interface
185
+ def create_interface():
186
+ analyzer = NutritionalAnalyzer()
 
 
 
187
 
188
+ with gr.Blocks(theme=gr.themes.Soft()) as iface:
189
+ gr.Markdown("""
190
+ # 🥗 Análise Nutricional Inteligente
191
+ Faça upload da foto do seu prato para receber uma análise nutricional detalhada.
192
+ """)
193
+
194
+ with gr.Row():
195
+ with gr.Column(scale=2):
196
+ image_input = gr.Image(
197
+ type="pil",
198
+ label="📸 Foto do Prato",
199
+ height=400
200
+ )
201
 
202
+ question_input = gr.Textbox(
203
+ label="💭 Sua Pergunta",
204
+ placeholder="Ex: Quais são os nutrientes principais deste prato?",
205
+ lines=2
206
+ )
207
 
208
+ model_choice = gr.Radio(
209
+ choices=["LLaVA", "GIT"],
210
+ value="LLaVA",
211
+ label="🤖 Escolha o Modelo de Análise"
212
+ )
213
 
214
+ analyze_btn = gr.Button(
215
+ "🔍 Analisar Prato",
216
+ variant="primary",
217
+ scale=1
218
+ )
 
 
 
 
 
 
 
 
 
 
219
 
220
+ with gr.Column(scale=3):
221
+ output = gr.Markdown(label="Resultado da Análise")
222
 
223
+ # Add examples and tips
224
+ with gr.Accordion("💡 Dicas de Uso", open=False):
225
+ gr.Markdown("""
226
+ ### Sugestões de Perguntas:
227
+ - Qual o valor nutricional aproximado deste prato?
228
+ - Como tornar esta refeição mais equilibrada?
229
+ - Este prato é adequado para dieta low-carb?
230
+ - Quais nutrientes importantes estão presentes?
231
 
232
+ ### Dicas para Melhores Resultados:
233
+ 1. Tire a foto com boa iluminação
234
+ 2. Capture todos os elementos do prato
235
+ 3. Evite ângulos muito inclinados
236
+ 4. Seja específico em suas perguntas
237
+ """)
238
+
239
+ analyze_btn.click(
240
+ fn=analyzer.analyze_image,
241
+ inputs=[image_input, question_input, model_choice],
242
+ outputs=output
243
+ )
244
 
245
+ return iface
 
 
 
 
 
246
 
247
  if __name__ == "__main__":
248
+ iface = create_interface()
249
  iface.launch()