rmayormartins's picture
go11
bdd4371
raw
history blame
7.89 kB
import os
import shutil
import gradio as gr
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
from PIL import Image
import tempfile
import warnings
warnings.filterwarnings("ignore")
# Estado global simples
model = None
train_loader = None
test_loader = None
dataset_path = None
class_names = ["classe_0", "classe_1"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def setup_dataset():
"""Cria estrutura de pastas"""
global dataset_path
dataset_path = tempfile.mkdtemp()
# Criar pastas para 2 classes
for i in range(2):
os.makedirs(os.path.join(dataset_path, f"classe_{i}"), exist_ok=True)
return f"✅ Dataset criado em: {dataset_path}"
def save_image(image, class_id):
"""Salva uma imagem na classe especificada"""
if dataset_path is None:
return "❌ Execute 'Criar Dataset' primeiro"
if image is None:
return "❌ Selecione uma imagem"
try:
class_dir = os.path.join(dataset_path, f"classe_{int(class_id)}")
# Salvar imagem
import time
filename = f"img_{int(time.time())}.jpg"
filepath = os.path.join(class_dir, filename)
image.save(filepath)
return f"✅ Imagem salva na classe {int(class_id)}"
except Exception as e:
return f"❌ Erro: {str(e)}"
def prepare_and_train():
"""Prepara dados e treina modelo"""
global model, train_loader, test_loader
try:
if dataset_path is None:
return "❌ Crie o dataset primeiro"
# Transformações
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Carregar dataset
dataset = datasets.ImageFolder(dataset_path, transform=transform)
if len(dataset) < 4:
return f"❌ Poucas imagens ({len(dataset)}). Adicione pelo menos 2 por classe."
# Dividir dados: 70% treino, 30% teste
train_size = int(0.7 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)
# Carregar modelo
model = models.mobilenet_v2(pretrained=True)
model.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(model.classifier[1].in_features, 2)
)
model = model.to(device)
# Treinar
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()
for epoch in range(3): # Apenas 3 épocas
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
return f"✅ Modelo treinado! Dataset: {train_size} treino, {test_size} teste"
except Exception as e:
return f"❌ Erro: {str(e)}"
def evaluate_model():
"""Avalia modelo"""
global model, test_loader
if model is None or test_loader is None:
return "❌ Treine o modelo primeiro"
try:
model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total if total > 0 else 0
return f"📊 Acurácia: {accuracy:.2f}% ({correct}/{total})"
except Exception as e:
return f"❌ Erro: {str(e)}"
def predict_single_image(image):
"""Prediz uma única imagem"""
global model
if model is None:
return "❌ Treine o modelo primeiro"
if image is None:
return "❌ Selecione uma imagem"
try:
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
img_tensor = transform(image).unsqueeze(0).to(device)
model.eval()
with torch.no_grad():
outputs = model(img_tensor)
probs = torch.nn.functional.softmax(outputs[0], dim=0)
_, predicted = torch.max(outputs, 1)
class_id = predicted.item()
confidence = probs[class_id].item() * 100
class_name = class_names[class_id]
return f"🎯 Predição: {class_name}\n📊 Confiança: {confidence:.2f}%"
except Exception as e:
return f"❌ Erro: {str(e)}"
def set_class_names(name0, name1):
"""Define nomes das classes"""
global class_names
if not name0.strip() or not name1.strip():
return "❌ Preencha ambos os nomes"
class_names = [name0.strip(), name1.strip()]
return f"✅ Classes: {class_names[0]} e {class_names[1]}"
# Interface ultra-simples
with gr.Blocks(title="🖼️ Classificador Simples") as demo:
gr.Markdown("# 🖼️ Classificador de Imagens Simples")
with gr.Row():
with gr.Column():
gr.Markdown("### 1️⃣ Configurar Classes")
class_0_name = gr.Textbox(label="Nome Classe 0", value="gato")
class_1_name = gr.Textbox(label="Nome Classe 1", value="cachorro")
set_names_btn = gr.Button("🏷️ Definir Nomes")
names_status = gr.Textbox(label="Status")
gr.Markdown("### 2️⃣ Criar Dataset")
create_btn = gr.Button("🔧 Criar Dataset", variant="primary")
create_status = gr.Textbox(label="Status")
with gr.Column():
gr.Markdown("### 3️⃣ Adicionar Imagens")
upload_image = gr.Image(type="pil", label="Imagem")
class_selector = gr.Number(label="Classe (0 ou 1)", value=0, precision=0)
save_btn = gr.Button("💾 Salvar Imagem")
save_status = gr.Textbox(label="Status")
with gr.Row():
with gr.Column():
gr.Markdown("### 4️⃣ Treinar")
train_btn = gr.Button("🚀 Preparar + Treinar", variant="primary")
train_status = gr.Textbox(label="Status", lines=3)
eval_btn = gr.Button("📊 Avaliar")
eval_status = gr.Textbox(label="Resultado")
with gr.Column():
gr.Markdown("### 5️⃣ Predizer")
predict_image = gr.Image(type="pil", label="Imagem para Predição")
predict_btn = gr.Button("🔮 Predizer")
predict_result = gr.Textbox(label="Resultado")
# Conectar eventos
set_names_btn.click(set_class_names, [class_0_name, class_1_name], names_status)
create_btn.click(setup_dataset, outputs=create_status)
save_btn.click(save_image, [upload_image, class_selector], save_status)
train_btn.click(prepare_and_train, outputs=train_status)
eval_btn.click(evaluate_model, outputs=eval_status)
predict_btn.click(predict_single_image, predict_image, predict_result)
demo.launch()