rmayormartins commited on
Commit
a2ab6d7
·
1 Parent(s): 6943d4d
Files changed (2) hide show
  1. app.py +220 -236
  2. requirements.txt +7 -7
app.py CHANGED
@@ -7,6 +7,8 @@ import torch.optim as optim
7
  from torchvision import datasets, transforms, models
8
  from torch.utils.data import DataLoader, random_split
9
  from PIL import Image
 
 
10
  import matplotlib.pyplot as plt
11
  import seaborn as sns
12
  import numpy as np
@@ -15,9 +17,9 @@ import tempfile
15
  import warnings
16
  warnings.filterwarnings("ignore")
17
 
18
- # Configuração do device
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
- print(f"🖥️ Usando device: {device}")
21
 
22
  # Modelos disponíveis
23
  MODELS = {
@@ -25,77 +27,61 @@ MODELS = {
25
  'MobileNetV2': models.mobilenet_v2
26
  }
27
 
28
- # Estado global da aplicação
29
  class AppState:
30
  def __init__(self):
31
  self.model = None
32
  self.train_loader = None
33
- self.val_loader = None
34
  self.test_loader = None
35
  self.dataset_path = None
36
  self.class_dirs = []
37
  self.class_labels = []
38
  self.num_classes = 2
39
 
40
- # Instância global do estado
41
- app_state = AppState()
42
 
43
  def setup_classes(num_classes_value):
44
- """Configura o número de classes e cria diretórios"""
45
  try:
46
- app_state.num_classes = int(num_classes_value)
 
 
47
 
48
- # Criar diretório temporário
49
- app_state.dataset_path = tempfile.mkdtemp()
50
-
51
- # Inicializar rótulos padrão
52
- app_state.class_labels = [f'classe_{i}' for i in range(app_state.num_classes)]
53
-
54
- # Criar diretórios para cada classe
55
- app_state.class_dirs = []
56
- for i in range(app_state.num_classes):
57
- class_dir = os.path.join(app_state.dataset_path, f'classe_{i}')
58
  os.makedirs(class_dir, exist_ok=True)
59
- app_state.class_dirs.append(class_dir)
60
-
61
- choices = [(f"{i} - {app_state.class_labels[i]}", i) for i in range(app_state.num_classes)]
62
 
63
- return (
64
- f"✅ Criados {app_state.num_classes} diretórios para classes",
65
- gr.Dropdown(choices=choices, value=0)
66
- )
67
  except Exception as e:
68
- return f"❌ Erro: {str(e)}", gr.Dropdown()
69
 
70
- def set_class_labels(label0, label1, label2, label3, label4):
71
- """Define rótulos personalizados para as classes"""
72
  try:
73
- labels = [label0, label1, label2, label3, label4]
74
- filtered_labels = [label.strip() for label in labels if label.strip()][:app_state.num_classes]
75
 
76
- if len(filtered_labels) != app_state.num_classes:
77
- return f"❌ Erro: Forneça exatamente {app_state.num_classes} rótulos.", gr.Dropdown()
78
 
79
- app_state.class_labels = filtered_labels
80
- choices = [(f"{i} - {app_state.class_labels[i]}", i) for i in range(app_state.num_classes)]
81
-
82
- return (
83
- f"✅ Rótulos definidos: {', '.join(app_state.class_labels)}",
84
- gr.Dropdown(choices=choices, value=0)
85
- )
86
  except Exception as e:
87
- return f"❌ Erro: {str(e)}", gr.Dropdown()
88
 
89
  def upload_images(class_id, images):
90
- """Faz upload das imagens para a classe especificada"""
91
  try:
92
  if not images:
93
- return "❌ Nenhuma imagem selecionada."
94
 
95
- if int(class_id) >= len(app_state.class_dirs):
96
- return f"❌ Classe {class_id} inválida."
 
97
 
98
- class_dir = app_state.class_dirs[int(class_id)]
99
  count = 0
100
 
101
  for image in images:
@@ -103,33 +89,29 @@ def upload_images(class_id, images):
103
  shutil.copy2(image, class_dir)
104
  count += 1
105
 
106
- class_name = app_state.class_labels[int(class_id)]
107
- return f"✅ {count} imagens salvas na classe {class_id} ({class_name})"
108
  except Exception as e:
109
  return f"❌ Erro: {str(e)}"
110
 
111
  def prepare_data(batch_size):
112
- """Prepara os dados para treinamento"""
113
  try:
114
- if not app_state.dataset_path or not os.path.exists(app_state.dataset_path):
115
- return "❌ Configure as classes primeiro."
116
 
117
- # Transformações
118
  transform = transforms.Compose([
119
  transforms.Resize((224, 224)),
120
  transforms.ToTensor(),
121
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
122
  ])
123
 
124
- dataset = datasets.ImageFolder(app_state.dataset_path, transform=transform)
125
-
126
- if len(dataset.classes) == 0:
127
- return "❌ Nenhuma classe encontrada. Faça upload das imagens primeiro."
128
 
129
  if len(dataset) < 6:
130
- return f"❌ Muito poucas imagens ({len(dataset)}). Adicione pelo menos 2 imagens por classe."
131
 
132
- # Divisão dos dados
133
  train_size = int(0.7 * len(dataset))
134
  val_size = int(0.2 * len(dataset))
135
  test_size = len(dataset) - train_size - val_size
@@ -139,268 +121,270 @@ def prepare_data(batch_size):
139
  generator=torch.Generator().manual_seed(42)
140
  )
141
 
142
- app_state.train_loader = DataLoader(train_dataset, batch_size=int(batch_size), shuffle=True)
143
- app_state.val_loader = DataLoader(val_dataset, batch_size=int(batch_size), shuffle=False)
144
- app_state.test_loader = DataLoader(test_dataset, batch_size=int(batch_size), shuffle=False)
145
 
146
- return f"✅ Dados preparados: {train_size} treino, {val_size} validação, {test_size} teste"
147
-
148
  except Exception as e:
149
- return f"❌ Erro na preparação: {str(e)}"
150
 
151
- def start_training(model_name, epochs, lr):
152
- """Inicia o treinamento do modelo"""
153
  try:
154
- if app_state.train_loader is None:
155
- return "❌ Erro: Dados não preparados."
156
 
157
  # Carregar modelo
158
- app_state.model = MODELS[model_name](pretrained=True)
159
-
160
- # Adaptar última camada
161
- if hasattr(app_state.model, 'fc'):
162
- app_state.model.fc = nn.Linear(app_state.model.fc.in_features, app_state.num_classes)
163
- elif hasattr(app_state.model, 'classifier'):
164
- if isinstance(app_state.model.classifier, nn.Sequential):
165
- app_state.model.classifier[-1] = nn.Linear(app_state.model.classifier[-1].in_features, app_state.num_classes)
166
- else:
167
- app_state.model.classifier = nn.Linear(app_state.model.classifier.in_features, app_state.num_classes)
168
 
169
- app_state.model = app_state.model.to(device)
 
 
 
 
 
170
 
 
171
  criterion = nn.CrossEntropyLoss()
172
- optimizer = optim.Adam(app_state.model.parameters(), lr=float(lr))
173
 
174
- app_state.model.train()
175
-
176
- results = [f"🚀 Treinando {model_name} por {epochs} épocas"]
177
 
178
  for epoch in range(int(epochs)):
179
  running_loss = 0.0
180
  correct = 0
181
  total = 0
182
 
183
- for inputs, labels in app_state.train_loader:
184
  inputs, labels = inputs.to(device), labels.to(device)
185
 
186
  optimizer.zero_grad()
187
- outputs = app_state.model(inputs)
188
  loss = criterion(outputs, labels)
189
  loss.backward()
190
  optimizer.step()
191
 
192
  running_loss += loss.item()
193
- _, predicted = torch.max(outputs.data, 1)
194
  total += labels.size(0)
195
  correct += (predicted == labels).sum().item()
196
 
197
- epoch_loss = running_loss / len(app_state.train_loader)
198
  epoch_acc = 100. * correct / total
199
  results.append(f"Época {epoch+1}: Loss={epoch_loss:.4f}, Acc={epoch_acc:.2f}%")
200
 
201
  results.append("✅ Treinamento concluído!")
202
  return "\n".join(results)
203
-
204
  except Exception as e:
205
- return f"❌ Erro durante treinamento: {str(e)}"
206
 
207
  def evaluate_model():
208
- """Avalia o modelo no conjunto de teste"""
209
  try:
210
- if app_state.model is None or app_state.test_loader is None:
211
- return "❌ Modelo ou dados não disponíveis."
212
 
213
- app_state.model.eval()
214
  all_preds = []
215
  all_labels = []
216
 
217
  with torch.no_grad():
218
- for inputs, labels in app_state.test_loader:
219
  inputs, labels = inputs.to(device), labels.to(device)
220
- outputs = app_state.model(inputs)
221
  _, preds = torch.max(outputs, 1)
222
  all_preds.extend(preds.cpu().numpy())
223
  all_labels.extend(labels.cpu().numpy())
224
 
225
- report = classification_report(all_labels, all_preds, target_names=app_state.class_labels, zero_division=0)
226
- return f"📊 RELATÓRIO DE CLASSIFICAÇÃO:\n\n{report}"
227
-
 
 
 
228
  except Exception as e:
229
- return f"❌ Erro durante avaliação: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
  def predict_images(images):
232
- """Faz predições em novas imagens"""
233
  try:
234
- if app_state.model is None:
235
- return "❌ Modelo não treinado."
236
 
237
  if not images:
238
- return "❌ Nenhuma imagem selecionada."
239
 
240
  transform = transforms.Compose([
241
  transforms.Resize((224, 224)),
242
  transforms.ToTensor(),
243
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
244
  ])
245
 
246
- app_state.model.eval()
247
  results = []
248
 
249
  for image_path in images:
250
- if image_path is not None:
251
  image = Image.open(image_path).convert('RGB')
252
  img_tensor = transform(image).unsqueeze(0).to(device)
253
 
254
  with torch.no_grad():
255
- outputs = app_state.model(img_tensor)
256
- probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
257
  _, predicted = torch.max(outputs, 1)
258
 
259
- predicted_class_id = predicted.item()
260
- confidence = probabilities[predicted_class_id].item() * 100
261
- predicted_class_name = app_state.class_labels[predicted_class_id]
262
 
263
  results.append(f"📸 {os.path.basename(image_path)}")
264
- results.append(f" 🎯 Classe: {predicted_class_name}")
265
- results.append(f" 📊 Confiança: {confidence:.2f}%")
266
- results.append("-" * 40)
267
 
268
- return "\n".join(results) if results else "❌ Nenhuma predição realizada."
269
-
270
  except Exception as e:
271
  return f"❌ Erro: {str(e)}"
272
 
273
- # Interface Gradio
274
- def create_interface():
275
- with gr.Blocks(title="🖼️ Classificador de Imagens", theme=gr.themes.Soft()) as demo:
276
-
277
- gr.Markdown("""
278
- # 🖼️ Sistema de Classificação de Imagens
279
-
280
- **Instruções:**
281
- 1. Configure as classes e rótulos
282
- 2. Faça upload das imagens
283
- 3. Prepare os dados e treine
284
- 4. Avalie e faça predições!
285
- """)
286
-
287
- with gr.Tab("1️⃣ Configuração"):
288
- with gr.Row():
289
- num_classes_input = gr.Number(
290
- label="Número de Classes",
291
- value=2,
292
- minimum=2,
293
- maximum=5,
294
- precision=0
295
- )
296
- setup_button = gr.Button("🔧 Configurar Classes", variant="primary")
297
-
298
- setup_output = gr.Textbox(label="Status", lines=2)
299
-
300
- gr.Markdown("### Rótulos das Classes")
301
-
302
- with gr.Row():
303
- label0 = gr.Textbox(label="Classe 0", placeholder="Ex: gato")
304
- label1 = gr.Textbox(label="Classe 1", placeholder="Ex: cachorro")
305
-
306
- with gr.Row():
307
- label2 = gr.Textbox(label="Classe 2", placeholder="Ex: pássaro", visible=False)
308
- label3 = gr.Textbox(label="Classe 3", placeholder="Ex: peixe", visible=False)
309
- label4 = gr.Textbox(label="Classe 4", placeholder="Ex: hamster", visible=False)
310
-
311
- set_labels_button = gr.Button("🏷️ Definir Rótulos")
312
- labels_output = gr.Textbox(label="Status dos Rótulos")
313
-
314
- # Dropdown que será atualizado
315
- class_selector = gr.Dropdown(
316
- label="Selecionar Classe",
317
- choices=[(f"Classe 0", 0), (f"Classe 1", 1)],
318
- value=0
319
- )
320
-
321
- with gr.Tab("2️⃣ Upload"):
322
- images_upload = gr.File(
323
- label="Selecionar Imagens",
324
- file_count="multiple",
325
- file_types=["image"]
326
- )
327
- upload_button = gr.Button("📤 Fazer Upload", variant="primary")
328
- upload_output = gr.Textbox(label="Status do Upload")
329
-
330
- with gr.Tab("3️⃣ Treinamento"):
331
- batch_size = gr.Number(label="Batch Size", value=8, minimum=1, maximum=32)
332
- prepare_button = gr.Button("⚙️ Preparar Dados", variant="primary")
333
- prepare_output = gr.Textbox(label="Status", lines=3)
334
-
335
- with gr.Row():
336
- model_name = gr.Dropdown(
337
- label="Modelo",
338
- choices=list(MODELS.keys()),
339
- value="MobileNetV2"
340
- )
341
- epochs = gr.Number(label="Épocas", value=3, minimum=1, maximum=10)
342
- lr = gr.Number(label="Learning Rate", value=0.001, minimum=0.0001, maximum=0.1)
343
-
344
- train_button = gr.Button("🚀 Treinar", variant="primary")
345
- train_output = gr.Textbox(label="Status do Treinamento", lines=10)
346
-
347
- with gr.Tab("4️⃣ Avaliação"):
348
- eval_button = gr.Button("📊 Avaliar", variant="primary")
349
- eval_output = gr.Textbox(label="Relatório", lines=15)
350
-
351
- with gr.Tab("5️⃣ Predição"):
352
- predict_images_input = gr.File(
353
- label="Imagens para Predição",
354
- file_count="multiple",
355
- file_types=["image"]
356
- )
357
- predict_button = gr.Button("🔮 Predizer", variant="primary")
358
- predict_output = gr.Textbox(label="Resultados", lines=10)
359
-
360
- # Conectar eventos
361
- setup_button.click(
362
- fn=setup_classes,
363
- inputs=[num_classes_input],
364
- outputs=[setup_output, class_selector]
365
- )
366
 
367
- set_labels_button.click(
368
- fn=set_class_labels,
369
- inputs=[label0, label1, label2, label3, label4],
370
- outputs=[labels_output, class_selector]
371
  )
372
 
373
- upload_button.click(
374
- fn=upload_images,
375
- inputs=[class_selector, images_upload],
376
- outputs=[upload_output]
377
- )
378
 
379
- prepare_button.click(
380
- fn=prepare_data,
381
- inputs=[batch_size],
382
- outputs=[prepare_output]
383
  )
384
 
385
- train_button.click(
386
- fn=start_training,
387
- inputs=[model_name, epochs, lr],
388
- outputs=[train_output]
 
 
 
 
 
389
  )
390
 
391
- eval_button.click(
392
- fn=evaluate_model,
393
- outputs=[eval_output]
 
394
  )
395
 
396
- predict_button.click(
397
- fn=predict_images,
398
- inputs=[predict_images_input],
399
- outputs=[predict_output]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  )
 
 
 
401
 
402
- return demo
 
 
 
 
 
 
 
 
403
 
404
  if __name__ == "__main__":
405
- demo = create_interface()
406
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
7
  from torchvision import datasets, transforms, models
8
  from torch.utils.data import DataLoader, random_split
9
  from PIL import Image
10
+ import matplotlib
11
+ matplotlib.use('Agg') # Use non-interactive backend
12
  import matplotlib.pyplot as plt
13
  import seaborn as sns
14
  import numpy as np
 
17
  import warnings
18
  warnings.filterwarnings("ignore")
19
 
20
+ # Configuração
21
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ print(f"🖥️ Device: {device}")
23
 
24
  # Modelos disponíveis
25
  MODELS = {
 
27
  'MobileNetV2': models.mobilenet_v2
28
  }
29
 
30
+ # Estado global
31
  class AppState:
32
  def __init__(self):
33
  self.model = None
34
  self.train_loader = None
35
+ self.val_loader = None
36
  self.test_loader = None
37
  self.dataset_path = None
38
  self.class_dirs = []
39
  self.class_labels = []
40
  self.num_classes = 2
41
 
42
+ state = AppState()
 
43
 
44
  def setup_classes(num_classes_value):
45
+ """Configura classes"""
46
  try:
47
+ state.num_classes = int(num_classes_value)
48
+ state.dataset_path = tempfile.mkdtemp()
49
+ state.class_labels = [f'classe_{i}' for i in range(state.num_classes)]
50
 
51
+ state.class_dirs = []
52
+ for i in range(state.num_classes):
53
+ class_dir = os.path.join(state.dataset_path, f'classe_{i}')
 
 
 
 
 
 
 
54
  os.makedirs(class_dir, exist_ok=True)
55
+ state.class_dirs.append(class_dir)
 
 
56
 
57
+ return f"✅ Criados {state.num_classes} diretórios"
 
 
 
58
  except Exception as e:
59
+ return f"❌ Erro: {str(e)}"
60
 
61
+ def set_class_labels(labels_text):
62
+ """Define rótulos das classes (separados por vírgula)"""
63
  try:
64
+ labels = [label.strip() for label in labels_text.split(',') if label.strip()]
 
65
 
66
+ if len(labels) != state.num_classes:
67
+ return f"❌ Forneça {state.num_classes} rótulos separados por vírgula"
68
 
69
+ state.class_labels = labels
70
+ return f" Rótulos: {', '.join(state.class_labels)}"
 
 
 
 
 
71
  except Exception as e:
72
+ return f"❌ Erro: {str(e)}"
73
 
74
  def upload_images(class_id, images):
75
+ """Upload de imagens"""
76
  try:
77
  if not images:
78
+ return "❌ Selecione imagens"
79
 
80
+ class_idx = int(class_id)
81
+ if class_idx >= len(state.class_dirs):
82
+ return f"❌ Classe inválida"
83
 
84
+ class_dir = state.class_dirs[class_idx]
85
  count = 0
86
 
87
  for image in images:
 
89
  shutil.copy2(image, class_dir)
90
  count += 1
91
 
92
+ class_name = state.class_labels[class_idx]
93
+ return f"✅ {count} imagens {class_name}"
94
  except Exception as e:
95
  return f"❌ Erro: {str(e)}"
96
 
97
  def prepare_data(batch_size):
98
+ """Prepara dados"""
99
  try:
100
+ if not state.dataset_path:
101
+ return "❌ Configure classes primeiro"
102
 
 
103
  transform = transforms.Compose([
104
  transforms.Resize((224, 224)),
105
  transforms.ToTensor(),
106
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
107
  ])
108
 
109
+ dataset = datasets.ImageFolder(state.dataset_path, transform=transform)
 
 
 
110
 
111
  if len(dataset) < 6:
112
+ return f"❌ Poucas imagens ({len(dataset)}). Mínimo: 6"
113
 
114
+ # Divisão: 70% treino, 20% val, 10% teste
115
  train_size = int(0.7 * len(dataset))
116
  val_size = int(0.2 * len(dataset))
117
  test_size = len(dataset) - train_size - val_size
 
121
  generator=torch.Generator().manual_seed(42)
122
  )
123
 
124
+ state.train_loader = DataLoader(train_dataset, batch_size=int(batch_size), shuffle=True)
125
+ state.val_loader = DataLoader(val_dataset, batch_size=int(batch_size), shuffle=False)
126
+ state.test_loader = DataLoader(test_dataset, batch_size=int(batch_size), shuffle=False)
127
 
128
+ return f"✅ Dados preparados:\n• Treino: {train_size}\n• Validação: {val_size}\n• Teste: {test_size}"
 
129
  except Exception as e:
130
+ return f"❌ Erro: {str(e)}"
131
 
132
+ def train_model(model_name, epochs, lr):
133
+ """Treina modelo"""
134
  try:
135
+ if state.train_loader is None:
136
+ return "❌ Prepare os dados primeiro"
137
 
138
  # Carregar modelo
139
+ state.model = MODELS[model_name](pretrained=True)
 
 
 
 
 
 
 
 
 
140
 
141
+ # Adaptar camada final
142
+ if hasattr(state.model, 'fc'):
143
+ state.model.fc = nn.Linear(state.model.fc.in_features, state.num_classes)
144
+ elif hasattr(state.model, 'classifier'):
145
+ if isinstance(state.model.classifier, nn.Sequential):
146
+ state.model.classifier[-1] = nn.Linear(state.model.classifier[-1].in_features, state.num_classes)
147
 
148
+ state.model = state.model.to(device)
149
  criterion = nn.CrossEntropyLoss()
150
+ optimizer = optim.Adam(state.model.parameters(), lr=float(lr))
151
 
152
+ state.model.train()
153
+ results = [f"🚀 Treinando {model_name}"]
 
154
 
155
  for epoch in range(int(epochs)):
156
  running_loss = 0.0
157
  correct = 0
158
  total = 0
159
 
160
+ for inputs, labels in state.train_loader:
161
  inputs, labels = inputs.to(device), labels.to(device)
162
 
163
  optimizer.zero_grad()
164
+ outputs = state.model(inputs)
165
  loss = criterion(outputs, labels)
166
  loss.backward()
167
  optimizer.step()
168
 
169
  running_loss += loss.item()
170
+ _, predicted = torch.max(outputs, 1)
171
  total += labels.size(0)
172
  correct += (predicted == labels).sum().item()
173
 
174
+ epoch_loss = running_loss / len(state.train_loader)
175
  epoch_acc = 100. * correct / total
176
  results.append(f"Época {epoch+1}: Loss={epoch_loss:.4f}, Acc={epoch_acc:.2f}%")
177
 
178
  results.append("✅ Treinamento concluído!")
179
  return "\n".join(results)
 
180
  except Exception as e:
181
+ return f"❌ Erro: {str(e)}"
182
 
183
  def evaluate_model():
184
+ """Avalia modelo"""
185
  try:
186
+ if state.model is None or state.test_loader is None:
187
+ return "❌ Modelo/dados não disponíveis"
188
 
189
+ state.model.eval()
190
  all_preds = []
191
  all_labels = []
192
 
193
  with torch.no_grad():
194
+ for inputs, labels in state.test_loader:
195
  inputs, labels = inputs.to(device), labels.to(device)
196
+ outputs = state.model(inputs)
197
  _, preds = torch.max(outputs, 1)
198
  all_preds.extend(preds.cpu().numpy())
199
  all_labels.extend(labels.cpu().numpy())
200
 
201
+ report = classification_report(
202
+ all_labels, all_preds,
203
+ target_names=state.class_labels,
204
+ zero_division=0
205
+ )
206
+ return f"📊 RELATÓRIO:\n\n{report}"
207
  except Exception as e:
208
+ return f"❌ Erro: {str(e)}"
209
+
210
+ def generate_confusion_matrix():
211
+ """Gera matriz de confusão"""
212
+ try:
213
+ if state.model is None or state.test_loader is None:
214
+ return None
215
+
216
+ state.model.eval()
217
+ all_preds = []
218
+ all_labels = []
219
+
220
+ with torch.no_grad():
221
+ for inputs, labels in state.test_loader:
222
+ inputs, labels = inputs.to(device), labels.to(device)
223
+ outputs = state.model(inputs)
224
+ _, preds = torch.max(outputs, 1)
225
+ all_preds.extend(preds.cpu().numpy())
226
+ all_labels.extend(labels.cpu().numpy())
227
+
228
+ cm = confusion_matrix(all_labels, all_preds)
229
+
230
+ plt.figure(figsize=(8, 6))
231
+ sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
232
+ xticklabels=state.class_labels,
233
+ yticklabels=state.class_labels)
234
+ plt.xlabel('Predições')
235
+ plt.ylabel('Valores Reais')
236
+ plt.title('Matriz de Confusão')
237
+ plt.tight_layout()
238
+
239
+ temp_path = tempfile.NamedTemporaryFile(suffix='.png', delete=False).name
240
+ plt.savefig(temp_path, dpi=150, bbox_inches='tight')
241
+ plt.close()
242
+
243
+ return temp_path
244
+ except Exception as e:
245
+ print(f"Erro matriz confusão: {e}")
246
+ return None
247
 
248
  def predict_images(images):
249
+ """Prediz imagens"""
250
  try:
251
+ if state.model is None:
252
+ return "❌ Treine o modelo primeiro"
253
 
254
  if not images:
255
+ return "❌ Selecione imagens"
256
 
257
  transform = transforms.Compose([
258
  transforms.Resize((224, 224)),
259
  transforms.ToTensor(),
260
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
261
  ])
262
 
263
+ state.model.eval()
264
  results = []
265
 
266
  for image_path in images:
267
+ if image_path:
268
  image = Image.open(image_path).convert('RGB')
269
  img_tensor = transform(image).unsqueeze(0).to(device)
270
 
271
  with torch.no_grad():
272
+ outputs = state.model(img_tensor)
273
+ probs = torch.nn.functional.softmax(outputs[0], dim=0)
274
  _, predicted = torch.max(outputs, 1)
275
 
276
+ class_id = predicted.item()
277
+ confidence = probs[class_id].item() * 100
278
+ class_name = state.class_labels[class_id]
279
 
280
  results.append(f"📸 {os.path.basename(image_path)}")
281
+ results.append(f" 🎯 {class_name}")
282
+ results.append(f" 📊 {confidence:.2f}%")
283
+ results.append("-" * 30)
284
 
285
+ return "\n".join(results) if results else "❌ Nenhuma predição"
 
286
  except Exception as e:
287
  return f"❌ Erro: {str(e)}"
288
 
289
+ # Interface
290
+ with gr.Blocks(title="🖼️ Classificador", theme=gr.themes.Soft()) as demo:
291
+
292
+ gr.Markdown("""
293
+ # 🖼️ Sistema de Classificação de Imagens
294
+ **Instruções:** Configure Upload Treine → Avalie → Prediga
295
+ """)
296
+
297
+ with gr.Tab("1️⃣ Configuração"):
298
+ gr.Markdown("### 🎯 Configurar Classes")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
+ num_classes = gr.Slider(
301
+ minimum=2, maximum=5, value=2, step=1,
302
+ label="Número de Classes"
 
303
  )
304
 
305
+ setup_btn = gr.Button("🔧 Configurar", variant="primary")
306
+ setup_status = gr.Textbox(label="Status", lines=2)
307
+
308
+ gr.Markdown("### 🏷️ Definir Rótulos")
 
309
 
310
+ labels_input = gr.Textbox(
311
+ label="Rótulos (separados por vírgula)",
312
+ placeholder="gato, cachorro",
313
+ value="gato, cachorro"
314
  )
315
 
316
+ labels_btn = gr.Button("🏷️ Definir Rótulos")
317
+ labels_status = gr.Textbox(label="Status Rótulos")
318
+
319
+ with gr.Tab("2️⃣ Upload"):
320
+ gr.Markdown("### 📤 Upload de Imagens")
321
+
322
+ class_selector = gr.Slider(
323
+ minimum=0, maximum=1, value=0, step=1,
324
+ label="Classe (0, 1, 2...)"
325
  )
326
 
327
+ images_upload = gr.File(
328
+ label="Imagens",
329
+ file_count="multiple",
330
+ file_types=["image"]
331
  )
332
 
333
+ upload_btn = gr.Button("📤 Upload", variant="primary")
334
+ upload_status = gr.Textbox(label="Status")
335
+
336
+ with gr.Tab("3️⃣ Treinamento"):
337
+ gr.Markdown("### ⚙️ Preparar Dados")
338
+
339
+ batch_size = gr.Slider(1, 32, 8, step=1, label="Batch Size")
340
+ prepare_btn = gr.Button("⚙️ Preparar", variant="primary")
341
+ prepare_status = gr.Textbox(label="Status", lines=4)
342
+
343
+ gr.Markdown("### 🚀 Treinar Modelo")
344
+
345
+ with gr.Row():
346
+ model_choice = gr.Radio(
347
+ choices=list(MODELS.keys()),
348
+ value="MobileNetV2",
349
+ label="Modelo"
350
+ )
351
+ epochs = gr.Slider(1, 10, 3, step=1, label="Épocas")
352
+ learning_rate = gr.Slider(0.0001, 0.01, 0.001, label="Learning Rate")
353
+
354
+ train_btn = gr.Button("🚀 Treinar", variant="primary")
355
+ train_status = gr.Textbox(label="Status Treinamento", lines=8)
356
+
357
+ with gr.Tab("4️⃣ Avaliação"):
358
+ gr.Markdown("### 📊 Avaliar Modelo")
359
+
360
+ with gr.Row():
361
+ eval_btn = gr.Button("📊 Avaliar", variant="primary")
362
+ matrix_btn = gr.Button("📈 Matriz Confusão")
363
+
364
+ eval_results = gr.Textbox(label="Relatório", lines=12)
365
+ confusion_matrix_plot = gr.Image(label="Matriz de Confusão")
366
+
367
+ with gr.Tab("5️⃣ Predição"):
368
+ gr.Markdown("### 🔮 Predizer Novas Imagens")
369
+
370
+ predict_images_input = gr.File(
371
+ label="Imagens para Predição",
372
+ file_count="multiple",
373
+ file_types=["image"]
374
  )
375
+
376
+ predict_btn = gr.Button("🔮 Predizer", variant="primary")
377
+ predict_results = gr.Textbox(label="Resultados", lines=10)
378
 
379
+ # Conectar eventos
380
+ setup_btn.click(setup_classes, [num_classes], [setup_status])
381
+ labels_btn.click(set_class_labels, [labels_input], [labels_status])
382
+ upload_btn.click(upload_images, [class_selector, images_upload], [upload_status])
383
+ prepare_btn.click(prepare_data, [batch_size], [prepare_status])
384
+ train_btn.click(train_model, [model_choice, epochs, learning_rate], [train_status])
385
+ eval_btn.click(evaluate_model, [], [eval_results])
386
+ matrix_btn.click(generate_confusion_matrix, [], [confusion_matrix_plot])
387
+ predict_btn.click(predict_images, [predict_images_input], [predict_results])
388
 
389
  if __name__ == "__main__":
390
+ demo.launch()
 
requirements.txt CHANGED
@@ -1,8 +1,8 @@
1
- gradio==4.44.0
2
- torch==2.1.0
3
- torchvision==0.16.0
4
- scikit-learn==1.3.2
5
- matplotlib==3.8.0
6
- seaborn==0.13.0
7
  numpy==1.24.3
8
- Pillow==10.0.1
 
1
+ gradio==4.20.0
2
+ torch==2.0.1
3
+ torchvision==0.15.2
4
+ scikit-learn==1.3.0
5
+ matplotlib==3.7.1
6
+ seaborn==0.12.2
7
  numpy==1.24.3
8
+ Pillow==9.5.0