AdrianRevi commited on
Commit
3376df1
·
verified ·
1 Parent(s): bcfdd81

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -14
app.py CHANGED
@@ -2,10 +2,15 @@ import torch
2
  import torch.nn.functional as F
3
  from torch_geometric.nn import GCNConv, global_mean_pool
4
  from torch_geometric.utils import from_networkx
 
5
  import networkx as nx
6
  import gradio as gr
 
 
 
 
 
7
 
8
- # --- Mismo modelo que en Colab ---
9
  class GCN(torch.nn.Module):
10
  def __init__(self, hidden_channels=64):
11
  super().__init__()
@@ -21,20 +26,63 @@ class GCN(torch.nn.Module):
21
  x = global_mean_pool(x, batch)
22
  return self.lin(x)
23
 
24
- # --- Carga de modelo entrenado ---
 
25
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
26
  model = GCN().to(device)
27
  model.load_state_dict(torch.load("model_gcn.pth", map_location=device))
28
  model.eval()
29
 
30
- # --- Función de predicción sobre un grafo ejemplo ---
31
- def demo_predict():
 
32
  G = nx.Graph()
33
- G.add_edges_from([(0, 1), (1, 2)])
34
- nx.set_node_attributes(G, {i: [1, 0, 0, 1, 0, 1, 0] for i in G.nodes}, "x") # vector de 7 dimensiones
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  data = from_networkx(G)
37
- data.x = torch.tensor(list(nx.get_node_attributes(G, "x").values()), dtype=torch.float)
38
  data.edge_index = data.edge_index
39
  data.batch = torch.tensor([0] * data.num_nodes)
40
 
@@ -43,15 +91,43 @@ def demo_predict():
43
  out = model(data.x, data.edge_index, data.batch)
44
  pred = out.argmax(dim=1).item()
45
 
46
- return "Mutagénico ✅" if pred == 1 else "No mutagénico ❌"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- # --- Interfaz Gradio ---
49
  demo = gr.Interface(
50
- fn=demo_predict,
51
- inputs=[],
52
- outputs="text",
53
- title="Clasificador de Moléculas con GCN",
54
- description="Este demo usa una red neuronal en grafo entrenada sobre MUTAG para clasificar moléculas como mutagénicas o no."
 
 
 
 
55
  )
56
 
57
  demo.launch()
 
2
  import torch.nn.functional as F
3
  from torch_geometric.nn import GCNConv, global_mean_pool
4
  from torch_geometric.utils import from_networkx
5
+ from torch_geometric.data import Data
6
  import networkx as nx
7
  import gradio as gr
8
+ import matplotlib.pyplot as plt
9
+ import io
10
+ import base64
11
+
12
+ # ---------- MODELO GCN ----------
13
 
 
14
  class GCN(torch.nn.Module):
15
  def __init__(self, hidden_channels=64):
16
  super().__init__()
 
26
  x = global_mean_pool(x, batch)
27
  return self.lin(x)
28
 
29
+ # ---------- CARGA DEL MODELO ENTRENADO ----------
30
+
31
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
32
  model = GCN().to(device)
33
  model.load_state_dict(torch.load("model_gcn.pth", map_location=device))
34
  model.eval()
35
 
36
+ # ---------- FUNCIÓN PARA PARSEAR INPUT ----------
37
+
38
+ def parse_input(num_nodes, edges_str, node_features_str):
39
  G = nx.Graph()
40
+
41
+ try:
42
+ # Añadimos nodos
43
+ for i in range(num_nodes):
44
+ G.add_node(i)
45
+
46
+ # Parseamos aristas
47
+ edges = eval(edges_str) # formato esperado: [(0,1), (1,2)]
48
+ G.add_edges_from(edges)
49
+
50
+ # Parseamos características
51
+ node_features = eval(node_features_str) # formato: [[1,0,1,0,1,0,1], [...]]
52
+ if len(node_features) != num_nodes or any(len(f) != 7 for f in node_features):
53
+ raise ValueError("Las características deben ser listas de longitud 7 para cada nodo.")
54
+
55
+ nx.set_node_attributes(G, {i: node_features[i] for i in range(num_nodes)}, "x")
56
+
57
+ return G
58
+
59
+ except Exception as e:
60
+ raise gr.Error(f"Error al procesar el input: {e}")
61
+
62
+ # ---------- VISUALIZACIÓN ----------
63
+
64
+ def draw_graph(G, pred_label):
65
+ pos = nx.spring_layout(G)
66
+ node_colors = 'lightgreen' if pred_label == 1 else 'lightcoral'
67
+ plt.figure(figsize=(4, 4))
68
+ nx.draw(G, pos, with_labels=True, node_color=node_colors, edge_color='gray', node_size=800)
69
+ plt.title("Grafo de entrada")
70
+
71
+ buf = io.BytesIO()
72
+ plt.savefig(buf, format='png')
73
+ buf.seek(0)
74
+ img_base64 = base64.b64encode(buf.read()).decode('utf-8')
75
+ plt.close()
76
+
77
+ return f"data:image/png;base64,{img_base64}"
78
+
79
+ # ---------- FUNCIÓN DE PREDICCIÓN ----------
80
+
81
+ def predict_graph(num_nodes, edges_str, node_features_str):
82
+ G = parse_input(num_nodes, edges_str, node_features_str)
83
 
84
  data = from_networkx(G)
85
+ data.x = torch.tensor([v for v in nx.get_node_attributes(G, "x").values()], dtype=torch.float)
86
  data.edge_index = data.edge_index
87
  data.batch = torch.tensor([0] * data.num_nodes)
88
 
 
91
  out = model(data.x, data.edge_index, data.batch)
92
  pred = out.argmax(dim=1).item()
93
 
94
+ label_text = "Mutagénico ✅" if pred == 1 else "No mutagénico ❌"
95
+ img = draw_graph(G, pred)
96
+ return label_text, img
97
+
98
+ # ---------- INTERFAZ GRADIO ----------
99
+
100
+ description = """
101
+ Este clasificador usa un modelo GCN entrenado sobre el dataset **MUTAG** para predecir si una molécula (representada como grafo) es mutagénica o no.
102
+
103
+ 🔹 Puedes definir tu propio grafo ingresando el número de nodos, las aristas y las características de cada nodo.
104
+
105
+ ✅ Cada nodo debe tener **7 características** (como en MUTAG).
106
+ 🔗 Las aristas deben estar en formato Python: `[(0, 1), (1, 2)]`
107
+ 📊 Las características deben ser una lista de listas: `[[1,0,0,1,0,1,0], [0,1,1,0,1,0,1], ...]`
108
+ """
109
+
110
+ inputs = [
111
+ gr.Number(label="Número de nodos", value=3, precision=0),
112
+ gr.Textbox(label="Aristas [(0,1), (1,2)]", lines=2, value="[(0,1),(1,2)]"),
113
+ gr.Textbox(label="Características por nodo (listas de 7)", lines=4, value="[[1,0,0,1,0,1,0], [0,1,1,0,1,0,1], [1,1,0,0,1,0,1]]")
114
+ ]
115
+
116
+ outputs = [
117
+ gr.Text(label="Predicción"),
118
+ gr.Image(label="Visualización del grafo")
119
+ ]
120
 
 
121
  demo = gr.Interface(
122
+ fn=predict_graph,
123
+ inputs=inputs,
124
+ outputs=outputs,
125
+ title="🔬 Clasificador Molecular con GNN (GCN)",
126
+ description=description,
127
+ examples=[
128
+ [3, "[(0,1),(1,2)]", "[[1,0,0,1,0,1,0],[0,1,1,0,1,0,1],[1,1,0,0,1,0,1]]"],
129
+ [4, "[(0,1),(1,2),(2,3)]", "[[1,1,0,1,0,0,1],[0,0,1,1,1,0,0],[1,0,1,0,1,1,0],[0,1,0,1,1,0,1]]"]
130
+ ]
131
  )
132
 
133
  demo.launch()