DHEIVER commited on
Commit
548dad1
·
verified ·
1 Parent(s): 30e029e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -88
app.py CHANGED
@@ -1,72 +1,73 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import Blip2Processor, Blip2ForConditionalGeneration
4
  import pandas as pd
5
- from PIL import Image
6
  import numpy as np
 
7
 
8
- # Inicializa o modelo BLIP2
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
11
- model = Blip2ForConditionalGeneration.from_pretrained(
12
- "Salesforce/blip2-opt-2.7b",
13
- torch_dtype=torch.float16 if device == "cuda" else torch.float32,
14
- device_map="auto"
15
- )
 
 
16
 
17
- # Base de dados nutricional
18
  NUTRITION_DB = {
19
  "arroz": {"calorias": 130, "proteinas": 2.7, "carboidratos": 28, "gorduras": 0.3},
20
  "feijão": {"calorias": 77, "proteinas": 5.2, "carboidratos": 13.6, "gorduras": 0.5},
21
  "frango": {"calorias": 165, "proteinas": 31, "carboidratos": 0, "gorduras": 3.6},
22
  "salada": {"calorias": 15, "proteinas": 1.4, "carboidratos": 2.9, "gorduras": 0.2},
23
- "batata": {"calorias": 93, "proteinas": 2.5, "carboidratos": 21, "gorduras": 0.1},
24
- "carne": {"calorias": 250, "proteinas": 26, "carboidratos": 0, "gorduras": 17},
25
- "peixe": {"calorias": 206, "proteinas": 22, "carboidratos": 0, "gorduras": 12},
26
- "macarrão": {"calorias": 158, "proteinas": 5.8, "carboidratos": 31, "gorduras": 1.2},
27
- "ovo": {"calorias": 155, "proteinas": 13, "carboidratos": 1.1, "gorduras": 11},
28
  }
29
 
30
- def analyze_image(image, progress=gr.Progress()):
31
- """Analisa a imagem usando BLIP2 e retorna descrição dos alimentos"""
32
  try:
33
- progress(0.2, desc="Processando imagem...")
 
34
 
35
- # Garante que a imagem está no formato correto
36
- if isinstance(image, str):
37
- image = Image.open(image)
38
- elif isinstance(image, np.ndarray):
39
  image = Image.fromarray(image)
40
-
41
- # Gera prompt específico para identificação de alimentos
42
- prompt = "Identifique e liste todos os alimentos visíveis nesta imagem de refeição. Forneça uma lista detalhada."
43
-
44
- progress(0.4, desc="Analisando alimentos...")
45
- inputs = processor(image, text=prompt, return_tensors="pt").to(device)
46
-
47
- progress(0.6, desc="Gerando descrição...")
48
- with torch.no_grad():
49
- outputs = model.generate(
50
- **inputs,
51
- max_new_tokens=100,
52
- num_beams=5,
53
- temperature=1.0,
54
- top_p=0.9,
55
- )
56
-
57
- food_description = processor.decode(outputs[0], skip_special_tokens=True)
58
- progress(0.8, desc="Calculando nutrientes...")
59
 
60
- # Identifica alimentos conhecidos na descrição
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  found_foods = []
62
  for food in NUTRITION_DB.keys():
63
- if food in food_description.lower():
64
  found_foods.append(food)
65
 
66
  if not found_foods:
67
- return None, None, None, "Não foi possível identificar alimentos conhecidos na imagem."
68
 
69
- # Calcula nutrientes totais
70
  total_nutrients = {
71
  "calorias": 0,
72
  "proteinas": 0,
@@ -74,13 +75,12 @@ def analyze_image(image, progress=gr.Progress()):
74
  "gorduras": 0
75
  }
76
 
77
- # Soma nutrientes de cada alimento encontrado
78
  for food in found_foods:
79
  for nutrient, value in NUTRITION_DB[food].items():
80
  total_nutrients[nutrient] += value
81
 
82
  # Prepara dados para visualização
83
- nutrients_table = [
84
  ["Calorias", f"{total_nutrients['calorias']:.1f} kcal"],
85
  ["Proteínas", f"{total_nutrients['proteinas']:.1f}g"],
86
  ["Carboidratos", f"{total_nutrients['carboidratos']:.1f}g"],
@@ -97,71 +97,79 @@ def analyze_image(image, progress=gr.Progress()):
97
  ]
98
  })
99
 
100
- # Gera análise
101
  analysis = f"""### Alimentos Identificados:
102
- {food_description}
103
 
104
- ### Alimentos na Base de Dados:
105
- {', '.join(found_foods)}
106
 
107
  ### Análise Nutricional:
108
- - Valor Calórico: {total_nutrients['calorias']:.1f} kcal
109
- - Proporção de Macronutrientes:
110
- Proteínas: {total_nutrients['proteinas']:.1f}g
111
- Carboidratos: {total_nutrients['carboidratos']:.1f}g
112
- • Gorduras: {total_nutrients['gorduras']:.1f}g
113
  """
114
 
115
- progress(1.0, desc="Concluído!")
116
- return analysis, nutrients_table, plot_data, None
117
 
118
  except Exception as e:
119
- return None, None, None, f"Erro na análise: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  # Interface Gradio
122
  with gr.Blocks(theme=gr.themes.Soft()) as iface:
123
- gr.Markdown("# 🍽️ Análise Nutricional com IA")
 
 
 
124
 
125
  with gr.Row():
126
- with gr.Column(scale=1):
127
- # Input
128
  image_input = gr.Image(
129
  type="pil",
130
  label="Foto do Prato",
131
  sources=["upload", "webcam"]
132
  )
133
- analyze_btn = gr.Button("Analisar", variant="primary")
 
 
 
 
 
134
 
135
- with gr.Column(scale=2):
136
- # Output
137
- error_output = gr.Markdown(visible=False)
 
 
138
 
139
- with gr.Tabs():
140
- with gr.Tab("Análise"):
141
- analysis_output = gr.Markdown()
142
-
143
- with gr.Tab("Nutrientes"):
144
- nutrients_table = gr.Dataframe(
145
- headers=["Nutriente", "Quantidade"],
146
- label="Informação Nutricional"
147
- )
148
- nutrients_plot = gr.BarPlot(
149
- x="Nutriente",
150
- y="Quantidade",
151
- title="Macronutrientes (g)",
152
- height=300
153
- )
154
 
155
  # Eventos
156
  analyze_btn.click(
157
  fn=analyze_image,
158
  inputs=[image_input],
159
- outputs=[
160
- analysis_output,
161
- nutrients_table,
162
- nutrients_plot,
163
- error_output
164
- ]
165
  )
166
 
167
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoProcessor, AutoModelForCausalLM
4
  import pandas as pd
 
5
  import numpy as np
6
+ from PIL import Image
7
 
8
+ def get_model():
9
+ """Inicializa o modelo uma única vez"""
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ model_name = "microsoft/git-base-coco"
12
+
13
+ processor = AutoProcessor.from_pretrained(model_name)
14
+ model = AutoModelForCausalLM.from_pretrained(model_name)
15
+ model = model.to(device)
16
+
17
+ return processor, model, device
18
 
19
+ # Base de dados nutricional simplificada
20
  NUTRITION_DB = {
21
  "arroz": {"calorias": 130, "proteinas": 2.7, "carboidratos": 28, "gorduras": 0.3},
22
  "feijão": {"calorias": 77, "proteinas": 5.2, "carboidratos": 13.6, "gorduras": 0.5},
23
  "frango": {"calorias": 165, "proteinas": 31, "carboidratos": 0, "gorduras": 3.6},
24
  "salada": {"calorias": 15, "proteinas": 1.4, "carboidratos": 2.9, "gorduras": 0.2},
 
 
 
 
 
25
  }
26
 
27
+ def process_image(image, progress=gr.Progress()):
28
+ """Processa a imagem e retorna a descrição"""
29
  try:
30
+ progress(0.3, desc="Carregando modelo...")
31
+ processor, model, device = get_model()
32
 
33
+ progress(0.5, desc="Processando imagem...")
34
+ if isinstance(image, np.ndarray):
 
 
35
  image = Image.fromarray(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ # Processa a imagem
38
+ inputs = processor(images=image, return_tensors="pt").to(device)
39
+
40
+ progress(0.7, desc="Gerando descrição...")
41
+ # Gera a descrição
42
+ outputs = model.generate(
43
+ **inputs,
44
+ max_new_tokens=50,
45
+ num_beams=1,
46
+ temperature=1.0,
47
+ )
48
+
49
+ # Decodifica a saída
50
+ description = processor.decode(outputs[0], skip_special_tokens=True)
51
+
52
+ progress(1.0, desc="Concluído!")
53
+ return description.strip()
54
+
55
+ except Exception as e:
56
+ raise gr.Error(f"Erro no processamento da imagem: {str(e)}")
57
+
58
+ def analyze_foods(description):
59
+ """Analisa a descrição e retorna informações nutricionais"""
60
+ try:
61
+ # Identifica alimentos da base de dados na descrição
62
  found_foods = []
63
  for food in NUTRITION_DB.keys():
64
+ if food in description.lower():
65
  found_foods.append(food)
66
 
67
  if not found_foods:
68
+ return "Nenhum alimento conhecido identificado.", None, None
69
 
70
+ # Calcula nutrientes
71
  total_nutrients = {
72
  "calorias": 0,
73
  "proteinas": 0,
 
75
  "gorduras": 0
76
  }
77
 
 
78
  for food in found_foods:
79
  for nutrient, value in NUTRITION_DB[food].items():
80
  total_nutrients[nutrient] += value
81
 
82
  # Prepara dados para visualização
83
+ table_data = [
84
  ["Calorias", f"{total_nutrients['calorias']:.1f} kcal"],
85
  ["Proteínas", f"{total_nutrients['proteinas']:.1f}g"],
86
  ["Carboidratos", f"{total_nutrients['carboidratos']:.1f}g"],
 
97
  ]
98
  })
99
 
 
100
  analysis = f"""### Alimentos Identificados:
101
+ {', '.join(found_foods)}
102
 
103
+ ### Descrição do Modelo:
104
+ {description}
105
 
106
  ### Análise Nutricional:
107
+ Calorias Totais: {total_nutrients['calorias']:.1f} kcal
108
+ Proteínas: {total_nutrients['proteinas']:.1f}g
109
+ Carboidratos: {total_nutrients['carboidratos']:.1f}g
110
+ Gorduras: {total_nutrients['gorduras']:.1f}g
 
111
  """
112
 
113
+ return analysis, table_data, plot_data
 
114
 
115
  except Exception as e:
116
+ raise gr.Error(f"Erro na análise: {str(e)}")
117
+
118
+ def analyze_image(image):
119
+ """Função principal que coordena o processo de análise"""
120
+ try:
121
+ # Processa a imagem
122
+ description = process_image(image)
123
+
124
+ # Analisa os alimentos
125
+ analysis, table_data, plot_data = analyze_foods(description)
126
+
127
+ return analysis, table_data, plot_data
128
+
129
+ except Exception as e:
130
+ return str(e), None, None
131
 
132
  # Interface Gradio
133
  with gr.Blocks(theme=gr.themes.Soft()) as iface:
134
+ gr.Markdown("""
135
+ # 🍽️ Análise Nutricional com IA
136
+ Faça upload de uma foto do seu prato para análise nutricional.
137
+ """)
138
 
139
  with gr.Row():
140
+ # Coluna de Input
141
+ with gr.Column():
142
  image_input = gr.Image(
143
  type="pil",
144
  label="Foto do Prato",
145
  sources=["upload", "webcam"]
146
  )
147
+ analyze_btn = gr.Button("📊 Analisar", variant="primary")
148
+
149
+ # Coluna de Output
150
+ with gr.Column():
151
+ # Análise textual
152
+ output_text = gr.Markdown(label="Análise")
153
 
154
+ # Tabela nutricional
155
+ output_table = gr.Dataframe(
156
+ headers=["Nutriente", "Quantidade"],
157
+ label="Informação Nutricional"
158
+ )
159
 
160
+ # Gráfico
161
+ output_plot = gr.BarPlot(
162
+ x="Nutriente",
163
+ y="Quantidade",
164
+ title="Macronutrientes (g)",
165
+ height=300
166
+ )
 
 
 
 
 
 
 
 
167
 
168
  # Eventos
169
  analyze_btn.click(
170
  fn=analyze_image,
171
  inputs=[image_input],
172
+ outputs=[output_text, output_table, output_plot]
 
 
 
 
 
173
  )
174
 
175
  if __name__ == "__main__":