Spaces:
Sleeping
Sleeping
File size: 4,471 Bytes
d5f2d19 bcfdd81 d5f2d19 3376df1 d5f2d19 bcfdd81 3376df1 bcfdd81 d5f2d19 bcfdd81 3376df1 bcfdd81 d5f2d19 3376df1 d5f2d19 3376df1 2f908d5 3376df1 2f908d5 3376df1 2f908d5 3376df1 d5f2d19 bcfdd81 3376df1 bcfdd81 d5f2d19 bcfdd81 d5f2d19 bcfdd81 d5f2d19 3376df1 2f908d5 3376df1 bcfdd81 3376df1 bcfdd81 d5f2d19 bcfdd81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.utils import from_networkx
from torch_geometric.data import Data
import networkx as nx
import gradio as gr
import matplotlib.pyplot as plt
import io
import base64
# ---------- MODELO GCN ----------
class GCN(torch.nn.Module):
def __init__(self, hidden_channels=64):
super().__init__()
self.conv1 = GCNConv(7, hidden_channels)
self.conv2 = GCNConv(hidden_channels, hidden_channels)
self.lin = torch.nn.Linear(hidden_channels, 2)
def forward(self, x, edge_index, batch):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
x = F.relu(x)
x = global_mean_pool(x, batch)
return self.lin(x)
# ---------- CARGA DEL MODELO ENTRENADO ----------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
model.load_state_dict(torch.load("model_gcn.pth", map_location=device))
model.eval()
# ---------- FUNCIÓN PARA PARSEAR INPUT ----------
def parse_input(num_nodes, edges_str, node_features_str):
G = nx.Graph()
try:
# Añadimos nodos
for i in range(num_nodes):
G.add_node(i)
# Parseamos aristas
edges = eval(edges_str) # formato esperado: [(0,1), (1,2)]
G.add_edges_from(edges)
# Parseamos características
node_features = eval(node_features_str) # formato: [[1,0,1,0,1,0,1], [...]]
if len(node_features) != num_nodes or any(len(f) != 7 for f in node_features):
raise ValueError("Las características deben ser listas de longitud 7 para cada nodo.")
nx.set_node_attributes(G, {i: node_features[i] for i in range(num_nodes)}, "x")
return G
except Exception as e:
raise gr.Error(f"Error al procesar el input: {e}")
# ---------- VISUALIZACIÓN ----------
def draw_graph(G, pred_label):
import matplotlib.pyplot as plt
import io
pos = nx.spring_layout(G)
node_colors = ['lightgreen' if pred_label == 1 else 'lightcoral'] * G.number_of_nodes()
plt.figure(figsize=(4, 4))
nx.draw(G, pos, with_labels=True, node_color=node_colors, edge_color='gray', node_size=800)
plt.title("Grafo de entrada")
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
plt.close()
return buf # Devuelve un objeto tipo archivo
# ---------- FUNCIÓN DE PREDICCIÓN ----------
def predict_graph(num_nodes, edges_str, node_features_str):
G = parse_input(num_nodes, edges_str, node_features_str)
data = from_networkx(G)
data.x = torch.tensor([v for v in nx.get_node_attributes(G, "x").values()], dtype=torch.float)
data.edge_index = data.edge_index
data.batch = torch.tensor([0] * data.num_nodes)
data = data.to(device)
with torch.no_grad():
out = model(data.x, data.edge_index, data.batch)
pred = out.argmax(dim=1).item()
label_text = "Mutagénico ✅" if pred == 1 else "No mutagénico ❌"
return label_text, draw_graph(G, pred)
# ---------- INTERFAZ GRADIO ----------
description = """
Este clasificador usa un modelo GCN entrenado sobre el dataset **MUTAG** para predecir si una molécula (representada como grafo) es mutagénica o no.
🔹 Puedes definir tu propio grafo ingresando el número de nodos, las aristas y las características de cada nodo.
✅ Cada nodo debe tener **7 características** (como en MUTAG).
🔗 Las aristas deben estar en formato Python: `[(0, 1), (1, 2)]`
📊 Las características deben ser una lista de listas: `[[1,0,0,1,0,1,0], [0,1,1,0,1,0,1], ...]`
"""
inputs = [
gr.Number(label="Número de nodos", value=3, precision=0),
gr.Textbox(label="Aristas [(0,1), (1,2)]", lines=2, value="[(0,1),(1,2)]"),
gr.Textbox(label="Características por nodo (listas de 7)", lines=4, value="[[1,0,0,1,0,1,0], [0,1,1,0,1,0,1], [1,1,0,0,1,0,1]]")
]
outputs = [
gr.Text(label="Predicción"),
gr.Image(label="Visualización del grafo")
]
demo = gr.Interface(
fn=predict_graph,
inputs=inputs,
outputs=outputs,
title="🔬 Clasificador Molecular con GNN (GCN)",
description=description,
examples=[
[3, "[(0,1),(1,2)]", "[[1,0,0,1,0,1,0],[0,1,1,0,1,0,1],[1,1,0,0,1,0,1]]"],
[4, "[(0,1),(1,2),(2,3)]", "[[1,1,0,1,0,0,1],[0,0,1,1,1,0,0],[1,0,1,0,1,1,0],[0,1,0,1,1,0,1]]"]
]
)
demo.launch()
|