Spaces:
Runtime error
Runtime error
import os | |
import shutil | |
import gradio as gr | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torchvision import datasets, transforms, models | |
from torch.utils.data import DataLoader, random_split | |
from PIL import Image | |
import tempfile | |
import warnings | |
warnings.filterwarnings("ignore") | |
# Estado global simples | |
model = None | |
train_loader = None | |
test_loader = None | |
dataset_path = None | |
class_names = ["classe_0", "classe_1"] | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def setup_dataset(): | |
"""Cria estrutura de pastas""" | |
global dataset_path | |
dataset_path = tempfile.mkdtemp() | |
# Criar pastas para 2 classes | |
for i in range(2): | |
os.makedirs(os.path.join(dataset_path, f"classe_{i}"), exist_ok=True) | |
return f"✅ Dataset criado em: {dataset_path}" | |
def save_image(image, class_id): | |
"""Salva uma imagem na classe especificada""" | |
if dataset_path is None: | |
return "❌ Execute 'Criar Dataset' primeiro" | |
if image is None: | |
return "❌ Selecione uma imagem" | |
try: | |
class_dir = os.path.join(dataset_path, f"classe_{int(class_id)}") | |
# Salvar imagem | |
import time | |
filename = f"img_{int(time.time())}.jpg" | |
filepath = os.path.join(class_dir, filename) | |
image.save(filepath) | |
return f"✅ Imagem salva na classe {int(class_id)}" | |
except Exception as e: | |
return f"❌ Erro: {str(e)}" | |
def prepare_and_train(): | |
"""Prepara dados e treina modelo""" | |
global model, train_loader, test_loader | |
try: | |
if dataset_path is None: | |
return "❌ Crie o dataset primeiro" | |
# Transformações | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
# Carregar dataset | |
dataset = datasets.ImageFolder(dataset_path, transform=transform) | |
if len(dataset) < 4: | |
return f"❌ Poucas imagens ({len(dataset)}). Adicione pelo menos 2 por classe." | |
# Dividir dados: 70% treino, 30% teste | |
train_size = int(0.7 * len(dataset)) | |
test_size = len(dataset) - train_size | |
train_dataset, test_dataset = random_split(dataset, [train_size, test_size]) | |
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True) | |
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False) | |
# Carregar modelo | |
model = models.mobilenet_v2(pretrained=True) | |
model.classifier = nn.Sequential( | |
nn.Dropout(0.2), | |
nn.Linear(model.classifier[1].in_features, 2) | |
) | |
model = model.to(device) | |
# Treinar | |
criterion = nn.CrossEntropyLoss() | |
optimizer = optim.Adam(model.parameters(), lr=0.001) | |
model.train() | |
for epoch in range(3): # Apenas 3 épocas | |
for inputs, labels in train_loader: | |
inputs, labels = inputs.to(device), labels.to(device) | |
optimizer.zero_grad() | |
outputs = model(inputs) | |
loss = criterion(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
return f"✅ Modelo treinado! Dataset: {train_size} treino, {test_size} teste" | |
except Exception as e: | |
return f"❌ Erro: {str(e)}" | |
def evaluate_model(): | |
"""Avalia modelo""" | |
global model, test_loader | |
if model is None or test_loader is None: | |
return "❌ Treine o modelo primeiro" | |
try: | |
model.eval() | |
correct = 0 | |
total = 0 | |
with torch.no_grad(): | |
for inputs, labels in test_loader: | |
inputs, labels = inputs.to(device), labels.to(device) | |
outputs = model(inputs) | |
_, predicted = torch.max(outputs, 1) | |
total += labels.size(0) | |
correct += (predicted == labels).sum().item() | |
accuracy = 100 * correct / total if total > 0 else 0 | |
return f"📊 Acurácia: {accuracy:.2f}% ({correct}/{total})" | |
except Exception as e: | |
return f"❌ Erro: {str(e)}" | |
def predict_single_image(image): | |
"""Prediz uma única imagem""" | |
global model | |
if model is None: | |
return "❌ Treine o modelo primeiro" | |
if image is None: | |
return "❌ Selecione uma imagem" | |
try: | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
img_tensor = transform(image).unsqueeze(0).to(device) | |
model.eval() | |
with torch.no_grad(): | |
outputs = model(img_tensor) | |
probs = torch.nn.functional.softmax(outputs[0], dim=0) | |
_, predicted = torch.max(outputs, 1) | |
class_id = predicted.item() | |
confidence = probs[class_id].item() * 100 | |
class_name = class_names[class_id] | |
return f"🎯 Predição: {class_name}\n📊 Confiança: {confidence:.2f}%" | |
except Exception as e: | |
return f"❌ Erro: {str(e)}" | |
def set_class_names(name0, name1): | |
"""Define nomes das classes""" | |
global class_names | |
if not name0.strip() or not name1.strip(): | |
return "❌ Preencha ambos os nomes" | |
class_names = [name0.strip(), name1.strip()] | |
return f"✅ Classes: {class_names[0]} e {class_names[1]}" | |
# Interface ultra-simples | |
with gr.Blocks(title="🖼️ Classificador Simples") as demo: | |
gr.Markdown("# 🖼️ Classificador de Imagens Simples") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### 1️⃣ Configurar Classes") | |
class_0_name = gr.Textbox(label="Nome Classe 0", value="gato") | |
class_1_name = gr.Textbox(label="Nome Classe 1", value="cachorro") | |
set_names_btn = gr.Button("🏷️ Definir Nomes") | |
names_status = gr.Textbox(label="Status") | |
gr.Markdown("### 2️⃣ Criar Dataset") | |
create_btn = gr.Button("🔧 Criar Dataset", variant="primary") | |
create_status = gr.Textbox(label="Status") | |
with gr.Column(): | |
gr.Markdown("### 3️⃣ Adicionar Imagens") | |
upload_image = gr.Image(type="pil", label="Imagem") | |
class_selector = gr.Number(label="Classe (0 ou 1)", value=0, precision=0) | |
save_btn = gr.Button("💾 Salvar Imagem") | |
save_status = gr.Textbox(label="Status") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### 4️⃣ Treinar") | |
train_btn = gr.Button("🚀 Preparar + Treinar", variant="primary") | |
train_status = gr.Textbox(label="Status", lines=3) | |
eval_btn = gr.Button("📊 Avaliar") | |
eval_status = gr.Textbox(label="Resultado") | |
with gr.Column(): | |
gr.Markdown("### 5️⃣ Predizer") | |
predict_image = gr.Image(type="pil", label="Imagem para Predição") | |
predict_btn = gr.Button("🔮 Predizer") | |
predict_result = gr.Textbox(label="Resultado") | |
# Conectar eventos | |
set_names_btn.click(set_class_names, [class_0_name, class_1_name], names_status) | |
create_btn.click(setup_dataset, outputs=create_status) | |
save_btn.click(save_image, [upload_image, class_selector], save_status) | |
train_btn.click(prepare_and_train, outputs=train_status) | |
eval_btn.click(evaluate_model, outputs=eval_status) | |
predict_btn.click(predict_single_image, predict_image, predict_result) | |
demo.launch() |