Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Streamlit App: PyTorch Geometric Structure Visualization
|
2 |
+
|
3 |
+
import streamlit as st
|
4 |
+
import torch
|
5 |
+
from torch_geometric.data import Data
|
6 |
+
import numpy as np
|
7 |
+
import plotly.graph_objs as go
|
8 |
+
|
9 |
+
st.title("PyTorch Geometric Structure Visualization")
|
10 |
+
|
11 |
+
structure_type = st.sidebar.selectbox(
|
12 |
+
"Select Structure Type",
|
13 |
+
("Sierpinski Triangle", "Spiral", "Plant Structure")
|
14 |
+
)
|
15 |
+
|
16 |
+
if structure_type == "Sierpinski Triangle":
|
17 |
+
depth = st.sidebar.slider("Recursion Depth", 0, 5, 3)
|
18 |
+
pos, edge_index = generate_sierpinski_triangle(depth)
|
19 |
+
data = Data(pos=torch.tensor(pos, dtype=torch.float), edge_index=torch.tensor(edge_index, dtype=torch.long))
|
20 |
+
fig = plot_graph_3d(pos, edge_index)
|
21 |
+
st.plotly_chart(fig)
|
22 |
+
|
23 |
+
elif structure_type == "Spiral":
|
24 |
+
turns = st.sidebar.slider("Number of Turns", 1, 20, 5)
|
25 |
+
points_per_turn = st.sidebar.slider("Points per Turn", 10, 100, 50)
|
26 |
+
pos, edge_index = generate_spiral(turns, points_per_turn)
|
27 |
+
data = Data(pos=torch.tensor(pos, dtype=torch.float), edge_index=torch.tensor(edge_index, dtype=torch.long))
|
28 |
+
fig = plot_graph_3d(pos, edge_index)
|
29 |
+
st.plotly_chart(fig)
|
30 |
+
|
31 |
+
elif structure_type == "Plant Structure":
|
32 |
+
iterations = st.sidebar.slider("L-system Iterations", 1, 5, 3)
|
33 |
+
angle = st.sidebar.slider("Branching Angle", 15, 45, 25)
|
34 |
+
pos, edge_index = generate_plant(iterations, angle)
|
35 |
+
data = Data(pos=torch.tensor(pos, dtype=torch.float), edge_index=torch.tensor(edge_index, dtype=torch.long))
|
36 |
+
fig = plot_graph_3d(pos, edge_index)
|
37 |
+
st.plotly_chart(fig)
|
38 |
+
|
39 |
+
# Function Definitions
|
40 |
+
|
41 |
+
def generate_sierpinski_triangle(depth):
|
42 |
+
# Generate the vertices of the initial triangle
|
43 |
+
vertices = np.array([
|
44 |
+
[0, 0, 0],
|
45 |
+
[1, 0, 0],
|
46 |
+
[0.5, np.sqrt(3)/2, 0]
|
47 |
+
])
|
48 |
+
|
49 |
+
# Function to recursively generate points
|
50 |
+
def recurse_triangle(v1, v2, v3, depth):
|
51 |
+
if depth == 0:
|
52 |
+
return [v1, v2, v3]
|
53 |
+
else:
|
54 |
+
# Calculate midpoints
|
55 |
+
m12 = (v1 + v2) / 2
|
56 |
+
m23 = (v2 + v3) / 2
|
57 |
+
m31 = (v3 + v1) / 2
|
58 |
+
# Recursively subdivide
|
59 |
+
return (recurse_triangle(v1, m12, m31, depth - 1) +
|
60 |
+
recurse_triangle(m12, v2, m23, depth - 1) +
|
61 |
+
recurse_triangle(m31, m23, v3, depth - 1))
|
62 |
+
|
63 |
+
points = recurse_triangle(vertices[0], vertices[1], vertices[2], depth)
|
64 |
+
pos = np.array(points)
|
65 |
+
# Create edges between points
|
66 |
+
edge_index = []
|
67 |
+
for i in range(0, len(pos), 3):
|
68 |
+
edge_index.extend([
|
69 |
+
[i, i+1],
|
70 |
+
[i+1, i+2],
|
71 |
+
[i+2, i]
|
72 |
+
])
|
73 |
+
edge_index = np.array(edge_index).T
|
74 |
+
return pos, edge_index
|
75 |
+
|
76 |
+
def generate_spiral(turns, points_per_turn):
|
77 |
+
total_points = turns * points_per_turn
|
78 |
+
theta_max = 2 * np.pi * turns
|
79 |
+
theta = np.linspace(0, theta_max, total_points)
|
80 |
+
z = np.linspace(0, 1, total_points)
|
81 |
+
r = z # Spiral expanding in radius
|
82 |
+
x = r * np.cos(theta)
|
83 |
+
y = r * np.sin(theta)
|
84 |
+
pos = np.vstack((x, y, z)).T
|
85 |
+
# Edges connect sequential points
|
86 |
+
edge_index = np.array([np.arange(total_points - 1), np.arange(1, total_points)])
|
87 |
+
return pos, edge_index
|
88 |
+
|
89 |
+
def generate_plant(iterations, angle):
|
90 |
+
axiom = "F"
|
91 |
+
rules = {"F": "F[+F]F[-F]F"}
|
92 |
+
import math
|
93 |
+
|
94 |
+
def expand_axiom(axiom, rules, iterations):
|
95 |
+
for _ in range(iterations):
|
96 |
+
new_axiom = ""
|
97 |
+
for ch in axiom:
|
98 |
+
new_axiom += rules.get(ch, ch)
|
99 |
+
axiom = new_axiom
|
100 |
+
return axiom
|
101 |
+
|
102 |
+
final_axiom = expand_axiom(axiom, rules, iterations)
|
103 |
+
|
104 |
+
stack = []
|
105 |
+
pos_list = []
|
106 |
+
edge_list = []
|
107 |
+
current_pos = np.array([0, 0, 0])
|
108 |
+
pos_list.append(current_pos.copy())
|
109 |
+
idx = 0
|
110 |
+
direction = np.array([0, 1, 0])
|
111 |
+
|
112 |
+
for command in final_axiom:
|
113 |
+
if command == 'F':
|
114 |
+
next_pos = current_pos + direction
|
115 |
+
pos_list.append(next_pos.copy())
|
116 |
+
edge_list.append([idx, idx + 1])
|
117 |
+
current_pos = next_pos
|
118 |
+
idx += 1
|
119 |
+
elif command == '+':
|
120 |
+
theta = np.radians(angle)
|
121 |
+
rotation_matrix = rotation_matrix_3d(np.array([0, 0, 1]), theta)
|
122 |
+
direction = rotation_matrix @ direction
|
123 |
+
elif command == '-':
|
124 |
+
theta = np.radians(-angle)
|
125 |
+
rotation_matrix = rotation_matrix_3d(np.array([0, 0, 1]), theta)
|
126 |
+
direction = rotation_matrix @ direction
|
127 |
+
elif command == '[':
|
128 |
+
stack.append((current_pos.copy(), direction.copy(), idx))
|
129 |
+
elif command == ']':
|
130 |
+
current_pos, direction, idx = stack.pop()
|
131 |
+
pos_list.append(current_pos.copy())
|
132 |
+
idx += 1
|
133 |
+
|
134 |
+
pos = np.array(pos_list)
|
135 |
+
edge_index = np.array(edge_list).T
|
136 |
+
return pos, edge_index
|
137 |
+
|
138 |
+
def rotation_matrix_3d(axis, theta):
|
139 |
+
# Return the rotation matrix associated with rotation about the given axis by theta radians.
|
140 |
+
axis = axis / np.linalg.norm(axis)
|
141 |
+
a = np.cos(theta / 2)
|
142 |
+
b, c, d = -axis * np.sin(theta / 2)
|
143 |
+
return np.array([[a*a + b*b - c*c - d*d, 2*(b*c - a*d), 2*(b*d + a*c)],
|
144 |
+
[2*(b*c + a*d), a*a + c*c - b*b - d*d, 2*(c*d - a*b)],
|
145 |
+
[2*(b*d - a*c), 2*(c*d + a*b), a*a + d*d - b*b - c*c]])
|
146 |
+
|
147 |
+
def plot_graph_3d(pos, edge_index):
|
148 |
+
x, y, z = pos[:, 0], pos[:, 1], pos[:, 2]
|
149 |
+
edge_x = []
|
150 |
+
edge_y = []
|
151 |
+
edge_z = []
|
152 |
+
|
153 |
+
for i in range(edge_index.shape[1]):
|
154 |
+
src = edge_index[0, i]
|
155 |
+
dst = edge_index[1, i]
|
156 |
+
edge_x += [x[src], x[dst], None]
|
157 |
+
edge_y += [y[src], y[dst], None]
|
158 |
+
edge_z += [z[src], z[dst], None]
|
159 |
+
|
160 |
+
edge_trace = go.Scatter3d(
|
161 |
+
x=edge_x, y=edge_y, z=edge_z,
|
162 |
+
line=dict(width=2, color='gray'),
|
163 |
+
hoverinfo='none',
|
164 |
+
mode='lines')
|
165 |
+
|
166 |
+
node_trace = go.Scatter3d(
|
167 |
+
x=x, y=y, z=z,
|
168 |
+
mode='markers',
|
169 |
+
marker=dict(
|
170 |
+
size=4,
|
171 |
+
color='red',
|
172 |
+
),
|
173 |
+
hoverinfo='none'
|
174 |
+
)
|
175 |
+
|
176 |
+
fig = go.Figure(data=[edge_trace, node_trace])
|
177 |
+
fig.update_layout(
|
178 |
+
scene=dict(
|
179 |
+
xaxis_title='X',
|
180 |
+
yaxis_title='Y',
|
181 |
+
zaxis_title='Z',
|
182 |
+
aspectmode='data'
|
183 |
+
),
|
184 |
+
showlegend=False
|
185 |
+
)
|
186 |
+
return fig
|