DHEIVER's picture
Update app.py
bc85663 verified
raw
history blame
7.04 kB
import gradio as gr
import torch
from transformers import (
Blip2Processor, Blip2ForConditionalGeneration,
AutoProcessor, AutoModelForCausalLM, AutoModelForVision2Seq
)
from PIL import Image
import numpy as np
class ModelManager:
def __init__(self):
self.current_model = None
self.current_processor = None
self.model_name = None
def load_blip2(self):
"""Carrega modelo BLIP-2"""
self.model_name = "Salesforce/blip2-opt-2.7b"
self.current_processor = Blip2Processor.from_pretrained(self.model_name)
self.current_model = Blip2ForConditionalGeneration.from_pretrained(
self.model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto"
)
return "BLIP-2 carregado com sucesso!"
def load_llava(self):
"""Carrega modelo LLaVA"""
self.model_name = "llava-hf/llava-1.5-7b-hf"
self.current_processor = AutoProcessor.from_pretrained(self.model_name)
self.current_model = AutoModelForVision2Seq.from_pretrained(
self.model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto"
)
return "LLaVA carregado com sucesso!"
def load_git(self):
"""Carrega modelo GIT"""
self.model_name = "microsoft/git-base-coco"
self.current_processor = AutoProcessor.from_pretrained(self.model_name)
self.current_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto"
)
return "GIT carregado com sucesso!"
def analyze_image(self, image, question, model_choice):
"""Analisa imagem com foco nutricional"""
try:
# Carrega o modelo apropriado se necessário
if model_choice == "BLIP-2" and (self.model_name != "Salesforce/blip2-opt-2.7b"):
status = self.load_blip2()
elif model_choice == "LLaVA" and (self.model_name != "llava-hf/llava-1.5-7b-hf"):
status = self.load_llava()
elif model_choice == "GIT" and (self.model_name != "microsoft/git-base-coco"):
status = self.load_git()
# Adiciona contexto nutricional à pergunta
nutritional_prompt = (
"Como nutricionista, analise este prato considerando: "
"1. Lista de ingredientes principais\n"
"2. Estimativa calórica total\n"
"3. Sugestões para uma versão mais saudável\n"
"4. Análise de grupos alimentares\n"
f"Pergunta do usuário: {question}"
"\nPor favor, responda em português com detalhes nutricionais."
)
# Prepara a imagem
if isinstance(image, str):
image = Image.open(image)
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Processa a entrada
inputs = self.current_processor(
images=image,
text=nutritional_prompt,
return_tensors="pt"
).to(self.current_model.device)
# Gera a resposta
outputs = self.current_model.generate(
**inputs,
max_new_tokens=200, # Aumentado para respostas mais completas
num_beams=5,
temperature=0.7,
top_p=0.9
)
# Decodifica e formata a resposta
response = self.current_processor.decode(outputs[0], skip_special_tokens=True)
formatted_response = response.replace(". ", ".\n").replace("; ", ";\n")
return f"**Análise Nutricional:**\n{formatted_response}"
except Exception as e:
return f"Erro na análise: {str(e)}"
# Cria instância do gerenciador de modelos
model_manager = ModelManager()
# Interface Gradio
with gr.Blocks(theme=gr.themes.Soft()) as iface:
gr.Markdown("""
# 🥗 Analisador Nutricional Inteligente
Escolha o modelo que deseja usar para analisar seu prato e obter recomendações nutricionais.
""")
with gr.Row():
with gr.Column():
# Inputs
model_choice = gr.Radio(
choices=["BLIP-2", "LLaVA", "GIT"],
label="Escolha o Modelo",
value="BLIP-2"
)
# Substitui gr.Box() por gr.Group() para compatibilidade
with gr.Group():
gr.Markdown("""
### 📝 Características dos Modelos:
**BLIP-2:**
- Análise detalhada de ingredientes
- Estimativas calóricas mais precisas
- Recomendações técnicas
**LLaVA:**
- Explicações mais conversacionais
- Sugestões práticas para o dia a dia
- Foco em hábitos alimentares
**GIT:**
- Respostas rápidas e diretas
- Ideal para análises simples
- Menor consumo de recursos
""")
image_input = gr.Image(
type="pil",
label="Foto do Prato"
)
question_input = gr.Textbox(
label="Sua Pergunta",
placeholder="Ex: Quantas calorias tem este prato? Como posso torná-lo mais saudável?"
)
analyze_btn = gr.Button("🔍 Analisar", variant="primary")
with gr.Column():
# Output
with gr.Group(): # Substitui gr.Box() por gr.Group()
gr.Markdown("### 📊 Resultado da Análise")
output_text = gr.Markdown()
with gr.Accordion("💡 Sugestões de Perguntas", open=False):
gr.Markdown("""
1. Quantas calorias tem este prato?
2. Quais são os ingredientes principais?
3. Como posso tornar este prato mais saudável?
4. Este prato é adequado para uma dieta low-carb?
5. Quais nutrientes estão presentes neste prato?
6. Este prato é rico em proteínas?
7. Como posso substituir ingredientes para reduzir calorias?
8. Este prato é indicado para quem tem restrição a glúten/lactose?
""")
# Eventos
analyze_btn.click(
fn=model_manager.analyze_image,
inputs=[image_input, question_input, model_choice],
outputs=output_text
)
if __name__ == "__main__":
print(f"Dispositivo: {'CUDA' if torch.cuda.is_available() else 'CPU'}")
iface.launch()