GNN-MUTAG / app.py
AdrianRevi's picture
Update app.py
9c57eba verified
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.utils import from_networkx
from torch_geometric.data import Data
import networkx as nx
import gradio as gr
import matplotlib.pyplot as plt
from PIL import Image
import io
# ---------- MODELO ----------
class GCN(torch.nn.Module):
def __init__(self, hidden_channels=64):
super().__init__()
self.conv1 = GCNConv(7, hidden_channels)
self.conv2 = GCNConv(hidden_channels, hidden_channels)
self.lin = torch.nn.Linear(hidden_channels, 2)
def forward(self, x, edge_index, batch):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
x = F.relu(x)
x = global_mean_pool(x, batch)
return self.lin(x)
# ---------- CARGA DEL MODELO ENTRENADO ----------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
model.load_state_dict(torch.load("model_gcn.pth", map_location=device))
model.eval()
# ---------- FUNCIONES AUXILIARES ----------
def parse_input(num_nodes, edges_str, node_features_str):
G = nx.Graph()
try:
# Añadir nodos
for i in range(num_nodes):
G.add_node(i)
# Parsear aristas
edges = eval(edges_str)
G.add_edges_from(edges)
# Parsear características
node_features = eval(node_features_str)
if len(node_features) != num_nodes or any(len(f) != 7 for f in node_features):
raise ValueError("Cada nodo debe tener exactamente 7 características.")
nx.set_node_attributes(G, {i: node_features[i] for i in range(num_nodes)}, "x")
return G
except Exception as e:
raise gr.Error(f"Error en los datos del grafo: {e}")
def draw_graph(G, pred_label):
pos = nx.spring_layout(G)
node_colors = ['lightgreen' if pred_label == 1 else 'lightcoral'] * G.number_of_nodes()
plt.figure(figsize=(4, 4))
nx.draw(G, pos, with_labels=True, node_color=node_colors, edge_color='gray', node_size=800)
plt.title("Grafo de entrada")
buf = io.BytesIO()
plt.savefig(buf, format='png')
plt.close()
buf.seek(0)
return Image.open(buf)
def predict_graph(num_nodes, edges_str, node_features_str):
G = parse_input(num_nodes, edges_str, node_features_str)
data = from_networkx(G)
data.x = torch.tensor([v for v in nx.get_node_attributes(G, "x").values()], dtype=torch.float)
data.edge_index = data.edge_index
data.batch = torch.tensor([0] * data.num_nodes)
data = data.to(device)
with torch.no_grad():
out = model(data.x, data.edge_index, data.batch)
pred = out.argmax(dim=1).item()
label_text = "Mutagénico ✅" if pred == 1 else "No mutagénico ❌"
return label_text, draw_graph(G, pred)
# ---------- INTERFAZ GRADIO ----------
description = """
Clasificador molecular basado en **Redes Neuronales en Grafo (GNN)** entrenado sobre el dataset **MUTAG**.
✏️ Puedes modificar el grafo:
- Número de nodos
- Aristas (formato: `[(0,1),(1,2)]`)
- Características de cada nodo (7 valores binarios por nodo)
"""
inputs = [
gr.Number(label="Número de nodos", value=3, precision=0),
gr.Textbox(label="Aristas", value="[(0,1),(1,2)]"),
gr.Textbox(label="Características por nodo", value="[[1,0,0,1,0,1,0],[0,1,1,0,1,0,1],[1,1,0,0,1,0,1]]")
]
outputs = [
gr.Text(label="Predicción"),
gr.Image(label="Visualización del grafo")
]
demo = gr.Interface(
fn=predict_graph,
inputs=inputs,
outputs=outputs,
title="🧪 Clasificador Molecular con GCN",
description=description,
examples=[
[3, "[(0,1),(1,2)]", "[[1,0,0,1,0,1,0],[0,1,1,0,1,0,1],[1,1,0,0,1,0,1]]"],
[4, "[(0,1),(1,2),(2,3)]", "[[1,1,0,1,0,0,1],[0,0,1,1,1,0,0],[1,0,1,0,1,1,0],[0,1,0,1,1,0,1]]"]
]
)
demo.launch(show_error=True)