File size: 4,471 Bytes
d5f2d19
bcfdd81
 
d5f2d19
3376df1
d5f2d19
bcfdd81
3376df1
 
 
 
 
bcfdd81
 
 
 
 
 
 
d5f2d19
bcfdd81
 
 
 
 
 
 
 
3376df1
 
bcfdd81
 
 
d5f2d19
 
3376df1
 
 
d5f2d19
3376df1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f908d5
 
 
3376df1
2f908d5
 
3376df1
 
 
 
 
 
 
 
 
2f908d5
3376df1
 
 
 
 
d5f2d19
bcfdd81
3376df1
bcfdd81
 
d5f2d19
bcfdd81
d5f2d19
bcfdd81
d5f2d19
 
3376df1
2f908d5
3376df1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bcfdd81
 
3376df1
 
 
 
 
 
 
 
 
bcfdd81
d5f2d19
bcfdd81
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.utils import from_networkx
from torch_geometric.data import Data
import networkx as nx
import gradio as gr
import matplotlib.pyplot as plt
import io
import base64

# ---------- MODELO GCN ----------

class GCN(torch.nn.Module):
    def __init__(self, hidden_channels=64):
        super().__init__()
        self.conv1 = GCNConv(7, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.lin = torch.nn.Linear(hidden_channels, 2)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = global_mean_pool(x, batch)
        return self.lin(x)

# ---------- CARGA DEL MODELO ENTRENADO ----------

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
model.load_state_dict(torch.load("model_gcn.pth", map_location=device))
model.eval()

# ---------- FUNCIÓN PARA PARSEAR INPUT ----------

def parse_input(num_nodes, edges_str, node_features_str):
    G = nx.Graph()

    try:
        # Añadimos nodos
        for i in range(num_nodes):
            G.add_node(i)

        # Parseamos aristas
        edges = eval(edges_str)  # formato esperado: [(0,1), (1,2)]
        G.add_edges_from(edges)

        # Parseamos características
        node_features = eval(node_features_str)  # formato: [[1,0,1,0,1,0,1], [...]]
        if len(node_features) != num_nodes or any(len(f) != 7 for f in node_features):
            raise ValueError("Las características deben ser listas de longitud 7 para cada nodo.")

        nx.set_node_attributes(G, {i: node_features[i] for i in range(num_nodes)}, "x")

        return G

    except Exception as e:
        raise gr.Error(f"Error al procesar el input: {e}")

# ---------- VISUALIZACIÓN ----------

def draw_graph(G, pred_label):
    import matplotlib.pyplot as plt
    import io

    pos = nx.spring_layout(G)
    node_colors = ['lightgreen' if pred_label == 1 else 'lightcoral'] * G.number_of_nodes()

    plt.figure(figsize=(4, 4))
    nx.draw(G, pos, with_labels=True, node_color=node_colors, edge_color='gray', node_size=800)
    plt.title("Grafo de entrada")

    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    buf.seek(0)
    plt.close()

    return buf  # Devuelve un objeto tipo archivo

# ---------- FUNCIÓN DE PREDICCIÓN ----------

def predict_graph(num_nodes, edges_str, node_features_str):
    G = parse_input(num_nodes, edges_str, node_features_str)

    data = from_networkx(G)
    data.x = torch.tensor([v for v in nx.get_node_attributes(G, "x").values()], dtype=torch.float)
    data.edge_index = data.edge_index
    data.batch = torch.tensor([0] * data.num_nodes)

    data = data.to(device)
    with torch.no_grad():
        out = model(data.x, data.edge_index, data.batch)
        pred = out.argmax(dim=1).item()

    label_text = "Mutagénico ✅" if pred == 1 else "No mutagénico ❌"
    return label_text, draw_graph(G, pred)

# ---------- INTERFAZ GRADIO ----------

description = """
Este clasificador usa un modelo GCN entrenado sobre el dataset **MUTAG** para predecir si una molécula (representada como grafo) es mutagénica o no.

🔹 Puedes definir tu propio grafo ingresando el número de nodos, las aristas y las características de cada nodo.

✅ Cada nodo debe tener **7 características** (como en MUTAG).  
🔗 Las aristas deben estar en formato Python: `[(0, 1), (1, 2)]`  
📊 Las características deben ser una lista de listas: `[[1,0,0,1,0,1,0], [0,1,1,0,1,0,1], ...]`
"""

inputs = [
    gr.Number(label="Número de nodos", value=3, precision=0),
    gr.Textbox(label="Aristas [(0,1), (1,2)]", lines=2, value="[(0,1),(1,2)]"),
    gr.Textbox(label="Características por nodo (listas de 7)", lines=4, value="[[1,0,0,1,0,1,0], [0,1,1,0,1,0,1], [1,1,0,0,1,0,1]]")
]

outputs = [
    gr.Text(label="Predicción"),
    gr.Image(label="Visualización del grafo")
]

demo = gr.Interface(
    fn=predict_graph,
    inputs=inputs,
    outputs=outputs,
    title="🔬 Clasificador Molecular con GNN (GCN)",
    description=description,
    examples=[
        [3, "[(0,1),(1,2)]", "[[1,0,0,1,0,1,0],[0,1,1,0,1,0,1],[1,1,0,0,1,0,1]]"],
        [4, "[(0,1),(1,2),(2,3)]", "[[1,1,0,1,0,0,1],[0,0,1,1,1,0,0],[1,0,1,0,1,1,0],[0,1,0,1,1,0,1]]"]
    ]
)

demo.launch()