import torch import torch.nn.functional as F from torch_geometric.nn import GCNConv, global_mean_pool from torch_geometric.utils import from_networkx from torch_geometric.data import Data import networkx as nx import gradio as gr import matplotlib.pyplot as plt from PIL import Image import io # ---------- MODELO ---------- class GCN(torch.nn.Module): def __init__(self, hidden_channels=64): super().__init__() self.conv1 = GCNConv(7, hidden_channels) self.conv2 = GCNConv(hidden_channels, hidden_channels) self.lin = torch.nn.Linear(hidden_channels, 2) def forward(self, x, edge_index, batch): x = self.conv1(x, edge_index) x = F.relu(x) x = self.conv2(x, edge_index) x = F.relu(x) x = global_mean_pool(x, batch) return self.lin(x) # ---------- CARGA DEL MODELO ENTRENADO ---------- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = GCN().to(device) model.load_state_dict(torch.load("model_gcn.pth", map_location=device)) model.eval() # ---------- FUNCIONES AUXILIARES ---------- def parse_input(num_nodes, edges_str, node_features_str): G = nx.Graph() try: # Añadir nodos for i in range(num_nodes): G.add_node(i) # Parsear aristas edges = eval(edges_str) G.add_edges_from(edges) # Parsear características node_features = eval(node_features_str) if len(node_features) != num_nodes or any(len(f) != 7 for f in node_features): raise ValueError("Cada nodo debe tener exactamente 7 características.") nx.set_node_attributes(G, {i: node_features[i] for i in range(num_nodes)}, "x") return G except Exception as e: raise gr.Error(f"Error en los datos del grafo: {e}") def draw_graph(G, pred_label): pos = nx.spring_layout(G) node_colors = ['lightgreen' if pred_label == 1 else 'lightcoral'] * G.number_of_nodes() plt.figure(figsize=(4, 4)) nx.draw(G, pos, with_labels=True, node_color=node_colors, edge_color='gray', node_size=800) plt.title("Grafo de entrada") buf = io.BytesIO() plt.savefig(buf, format='png') plt.close() buf.seek(0) return Image.open(buf) def predict_graph(num_nodes, edges_str, node_features_str): G = parse_input(num_nodes, edges_str, node_features_str) data = from_networkx(G) data.x = torch.tensor([v for v in nx.get_node_attributes(G, "x").values()], dtype=torch.float) data.edge_index = data.edge_index data.batch = torch.tensor([0] * data.num_nodes) data = data.to(device) with torch.no_grad(): out = model(data.x, data.edge_index, data.batch) pred = out.argmax(dim=1).item() label_text = "Mutagénico ✅" if pred == 1 else "No mutagénico ❌" return label_text, draw_graph(G, pred) # ---------- INTERFAZ GRADIO ---------- description = """ Clasificador molecular basado en **Redes Neuronales en Grafo (GNN)** entrenado sobre el dataset **MUTAG**. ✏️ Puedes modificar el grafo: - Número de nodos - Aristas (formato: `[(0,1),(1,2)]`) - Características de cada nodo (7 valores binarios por nodo) """ inputs = [ gr.Number(label="Número de nodos", value=3, precision=0), gr.Textbox(label="Aristas", value="[(0,1),(1,2)]"), gr.Textbox(label="Características por nodo", value="[[1,0,0,1,0,1,0],[0,1,1,0,1,0,1],[1,1,0,0,1,0,1]]") ] outputs = [ gr.Text(label="Predicción"), gr.Image(label="Visualización del grafo") ] demo = gr.Interface( fn=predict_graph, inputs=inputs, outputs=outputs, title="🧪 Clasificador Molecular con GCN", description=description, examples=[ [3, "[(0,1),(1,2)]", "[[1,0,0,1,0,1,0],[0,1,1,0,1,0,1],[1,1,0,0,1,0,1]]"], [4, "[(0,1),(1,2),(2,3)]", "[[1,1,0,1,0,0,1],[0,0,1,1,1,0,0],[1,0,1,0,1,1,0],[0,1,0,1,1,0,1]]"] ] ) demo.launch(show_error=True)