GNN-MUTAG / app.py
AdrianRevi's picture
Update app.py
2f908d5 verified
raw
history blame
4.47 kB
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.utils import from_networkx
from torch_geometric.data import Data
import networkx as nx
import gradio as gr
import matplotlib.pyplot as plt
import io
import base64
# ---------- MODELO GCN ----------
class GCN(torch.nn.Module):
def __init__(self, hidden_channels=64):
super().__init__()
self.conv1 = GCNConv(7, hidden_channels)
self.conv2 = GCNConv(hidden_channels, hidden_channels)
self.lin = torch.nn.Linear(hidden_channels, 2)
def forward(self, x, edge_index, batch):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
x = F.relu(x)
x = global_mean_pool(x, batch)
return self.lin(x)
# ---------- CARGA DEL MODELO ENTRENADO ----------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
model.load_state_dict(torch.load("model_gcn.pth", map_location=device))
model.eval()
# ---------- FUNCIÓN PARA PARSEAR INPUT ----------
def parse_input(num_nodes, edges_str, node_features_str):
G = nx.Graph()
try:
# Añadimos nodos
for i in range(num_nodes):
G.add_node(i)
# Parseamos aristas
edges = eval(edges_str) # formato esperado: [(0,1), (1,2)]
G.add_edges_from(edges)
# Parseamos características
node_features = eval(node_features_str) # formato: [[1,0,1,0,1,0,1], [...]]
if len(node_features) != num_nodes or any(len(f) != 7 for f in node_features):
raise ValueError("Las características deben ser listas de longitud 7 para cada nodo.")
nx.set_node_attributes(G, {i: node_features[i] for i in range(num_nodes)}, "x")
return G
except Exception as e:
raise gr.Error(f"Error al procesar el input: {e}")
# ---------- VISUALIZACIÓN ----------
def draw_graph(G, pred_label):
import matplotlib.pyplot as plt
import io
pos = nx.spring_layout(G)
node_colors = ['lightgreen' if pred_label == 1 else 'lightcoral'] * G.number_of_nodes()
plt.figure(figsize=(4, 4))
nx.draw(G, pos, with_labels=True, node_color=node_colors, edge_color='gray', node_size=800)
plt.title("Grafo de entrada")
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
plt.close()
return buf # Devuelve un objeto tipo archivo
# ---------- FUNCIÓN DE PREDICCIÓN ----------
def predict_graph(num_nodes, edges_str, node_features_str):
G = parse_input(num_nodes, edges_str, node_features_str)
data = from_networkx(G)
data.x = torch.tensor([v for v in nx.get_node_attributes(G, "x").values()], dtype=torch.float)
data.edge_index = data.edge_index
data.batch = torch.tensor([0] * data.num_nodes)
data = data.to(device)
with torch.no_grad():
out = model(data.x, data.edge_index, data.batch)
pred = out.argmax(dim=1).item()
label_text = "Mutagénico ✅" if pred == 1 else "No mutagénico ❌"
return label_text, draw_graph(G, pred)
# ---------- INTERFAZ GRADIO ----------
description = """
Este clasificador usa un modelo GCN entrenado sobre el dataset **MUTAG** para predecir si una molécula (representada como grafo) es mutagénica o no.
🔹 Puedes definir tu propio grafo ingresando el número de nodos, las aristas y las características de cada nodo.
✅ Cada nodo debe tener **7 características** (como en MUTAG).
🔗 Las aristas deben estar en formato Python: `[(0, 1), (1, 2)]`
📊 Las características deben ser una lista de listas: `[[1,0,0,1,0,1,0], [0,1,1,0,1,0,1], ...]`
"""
inputs = [
gr.Number(label="Número de nodos", value=3, precision=0),
gr.Textbox(label="Aristas [(0,1), (1,2)]", lines=2, value="[(0,1),(1,2)]"),
gr.Textbox(label="Características por nodo (listas de 7)", lines=4, value="[[1,0,0,1,0,1,0], [0,1,1,0,1,0,1], [1,1,0,0,1,0,1]]")
]
outputs = [
gr.Text(label="Predicción"),
gr.Image(label="Visualización del grafo")
]
demo = gr.Interface(
fn=predict_graph,
inputs=inputs,
outputs=outputs,
title="🔬 Clasificador Molecular con GNN (GCN)",
description=description,
examples=[
[3, "[(0,1),(1,2)]", "[[1,0,0,1,0,1,0],[0,1,1,0,1,0,1],[1,1,0,0,1,0,1]]"],
[4, "[(0,1),(1,2),(2,3)]", "[[1,1,0,1,0,0,1],[0,0,1,1,1,0,0],[1,0,1,0,1,1,0],[0,1,0,1,1,0,1]]"]
]
)
demo.launch()