Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn.functional as F | |
from torch_geometric.nn import GCNConv, global_mean_pool | |
from torch_geometric.utils import from_networkx | |
import networkx as nx | |
import gradio as gr | |
# --- Mismo modelo que en Colab --- | |
class GCN(torch.nn.Module): | |
def __init__(self, hidden_channels=64): | |
super().__init__() | |
self.conv1 = GCNConv(7, hidden_channels) | |
self.conv2 = GCNConv(hidden_channels, hidden_channels) | |
self.lin = torch.nn.Linear(hidden_channels, 2) | |
def forward(self, x, edge_index, batch): | |
x = self.conv1(x, edge_index) | |
x = F.relu(x) | |
x = self.conv2(x, edge_index) | |
x = F.relu(x) | |
x = global_mean_pool(x, batch) | |
return self.lin(x) | |
# --- Carga de modelo entrenado --- | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = GCN().to(device) | |
model.load_state_dict(torch.load("model_gcn.pth", map_location=device)) | |
model.eval() | |
# --- Función de predicción sobre un grafo ejemplo --- | |
def demo_predict(): | |
G = nx.Graph() | |
G.add_edges_from([(0, 1), (1, 2)]) | |
nx.set_node_attributes(G, {i: [1, 0, 0, 1, 0, 1, 0] for i in G.nodes}, "x") # vector de 7 dimensiones | |
data = from_networkx(G) | |
data.x = torch.tensor(list(nx.get_node_attributes(G, "x").values()), dtype=torch.float) | |
data.edge_index = data.edge_index | |
data.batch = torch.tensor([0] * data.num_nodes) | |
data = data.to(device) | |
with torch.no_grad(): | |
out = model(data.x, data.edge_index, data.batch) | |
pred = out.argmax(dim=1).item() | |
return "Mutagénico ✅" if pred == 1 else "No mutagénico ❌" | |
# --- Interfaz Gradio --- | |
demo = gr.Interface( | |
fn=demo_predict, | |
inputs=[], | |
outputs="text", | |
title="Clasificador de Moléculas con GCN", | |
description="Este demo usa una red neuronal en grafo entrenada sobre MUTAG para clasificar moléculas como mutagénicas o no." | |
) | |
demo.launch() | |