Spaces:
Build error
Build error
import gradio as gr | |
import torch | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
import pandas as pd | |
import numpy as np | |
from PIL import Image | |
def get_model(): | |
"""Inicializa o modelo uma única vez""" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model_name = "microsoft/git-base-coco" | |
processor = AutoProcessor.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
model = model.to(device) | |
return processor, model, device | |
# Base de dados nutricional simplificada | |
NUTRITION_DB = { | |
"arroz": {"calorias": 130, "proteinas": 2.7, "carboidratos": 28, "gorduras": 0.3}, | |
"feijão": {"calorias": 77, "proteinas": 5.2, "carboidratos": 13.6, "gorduras": 0.5}, | |
"frango": {"calorias": 165, "proteinas": 31, "carboidratos": 0, "gorduras": 3.6}, | |
"salada": {"calorias": 15, "proteinas": 1.4, "carboidratos": 2.9, "gorduras": 0.2}, | |
} | |
def process_image(image, progress=gr.Progress()): | |
"""Processa a imagem e retorna a descrição""" | |
try: | |
progress(0.3, desc="Carregando modelo...") | |
processor, model, device = get_model() | |
progress(0.5, desc="Processando imagem...") | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
# Processa a imagem | |
inputs = processor(images=image, return_tensors="pt").to(device) | |
progress(0.7, desc="Gerando descrição...") | |
# Gera a descrição | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=50, | |
num_beams=1, | |
temperature=1.0, | |
) | |
# Decodifica a saída | |
description = processor.decode(outputs[0], skip_special_tokens=True) | |
progress(1.0, desc="Concluído!") | |
return description.strip() | |
except Exception as e: | |
raise gr.Error(f"Erro no processamento da imagem: {str(e)}") | |
def analyze_foods(description): | |
"""Analisa a descrição e retorna informações nutricionais""" | |
try: | |
# Identifica alimentos da base de dados na descrição | |
found_foods = [] | |
for food in NUTRITION_DB.keys(): | |
if food in description.lower(): | |
found_foods.append(food) | |
if not found_foods: | |
return "Nenhum alimento conhecido identificado.", None, None | |
# Calcula nutrientes | |
total_nutrients = { | |
"calorias": 0, | |
"proteinas": 0, | |
"carboidratos": 0, | |
"gorduras": 0 | |
} | |
for food in found_foods: | |
for nutrient, value in NUTRITION_DB[food].items(): | |
total_nutrients[nutrient] += value | |
# Prepara dados para visualização | |
table_data = [ | |
["Calorias", f"{total_nutrients['calorias']:.1f} kcal"], | |
["Proteínas", f"{total_nutrients['proteinas']:.1f}g"], | |
["Carboidratos", f"{total_nutrients['carboidratos']:.1f}g"], | |
["Gorduras", f"{total_nutrients['gorduras']:.1f}g"] | |
] | |
# Dados para o gráfico | |
plot_data = pd.DataFrame({ | |
'Nutriente': ['Proteínas', 'Carboidratos', 'Gorduras'], | |
'Quantidade': [ | |
total_nutrients['proteinas'], | |
total_nutrients['carboidratos'], | |
total_nutrients['gorduras'] | |
] | |
}) | |
analysis = f"""### Alimentos Identificados: | |
• {', '.join(found_foods)} | |
### Descrição do Modelo: | |
{description} | |
### Análise Nutricional: | |
• Calorias Totais: {total_nutrients['calorias']:.1f} kcal | |
• Proteínas: {total_nutrients['proteinas']:.1f}g | |
• Carboidratos: {total_nutrients['carboidratos']:.1f}g | |
• Gorduras: {total_nutrients['gorduras']:.1f}g | |
""" | |
return analysis, table_data, plot_data | |
except Exception as e: | |
raise gr.Error(f"Erro na análise: {str(e)}") | |
def analyze_image(image): | |
"""Função principal que coordena o processo de análise""" | |
try: | |
# Processa a imagem | |
description = process_image(image) | |
# Analisa os alimentos | |
analysis, table_data, plot_data = analyze_foods(description) | |
return analysis, table_data, plot_data | |
except Exception as e: | |
return str(e), None, None | |
# Interface Gradio | |
with gr.Blocks(theme=gr.themes.Soft()) as iface: | |
gr.Markdown(""" | |
# 🍽️ Análise Nutricional com IA | |
Faça upload de uma foto do seu prato para análise nutricional. | |
""") | |
with gr.Row(): | |
# Coluna de Input | |
with gr.Column(): | |
image_input = gr.Image( | |
type="pil", | |
label="Foto do Prato", | |
sources=["upload", "webcam"] | |
) | |
analyze_btn = gr.Button("📊 Analisar", variant="primary") | |
# Coluna de Output | |
with gr.Column(): | |
# Análise textual | |
output_text = gr.Markdown(label="Análise") | |
# Tabela nutricional | |
output_table = gr.Dataframe( | |
headers=["Nutriente", "Quantidade"], | |
label="Informação Nutricional" | |
) | |
# Gráfico | |
output_plot = gr.BarPlot( | |
x="Nutriente", | |
y="Quantidade", | |
title="Macronutrientes (g)", | |
height=300 | |
) | |
# Eventos | |
analyze_btn.click( | |
fn=analyze_image, | |
inputs=[image_input], | |
outputs=[output_text, output_table, output_plot] | |
) | |
if __name__ == "__main__": | |
iface.launch() |