Spaces:
Build error
Build error
import gradio as gr | |
import torch | |
from transformers import ( | |
Blip2Processor, Blip2ForConditionalGeneration, | |
AutoProcessor, AutoModelForCausalLM, AutoModelForVision2Seq | |
) | |
from PIL import Image | |
import numpy as np | |
class ModelManager: | |
def __init__(self): | |
self.current_model = None | |
self.current_processor = None | |
self.model_name = None | |
def load_blip2(self): | |
"""Carrega modelo BLIP-2""" | |
self.model_name = "Salesforce/blip2-opt-2.7b" | |
self.current_processor = Blip2Processor.from_pretrained(self.model_name) | |
self.current_model = Blip2ForConditionalGeneration.from_pretrained( | |
self.model_name, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
device_map="auto" | |
) | |
return "BLIP-2 carregado com sucesso!" | |
def load_llava(self): | |
"""Carrega modelo LLaVA""" | |
self.model_name = "llava-hf/llava-1.5-7b-hf" | |
self.current_processor = AutoProcessor.from_pretrained(self.model_name) | |
self.current_model = AutoModelForVision2Seq.from_pretrained( | |
self.model_name, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
device_map="auto" | |
) | |
return "LLaVA carregado com sucesso!" | |
def load_git(self): | |
"""Carrega modelo GIT""" | |
self.model_name = "microsoft/git-base-coco" | |
self.current_processor = AutoProcessor.from_pretrained(self.model_name) | |
self.current_model = AutoModelForCausalLM.from_pretrained( | |
self.model_name, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
device_map="auto" | |
) | |
return "GIT carregado com sucesso!" | |
def analyze_image(self, image, question, model_choice): | |
"""Analisa imagem com foco nutricional""" | |
try: | |
# Carrega o modelo apropriado se necessário | |
if model_choice == "BLIP-2" and (self.model_name != "Salesforce/blip2-opt-2.7b"): | |
status = self.load_blip2() | |
elif model_choice == "LLaVA" and (self.model_name != "llava-hf/llava-1.5-7b-hf"): | |
status = self.load_llava() | |
elif model_choice == "GIT" and (self.model_name != "microsoft/git-base-coco"): | |
status = self.load_git() | |
# Adiciona contexto nutricional à pergunta | |
nutritional_prompt = ( | |
"Como nutricionista, analise este prato considerando: " | |
"1. Lista de ingredientes principais\n" | |
"2. Estimativa calórica total\n" | |
"3. Sugestões para uma versão mais saudável\n" | |
"4. Análise de grupos alimentares\n" | |
f"Pergunta do usuário: {question}" | |
"\nPor favor, responda em português com detalhes nutricionais." | |
) | |
# Prepara a imagem | |
if isinstance(image, str): | |
image = Image.open(image) | |
elif isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
# Processa a entrada | |
inputs = self.current_processor( | |
images=image, | |
text=nutritional_prompt, | |
return_tensors="pt" | |
).to(self.current_model.device) | |
# Gera a resposta | |
outputs = self.current_model.generate( | |
**inputs, | |
max_new_tokens=200, # Aumentado para respostas mais completas | |
num_beams=5, | |
temperature=0.7, | |
top_p=0.9 | |
) | |
# Decodifica e formata a resposta | |
response = self.current_processor.decode(outputs[0], skip_special_tokens=True) | |
formatted_response = response.replace(". ", ".\n").replace("; ", ";\n") | |
return f"**Análise Nutricional:**\n{formatted_response}" | |
except Exception as e: | |
return f"Erro na análise: {str(e)}" | |
# Cria instância do gerenciador de modelos | |
model_manager = ModelManager() | |
# Interface Gradio | |
with gr.Blocks(theme=gr.themes.Soft()) as iface: | |
gr.Markdown(""" | |
# 🥗 Analisador Nutricional Inteligente | |
Escolha o modelo que deseja usar para analisar seu prato e obter recomendações nutricionais. | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
# Inputs | |
model_choice = gr.Radio( | |
choices=["BLIP-2", "LLaVA", "GIT"], | |
label="Escolha o Modelo", | |
value="BLIP-2" | |
) | |
# Substitui gr.Box() por gr.Group() para compatibilidade | |
with gr.Group(): | |
gr.Markdown(""" | |
### 📝 Características dos Modelos: | |
**BLIP-2:** | |
- Análise detalhada de ingredientes | |
- Estimativas calóricas mais precisas | |
- Recomendações técnicas | |
**LLaVA:** | |
- Explicações mais conversacionais | |
- Sugestões práticas para o dia a dia | |
- Foco em hábitos alimentares | |
**GIT:** | |
- Respostas rápidas e diretas | |
- Ideal para análises simples | |
- Menor consumo de recursos | |
""") | |
image_input = gr.Image( | |
type="pil", | |
label="Foto do Prato" | |
) | |
question_input = gr.Textbox( | |
label="Sua Pergunta", | |
placeholder="Ex: Quantas calorias tem este prato? Como posso torná-lo mais saudável?" | |
) | |
analyze_btn = gr.Button("🔍 Analisar", variant="primary") | |
with gr.Column(): | |
# Output | |
with gr.Group(): # Substitui gr.Box() por gr.Group() | |
gr.Markdown("### 📊 Resultado da Análise") | |
output_text = gr.Markdown() | |
with gr.Accordion("💡 Sugestões de Perguntas", open=False): | |
gr.Markdown(""" | |
1. Quantas calorias tem este prato? | |
2. Quais são os ingredientes principais? | |
3. Como posso tornar este prato mais saudável? | |
4. Este prato é adequado para uma dieta low-carb? | |
5. Quais nutrientes estão presentes neste prato? | |
6. Este prato é rico em proteínas? | |
7. Como posso substituir ingredientes para reduzir calorias? | |
8. Este prato é indicado para quem tem restrição a glúten/lactose? | |
""") | |
# Eventos | |
analyze_btn.click( | |
fn=model_manager.analyze_image, | |
inputs=[image_input, question_input, model_choice], | |
outputs=output_text | |
) | |
if __name__ == "__main__": | |
print(f"Dispositivo: {'CUDA' if torch.cuda.is_available() else 'CPU'}") | |
iface.launch() |