Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn.functional as F | |
from torch_geometric.nn import GCNConv, global_mean_pool | |
from torch_geometric.utils import from_networkx | |
from torch_geometric.data import Data | |
import networkx as nx | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
import io | |
# ---------- MODELO ---------- | |
class GCN(torch.nn.Module): | |
def __init__(self, hidden_channels=64): | |
super().__init__() | |
self.conv1 = GCNConv(7, hidden_channels) | |
self.conv2 = GCNConv(hidden_channels, hidden_channels) | |
self.lin = torch.nn.Linear(hidden_channels, 2) | |
def forward(self, x, edge_index, batch): | |
x = self.conv1(x, edge_index) | |
x = F.relu(x) | |
x = self.conv2(x, edge_index) | |
x = F.relu(x) | |
x = global_mean_pool(x, batch) | |
return self.lin(x) | |
# ---------- CARGA DEL MODELO ENTRENADO ---------- | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = GCN().to(device) | |
model.load_state_dict(torch.load("model_gcn.pth", map_location=device)) | |
model.eval() | |
# ---------- FUNCIONES AUXILIARES ---------- | |
def parse_input(num_nodes, edges_str, node_features_str): | |
G = nx.Graph() | |
try: | |
# Añadir nodos | |
for i in range(num_nodes): | |
G.add_node(i) | |
# Parsear aristas | |
edges = eval(edges_str) | |
G.add_edges_from(edges) | |
# Parsear características | |
node_features = eval(node_features_str) | |
if len(node_features) != num_nodes or any(len(f) != 7 for f in node_features): | |
raise ValueError("Cada nodo debe tener exactamente 7 características.") | |
nx.set_node_attributes(G, {i: node_features[i] for i in range(num_nodes)}, "x") | |
return G | |
except Exception as e: | |
raise gr.Error(f"Error en los datos del grafo: {e}") | |
def draw_graph(G, pred_label): | |
pos = nx.spring_layout(G) | |
node_colors = ['lightgreen' if pred_label == 1 else 'lightcoral'] * G.number_of_nodes() | |
plt.figure(figsize=(4, 4)) | |
nx.draw(G, pos, with_labels=True, node_color=node_colors, edge_color='gray', node_size=800) | |
plt.title("Grafo de entrada") | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png') | |
plt.close() | |
buf.seek(0) | |
return Image.open(buf) | |
def predict_graph(num_nodes, edges_str, node_features_str): | |
G = parse_input(num_nodes, edges_str, node_features_str) | |
data = from_networkx(G) | |
data.x = torch.tensor([v for v in nx.get_node_attributes(G, "x").values()], dtype=torch.float) | |
data.edge_index = data.edge_index | |
data.batch = torch.tensor([0] * data.num_nodes) | |
data = data.to(device) | |
with torch.no_grad(): | |
out = model(data.x, data.edge_index, data.batch) | |
pred = out.argmax(dim=1).item() | |
label_text = "Mutagénico ✅" if pred == 1 else "No mutagénico ❌" | |
return label_text, draw_graph(G, pred) | |
# ---------- INTERFAZ GRADIO ---------- | |
description = """ | |
Clasificador molecular basado en **Redes Neuronales en Grafo (GNN)** entrenado sobre el dataset **MUTAG**. | |
✏️ Puedes modificar el grafo: | |
- Número de nodos | |
- Aristas (formato: `[(0,1),(1,2)]`) | |
- Características de cada nodo (7 valores binarios por nodo) | |
""" | |
inputs = [ | |
gr.Number(label="Número de nodos", value=3, precision=0), | |
gr.Textbox(label="Aristas", value="[(0,1),(1,2)]"), | |
gr.Textbox(label="Características por nodo", value="[[1,0,0,1,0,1,0],[0,1,1,0,1,0,1],[1,1,0,0,1,0,1]]") | |
] | |
outputs = [ | |
gr.Text(label="Predicción"), | |
gr.Image(label="Visualización del grafo") | |
] | |
demo = gr.Interface( | |
fn=predict_graph, | |
inputs=inputs, | |
outputs=outputs, | |
title="🧪 Clasificador Molecular con GCN", | |
description=description, | |
examples=[ | |
[3, "[(0,1),(1,2)]", "[[1,0,0,1,0,1,0],[0,1,1,0,1,0,1],[1,1,0,0,1,0,1]]"], | |
[4, "[(0,1),(1,2),(2,3)]", "[[1,1,0,1,0,0,1],[0,0,1,1,1,0,0],[1,0,1,0,1,1,0],[0,1,0,1,1,0,1]]"] | |
] | |
) | |
demo.launch(show_error=True) | |