AdrianRevi commited on
Commit
2c5abfa
·
verified ·
1 Parent(s): 2f908d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -31
app.py CHANGED
@@ -3,13 +3,14 @@ import torch.nn.functional as F
3
  from torch_geometric.nn import GCNConv, global_mean_pool
4
  from torch_geometric.utils import from_networkx
5
  from torch_geometric.data import Data
 
6
  import networkx as nx
7
  import gradio as gr
8
  import matplotlib.pyplot as plt
 
9
  import io
10
- import base64
11
 
12
- # ---------- MODELO GCN ----------
13
 
14
  class GCN(torch.nn.Module):
15
  def __init__(self, hidden_channels=64):
@@ -33,38 +34,32 @@ model = GCN().to(device)
33
  model.load_state_dict(torch.load("model_gcn.pth", map_location=device))
34
  model.eval()
35
 
36
- # ---------- FUNCIÓN PARA PARSEAR INPUT ----------
37
 
38
  def parse_input(num_nodes, edges_str, node_features_str):
39
  G = nx.Graph()
40
 
41
  try:
42
- # Añadimos nodos
43
  for i in range(num_nodes):
44
  G.add_node(i)
45
 
46
- # Parseamos aristas
47
- edges = eval(edges_str) # formato esperado: [(0,1), (1,2)]
48
  G.add_edges_from(edges)
49
 
50
- # Parseamos características
51
- node_features = eval(node_features_str) # formato: [[1,0,1,0,1,0,1], [...]]
52
  if len(node_features) != num_nodes or any(len(f) != 7 for f in node_features):
53
- raise ValueError("Las características deben ser listas de longitud 7 para cada nodo.")
54
 
55
  nx.set_node_attributes(G, {i: node_features[i] for i in range(num_nodes)}, "x")
56
-
57
  return G
58
 
59
  except Exception as e:
60
- raise gr.Error(f"Error al procesar el input: {e}")
61
-
62
- # ---------- VISUALIZACIÓN ----------
63
 
64
  def draw_graph(G, pred_label):
65
- import matplotlib.pyplot as plt
66
- import io
67
-
68
  pos = nx.spring_layout(G)
69
  node_colors = ['lightgreen' if pred_label == 1 else 'lightcoral'] * G.number_of_nodes()
70
 
@@ -74,12 +69,9 @@ def draw_graph(G, pred_label):
74
 
75
  buf = io.BytesIO()
76
  plt.savefig(buf, format='png')
77
- buf.seek(0)
78
  plt.close()
79
-
80
- return buf # Devuelve un objeto tipo archivo
81
-
82
- # ---------- FUNCIÓN DE PREDICCIÓN ----------
83
 
84
  def predict_graph(num_nodes, edges_str, node_features_str):
85
  G = parse_input(num_nodes, edges_str, node_features_str)
@@ -100,19 +92,18 @@ def predict_graph(num_nodes, edges_str, node_features_str):
100
  # ---------- INTERFAZ GRADIO ----------
101
 
102
  description = """
103
- Este clasificador usa un modelo GCN entrenado sobre el dataset **MUTAG** para predecir si una molécula (representada como grafo) es mutagénica o no.
104
-
105
- 🔹 Puedes definir tu propio grafo ingresando el número de nodos, las aristas y las características de cada nodo.
106
 
107
- Cada nodo debe tener **7 características** (como en MUTAG).
108
- 🔗 Las aristas deben estar en formato Python: `[(0, 1), (1, 2)]`
109
- 📊 Las características deben ser una lista de listas: `[[1,0,0,1,0,1,0], [0,1,1,0,1,0,1], ...]`
 
110
  """
111
 
112
  inputs = [
113
  gr.Number(label="Número de nodos", value=3, precision=0),
114
- gr.Textbox(label="Aristas [(0,1), (1,2)]", lines=2, value="[(0,1),(1,2)]"),
115
- gr.Textbox(label="Características por nodo (listas de 7)", lines=4, value="[[1,0,0,1,0,1,0], [0,1,1,0,1,0,1], [1,1,0,0,1,0,1]]")
116
  ]
117
 
118
  outputs = [
@@ -124,7 +115,7 @@ demo = gr.Interface(
124
  fn=predict_graph,
125
  inputs=inputs,
126
  outputs=outputs,
127
- title="🔬 Clasificador Molecular con GNN (GCN)",
128
  description=description,
129
  examples=[
130
  [3, "[(0,1),(1,2)]", "[[1,0,0,1,0,1,0],[0,1,1,0,1,0,1],[1,1,0,0,1,0,1]]"],
@@ -132,4 +123,4 @@ demo = gr.Interface(
132
  ]
133
  )
134
 
135
- demo.launch()
 
3
  from torch_geometric.nn import GCNConv, global_mean_pool
4
  from torch_geometric.utils import from_networkx
5
  from torch_geometric.data import Data
6
+
7
  import networkx as nx
8
  import gradio as gr
9
  import matplotlib.pyplot as plt
10
+ from PIL import Image
11
  import io
 
12
 
13
+ # ---------- MODELO ----------
14
 
15
  class GCN(torch.nn.Module):
16
  def __init__(self, hidden_channels=64):
 
34
  model.load_state_dict(torch.load("model_gcn.pth", map_location=device))
35
  model.eval()
36
 
37
+ # ---------- FUNCIONES AUXILIARES ----------
38
 
39
  def parse_input(num_nodes, edges_str, node_features_str):
40
  G = nx.Graph()
41
 
42
  try:
43
+ # Añadir nodos
44
  for i in range(num_nodes):
45
  G.add_node(i)
46
 
47
+ # Parsear aristas
48
+ edges = eval(edges_str)
49
  G.add_edges_from(edges)
50
 
51
+ # Parsear características
52
+ node_features = eval(node_features_str)
53
  if len(node_features) != num_nodes or any(len(f) != 7 for f in node_features):
54
+ raise ValueError("Cada nodo debe tener exactamente 7 características.")
55
 
56
  nx.set_node_attributes(G, {i: node_features[i] for i in range(num_nodes)}, "x")
 
57
  return G
58
 
59
  except Exception as e:
60
+ raise gr.Error(f"Error en los datos del grafo: {e}")
 
 
61
 
62
  def draw_graph(G, pred_label):
 
 
 
63
  pos = nx.spring_layout(G)
64
  node_colors = ['lightgreen' if pred_label == 1 else 'lightcoral'] * G.number_of_nodes()
65
 
 
69
 
70
  buf = io.BytesIO()
71
  plt.savefig(buf, format='png')
 
72
  plt.close()
73
+ buf.seek(0)
74
+ return Image.open(buf)
 
 
75
 
76
  def predict_graph(num_nodes, edges_str, node_features_str):
77
  G = parse_input(num_nodes, edges_str, node_features_str)
 
92
  # ---------- INTERFAZ GRADIO ----------
93
 
94
  description = """
95
+ Clasificador molecular basado en **Redes Neuronales en Grafo (GNN)** entrenado sobre el dataset **MUTAG**.
 
 
96
 
97
+ ✏️ Puedes modificar el grafo:
98
+ - Número de nodos
99
+ - Aristas (formato: `[(0,1),(1,2)]`)
100
+ - Características de cada nodo (7 valores binarios por nodo)
101
  """
102
 
103
  inputs = [
104
  gr.Number(label="Número de nodos", value=3, precision=0),
105
+ gr.Textbox(label="Aristas", value="[(0,1),(1,2)]"),
106
+ gr.Textbox(label="Características por nodo", value="[[1,0,0,1,0,1,0],[0,1,1,0,1,0,1],[1,1,0,0,1,0,1]]")
107
  ]
108
 
109
  outputs = [
 
115
  fn=predict_graph,
116
  inputs=inputs,
117
  outputs=outputs,
118
+ title="🧪 Clasificador Molecular con GCN",
119
  description=description,
120
  examples=[
121
  [3, "[(0,1),(1,2)]", "[[1,0,0,1,0,1,0],[0,1,1,0,1,0,1],[1,1,0,0,1,0,1]]"],
 
123
  ]
124
  )
125
 
126
+ demo.launch(show_error=True, cache_examples=False)