Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -5,36 +5,7 @@ import torch
|
|
5 |
from torch_geometric.data import Data
|
6 |
import numpy as np
|
7 |
import plotly.graph_objs as go
|
8 |
-
|
9 |
-
st.title("PyTorch Geometric Structure Visualization")
|
10 |
-
|
11 |
-
structure_type = st.sidebar.selectbox(
|
12 |
-
"Select Structure Type",
|
13 |
-
("Sierpinski Triangle", "Spiral", "Plant Structure")
|
14 |
-
)
|
15 |
-
|
16 |
-
if structure_type == "Sierpinski Triangle":
|
17 |
-
depth = st.sidebar.slider("Recursion Depth", 0, 5, 3)
|
18 |
-
pos, edge_index = generate_sierpinski_triangle(depth)
|
19 |
-
data = Data(pos=torch.tensor(pos, dtype=torch.float), edge_index=torch.tensor(edge_index, dtype=torch.long))
|
20 |
-
fig = plot_graph_3d(pos, edge_index)
|
21 |
-
st.plotly_chart(fig)
|
22 |
-
|
23 |
-
elif structure_type == "Spiral":
|
24 |
-
turns = st.sidebar.slider("Number of Turns", 1, 20, 5)
|
25 |
-
points_per_turn = st.sidebar.slider("Points per Turn", 10, 100, 50)
|
26 |
-
pos, edge_index = generate_spiral(turns, points_per_turn)
|
27 |
-
data = Data(pos=torch.tensor(pos, dtype=torch.float), edge_index=torch.tensor(edge_index, dtype=torch.long))
|
28 |
-
fig = plot_graph_3d(pos, edge_index)
|
29 |
-
st.plotly_chart(fig)
|
30 |
-
|
31 |
-
elif structure_type == "Plant Structure":
|
32 |
-
iterations = st.sidebar.slider("L-system Iterations", 1, 5, 3)
|
33 |
-
angle = st.sidebar.slider("Branching Angle", 15, 45, 25)
|
34 |
-
pos, edge_index = generate_plant(iterations, angle)
|
35 |
-
data = Data(pos=torch.tensor(pos, dtype=torch.float), edge_index=torch.tensor(edge_index, dtype=torch.long))
|
36 |
-
fig = plot_graph_3d(pos, edge_index)
|
37 |
-
st.plotly_chart(fig)
|
38 |
|
39 |
# Function Definitions
|
40 |
|
@@ -62,13 +33,16 @@ def generate_sierpinski_triangle(depth):
|
|
62 |
|
63 |
points = recurse_triangle(vertices[0], vertices[1], vertices[2], depth)
|
64 |
pos = np.array(points)
|
|
|
|
|
65 |
# Create edges between points
|
66 |
edge_index = []
|
67 |
for i in range(0, len(pos), 3):
|
|
|
68 |
edge_index.extend([
|
69 |
-
[
|
70 |
-
[
|
71 |
-
[
|
72 |
])
|
73 |
edge_index = np.array(edge_index).T
|
74 |
return pos, edge_index
|
@@ -89,7 +63,6 @@ def generate_spiral(turns, points_per_turn):
|
|
89 |
def generate_plant(iterations, angle):
|
90 |
axiom = "F"
|
91 |
rules = {"F": "F[+F]F[-F]F"}
|
92 |
-
import math
|
93 |
|
94 |
def expand_axiom(axiom, rules, iterations):
|
95 |
for _ in range(iterations):
|
@@ -104,10 +77,10 @@ def generate_plant(iterations, angle):
|
|
104 |
stack = []
|
105 |
pos_list = []
|
106 |
edge_list = []
|
107 |
-
current_pos = np.array([0, 0, 0])
|
108 |
pos_list.append(current_pos.copy())
|
109 |
idx = 0
|
110 |
-
direction = np.array([0, 1, 0])
|
111 |
|
112 |
for command in final_axiom:
|
113 |
if command == 'F':
|
@@ -128,8 +101,6 @@ def generate_plant(iterations, angle):
|
|
128 |
stack.append((current_pos.copy(), direction.copy(), idx))
|
129 |
elif command == ']':
|
130 |
current_pos, direction, idx = stack.pop()
|
131 |
-
pos_list.append(current_pos.copy())
|
132 |
-
idx += 1
|
133 |
|
134 |
pos = np.array(pos_list)
|
135 |
edge_index = np.array(edge_list).T
|
@@ -138,11 +109,19 @@ def generate_plant(iterations, angle):
|
|
138 |
def rotation_matrix_3d(axis, theta):
|
139 |
# Return the rotation matrix associated with rotation about the given axis by theta radians.
|
140 |
axis = axis / np.linalg.norm(axis)
|
141 |
-
a = np.cos(theta
|
142 |
-
b, c, d =
|
143 |
-
return np.array([
|
144 |
-
|
145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
|
147 |
def plot_graph_3d(pos, edge_index):
|
148 |
x, y, z = pos[:, 0], pos[:, 1], pos[:, 2]
|
@@ -181,6 +160,43 @@ def plot_graph_3d(pos, edge_index):
|
|
181 |
zaxis_title='Z',
|
182 |
aspectmode='data'
|
183 |
),
|
184 |
-
showlegend=False
|
|
|
185 |
)
|
186 |
return fig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
from torch_geometric.data import Data
|
6 |
import numpy as np
|
7 |
import plotly.graph_objs as go
|
8 |
+
import math # Moved import to the top
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
# Function Definitions
|
11 |
|
|
|
33 |
|
34 |
points = recurse_triangle(vertices[0], vertices[1], vertices[2], depth)
|
35 |
pos = np.array(points)
|
36 |
+
# Remove duplicate points
|
37 |
+
pos = np.unique(pos, axis=0)
|
38 |
# Create edges between points
|
39 |
edge_index = []
|
40 |
for i in range(0, len(pos), 3):
|
41 |
+
idx = i % len(pos)
|
42 |
edge_index.extend([
|
43 |
+
[idx, (idx+1)%len(pos)],
|
44 |
+
[(idx+1)%len(pos), (idx+2)%len(pos)],
|
45 |
+
[(idx+2)%len(pos), idx]
|
46 |
])
|
47 |
edge_index = np.array(edge_index).T
|
48 |
return pos, edge_index
|
|
|
63 |
def generate_plant(iterations, angle):
|
64 |
axiom = "F"
|
65 |
rules = {"F": "F[+F]F[-F]F"}
|
|
|
66 |
|
67 |
def expand_axiom(axiom, rules, iterations):
|
68 |
for _ in range(iterations):
|
|
|
77 |
stack = []
|
78 |
pos_list = []
|
79 |
edge_list = []
|
80 |
+
current_pos = np.array([0.0, 0.0, 0.0])
|
81 |
pos_list.append(current_pos.copy())
|
82 |
idx = 0
|
83 |
+
direction = np.array([0.0, 1.0, 0.0])
|
84 |
|
85 |
for command in final_axiom:
|
86 |
if command == 'F':
|
|
|
101 |
stack.append((current_pos.copy(), direction.copy(), idx))
|
102 |
elif command == ']':
|
103 |
current_pos, direction, idx = stack.pop()
|
|
|
|
|
104 |
|
105 |
pos = np.array(pos_list)
|
106 |
edge_index = np.array(edge_list).T
|
|
|
109 |
def rotation_matrix_3d(axis, theta):
|
110 |
# Return the rotation matrix associated with rotation about the given axis by theta radians.
|
111 |
axis = axis / np.linalg.norm(axis)
|
112 |
+
a = np.cos(theta)
|
113 |
+
b, c, d = axis * np.sin(theta)
|
114 |
+
return np.array([
|
115 |
+
[a + (1 - a) * axis[0] * axis[0],
|
116 |
+
(1 - a) * axis[0] * axis[1] - axis[2] * np.sin(theta),
|
117 |
+
(1 - a) * axis[0] * axis[2] + axis[1] * np.sin(theta)],
|
118 |
+
[(1 - a) * axis[1] * axis[0] + axis[2] * np.sin(theta),
|
119 |
+
a + (1 - a) * axis[1] * axis[1],
|
120 |
+
(1 - a) * axis[1] * axis[2] - axis[0] * np.sin(theta)],
|
121 |
+
[(1 - a) * axis[2] * axis[0] - axis[1] * np.sin(theta),
|
122 |
+
(1 - a) * axis[2] * axis[1] + axis[0] * np.sin(theta),
|
123 |
+
a + (1 - a) * axis[2] * axis[2]]
|
124 |
+
])
|
125 |
|
126 |
def plot_graph_3d(pos, edge_index):
|
127 |
x, y, z = pos[:, 0], pos[:, 1], pos[:, 2]
|
|
|
160 |
zaxis_title='Z',
|
161 |
aspectmode='data'
|
162 |
),
|
163 |
+
showlegend=False,
|
164 |
+
margin=dict(l=0, r=0, b=0, t=0) # Optional: adjust margins
|
165 |
)
|
166 |
return fig
|
167 |
+
|
168 |
+
# Main App Code
|
169 |
+
|
170 |
+
def main():
|
171 |
+
st.title("PyTorch Geometric Structure Visualization")
|
172 |
+
|
173 |
+
structure_type = st.sidebar.selectbox(
|
174 |
+
"Select Structure Type",
|
175 |
+
("Sierpinski Triangle", "Spiral", "Plant Structure")
|
176 |
+
)
|
177 |
+
|
178 |
+
if structure_type == "Sierpinski Triangle":
|
179 |
+
depth = st.sidebar.slider("Recursion Depth", 0, 5, 3)
|
180 |
+
pos, edge_index = generate_sierpinski_triangle(depth)
|
181 |
+
data = Data(pos=torch.tensor(pos, dtype=torch.float), edge_index=torch.tensor(edge_index, dtype=torch.long))
|
182 |
+
fig = plot_graph_3d(pos, edge_index)
|
183 |
+
st.plotly_chart(fig)
|
184 |
+
|
185 |
+
elif structure_type == "Spiral":
|
186 |
+
turns = st.sidebar.slider("Number of Turns", 1, 20, 5)
|
187 |
+
points_per_turn = st.sidebar.slider("Points per Turn", 10, 100, 50)
|
188 |
+
pos, edge_index = generate_spiral(turns, points_per_turn)
|
189 |
+
data = Data(pos=torch.tensor(pos, dtype=torch.float), edge_index=torch.tensor(edge_index, dtype=torch.long))
|
190 |
+
fig = plot_graph_3d(pos, edge_index)
|
191 |
+
st.plotly_chart(fig)
|
192 |
+
|
193 |
+
elif structure_type == "Plant Structure":
|
194 |
+
iterations = st.sidebar.slider("L-system Iterations", 1, 5, 3)
|
195 |
+
angle = st.sidebar.slider("Branching Angle", 15, 45, 25)
|
196 |
+
pos, edge_index = generate_plant(iterations, angle)
|
197 |
+
data = Data(pos=torch.tensor(pos, dtype=torch.float), edge_index=torch.tensor(edge_index, dtype=torch.long))
|
198 |
+
fig = plot_graph_3d(pos, edge_index)
|
199 |
+
st.plotly_chart(fig)
|
200 |
+
|
201 |
+
if __name__ == "__main__":
|
202 |
+
main()
|