GNN-MUTAG / app.py
AdrianRevi's picture
Update app.py
bcfdd81 verified
raw
history blame
1.92 kB
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.utils import from_networkx
import networkx as nx
import gradio as gr
# --- Mismo modelo que en Colab ---
class GCN(torch.nn.Module):
def __init__(self, hidden_channels=64):
super().__init__()
self.conv1 = GCNConv(7, hidden_channels)
self.conv2 = GCNConv(hidden_channels, hidden_channels)
self.lin = torch.nn.Linear(hidden_channels, 2)
def forward(self, x, edge_index, batch):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
x = F.relu(x)
x = global_mean_pool(x, batch)
return self.lin(x)
# --- Carga de modelo entrenado ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
model.load_state_dict(torch.load("model_gcn.pth", map_location=device))
model.eval()
# --- Función de predicción sobre un grafo ejemplo ---
def demo_predict():
G = nx.Graph()
G.add_edges_from([(0, 1), (1, 2)])
nx.set_node_attributes(G, {i: [1, 0, 0, 1, 0, 1, 0] for i in G.nodes}, "x") # vector de 7 dimensiones
data = from_networkx(G)
data.x = torch.tensor(list(nx.get_node_attributes(G, "x").values()), dtype=torch.float)
data.edge_index = data.edge_index
data.batch = torch.tensor([0] * data.num_nodes)
data = data.to(device)
with torch.no_grad():
out = model(data.x, data.edge_index, data.batch)
pred = out.argmax(dim=1).item()
return "Mutagénico ✅" if pred == 1 else "No mutagénico ❌"
# --- Interfaz Gradio ---
demo = gr.Interface(
fn=demo_predict,
inputs=[],
outputs="text",
title="Clasificador de Moléculas con GCN",
description="Este demo usa una red neuronal en grafo entrenada sobre MUTAG para clasificar moléculas como mutagénicas o no."
)
demo.launch()