File size: 1,918 Bytes
d5f2d19
bcfdd81
 
d5f2d19
 
bcfdd81
 
 
 
 
 
 
 
 
d5f2d19
bcfdd81
 
 
 
 
 
 
 
 
 
 
 
d5f2d19
 
bcfdd81
 
d5f2d19
 
bcfdd81
d5f2d19
bcfdd81
 
 
 
d5f2d19
bcfdd81
d5f2d19
bcfdd81
d5f2d19
 
bcfdd81
 
 
 
 
 
 
 
 
 
d5f2d19
bcfdd81
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.utils import from_networkx
import networkx as nx
import gradio as gr

# --- Mismo modelo que en Colab ---
class GCN(torch.nn.Module):
    def __init__(self, hidden_channels=64):
        super().__init__()
        self.conv1 = GCNConv(7, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.lin = torch.nn.Linear(hidden_channels, 2)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = global_mean_pool(x, batch)
        return self.lin(x)

# --- Carga de modelo entrenado ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
model.load_state_dict(torch.load("model_gcn.pth", map_location=device))
model.eval()

# --- Función de predicción sobre un grafo ejemplo ---
def demo_predict():
    G = nx.Graph()
    G.add_edges_from([(0, 1), (1, 2)])
    nx.set_node_attributes(G, {i: [1, 0, 0, 1, 0, 1, 0] for i in G.nodes}, "x")  # vector de 7 dimensiones

    data = from_networkx(G)
    data.x = torch.tensor(list(nx.get_node_attributes(G, "x").values()), dtype=torch.float)
    data.edge_index = data.edge_index
    data.batch = torch.tensor([0] * data.num_nodes)

    data = data.to(device)
    with torch.no_grad():
        out = model(data.x, data.edge_index, data.batch)
        pred = out.argmax(dim=1).item()

    return "Mutagénico ✅" if pred == 1 else "No mutagénico ❌"

# --- Interfaz Gradio ---
demo = gr.Interface(
    fn=demo_predict,
    inputs=[],
    outputs="text",
    title="Clasificador de Moléculas con GCN",
    description="Este demo usa una red neuronal en grafo entrenada sobre MUTAG para clasificar moléculas como mutagénicas o no."
)

demo.launch()