rmayormartins's picture
go3
6943d4d
raw
history blame
15.4 kB
import os
import shutil
import gradio as gr
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix
import tempfile
import warnings
warnings.filterwarnings("ignore")
# Configuração do device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🖥️ Usando device: {device}")
# Modelos disponíveis
MODELS = {
'ResNet18': models.resnet18,
'MobileNetV2': models.mobilenet_v2
}
# Estado global da aplicação
class AppState:
def __init__(self):
self.model = None
self.train_loader = None
self.val_loader = None
self.test_loader = None
self.dataset_path = None
self.class_dirs = []
self.class_labels = []
self.num_classes = 2
# Instância global do estado
app_state = AppState()
def setup_classes(num_classes_value):
"""Configura o número de classes e cria diretórios"""
try:
app_state.num_classes = int(num_classes_value)
# Criar diretório temporário
app_state.dataset_path = tempfile.mkdtemp()
# Inicializar rótulos padrão
app_state.class_labels = [f'classe_{i}' for i in range(app_state.num_classes)]
# Criar diretórios para cada classe
app_state.class_dirs = []
for i in range(app_state.num_classes):
class_dir = os.path.join(app_state.dataset_path, f'classe_{i}')
os.makedirs(class_dir, exist_ok=True)
app_state.class_dirs.append(class_dir)
choices = [(f"{i} - {app_state.class_labels[i]}", i) for i in range(app_state.num_classes)]
return (
f"✅ Criados {app_state.num_classes} diretórios para classes",
gr.Dropdown(choices=choices, value=0)
)
except Exception as e:
return f"❌ Erro: {str(e)}", gr.Dropdown()
def set_class_labels(label0, label1, label2, label3, label4):
"""Define rótulos personalizados para as classes"""
try:
labels = [label0, label1, label2, label3, label4]
filtered_labels = [label.strip() for label in labels if label.strip()][:app_state.num_classes]
if len(filtered_labels) != app_state.num_classes:
return f"❌ Erro: Forneça exatamente {app_state.num_classes} rótulos.", gr.Dropdown()
app_state.class_labels = filtered_labels
choices = [(f"{i} - {app_state.class_labels[i]}", i) for i in range(app_state.num_classes)]
return (
f"✅ Rótulos definidos: {', '.join(app_state.class_labels)}",
gr.Dropdown(choices=choices, value=0)
)
except Exception as e:
return f"❌ Erro: {str(e)}", gr.Dropdown()
def upload_images(class_id, images):
"""Faz upload das imagens para a classe especificada"""
try:
if not images:
return "❌ Nenhuma imagem selecionada."
if int(class_id) >= len(app_state.class_dirs):
return f"❌ Classe {class_id} inválida."
class_dir = app_state.class_dirs[int(class_id)]
count = 0
for image in images:
if image is not None:
shutil.copy2(image, class_dir)
count += 1
class_name = app_state.class_labels[int(class_id)]
return f"✅ {count} imagens salvas na classe {class_id} ({class_name})"
except Exception as e:
return f"❌ Erro: {str(e)}"
def prepare_data(batch_size):
"""Prepara os dados para treinamento"""
try:
if not app_state.dataset_path or not os.path.exists(app_state.dataset_path):
return "❌ Configure as classes primeiro."
# Transformações
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
dataset = datasets.ImageFolder(app_state.dataset_path, transform=transform)
if len(dataset.classes) == 0:
return "❌ Nenhuma classe encontrada. Faça upload das imagens primeiro."
if len(dataset) < 6:
return f"❌ Muito poucas imagens ({len(dataset)}). Adicione pelo menos 2 imagens por classe."
# Divisão dos dados
train_size = int(0.7 * len(dataset))
val_size = int(0.2 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(
dataset, [train_size, val_size, test_size],
generator=torch.Generator().manual_seed(42)
)
app_state.train_loader = DataLoader(train_dataset, batch_size=int(batch_size), shuffle=True)
app_state.val_loader = DataLoader(val_dataset, batch_size=int(batch_size), shuffle=False)
app_state.test_loader = DataLoader(test_dataset, batch_size=int(batch_size), shuffle=False)
return f"✅ Dados preparados: {train_size} treino, {val_size} validação, {test_size} teste"
except Exception as e:
return f"❌ Erro na preparação: {str(e)}"
def start_training(model_name, epochs, lr):
"""Inicia o treinamento do modelo"""
try:
if app_state.train_loader is None:
return "❌ Erro: Dados não preparados."
# Carregar modelo
app_state.model = MODELS[model_name](pretrained=True)
# Adaptar última camada
if hasattr(app_state.model, 'fc'):
app_state.model.fc = nn.Linear(app_state.model.fc.in_features, app_state.num_classes)
elif hasattr(app_state.model, 'classifier'):
if isinstance(app_state.model.classifier, nn.Sequential):
app_state.model.classifier[-1] = nn.Linear(app_state.model.classifier[-1].in_features, app_state.num_classes)
else:
app_state.model.classifier = nn.Linear(app_state.model.classifier.in_features, app_state.num_classes)
app_state.model = app_state.model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(app_state.model.parameters(), lr=float(lr))
app_state.model.train()
results = [f"🚀 Treinando {model_name} por {epochs} épocas"]
for epoch in range(int(epochs)):
running_loss = 0.0
correct = 0
total = 0
for inputs, labels in app_state.train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = app_state.model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
epoch_loss = running_loss / len(app_state.train_loader)
epoch_acc = 100. * correct / total
results.append(f"Época {epoch+1}: Loss={epoch_loss:.4f}, Acc={epoch_acc:.2f}%")
results.append("✅ Treinamento concluído!")
return "\n".join(results)
except Exception as e:
return f"❌ Erro durante treinamento: {str(e)}"
def evaluate_model():
"""Avalia o modelo no conjunto de teste"""
try:
if app_state.model is None or app_state.test_loader is None:
return "❌ Modelo ou dados não disponíveis."
app_state.model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
for inputs, labels in app_state.test_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = app_state.model(inputs)
_, preds = torch.max(outputs, 1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
report = classification_report(all_labels, all_preds, target_names=app_state.class_labels, zero_division=0)
return f"📊 RELATÓRIO DE CLASSIFICAÇÃO:\n\n{report}"
except Exception as e:
return f"❌ Erro durante avaliação: {str(e)}"
def predict_images(images):
"""Faz predições em novas imagens"""
try:
if app_state.model is None:
return "❌ Modelo não treinado."
if not images:
return "❌ Nenhuma imagem selecionada."
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
app_state.model.eval()
results = []
for image_path in images:
if image_path is not None:
image = Image.open(image_path).convert('RGB')
img_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = app_state.model(img_tensor)
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
_, predicted = torch.max(outputs, 1)
predicted_class_id = predicted.item()
confidence = probabilities[predicted_class_id].item() * 100
predicted_class_name = app_state.class_labels[predicted_class_id]
results.append(f"📸 {os.path.basename(image_path)}")
results.append(f" 🎯 Classe: {predicted_class_name}")
results.append(f" 📊 Confiança: {confidence:.2f}%")
results.append("-" * 40)
return "\n".join(results) if results else "❌ Nenhuma predição realizada."
except Exception as e:
return f"❌ Erro: {str(e)}"
# Interface Gradio
def create_interface():
with gr.Blocks(title="🖼️ Classificador de Imagens", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🖼️ Sistema de Classificação de Imagens
**Instruções:**
1. Configure as classes e rótulos
2. Faça upload das imagens
3. Prepare os dados e treine
4. Avalie e faça predições!
""")
with gr.Tab("1️⃣ Configuração"):
with gr.Row():
num_classes_input = gr.Number(
label="Número de Classes",
value=2,
minimum=2,
maximum=5,
precision=0
)
setup_button = gr.Button("🔧 Configurar Classes", variant="primary")
setup_output = gr.Textbox(label="Status", lines=2)
gr.Markdown("### Rótulos das Classes")
with gr.Row():
label0 = gr.Textbox(label="Classe 0", placeholder="Ex: gato")
label1 = gr.Textbox(label="Classe 1", placeholder="Ex: cachorro")
with gr.Row():
label2 = gr.Textbox(label="Classe 2", placeholder="Ex: pássaro", visible=False)
label3 = gr.Textbox(label="Classe 3", placeholder="Ex: peixe", visible=False)
label4 = gr.Textbox(label="Classe 4", placeholder="Ex: hamster", visible=False)
set_labels_button = gr.Button("🏷️ Definir Rótulos")
labels_output = gr.Textbox(label="Status dos Rótulos")
# Dropdown que será atualizado
class_selector = gr.Dropdown(
label="Selecionar Classe",
choices=[(f"Classe 0", 0), (f"Classe 1", 1)],
value=0
)
with gr.Tab("2️⃣ Upload"):
images_upload = gr.File(
label="Selecionar Imagens",
file_count="multiple",
file_types=["image"]
)
upload_button = gr.Button("📤 Fazer Upload", variant="primary")
upload_output = gr.Textbox(label="Status do Upload")
with gr.Tab("3️⃣ Treinamento"):
batch_size = gr.Number(label="Batch Size", value=8, minimum=1, maximum=32)
prepare_button = gr.Button("⚙️ Preparar Dados", variant="primary")
prepare_output = gr.Textbox(label="Status", lines=3)
with gr.Row():
model_name = gr.Dropdown(
label="Modelo",
choices=list(MODELS.keys()),
value="MobileNetV2"
)
epochs = gr.Number(label="Épocas", value=3, minimum=1, maximum=10)
lr = gr.Number(label="Learning Rate", value=0.001, minimum=0.0001, maximum=0.1)
train_button = gr.Button("🚀 Treinar", variant="primary")
train_output = gr.Textbox(label="Status do Treinamento", lines=10)
with gr.Tab("4️⃣ Avaliação"):
eval_button = gr.Button("📊 Avaliar", variant="primary")
eval_output = gr.Textbox(label="Relatório", lines=15)
with gr.Tab("5️⃣ Predição"):
predict_images_input = gr.File(
label="Imagens para Predição",
file_count="multiple",
file_types=["image"]
)
predict_button = gr.Button("🔮 Predizer", variant="primary")
predict_output = gr.Textbox(label="Resultados", lines=10)
# Conectar eventos
setup_button.click(
fn=setup_classes,
inputs=[num_classes_input],
outputs=[setup_output, class_selector]
)
set_labels_button.click(
fn=set_class_labels,
inputs=[label0, label1, label2, label3, label4],
outputs=[labels_output, class_selector]
)
upload_button.click(
fn=upload_images,
inputs=[class_selector, images_upload],
outputs=[upload_output]
)
prepare_button.click(
fn=prepare_data,
inputs=[batch_size],
outputs=[prepare_output]
)
train_button.click(
fn=start_training,
inputs=[model_name, epochs, lr],
outputs=[train_output]
)
eval_button.click(
fn=evaluate_model,
outputs=[eval_output]
)
predict_button.click(
fn=predict_images,
inputs=[predict_images_input],
outputs=[predict_output]
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch(server_name="0.0.0.0", server_port=7860)