DHEIVER's picture
Update app.py
15646ab verified
raw
history blame
11.3 kB
import gradio as gr
import torch
from transformers import pipeline, AutoProcessor, AutoModelForVision2Seq
from PIL import Image, ImageOps
import numpy as np
import os
from huggingface_hub import snapshot_download
import logging
from pathlib import Path
import tempfile
import requests
from io import BytesIO
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ImageHandler:
"""Handle image processing and conversion"""
@staticmethod
def convert_to_rgb(image_path):
"""Convert image to RGB format supporting multiple formats"""
try:
# If image is a URL, download it first
if isinstance(image_path, str) and (image_path.startswith('http://') or image_path.startswith('https://')):
response = requests.get(image_path)
image_data = BytesIO(response.content)
image = Image.open(image_data)
else:
image = Image.open(image_path)
# Convert RGBA to RGB if needed
if image.mode == 'RGBA':
background = Image.new('RGB', image.size, (255, 255, 255))
background.paste(image, mask=image.split()[3])
image = background
# Convert any other mode to RGB
elif image.mode != 'RGB':
image = image.convert('RGB')
return image
except Exception as e:
logger.error(f"Error converting image: {str(e)}")
raise ValueError(f"Não foi possível processar a imagem. Erro: {str(e)}")
@staticmethod
def process_image(image):
"""Process image from various input types"""
try:
if isinstance(image, np.ndarray):
return Image.fromarray(image)
elif isinstance(image, Image.Image):
return image
elif isinstance(image, (str, Path)):
return ImageHandler.convert_to_rgb(image)
else:
raise ValueError("Formato de imagem não suportado")
except Exception as e:
logger.error(f"Error processing image: {str(e)}")
raise ValueError(f"Erro no processamento da imagem: {str(e)}")
class NutritionalAnalyzer:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.models = {}
self.processors = {}
self.image_handler = ImageHandler()
def initialize_model(self, model_name):
"""Initialize a specific model with error handling and caching"""
try:
if model_name not in self.models:
logger.info(f"Initializing {model_name}...")
# Model-specific configurations
model_configs = {
"llava": {
"repo": "llava-hf/llava-1.5-7b-hf",
"local_cache": "models/llava"
},
"git": {
"repo": "microsoft/git-base-coco",
"local_cache": "models/git"
}
}
config = model_configs.get(model_name)
if not config:
raise ValueError(f"Modelo não suportado: {model_name}")
# Ensure cache directory exists
os.makedirs(config["local_cache"], exist_ok=True)
# Download model if not cached
if not os.path.exists(os.path.join(config["local_cache"], "model.safetensors")):
snapshot_download(
repo_id=config["repo"],
local_dir=config["local_cache"],
ignore_patterns=["*.md", "*.txt"]
)
# Load processor and model
self.processors[model_name] = AutoProcessor.from_pretrained(
config["local_cache"],
local_files_only=True
)
self.models[model_name] = AutoModelForVision2Seq.from_pretrained(
config["local_cache"],
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
device_map="auto",
local_files_only=True
)
logger.info(f"{model_name} initialized successfully")
return True
return True
except Exception as e:
logger.error(f"Error initializing {model_name}: {str(e)}")
return False
def generate_nutritional_prompt(self, user_question):
"""Generate a comprehensive nutritional analysis prompt"""
return f"""Como nutricionista especializado, analise esta refeição detalhadamente:
1. Composição do Prato:
- Ingredientes principais
- Proporções aproximadas
- Método de preparo aparente
2. Análise Nutricional:
- Estimativa calórica
- Macronutrientes (proteínas, carboidratos, gorduras)
- Principais micronutrientes
3. Recomendações:
- Sugestões para versão mais saudável
- Porção recomendada
- Adequação para dietas específicas
Pergunta específica do usuário: {user_question}
Por favor, forneça uma análise detalhada em português."""
def analyze_image(self, image, question, model_choice):
"""Analyze image with nutritional focus"""
try:
if image is None:
return "Por favor, faça upload de uma imagem para análise."
# Convert model choice to internal name
model_name = model_choice.lower().replace("-", "")
# Initialize model if needed
if not self.initialize_model(model_name):
return "Erro: Não foi possível inicializar o modelo. Por favor, tente novamente."
# Process image with enhanced error handling
try:
processed_image = self.image_handler.process_image(image)
except ValueError as e:
return str(e)
except Exception as e:
return f"Erro no processamento da imagem: {str(e)}"
# Generate and process prompt
nutritional_prompt = self.generate_nutritional_prompt(question)
# Process input
try:
inputs = self.processors[model_name](
images=processed_image,
text=nutritional_prompt,
return_tensors="pt"
).to(self.device)
except Exception as e:
return f"Erro no processamento do modelo: {str(e)}"
# Generate response with enhanced parameters
try:
with torch.no_grad():
outputs = self.models[model_name].generate(
**inputs,
max_new_tokens=300,
num_beams=5,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.2,
length_penalty=1.0
)
# Decode and format response
response = self.processors[model_name].decode(outputs[0], skip_special_tokens=True)
formatted_response = self.format_response(response)
return formatted_response
except Exception as e:
return f"Erro na geração da análise: {str(e)}"
except Exception as e:
logger.error(f"Analysis error: {str(e)}")
return f"Erro na análise: {str(e)}\nPor favor, tente novamente ou escolha outro modelo."
def format_response(self, response):
"""Format the response for better readability"""
sections = [
"Composição do Prato",
"Análise Nutricional",
"Recomendações"
]
formatted = "# 📊 Análise Nutricional\n\n"
# Split response into paragraphs
paragraphs = response.split("\n")
current_section = ""
for paragraph in paragraphs:
# Check if paragraph starts a new section
for section in sections:
if section.lower() in paragraph.lower():
current_section = f"\n## {section}\n"
formatted += current_section
break
# Add paragraph to current section
if paragraph.strip() and current_section:
formatted += f"- {paragraph.strip()}\n"
elif paragraph.strip():
formatted += f"{paragraph.strip()}\n"
return formatted
# Create interface
def create_interface():
analyzer = NutritionalAnalyzer()
with gr.Blocks(theme=gr.themes.Soft()) as iface:
gr.Markdown("""
# 🥗 Análise Nutricional Inteligente
Faça upload da foto do seu prato para receber uma análise nutricional detalhada.
""")
with gr.Row():
with gr.Column(scale=2):
image_input = gr.Image(
type="pil",
label="📸 Foto do Prato",
height=400
)
question_input = gr.Textbox(
label="💭 Sua Pergunta",
placeholder="Ex: Quais são os nutrientes principais deste prato?",
lines=2
)
model_choice = gr.Radio(
choices=["LLaVA", "GIT"],
value="LLaVA",
label="🤖 Escolha o Modelo de Análise"
)
analyze_btn = gr.Button(
"🔍 Analisar Prato",
variant="primary",
scale=1
)
with gr.Column(scale=3):
output = gr.Markdown(label="Resultado da Análise")
# Add examples and tips
with gr.Accordion("💡 Dicas de Uso", open=False):
gr.Markdown("""
### Sugestões de Perguntas:
- Qual o valor nutricional aproximado deste prato?
- Como tornar esta refeição mais equilibrada?
- Este prato é adequado para dieta low-carb?
- Quais nutrientes importantes estão presentes?
### Dicas para Melhores Resultados:
1. Tire a foto com boa iluminação
2. Capture todos os elementos do prato
3. Evite ângulos muito inclinados
4. Seja específico em suas perguntas
5. Formatos de imagem suportados: JPG, PNG, WEBP
""")
analyze_btn.click(
fn=analyzer.analyze_image,
inputs=[image_input, question_input, model_choice],
outputs=output
)
return iface
if __name__ == "__main__":
iface = create_interface()
iface.launch()