File size: 3,935 Bytes
d5f2d19
bcfdd81
 
d5f2d19
3376df1
2c5abfa
d5f2d19
bcfdd81
3376df1
2c5abfa
3376df1
 
2c5abfa
bcfdd81
 
 
 
 
 
 
d5f2d19
bcfdd81
 
 
 
 
 
 
 
3376df1
 
bcfdd81
 
 
d5f2d19
 
2c5abfa
3376df1
 
d5f2d19
3376df1
 
2c5abfa
3376df1
 
 
2c5abfa
 
3376df1
 
2c5abfa
 
3376df1
2c5abfa
3376df1
 
 
 
 
2c5abfa
3376df1
 
 
2f908d5
 
3376df1
 
 
 
 
 
 
2c5abfa
 
3376df1
 
 
d5f2d19
bcfdd81
3376df1
bcfdd81
 
d5f2d19
bcfdd81
d5f2d19
bcfdd81
d5f2d19
 
3376df1
2f908d5
3376df1
 
 
 
2c5abfa
3376df1
2c5abfa
 
 
 
3376df1
 
 
 
2c5abfa
 
3376df1
 
 
 
 
 
bcfdd81
 
3376df1
 
 
2c5abfa
3376df1
 
 
 
 
bcfdd81
d5f2d19
9c57eba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.utils import from_networkx
from torch_geometric.data import Data

import networkx as nx
import gradio as gr
import matplotlib.pyplot as plt
from PIL import Image
import io

# ---------- MODELO ----------

class GCN(torch.nn.Module):
    def __init__(self, hidden_channels=64):
        super().__init__()
        self.conv1 = GCNConv(7, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.lin = torch.nn.Linear(hidden_channels, 2)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = global_mean_pool(x, batch)
        return self.lin(x)

# ---------- CARGA DEL MODELO ENTRENADO ----------

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
model.load_state_dict(torch.load("model_gcn.pth", map_location=device))
model.eval()

# ---------- FUNCIONES AUXILIARES ----------

def parse_input(num_nodes, edges_str, node_features_str):
    G = nx.Graph()

    try:
        # Añadir nodos
        for i in range(num_nodes):
            G.add_node(i)

        # Parsear aristas
        edges = eval(edges_str)
        G.add_edges_from(edges)

        # Parsear características
        node_features = eval(node_features_str)
        if len(node_features) != num_nodes or any(len(f) != 7 for f in node_features):
            raise ValueError("Cada nodo debe tener exactamente 7 características.")

        nx.set_node_attributes(G, {i: node_features[i] for i in range(num_nodes)}, "x")
        return G

    except Exception as e:
        raise gr.Error(f"Error en los datos del grafo: {e}")

def draw_graph(G, pred_label):
    pos = nx.spring_layout(G)
    node_colors = ['lightgreen' if pred_label == 1 else 'lightcoral'] * G.number_of_nodes()

    plt.figure(figsize=(4, 4))
    nx.draw(G, pos, with_labels=True, node_color=node_colors, edge_color='gray', node_size=800)
    plt.title("Grafo de entrada")

    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    plt.close()
    buf.seek(0)
    return Image.open(buf)

def predict_graph(num_nodes, edges_str, node_features_str):
    G = parse_input(num_nodes, edges_str, node_features_str)

    data = from_networkx(G)
    data.x = torch.tensor([v for v in nx.get_node_attributes(G, "x").values()], dtype=torch.float)
    data.edge_index = data.edge_index
    data.batch = torch.tensor([0] * data.num_nodes)

    data = data.to(device)
    with torch.no_grad():
        out = model(data.x, data.edge_index, data.batch)
        pred = out.argmax(dim=1).item()

    label_text = "Mutagénico ✅" if pred == 1 else "No mutagénico ❌"
    return label_text, draw_graph(G, pred)

# ---------- INTERFAZ GRADIO ----------

description = """
Clasificador molecular basado en **Redes Neuronales en Grafo (GNN)** entrenado sobre el dataset **MUTAG**.

✏️ Puedes modificar el grafo:  
- Número de nodos  
- Aristas (formato: `[(0,1),(1,2)]`)  
- Características de cada nodo (7 valores binarios por nodo)  
"""

inputs = [
    gr.Number(label="Número de nodos", value=3, precision=0),
    gr.Textbox(label="Aristas", value="[(0,1),(1,2)]"),
    gr.Textbox(label="Características por nodo", value="[[1,0,0,1,0,1,0],[0,1,1,0,1,0,1],[1,1,0,0,1,0,1]]")
]

outputs = [
    gr.Text(label="Predicción"),
    gr.Image(label="Visualización del grafo")
]

demo = gr.Interface(
    fn=predict_graph,
    inputs=inputs,
    outputs=outputs,
    title="🧪 Clasificador Molecular con GCN",
    description=description,
    examples=[
        [3, "[(0,1),(1,2)]", "[[1,0,0,1,0,1,0],[0,1,1,0,1,0,1],[1,1,0,0,1,0,1]]"],
        [4, "[(0,1),(1,2),(2,3)]", "[[1,1,0,1,0,0,1],[0,0,1,1,1,0,0],[1,0,1,0,1,1,0],[0,1,0,1,1,0,1]]"]
    ]
)

demo.launch(show_error=True)