# Streamlit App: PyTorch Geometric Structure Visualization import streamlit as st import torch from torch_geometric.data import Data import numpy as np import plotly.graph_objs as go import math # Moved import to the top # Function Definitions def generate_sierpinski_triangle(depth): # Generate the vertices of the initial triangle vertices = np.array([ [0, 0, 0], [1, 0, 0], [0.5, np.sqrt(3)/2, 0] ]) # Function to recursively generate points def recurse_triangle(v1, v2, v3, depth): if depth == 0: return [v1, v2, v3] else: # Calculate midpoints m12 = (v1 + v2) / 2 m23 = (v2 + v3) / 2 m31 = (v3 + v1) / 2 # Recursively subdivide return (recurse_triangle(v1, m12, m31, depth - 1) + recurse_triangle(m12, v2, m23, depth - 1) + recurse_triangle(m31, m23, v3, depth - 1)) points = recurse_triangle(vertices[0], vertices[1], vertices[2], depth) pos = np.array(points) # Remove duplicate points pos = np.unique(pos, axis=0) # Create edges between points edge_index = [] for i in range(0, len(pos), 3): idx = i % len(pos) edge_index.extend([ [idx, (idx+1)%len(pos)], [(idx+1)%len(pos), (idx+2)%len(pos)], [(idx+2)%len(pos), idx] ]) edge_index = np.array(edge_index).T return pos, edge_index def generate_spiral(turns, points_per_turn): total_points = turns * points_per_turn theta_max = 2 * np.pi * turns theta = np.linspace(0, theta_max, total_points) z = np.linspace(0, 1, total_points) r = z # Spiral expanding in radius x = r * np.cos(theta) y = r * np.sin(theta) pos = np.vstack((x, y, z)).T # Edges connect sequential points edge_index = np.array([np.arange(total_points - 1), np.arange(1, total_points)]) return pos, edge_index def generate_plant(iterations, angle): axiom = "F" rules = {"F": "F[+F]F[-F]F"} def expand_axiom(axiom, rules, iterations): for _ in range(iterations): new_axiom = "" for ch in axiom: new_axiom += rules.get(ch, ch) axiom = new_axiom return axiom final_axiom = expand_axiom(axiom, rules, iterations) stack = [] pos_list = [] edge_list = [] current_pos = np.array([0.0, 0.0, 0.0]) pos_list.append(current_pos.copy()) idx = 0 direction = np.array([0.0, 1.0, 0.0]) for command in final_axiom: if command == 'F': next_pos = current_pos + direction pos_list.append(next_pos.copy()) edge_list.append([idx, idx + 1]) current_pos = next_pos idx += 1 elif command == '+': theta = np.radians(angle) rotation_matrix = rotation_matrix_3d(np.array([0, 0, 1]), theta) direction = rotation_matrix @ direction elif command == '-': theta = np.radians(-angle) rotation_matrix = rotation_matrix_3d(np.array([0, 0, 1]), theta) direction = rotation_matrix @ direction elif command == '[': stack.append((current_pos.copy(), direction.copy(), idx)) elif command == ']': current_pos, direction, idx = stack.pop() pos = np.array(pos_list) edge_index = np.array(edge_list).T return pos, edge_index def rotation_matrix_3d(axis, theta): # Return the rotation matrix associated with rotation about the given axis by theta radians. axis = axis / np.linalg.norm(axis) a = np.cos(theta) b, c, d = axis * np.sin(theta) return np.array([ [a + (1 - a) * axis[0] * axis[0], (1 - a) * axis[0] * axis[1] - axis[2] * np.sin(theta), (1 - a) * axis[0] * axis[2] + axis[1] * np.sin(theta)], [(1 - a) * axis[1] * axis[0] + axis[2] * np.sin(theta), a + (1 - a) * axis[1] * axis[1], (1 - a) * axis[1] * axis[2] - axis[0] * np.sin(theta)], [(1 - a) * axis[2] * axis[0] - axis[1] * np.sin(theta), (1 - a) * axis[2] * axis[1] + axis[0] * np.sin(theta), a + (1 - a) * axis[2] * axis[2]] ]) def plot_graph_3d(pos, edge_index): x, y, z = pos[:, 0], pos[:, 1], pos[:, 2] edge_x = [] edge_y = [] edge_z = [] for i in range(edge_index.shape[1]): src = edge_index[0, i] dst = edge_index[1, i] edge_x += [x[src], x[dst], None] edge_y += [y[src], y[dst], None] edge_z += [z[src], z[dst], None] edge_trace = go.Scatter3d( x=edge_x, y=edge_y, z=edge_z, line=dict(width=2, color='gray'), hoverinfo='none', mode='lines') node_trace = go.Scatter3d( x=x, y=y, z=z, mode='markers', marker=dict( size=4, color='red', ), hoverinfo='none' ) fig = go.Figure(data=[edge_trace, node_trace]) fig.update_layout( scene=dict( xaxis_title='X', yaxis_title='Y', zaxis_title='Z', aspectmode='data' ), showlegend=False, margin=dict(l=0, r=0, b=0, t=0) # Optional: adjust margins ) return fig # Main App Code def main(): st.title("PyTorch Geometric Structure Visualization") structure_type = st.sidebar.selectbox( "Select Structure Type", ("Sierpinski Triangle", "Spiral", "Plant Structure") ) if structure_type == "Sierpinski Triangle": depth = st.sidebar.slider("Recursion Depth", 0, 5, 3) pos, edge_index = generate_sierpinski_triangle(depth) data = Data(pos=torch.tensor(pos, dtype=torch.float), edge_index=torch.tensor(edge_index, dtype=torch.long)) fig = plot_graph_3d(pos, edge_index) st.plotly_chart(fig) elif structure_type == "Spiral": turns = st.sidebar.slider("Number of Turns", 1, 20, 5) points_per_turn = st.sidebar.slider("Points per Turn", 10, 100, 50) pos, edge_index = generate_spiral(turns, points_per_turn) data = Data(pos=torch.tensor(pos, dtype=torch.float), edge_index=torch.tensor(edge_index, dtype=torch.long)) fig = plot_graph_3d(pos, edge_index) st.plotly_chart(fig) elif structure_type == "Plant Structure": iterations = st.sidebar.slider("L-system Iterations", 1, 5, 3) angle = st.sidebar.slider("Branching Angle", 15, 45, 25) pos, edge_index = generate_plant(iterations, angle) data = Data(pos=torch.tensor(pos, dtype=torch.float), edge_index=torch.tensor(edge_index, dtype=torch.long)) fig = plot_graph_3d(pos, edge_index) st.plotly_chart(fig) if __name__ == "__main__": main()