Spaces:
Sleeping
Sleeping
File size: 3,935 Bytes
d5f2d19 bcfdd81 d5f2d19 3376df1 2c5abfa d5f2d19 bcfdd81 3376df1 2c5abfa 3376df1 2c5abfa bcfdd81 d5f2d19 bcfdd81 3376df1 bcfdd81 d5f2d19 2c5abfa 3376df1 d5f2d19 3376df1 2c5abfa 3376df1 2c5abfa 3376df1 2c5abfa 3376df1 2c5abfa 3376df1 2c5abfa 3376df1 2f908d5 3376df1 2c5abfa 3376df1 d5f2d19 bcfdd81 3376df1 bcfdd81 d5f2d19 bcfdd81 d5f2d19 bcfdd81 d5f2d19 3376df1 2f908d5 3376df1 2c5abfa 3376df1 2c5abfa 3376df1 2c5abfa 3376df1 bcfdd81 3376df1 2c5abfa 3376df1 bcfdd81 d5f2d19 9c57eba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.utils import from_networkx
from torch_geometric.data import Data
import networkx as nx
import gradio as gr
import matplotlib.pyplot as plt
from PIL import Image
import io
# ---------- MODELO ----------
class GCN(torch.nn.Module):
def __init__(self, hidden_channels=64):
super().__init__()
self.conv1 = GCNConv(7, hidden_channels)
self.conv2 = GCNConv(hidden_channels, hidden_channels)
self.lin = torch.nn.Linear(hidden_channels, 2)
def forward(self, x, edge_index, batch):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
x = F.relu(x)
x = global_mean_pool(x, batch)
return self.lin(x)
# ---------- CARGA DEL MODELO ENTRENADO ----------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
model.load_state_dict(torch.load("model_gcn.pth", map_location=device))
model.eval()
# ---------- FUNCIONES AUXILIARES ----------
def parse_input(num_nodes, edges_str, node_features_str):
G = nx.Graph()
try:
# Añadir nodos
for i in range(num_nodes):
G.add_node(i)
# Parsear aristas
edges = eval(edges_str)
G.add_edges_from(edges)
# Parsear características
node_features = eval(node_features_str)
if len(node_features) != num_nodes or any(len(f) != 7 for f in node_features):
raise ValueError("Cada nodo debe tener exactamente 7 características.")
nx.set_node_attributes(G, {i: node_features[i] for i in range(num_nodes)}, "x")
return G
except Exception as e:
raise gr.Error(f"Error en los datos del grafo: {e}")
def draw_graph(G, pred_label):
pos = nx.spring_layout(G)
node_colors = ['lightgreen' if pred_label == 1 else 'lightcoral'] * G.number_of_nodes()
plt.figure(figsize=(4, 4))
nx.draw(G, pos, with_labels=True, node_color=node_colors, edge_color='gray', node_size=800)
plt.title("Grafo de entrada")
buf = io.BytesIO()
plt.savefig(buf, format='png')
plt.close()
buf.seek(0)
return Image.open(buf)
def predict_graph(num_nodes, edges_str, node_features_str):
G = parse_input(num_nodes, edges_str, node_features_str)
data = from_networkx(G)
data.x = torch.tensor([v for v in nx.get_node_attributes(G, "x").values()], dtype=torch.float)
data.edge_index = data.edge_index
data.batch = torch.tensor([0] * data.num_nodes)
data = data.to(device)
with torch.no_grad():
out = model(data.x, data.edge_index, data.batch)
pred = out.argmax(dim=1).item()
label_text = "Mutagénico ✅" if pred == 1 else "No mutagénico ❌"
return label_text, draw_graph(G, pred)
# ---------- INTERFAZ GRADIO ----------
description = """
Clasificador molecular basado en **Redes Neuronales en Grafo (GNN)** entrenado sobre el dataset **MUTAG**.
✏️ Puedes modificar el grafo:
- Número de nodos
- Aristas (formato: `[(0,1),(1,2)]`)
- Características de cada nodo (7 valores binarios por nodo)
"""
inputs = [
gr.Number(label="Número de nodos", value=3, precision=0),
gr.Textbox(label="Aristas", value="[(0,1),(1,2)]"),
gr.Textbox(label="Características por nodo", value="[[1,0,0,1,0,1,0],[0,1,1,0,1,0,1],[1,1,0,0,1,0,1]]")
]
outputs = [
gr.Text(label="Predicción"),
gr.Image(label="Visualización del grafo")
]
demo = gr.Interface(
fn=predict_graph,
inputs=inputs,
outputs=outputs,
title="🧪 Clasificador Molecular con GCN",
description=description,
examples=[
[3, "[(0,1),(1,2)]", "[[1,0,0,1,0,1,0],[0,1,1,0,1,0,1],[1,1,0,0,1,0,1]]"],
[4, "[(0,1),(1,2),(2,3)]", "[[1,1,0,1,0,0,1],[0,0,1,1,1,0,0],[1,0,1,0,1,1,0],[0,1,0,1,1,0,1]]"]
]
)
demo.launch(show_error=True)
|