Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -2,10 +2,15 @@ import torch
|
|
2 |
import torch.nn.functional as F
|
3 |
from torch_geometric.nn import GCNConv, global_mean_pool
|
4 |
from torch_geometric.utils import from_networkx
|
|
|
5 |
import networkx as nx
|
6 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
-
# --- Mismo modelo que en Colab ---
|
9 |
class GCN(torch.nn.Module):
|
10 |
def __init__(self, hidden_channels=64):
|
11 |
super().__init__()
|
@@ -21,20 +26,63 @@ class GCN(torch.nn.Module):
|
|
21 |
x = global_mean_pool(x, batch)
|
22 |
return self.lin(x)
|
23 |
|
24 |
-
#
|
|
|
25 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
26 |
model = GCN().to(device)
|
27 |
model.load_state_dict(torch.load("model_gcn.pth", map_location=device))
|
28 |
model.eval()
|
29 |
|
30 |
-
#
|
31 |
-
|
|
|
32 |
G = nx.Graph()
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
data = from_networkx(G)
|
37 |
-
data.x = torch.tensor(
|
38 |
data.edge_index = data.edge_index
|
39 |
data.batch = torch.tensor([0] * data.num_nodes)
|
40 |
|
@@ -43,15 +91,43 @@ def demo_predict():
|
|
43 |
out = model(data.x, data.edge_index, data.batch)
|
44 |
pred = out.argmax(dim=1).item()
|
45 |
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
-
# --- Interfaz Gradio ---
|
49 |
demo = gr.Interface(
|
50 |
-
fn=
|
51 |
-
inputs=
|
52 |
-
outputs=
|
53 |
-
title="Clasificador
|
54 |
-
description=
|
|
|
|
|
|
|
|
|
55 |
)
|
56 |
|
57 |
demo.launch()
|
|
|
2 |
import torch.nn.functional as F
|
3 |
from torch_geometric.nn import GCNConv, global_mean_pool
|
4 |
from torch_geometric.utils import from_networkx
|
5 |
+
from torch_geometric.data import Data
|
6 |
import networkx as nx
|
7 |
import gradio as gr
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import io
|
10 |
+
import base64
|
11 |
+
|
12 |
+
# ---------- MODELO GCN ----------
|
13 |
|
|
|
14 |
class GCN(torch.nn.Module):
|
15 |
def __init__(self, hidden_channels=64):
|
16 |
super().__init__()
|
|
|
26 |
x = global_mean_pool(x, batch)
|
27 |
return self.lin(x)
|
28 |
|
29 |
+
# ---------- CARGA DEL MODELO ENTRENADO ----------
|
30 |
+
|
31 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
32 |
model = GCN().to(device)
|
33 |
model.load_state_dict(torch.load("model_gcn.pth", map_location=device))
|
34 |
model.eval()
|
35 |
|
36 |
+
# ---------- FUNCIÓN PARA PARSEAR INPUT ----------
|
37 |
+
|
38 |
+
def parse_input(num_nodes, edges_str, node_features_str):
|
39 |
G = nx.Graph()
|
40 |
+
|
41 |
+
try:
|
42 |
+
# Añadimos nodos
|
43 |
+
for i in range(num_nodes):
|
44 |
+
G.add_node(i)
|
45 |
+
|
46 |
+
# Parseamos aristas
|
47 |
+
edges = eval(edges_str) # formato esperado: [(0,1), (1,2)]
|
48 |
+
G.add_edges_from(edges)
|
49 |
+
|
50 |
+
# Parseamos características
|
51 |
+
node_features = eval(node_features_str) # formato: [[1,0,1,0,1,0,1], [...]]
|
52 |
+
if len(node_features) != num_nodes or any(len(f) != 7 for f in node_features):
|
53 |
+
raise ValueError("Las características deben ser listas de longitud 7 para cada nodo.")
|
54 |
+
|
55 |
+
nx.set_node_attributes(G, {i: node_features[i] for i in range(num_nodes)}, "x")
|
56 |
+
|
57 |
+
return G
|
58 |
+
|
59 |
+
except Exception as e:
|
60 |
+
raise gr.Error(f"Error al procesar el input: {e}")
|
61 |
+
|
62 |
+
# ---------- VISUALIZACIÓN ----------
|
63 |
+
|
64 |
+
def draw_graph(G, pred_label):
|
65 |
+
pos = nx.spring_layout(G)
|
66 |
+
node_colors = 'lightgreen' if pred_label == 1 else 'lightcoral'
|
67 |
+
plt.figure(figsize=(4, 4))
|
68 |
+
nx.draw(G, pos, with_labels=True, node_color=node_colors, edge_color='gray', node_size=800)
|
69 |
+
plt.title("Grafo de entrada")
|
70 |
+
|
71 |
+
buf = io.BytesIO()
|
72 |
+
plt.savefig(buf, format='png')
|
73 |
+
buf.seek(0)
|
74 |
+
img_base64 = base64.b64encode(buf.read()).decode('utf-8')
|
75 |
+
plt.close()
|
76 |
+
|
77 |
+
return f"data:image/png;base64,{img_base64}"
|
78 |
+
|
79 |
+
# ---------- FUNCIÓN DE PREDICCIÓN ----------
|
80 |
+
|
81 |
+
def predict_graph(num_nodes, edges_str, node_features_str):
|
82 |
+
G = parse_input(num_nodes, edges_str, node_features_str)
|
83 |
|
84 |
data = from_networkx(G)
|
85 |
+
data.x = torch.tensor([v for v in nx.get_node_attributes(G, "x").values()], dtype=torch.float)
|
86 |
data.edge_index = data.edge_index
|
87 |
data.batch = torch.tensor([0] * data.num_nodes)
|
88 |
|
|
|
91 |
out = model(data.x, data.edge_index, data.batch)
|
92 |
pred = out.argmax(dim=1).item()
|
93 |
|
94 |
+
label_text = "Mutagénico ✅" if pred == 1 else "No mutagénico ❌"
|
95 |
+
img = draw_graph(G, pred)
|
96 |
+
return label_text, img
|
97 |
+
|
98 |
+
# ---------- INTERFAZ GRADIO ----------
|
99 |
+
|
100 |
+
description = """
|
101 |
+
Este clasificador usa un modelo GCN entrenado sobre el dataset **MUTAG** para predecir si una molécula (representada como grafo) es mutagénica o no.
|
102 |
+
|
103 |
+
🔹 Puedes definir tu propio grafo ingresando el número de nodos, las aristas y las características de cada nodo.
|
104 |
+
|
105 |
+
✅ Cada nodo debe tener **7 características** (como en MUTAG).
|
106 |
+
🔗 Las aristas deben estar en formato Python: `[(0, 1), (1, 2)]`
|
107 |
+
📊 Las características deben ser una lista de listas: `[[1,0,0,1,0,1,0], [0,1,1,0,1,0,1], ...]`
|
108 |
+
"""
|
109 |
+
|
110 |
+
inputs = [
|
111 |
+
gr.Number(label="Número de nodos", value=3, precision=0),
|
112 |
+
gr.Textbox(label="Aristas [(0,1), (1,2)]", lines=2, value="[(0,1),(1,2)]"),
|
113 |
+
gr.Textbox(label="Características por nodo (listas de 7)", lines=4, value="[[1,0,0,1,0,1,0], [0,1,1,0,1,0,1], [1,1,0,0,1,0,1]]")
|
114 |
+
]
|
115 |
+
|
116 |
+
outputs = [
|
117 |
+
gr.Text(label="Predicción"),
|
118 |
+
gr.Image(label="Visualización del grafo")
|
119 |
+
]
|
120 |
|
|
|
121 |
demo = gr.Interface(
|
122 |
+
fn=predict_graph,
|
123 |
+
inputs=inputs,
|
124 |
+
outputs=outputs,
|
125 |
+
title="🔬 Clasificador Molecular con GNN (GCN)",
|
126 |
+
description=description,
|
127 |
+
examples=[
|
128 |
+
[3, "[(0,1),(1,2)]", "[[1,0,0,1,0,1,0],[0,1,1,0,1,0,1],[1,1,0,0,1,0,1]]"],
|
129 |
+
[4, "[(0,1),(1,2),(2,3)]", "[[1,1,0,1,0,0,1],[0,0,1,1,1,0,0],[1,0,1,0,1,1,0],[0,1,0,1,1,0,1]]"]
|
130 |
+
]
|
131 |
)
|
132 |
|
133 |
demo.launch()
|