AdrianRevi commited on
Commit
bcfdd81
·
verified ·
1 Parent(s): 18de874

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -17
app.py CHANGED
@@ -1,31 +1,57 @@
1
- import gradio as gr
2
  import torch
3
- from torch_geometric.data import Data
 
4
  from torch_geometric.utils import from_networkx
5
  import networkx as nx
 
 
 
 
 
 
 
 
 
6
 
7
- # Usamos el modelo GCN previamente entrenado (podrías cambiar por GAT si lo prefieres)
 
 
 
 
 
 
 
 
 
 
 
8
  model.eval()
9
 
10
- def predict_mutagenicity():
11
- # Creamos un grafo de prueba simple (3 nodos conectados)
12
  G = nx.Graph()
13
  G.add_edges_from([(0, 1), (1, 2)])
14
- nx.set_node_attributes(G, {i: [1, 0, 0, 1, 0, 1, 0] for i in G.nodes}, "x") # vector ficticio
15
 
16
- # Convertimos a objeto PyG
17
- pyg_data = from_networkx(G)
18
- pyg_data.x = torch.tensor(list(nx.get_node_attributes(G, 'x').values()), dtype=torch.float)
19
- pyg_data.edge_index = pyg_data.edge_index
20
- pyg_data.batch = torch.tensor([0] * pyg_data.num_nodes)
21
 
22
- pyg_data = pyg_data.to(device)
23
  with torch.no_grad():
24
- out = model(pyg_data.x, pyg_data.edge_index, pyg_data.batch)
25
  pred = out.argmax(dim=1).item()
26
 
27
- return "Mutagénico" if pred == 1 else "No mutagénico"
 
 
 
 
 
 
 
 
 
28
 
29
- gr.Interface(fn=predict_mutagenicity, inputs=[], outputs="text",
30
- title="Clasificador de Moléculas con GNN",
31
- description="Demo simple de GCN sobre grafos moleculares (MUTAG)").launch()
 
 
1
  import torch
2
+ import torch.nn.functional as F
3
+ from torch_geometric.nn import GCNConv, global_mean_pool
4
  from torch_geometric.utils import from_networkx
5
  import networkx as nx
6
+ import gradio as gr
7
+
8
+ # --- Mismo modelo que en Colab ---
9
+ class GCN(torch.nn.Module):
10
+ def __init__(self, hidden_channels=64):
11
+ super().__init__()
12
+ self.conv1 = GCNConv(7, hidden_channels)
13
+ self.conv2 = GCNConv(hidden_channels, hidden_channels)
14
+ self.lin = torch.nn.Linear(hidden_channels, 2)
15
 
16
+ def forward(self, x, edge_index, batch):
17
+ x = self.conv1(x, edge_index)
18
+ x = F.relu(x)
19
+ x = self.conv2(x, edge_index)
20
+ x = F.relu(x)
21
+ x = global_mean_pool(x, batch)
22
+ return self.lin(x)
23
+
24
+ # --- Carga de modelo entrenado ---
25
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
26
+ model = GCN().to(device)
27
+ model.load_state_dict(torch.load("model_gcn.pth", map_location=device))
28
  model.eval()
29
 
30
+ # --- Función de predicción sobre un grafo ejemplo ---
31
+ def demo_predict():
32
  G = nx.Graph()
33
  G.add_edges_from([(0, 1), (1, 2)])
34
+ nx.set_node_attributes(G, {i: [1, 0, 0, 1, 0, 1, 0] for i in G.nodes}, "x") # vector de 7 dimensiones
35
 
36
+ data = from_networkx(G)
37
+ data.x = torch.tensor(list(nx.get_node_attributes(G, "x").values()), dtype=torch.float)
38
+ data.edge_index = data.edge_index
39
+ data.batch = torch.tensor([0] * data.num_nodes)
 
40
 
41
+ data = data.to(device)
42
  with torch.no_grad():
43
+ out = model(data.x, data.edge_index, data.batch)
44
  pred = out.argmax(dim=1).item()
45
 
46
+ return "Mutagénico" if pred == 1 else "No mutagénico"
47
+
48
+ # --- Interfaz Gradio ---
49
+ demo = gr.Interface(
50
+ fn=demo_predict,
51
+ inputs=[],
52
+ outputs="text",
53
+ title="Clasificador de Moléculas con GCN",
54
+ description="Este demo usa una red neuronal en grafo entrenada sobre MUTAG para clasificar moléculas como mutagénicas o no."
55
+ )
56
 
57
+ demo.launch()