Spaces:
Sleeping
Sleeping
# Streamlit App: PyTorch Geometric Structure Visualization | |
import streamlit as st | |
import torch | |
from torch_geometric.data import Data | |
import numpy as np | |
import plotly.graph_objs as go | |
import math # Moved import to the top | |
# Function Definitions | |
def generate_sierpinski_triangle(depth): | |
# Generate the vertices of the initial triangle | |
vertices = np.array([ | |
[0, 0, 0], | |
[1, 0, 0], | |
[0.5, np.sqrt(3)/2, 0] | |
]) | |
# Function to recursively generate points | |
def recurse_triangle(v1, v2, v3, depth): | |
if depth == 0: | |
return [v1, v2, v3] | |
else: | |
# Calculate midpoints | |
m12 = (v1 + v2) / 2 | |
m23 = (v2 + v3) / 2 | |
m31 = (v3 + v1) / 2 | |
# Recursively subdivide | |
return (recurse_triangle(v1, m12, m31, depth - 1) + | |
recurse_triangle(m12, v2, m23, depth - 1) + | |
recurse_triangle(m31, m23, v3, depth - 1)) | |
points = recurse_triangle(vertices[0], vertices[1], vertices[2], depth) | |
pos = np.array(points) | |
# Remove duplicate points | |
pos = np.unique(pos, axis=0) | |
# Create edges between points | |
edge_index = [] | |
for i in range(0, len(pos), 3): | |
idx = i % len(pos) | |
edge_index.extend([ | |
[idx, (idx+1)%len(pos)], | |
[(idx+1)%len(pos), (idx+2)%len(pos)], | |
[(idx+2)%len(pos), idx] | |
]) | |
edge_index = np.array(edge_index).T | |
return pos, edge_index | |
def generate_spiral(turns, points_per_turn): | |
total_points = turns * points_per_turn | |
theta_max = 2 * np.pi * turns | |
theta = np.linspace(0, theta_max, total_points) | |
z = np.linspace(0, 1, total_points) | |
r = z # Spiral expanding in radius | |
x = r * np.cos(theta) | |
y = r * np.sin(theta) | |
pos = np.vstack((x, y, z)).T | |
# Edges connect sequential points | |
edge_index = np.array([np.arange(total_points - 1), np.arange(1, total_points)]) | |
return pos, edge_index | |
def generate_plant(iterations, angle): | |
axiom = "F" | |
rules = {"F": "F[+F]F[-F]F"} | |
def expand_axiom(axiom, rules, iterations): | |
for _ in range(iterations): | |
new_axiom = "" | |
for ch in axiom: | |
new_axiom += rules.get(ch, ch) | |
axiom = new_axiom | |
return axiom | |
final_axiom = expand_axiom(axiom, rules, iterations) | |
stack = [] | |
pos_list = [] | |
edge_list = [] | |
current_pos = np.array([0.0, 0.0, 0.0]) | |
pos_list.append(current_pos.copy()) | |
idx = 0 | |
direction = np.array([0.0, 1.0, 0.0]) | |
for command in final_axiom: | |
if command == 'F': | |
next_pos = current_pos + direction | |
pos_list.append(next_pos.copy()) | |
edge_list.append([idx, idx + 1]) | |
current_pos = next_pos | |
idx += 1 | |
elif command == '+': | |
theta = np.radians(angle) | |
rotation_matrix = rotation_matrix_3d(np.array([0, 0, 1]), theta) | |
direction = rotation_matrix @ direction | |
elif command == '-': | |
theta = np.radians(-angle) | |
rotation_matrix = rotation_matrix_3d(np.array([0, 0, 1]), theta) | |
direction = rotation_matrix @ direction | |
elif command == '[': | |
stack.append((current_pos.copy(), direction.copy(), idx)) | |
elif command == ']': | |
current_pos, direction, idx = stack.pop() | |
pos = np.array(pos_list) | |
edge_index = np.array(edge_list).T | |
return pos, edge_index | |
def rotation_matrix_3d(axis, theta): | |
# Return the rotation matrix associated with rotation about the given axis by theta radians. | |
axis = axis / np.linalg.norm(axis) | |
a = np.cos(theta) | |
b, c, d = axis * np.sin(theta) | |
return np.array([ | |
[a + (1 - a) * axis[0] * axis[0], | |
(1 - a) * axis[0] * axis[1] - axis[2] * np.sin(theta), | |
(1 - a) * axis[0] * axis[2] + axis[1] * np.sin(theta)], | |
[(1 - a) * axis[1] * axis[0] + axis[2] * np.sin(theta), | |
a + (1 - a) * axis[1] * axis[1], | |
(1 - a) * axis[1] * axis[2] - axis[0] * np.sin(theta)], | |
[(1 - a) * axis[2] * axis[0] - axis[1] * np.sin(theta), | |
(1 - a) * axis[2] * axis[1] + axis[0] * np.sin(theta), | |
a + (1 - a) * axis[2] * axis[2]] | |
]) | |
def plot_graph_3d(pos, edge_index): | |
x, y, z = pos[:, 0], pos[:, 1], pos[:, 2] | |
edge_x = [] | |
edge_y = [] | |
edge_z = [] | |
for i in range(edge_index.shape[1]): | |
src = edge_index[0, i] | |
dst = edge_index[1, i] | |
edge_x += [x[src], x[dst], None] | |
edge_y += [y[src], y[dst], None] | |
edge_z += [z[src], z[dst], None] | |
edge_trace = go.Scatter3d( | |
x=edge_x, y=edge_y, z=edge_z, | |
line=dict(width=2, color='gray'), | |
hoverinfo='none', | |
mode='lines') | |
node_trace = go.Scatter3d( | |
x=x, y=y, z=z, | |
mode='markers', | |
marker=dict( | |
size=4, | |
color='red', | |
), | |
hoverinfo='none' | |
) | |
fig = go.Figure(data=[edge_trace, node_trace]) | |
fig.update_layout( | |
scene=dict( | |
xaxis_title='X', | |
yaxis_title='Y', | |
zaxis_title='Z', | |
aspectmode='data' | |
), | |
showlegend=False, | |
margin=dict(l=0, r=0, b=0, t=0) # Optional: adjust margins | |
) | |
return fig | |
# Main App Code | |
def main(): | |
st.title("PyTorch Geometric Structure Visualization") | |
structure_type = st.sidebar.selectbox( | |
"Select Structure Type", | |
("Sierpinski Triangle", "Spiral", "Plant Structure") | |
) | |
if structure_type == "Sierpinski Triangle": | |
depth = st.sidebar.slider("Recursion Depth", 0, 5, 3) | |
pos, edge_index = generate_sierpinski_triangle(depth) | |
data = Data(pos=torch.tensor(pos, dtype=torch.float), edge_index=torch.tensor(edge_index, dtype=torch.long)) | |
fig = plot_graph_3d(pos, edge_index) | |
st.plotly_chart(fig) | |
elif structure_type == "Spiral": | |
turns = st.sidebar.slider("Number of Turns", 1, 20, 5) | |
points_per_turn = st.sidebar.slider("Points per Turn", 10, 100, 50) | |
pos, edge_index = generate_spiral(turns, points_per_turn) | |
data = Data(pos=torch.tensor(pos, dtype=torch.float), edge_index=torch.tensor(edge_index, dtype=torch.long)) | |
fig = plot_graph_3d(pos, edge_index) | |
st.plotly_chart(fig) | |
elif structure_type == "Plant Structure": | |
iterations = st.sidebar.slider("L-system Iterations", 1, 5, 3) | |
angle = st.sidebar.slider("Branching Angle", 15, 45, 25) | |
pos, edge_index = generate_plant(iterations, angle) | |
data = Data(pos=torch.tensor(pos, dtype=torch.float), edge_index=torch.tensor(edge_index, dtype=torch.long)) | |
fig = plot_graph_3d(pos, edge_index) | |
st.plotly_chart(fig) | |
if __name__ == "__main__": | |
main() | |