awacke1's picture
Update app.py
4805c8c verified
raw
history blame
6.78 kB
# Streamlit App: PyTorch Geometric Structure Visualization
import streamlit as st
import torch
from torch_geometric.data import Data
import numpy as np
import plotly.graph_objs as go
import math # Moved import to the top
# Function Definitions
def generate_sierpinski_triangle(depth):
# Generate the vertices of the initial triangle
vertices = np.array([
[0, 0, 0],
[1, 0, 0],
[0.5, np.sqrt(3)/2, 0]
])
# Function to recursively generate points
def recurse_triangle(v1, v2, v3, depth):
if depth == 0:
return [v1, v2, v3]
else:
# Calculate midpoints
m12 = (v1 + v2) / 2
m23 = (v2 + v3) / 2
m31 = (v3 + v1) / 2
# Recursively subdivide
return (recurse_triangle(v1, m12, m31, depth - 1) +
recurse_triangle(m12, v2, m23, depth - 1) +
recurse_triangle(m31, m23, v3, depth - 1))
points = recurse_triangle(vertices[0], vertices[1], vertices[2], depth)
pos = np.array(points)
# Remove duplicate points
pos = np.unique(pos, axis=0)
# Create edges between points
edge_index = []
for i in range(0, len(pos), 3):
idx = i % len(pos)
edge_index.extend([
[idx, (idx+1)%len(pos)],
[(idx+1)%len(pos), (idx+2)%len(pos)],
[(idx+2)%len(pos), idx]
])
edge_index = np.array(edge_index).T
return pos, edge_index
def generate_spiral(turns, points_per_turn):
total_points = turns * points_per_turn
theta_max = 2 * np.pi * turns
theta = np.linspace(0, theta_max, total_points)
z = np.linspace(0, 1, total_points)
r = z # Spiral expanding in radius
x = r * np.cos(theta)
y = r * np.sin(theta)
pos = np.vstack((x, y, z)).T
# Edges connect sequential points
edge_index = np.array([np.arange(total_points - 1), np.arange(1, total_points)])
return pos, edge_index
def generate_plant(iterations, angle):
axiom = "F"
rules = {"F": "F[+F]F[-F]F"}
def expand_axiom(axiom, rules, iterations):
for _ in range(iterations):
new_axiom = ""
for ch in axiom:
new_axiom += rules.get(ch, ch)
axiom = new_axiom
return axiom
final_axiom = expand_axiom(axiom, rules, iterations)
stack = []
pos_list = []
edge_list = []
current_pos = np.array([0.0, 0.0, 0.0])
pos_list.append(current_pos.copy())
idx = 0
direction = np.array([0.0, 1.0, 0.0])
for command in final_axiom:
if command == 'F':
next_pos = current_pos + direction
pos_list.append(next_pos.copy())
edge_list.append([idx, idx + 1])
current_pos = next_pos
idx += 1
elif command == '+':
theta = np.radians(angle)
rotation_matrix = rotation_matrix_3d(np.array([0, 0, 1]), theta)
direction = rotation_matrix @ direction
elif command == '-':
theta = np.radians(-angle)
rotation_matrix = rotation_matrix_3d(np.array([0, 0, 1]), theta)
direction = rotation_matrix @ direction
elif command == '[':
stack.append((current_pos.copy(), direction.copy(), idx))
elif command == ']':
current_pos, direction, idx = stack.pop()
pos = np.array(pos_list)
edge_index = np.array(edge_list).T
return pos, edge_index
def rotation_matrix_3d(axis, theta):
# Return the rotation matrix associated with rotation about the given axis by theta radians.
axis = axis / np.linalg.norm(axis)
a = np.cos(theta)
b, c, d = axis * np.sin(theta)
return np.array([
[a + (1 - a) * axis[0] * axis[0],
(1 - a) * axis[0] * axis[1] - axis[2] * np.sin(theta),
(1 - a) * axis[0] * axis[2] + axis[1] * np.sin(theta)],
[(1 - a) * axis[1] * axis[0] + axis[2] * np.sin(theta),
a + (1 - a) * axis[1] * axis[1],
(1 - a) * axis[1] * axis[2] - axis[0] * np.sin(theta)],
[(1 - a) * axis[2] * axis[0] - axis[1] * np.sin(theta),
(1 - a) * axis[2] * axis[1] + axis[0] * np.sin(theta),
a + (1 - a) * axis[2] * axis[2]]
])
def plot_graph_3d(pos, edge_index):
x, y, z = pos[:, 0], pos[:, 1], pos[:, 2]
edge_x = []
edge_y = []
edge_z = []
for i in range(edge_index.shape[1]):
src = edge_index[0, i]
dst = edge_index[1, i]
edge_x += [x[src], x[dst], None]
edge_y += [y[src], y[dst], None]
edge_z += [z[src], z[dst], None]
edge_trace = go.Scatter3d(
x=edge_x, y=edge_y, z=edge_z,
line=dict(width=2, color='gray'),
hoverinfo='none',
mode='lines')
node_trace = go.Scatter3d(
x=x, y=y, z=z,
mode='markers',
marker=dict(
size=4,
color='red',
),
hoverinfo='none'
)
fig = go.Figure(data=[edge_trace, node_trace])
fig.update_layout(
scene=dict(
xaxis_title='X',
yaxis_title='Y',
zaxis_title='Z',
aspectmode='data'
),
showlegend=False,
margin=dict(l=0, r=0, b=0, t=0) # Optional: adjust margins
)
return fig
# Main App Code
def main():
st.title("PyTorch Geometric Structure Visualization")
structure_type = st.sidebar.selectbox(
"Select Structure Type",
("Sierpinski Triangle", "Spiral", "Plant Structure")
)
if structure_type == "Sierpinski Triangle":
depth = st.sidebar.slider("Recursion Depth", 0, 5, 3)
pos, edge_index = generate_sierpinski_triangle(depth)
data = Data(pos=torch.tensor(pos, dtype=torch.float), edge_index=torch.tensor(edge_index, dtype=torch.long))
fig = plot_graph_3d(pos, edge_index)
st.plotly_chart(fig)
elif structure_type == "Spiral":
turns = st.sidebar.slider("Number of Turns", 1, 20, 5)
points_per_turn = st.sidebar.slider("Points per Turn", 10, 100, 50)
pos, edge_index = generate_spiral(turns, points_per_turn)
data = Data(pos=torch.tensor(pos, dtype=torch.float), edge_index=torch.tensor(edge_index, dtype=torch.long))
fig = plot_graph_3d(pos, edge_index)
st.plotly_chart(fig)
elif structure_type == "Plant Structure":
iterations = st.sidebar.slider("L-system Iterations", 1, 5, 3)
angle = st.sidebar.slider("Branching Angle", 15, 45, 25)
pos, edge_index = generate_plant(iterations, angle)
data = Data(pos=torch.tensor(pos, dtype=torch.float), edge_index=torch.tensor(edge_index, dtype=torch.long))
fig = plot_graph_3d(pos, edge_index)
st.plotly_chart(fig)
if __name__ == "__main__":
main()