rmayormartins's picture
go15
84845c0
import os
import shutil
import gradio as gr
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
from PIL import Image
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix
import tempfile
import warnings
warnings.filterwarnings("ignore")
print("🖥️ Iniciando sistema...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
# Modelos disponíveis
MODELS = {
'ResNet18': models.resnet18,
'ResNet34': models.resnet34,
'ResNet50': models.resnet50,
'MobileNetV2': models.mobilenet_v2
}
# Estado global
class AppState:
def __init__(self):
self.model = None
self.train_loader = None
self.val_loader = None
self.test_loader = None
self.dataset_path = None
self.class_dirs = []
self.class_labels = ['classe_0', 'classe_1']
self.num_classes = 2
self.image_queue = []
state = AppState()
def setup_classes(num_classes_value):
"""Configura número de classes"""
try:
state.num_classes = max(2, min(5, int(num_classes_value)))
state.dataset_path = tempfile.mkdtemp()
state.class_labels = [f'classe_{i}' for i in range(state.num_classes)]
# Criar diretórios
state.class_dirs = []
for i in range(state.num_classes):
class_dir = os.path.join(state.dataset_path, f'classe_{i}')
os.makedirs(class_dir, exist_ok=True)
state.class_dirs.append(class_dir)
return f"✅ Sistema configurado para {state.num_classes} classes"
except Exception as e:
return f"❌ Erro: {str(e)}"
def set_class_labels(labels_text):
"""Define rótulos das classes"""
try:
labels = [label.strip() for label in labels_text.split(',')]
if len(labels) != state.num_classes:
return f"❌ Forneça {state.num_classes} rótulos separados por vírgula"
state.class_labels = labels
return f"✅ Rótulos definidos: {', '.join(state.class_labels)}"
except Exception as e:
return f"❌ Erro: {str(e)}"
def add_images_to_queue(images):
"""Adiciona múltiplas imagens à fila"""
if not images:
return "❌ Nenhuma imagem selecionada", len(state.image_queue)
count = 0
for image_file in images:
try:
if image_file is not None:
# Carregar imagem
img = Image.open(image_file.name).convert('RGB')
state.image_queue.append(img)
count += 1
except Exception as e:
print(f"Erro processando imagem: {e}")
return f"✅ {count} imagens adicionadas. Total na fila: {len(state.image_queue)}", len(state.image_queue)
def save_queue_to_class(class_id):
"""Salva fila de imagens para uma classe"""
try:
if not state.image_queue:
return "❌ Nenhuma imagem na fila"
if not state.class_dirs:
return "❌ Configure as classes primeiro"
class_idx = max(0, min(int(class_id), len(state.class_dirs) - 1))
class_dir = state.class_dirs[class_idx]
count = 0
for i, image in enumerate(state.image_queue):
try:
import time
filename = f"img_{int(time.time())}_{i}.jpg"
filepath = os.path.join(class_dir, filename)
image.save(filepath)
count += 1
except Exception as e:
print(f"Erro salvando imagem {i}: {e}")
state.image_queue = [] # Limpar fila
class_name = state.class_labels[class_idx]
return f"✅ {count} imagens salvas em '{class_name}'"
except Exception as e:
return f"❌ Erro: {str(e)}"
def clear_queue():
"""Limpa a fila"""
state.image_queue = []
return "✅ Fila limpa", 0
def prepare_data(batch_size):
"""Prepara dados"""
try:
if not state.dataset_path:
return "❌ Configure as classes primeiro"
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
dataset = datasets.ImageFolder(state.dataset_path, transform=transform)
if len(dataset) < 6:
return f"❌ Poucas imagens ({len(dataset)}). Mínimo: 6"
# Divisão: 70% treino, 20% val, 10% teste
train_size = int(0.7 * len(dataset))
val_size = int(0.2 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(
dataset, [train_size, val_size, test_size],
generator=torch.Generator().manual_seed(42)
)
batch_size = max(1, min(int(batch_size), 32))
state.train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
state.val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
state.test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
return f"✅ Dados preparados:\n• Treino: {train_size}\n• Validação: {val_size}\n• Teste: {test_size}\n• Batch size: {batch_size}"
except Exception as e:
return f"❌ Erro: {str(e)}"
def train_model(model_name, epochs, lr):
"""Treina modelo"""
try:
if state.train_loader is None:
return "❌ Prepare os dados primeiro"
# Carregar modelo
state.model = MODELS[model_name](pretrained=True)
# Adaptar última camada
if hasattr(state.model, 'fc'):
state.model.fc = nn.Linear(state.model.fc.in_features, state.num_classes)
elif hasattr(state.model, 'classifier'):
if isinstance(state.model.classifier, nn.Sequential):
state.model.classifier[-1] = nn.Linear(state.model.classifier[-1].in_features, state.num_classes)
state.model = state.model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(state.model.parameters(), lr=float(lr))
results = [f"🚀 Treinando {model_name}"]
state.model.train()
for epoch in range(int(epochs)):
running_loss = 0.0
correct = 0
total = 0
for inputs, labels in state.train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = state.model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
epoch_loss = running_loss / len(state.train_loader)
epoch_acc = 100. * correct / total
results.append(f"Época {epoch+1}: Loss={epoch_loss:.4f}, Acc={epoch_acc:.2f}%")
results.append("✅ Treinamento concluído!")
return "\n".join(results)
except Exception as e:
return f"❌ Erro: {str(e)}"
def evaluate_model():
"""Avalia modelo"""
try:
if state.model is None or state.test_loader is None:
return "❌ Modelo/dados não disponíveis"
state.model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
for inputs, labels in state.test_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = state.model(inputs)
_, preds = torch.max(outputs, 1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
report = classification_report(all_labels, all_preds, target_names=state.class_labels, zero_division=0)
return f"📊 RELATÓRIO DE AVALIAÇÃO:\n\n{report}"
except Exception as e:
return f"❌ Erro: {str(e)}"
def generate_confusion_matrix():
"""Gera matriz de confusão"""
try:
if state.model is None or state.test_loader is None:
return None
state.model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
for inputs, labels in state.test_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = state.model(inputs)
_, preds = torch.max(outputs, 1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
xticklabels=state.class_labels,
yticklabels=state.class_labels)
plt.xlabel('Predições')
plt.ylabel('Valores Reais')
plt.title('Matriz de Confusão')
plt.tight_layout()
temp_path = tempfile.NamedTemporaryFile(suffix='.png', delete=False).name
plt.savefig(temp_path, dpi=150, bbox_inches='tight')
plt.close()
return temp_path
except Exception as e:
return None
def predict_image(image):
"""Prediz imagem"""
try:
if state.model is None:
return "❌ Treine o modelo primeiro"
if image is None:
return "❌ Selecione uma imagem"
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
img_tensor = transform(image).unsqueeze(0).to(device)
state.model.eval()
with torch.no_grad():
outputs = state.model(img_tensor)
probs = torch.nn.functional.softmax(outputs[0], dim=0)
_, predicted = torch.max(outputs, 1)
class_id = predicted.item()
confidence = probs[class_id].item() * 100
class_name = state.class_labels[class_id]
return f"🎯 Predição: {class_name}\n📊 Confiança: {confidence:.2f}%"
except Exception as e:
return f"❌ Erro: {str(e)}"
# Interface usando Gradio 3.x (sintaxe correta)
def create_interface():
# Interface com abas usando Gradio 3.x
with gr.Blocks() as demo:
gr.Markdown("# 🖼️ Sistema de Classificação de Imagens Completo")
gr.Markdown("**Versão estável sem bugs - Funcionalidade completa mantida**")
with gr.Tab("1️⃣ Configuração"):
gr.Markdown("### 🎯 Configurar Classes")
num_classes_input = gr.Number(value=2, label="Número de Classes (2-5)")
setup_btn = gr.Button("🔧 Configurar Classes", variant="primary")
setup_output = gr.Textbox(label="Status da Configuração")
gr.Markdown("### 🏷️ Definir Rótulos")
labels_input = gr.Textbox(value="gato,cachorro", label="Rótulos (separados por vírgula)")
labels_btn = gr.Button("🏷️ Definir Rótulos")
labels_output = gr.Textbox(label="Status dos Rótulos")
# Conectar eventos
setup_btn.click(setup_classes, inputs=[num_classes_input], outputs=[setup_output])
labels_btn.click(set_class_labels, inputs=[labels_input], outputs=[labels_output])
with gr.Tab("2️⃣ Upload de Imagens"):
gr.Markdown("### 📤 Upload Múltiplo via Fila")
images_upload = gr.File(file_count="multiple", label="Selecionar Múltiplas Imagens", file_types=["image"])
add_btn = gr.Button("➕ Adicionar à Fila")
with gr.Row():
queue_output = gr.Textbox(label="Status da Fila")
queue_count_output = gr.Number(label="Total na Fila", value=0)
gr.Markdown("### 💾 Salvar por Classe")
with gr.Row():
class_id_input = gr.Number(value=0, label="Classe de Destino (0, 1, 2...)")
save_btn = gr.Button("💾 Salvar Fila na Classe", variant="primary")
clear_btn = gr.Button("🗑️ Limpar Fila")
save_output = gr.Textbox(label="Status do Upload")
# Conectar eventos
add_btn.click(add_images_to_queue, inputs=[images_upload], outputs=[queue_output, queue_count_output])
save_btn.click(save_queue_to_class, inputs=[class_id_input], outputs=[save_output])
clear_btn.click(clear_queue, outputs=[queue_output, queue_count_output])
with gr.Tab("3️⃣ Preparação e Treinamento"):
gr.Markdown("### ⚙️ Preparar Dados")
batch_size_input = gr.Number(value=8, label="Batch Size")
prepare_btn = gr.Button("⚙️ Preparar Dados", variant="primary")
prepare_output = gr.Textbox(label="Status da Preparação", lines=4)
gr.Markdown("### 🚀 Configurar e Treinar Modelo")
with gr.Row():
model_input = gr.Dropdown(choices=list(MODELS.keys()), value="MobileNetV2", label="Modelo")
epochs_input = gr.Number(value=5, label="Épocas")
lr_input = gr.Number(value=0.001, label="Learning Rate")
train_btn = gr.Button("🚀 Iniciar Treinamento", variant="primary")
train_output = gr.Textbox(label="Status do Treinamento", lines=8)
# Conectar eventos
prepare_btn.click(prepare_data, inputs=[batch_size_input], outputs=[prepare_output])
train_btn.click(train_model, inputs=[model_input, epochs_input, lr_input], outputs=[train_output])
with gr.Tab("4️⃣ Avaliação do Modelo"):
gr.Markdown("### 📊 Avaliar Desempenho")
with gr.Row():
eval_btn = gr.Button("📊 Avaliar Modelo", variant="primary")
matrix_btn = gr.Button("📈 Gerar Matriz de Confusão")
eval_output = gr.Textbox(label="Relatório de Avaliação", lines=12)
matrix_output = gr.Image(label="Matriz de Confusão")
# Conectar eventos
eval_btn.click(evaluate_model, outputs=[eval_output])
matrix_btn.click(generate_confusion_matrix, outputs=[matrix_output])
with gr.Tab("5️⃣ Predição"):
gr.Markdown("### 🔮 Predizer Novas Imagens")
predict_image_input = gr.Image(type="pil", label="Imagem para Predição")
predict_btn = gr.Button("🔮 Fazer Predição", variant="primary")
predict_output = gr.Textbox(label="Resultado da Predição", lines=3)
# Conectar eventos
predict_btn.click(predict_image, inputs=[predict_image_input], outputs=[predict_output])
# Informações adicionais
with gr.Tab("ℹ️ Informações"):
gr.Markdown("""
## 📋 Como Usar Este Sistema
### 1️⃣ **Configuração Inicial**
- Defina o número de classes (2-5)
- Configure rótulos personalizados separados por vírgula
### 2️⃣ **Upload de Imagens**
- Selecione múltiplas imagens
- Adicione à fila
- Escolha a classe de destino (0, 1, 2...)
- Salve a fila na classe escolhida
- Repita para todas as classes
### 3️⃣ **Treinamento**
- Configure batch size (recomendado: 8-16)
- Prepare os dados
- Escolha modelo (MobileNetV2 = mais rápido)
- Configure épocas (recomendado: 3-10)
- Inicie o treinamento
### 4️⃣ **Avaliação**
- Avalie o modelo para ver métricas
- Gere matriz de confusão para análise visual
### 5️⃣ **Predição**
- Teste com novas imagens
- Veja predições com níveis de confiança
## 🎯 **Dicas para Melhores Resultados**
- Use pelo menos 10-20 imagens por classe
- Imagens bem balanceadas entre classes
- Imagens claras e bem iluminadas
- Varie poses, ângulos e ambientes
## 🔧 **Modelos Disponíveis**
- **MobileNetV2**: Rápido, ideal para prototipagem
- **ResNet18**: Bom equilíbrio velocidade/precisão
- **ResNet34/50**: Maior precisão, mais lento
""")
return demo
if __name__ == "__main__":
print("🎯 Criando interface...")
demo = create_interface()
print("🚀 Iniciando aplicação...")
demo.launch(server_name="0.0.0.0", server_port=7860)