awacke1 commited on
Commit
4805c8c
·
verified ·
1 Parent(s): 3e04947

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -44
app.py CHANGED
@@ -5,36 +5,7 @@ import torch
5
  from torch_geometric.data import Data
6
  import numpy as np
7
  import plotly.graph_objs as go
8
-
9
- st.title("PyTorch Geometric Structure Visualization")
10
-
11
- structure_type = st.sidebar.selectbox(
12
- "Select Structure Type",
13
- ("Sierpinski Triangle", "Spiral", "Plant Structure")
14
- )
15
-
16
- if structure_type == "Sierpinski Triangle":
17
- depth = st.sidebar.slider("Recursion Depth", 0, 5, 3)
18
- pos, edge_index = generate_sierpinski_triangle(depth)
19
- data = Data(pos=torch.tensor(pos, dtype=torch.float), edge_index=torch.tensor(edge_index, dtype=torch.long))
20
- fig = plot_graph_3d(pos, edge_index)
21
- st.plotly_chart(fig)
22
-
23
- elif structure_type == "Spiral":
24
- turns = st.sidebar.slider("Number of Turns", 1, 20, 5)
25
- points_per_turn = st.sidebar.slider("Points per Turn", 10, 100, 50)
26
- pos, edge_index = generate_spiral(turns, points_per_turn)
27
- data = Data(pos=torch.tensor(pos, dtype=torch.float), edge_index=torch.tensor(edge_index, dtype=torch.long))
28
- fig = plot_graph_3d(pos, edge_index)
29
- st.plotly_chart(fig)
30
-
31
- elif structure_type == "Plant Structure":
32
- iterations = st.sidebar.slider("L-system Iterations", 1, 5, 3)
33
- angle = st.sidebar.slider("Branching Angle", 15, 45, 25)
34
- pos, edge_index = generate_plant(iterations, angle)
35
- data = Data(pos=torch.tensor(pos, dtype=torch.float), edge_index=torch.tensor(edge_index, dtype=torch.long))
36
- fig = plot_graph_3d(pos, edge_index)
37
- st.plotly_chart(fig)
38
 
39
  # Function Definitions
40
 
@@ -62,13 +33,16 @@ def generate_sierpinski_triangle(depth):
62
 
63
  points = recurse_triangle(vertices[0], vertices[1], vertices[2], depth)
64
  pos = np.array(points)
 
 
65
  # Create edges between points
66
  edge_index = []
67
  for i in range(0, len(pos), 3):
 
68
  edge_index.extend([
69
- [i, i+1],
70
- [i+1, i+2],
71
- [i+2, i]
72
  ])
73
  edge_index = np.array(edge_index).T
74
  return pos, edge_index
@@ -89,7 +63,6 @@ def generate_spiral(turns, points_per_turn):
89
  def generate_plant(iterations, angle):
90
  axiom = "F"
91
  rules = {"F": "F[+F]F[-F]F"}
92
- import math
93
 
94
  def expand_axiom(axiom, rules, iterations):
95
  for _ in range(iterations):
@@ -104,10 +77,10 @@ def generate_plant(iterations, angle):
104
  stack = []
105
  pos_list = []
106
  edge_list = []
107
- current_pos = np.array([0, 0, 0])
108
  pos_list.append(current_pos.copy())
109
  idx = 0
110
- direction = np.array([0, 1, 0])
111
 
112
  for command in final_axiom:
113
  if command == 'F':
@@ -128,8 +101,6 @@ def generate_plant(iterations, angle):
128
  stack.append((current_pos.copy(), direction.copy(), idx))
129
  elif command == ']':
130
  current_pos, direction, idx = stack.pop()
131
- pos_list.append(current_pos.copy())
132
- idx += 1
133
 
134
  pos = np.array(pos_list)
135
  edge_index = np.array(edge_list).T
@@ -138,11 +109,19 @@ def generate_plant(iterations, angle):
138
  def rotation_matrix_3d(axis, theta):
139
  # Return the rotation matrix associated with rotation about the given axis by theta radians.
140
  axis = axis / np.linalg.norm(axis)
141
- a = np.cos(theta / 2)
142
- b, c, d = -axis * np.sin(theta / 2)
143
- return np.array([[a*a + b*b - c*c - d*d, 2*(b*c - a*d), 2*(b*d + a*c)],
144
- [2*(b*c + a*d), a*a + c*c - b*b - d*d, 2*(c*d - a*b)],
145
- [2*(b*d - a*c), 2*(c*d + a*b), a*a + d*d - b*b - c*c]])
 
 
 
 
 
 
 
 
146
 
147
  def plot_graph_3d(pos, edge_index):
148
  x, y, z = pos[:, 0], pos[:, 1], pos[:, 2]
@@ -181,6 +160,43 @@ def plot_graph_3d(pos, edge_index):
181
  zaxis_title='Z',
182
  aspectmode='data'
183
  ),
184
- showlegend=False
 
185
  )
186
  return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from torch_geometric.data import Data
6
  import numpy as np
7
  import plotly.graph_objs as go
8
+ import math # Moved import to the top
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  # Function Definitions
11
 
 
33
 
34
  points = recurse_triangle(vertices[0], vertices[1], vertices[2], depth)
35
  pos = np.array(points)
36
+ # Remove duplicate points
37
+ pos = np.unique(pos, axis=0)
38
  # Create edges between points
39
  edge_index = []
40
  for i in range(0, len(pos), 3):
41
+ idx = i % len(pos)
42
  edge_index.extend([
43
+ [idx, (idx+1)%len(pos)],
44
+ [(idx+1)%len(pos), (idx+2)%len(pos)],
45
+ [(idx+2)%len(pos), idx]
46
  ])
47
  edge_index = np.array(edge_index).T
48
  return pos, edge_index
 
63
  def generate_plant(iterations, angle):
64
  axiom = "F"
65
  rules = {"F": "F[+F]F[-F]F"}
 
66
 
67
  def expand_axiom(axiom, rules, iterations):
68
  for _ in range(iterations):
 
77
  stack = []
78
  pos_list = []
79
  edge_list = []
80
+ current_pos = np.array([0.0, 0.0, 0.0])
81
  pos_list.append(current_pos.copy())
82
  idx = 0
83
+ direction = np.array([0.0, 1.0, 0.0])
84
 
85
  for command in final_axiom:
86
  if command == 'F':
 
101
  stack.append((current_pos.copy(), direction.copy(), idx))
102
  elif command == ']':
103
  current_pos, direction, idx = stack.pop()
 
 
104
 
105
  pos = np.array(pos_list)
106
  edge_index = np.array(edge_list).T
 
109
  def rotation_matrix_3d(axis, theta):
110
  # Return the rotation matrix associated with rotation about the given axis by theta radians.
111
  axis = axis / np.linalg.norm(axis)
112
+ a = np.cos(theta)
113
+ b, c, d = axis * np.sin(theta)
114
+ return np.array([
115
+ [a + (1 - a) * axis[0] * axis[0],
116
+ (1 - a) * axis[0] * axis[1] - axis[2] * np.sin(theta),
117
+ (1 - a) * axis[0] * axis[2] + axis[1] * np.sin(theta)],
118
+ [(1 - a) * axis[1] * axis[0] + axis[2] * np.sin(theta),
119
+ a + (1 - a) * axis[1] * axis[1],
120
+ (1 - a) * axis[1] * axis[2] - axis[0] * np.sin(theta)],
121
+ [(1 - a) * axis[2] * axis[0] - axis[1] * np.sin(theta),
122
+ (1 - a) * axis[2] * axis[1] + axis[0] * np.sin(theta),
123
+ a + (1 - a) * axis[2] * axis[2]]
124
+ ])
125
 
126
  def plot_graph_3d(pos, edge_index):
127
  x, y, z = pos[:, 0], pos[:, 1], pos[:, 2]
 
160
  zaxis_title='Z',
161
  aspectmode='data'
162
  ),
163
+ showlegend=False,
164
+ margin=dict(l=0, r=0, b=0, t=0) # Optional: adjust margins
165
  )
166
  return fig
167
+
168
+ # Main App Code
169
+
170
+ def main():
171
+ st.title("PyTorch Geometric Structure Visualization")
172
+
173
+ structure_type = st.sidebar.selectbox(
174
+ "Select Structure Type",
175
+ ("Sierpinski Triangle", "Spiral", "Plant Structure")
176
+ )
177
+
178
+ if structure_type == "Sierpinski Triangle":
179
+ depth = st.sidebar.slider("Recursion Depth", 0, 5, 3)
180
+ pos, edge_index = generate_sierpinski_triangle(depth)
181
+ data = Data(pos=torch.tensor(pos, dtype=torch.float), edge_index=torch.tensor(edge_index, dtype=torch.long))
182
+ fig = plot_graph_3d(pos, edge_index)
183
+ st.plotly_chart(fig)
184
+
185
+ elif structure_type == "Spiral":
186
+ turns = st.sidebar.slider("Number of Turns", 1, 20, 5)
187
+ points_per_turn = st.sidebar.slider("Points per Turn", 10, 100, 50)
188
+ pos, edge_index = generate_spiral(turns, points_per_turn)
189
+ data = Data(pos=torch.tensor(pos, dtype=torch.float), edge_index=torch.tensor(edge_index, dtype=torch.long))
190
+ fig = plot_graph_3d(pos, edge_index)
191
+ st.plotly_chart(fig)
192
+
193
+ elif structure_type == "Plant Structure":
194
+ iterations = st.sidebar.slider("L-system Iterations", 1, 5, 3)
195
+ angle = st.sidebar.slider("Branching Angle", 15, 45, 25)
196
+ pos, edge_index = generate_plant(iterations, angle)
197
+ data = Data(pos=torch.tensor(pos, dtype=torch.float), edge_index=torch.tensor(edge_index, dtype=torch.long))
198
+ fig = plot_graph_3d(pos, edge_index)
199
+ st.plotly_chart(fig)
200
+
201
+ if __name__ == "__main__":
202
+ main()