awacke1 commited on
Commit
25eca49
·
verified ·
1 Parent(s): 5e25ae8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +186 -0
app.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Streamlit App: PyTorch Geometric Structure Visualization
2
+
3
+ import streamlit as st
4
+ import torch
5
+ from torch_geometric.data import Data
6
+ import numpy as np
7
+ import plotly.graph_objs as go
8
+
9
+ st.title("PyTorch Geometric Structure Visualization")
10
+
11
+ structure_type = st.sidebar.selectbox(
12
+ "Select Structure Type",
13
+ ("Sierpinski Triangle", "Spiral", "Plant Structure")
14
+ )
15
+
16
+ if structure_type == "Sierpinski Triangle":
17
+ depth = st.sidebar.slider("Recursion Depth", 0, 5, 3)
18
+ pos, edge_index = generate_sierpinski_triangle(depth)
19
+ data = Data(pos=torch.tensor(pos, dtype=torch.float), edge_index=torch.tensor(edge_index, dtype=torch.long))
20
+ fig = plot_graph_3d(pos, edge_index)
21
+ st.plotly_chart(fig)
22
+
23
+ elif structure_type == "Spiral":
24
+ turns = st.sidebar.slider("Number of Turns", 1, 20, 5)
25
+ points_per_turn = st.sidebar.slider("Points per Turn", 10, 100, 50)
26
+ pos, edge_index = generate_spiral(turns, points_per_turn)
27
+ data = Data(pos=torch.tensor(pos, dtype=torch.float), edge_index=torch.tensor(edge_index, dtype=torch.long))
28
+ fig = plot_graph_3d(pos, edge_index)
29
+ st.plotly_chart(fig)
30
+
31
+ elif structure_type == "Plant Structure":
32
+ iterations = st.sidebar.slider("L-system Iterations", 1, 5, 3)
33
+ angle = st.sidebar.slider("Branching Angle", 15, 45, 25)
34
+ pos, edge_index = generate_plant(iterations, angle)
35
+ data = Data(pos=torch.tensor(pos, dtype=torch.float), edge_index=torch.tensor(edge_index, dtype=torch.long))
36
+ fig = plot_graph_3d(pos, edge_index)
37
+ st.plotly_chart(fig)
38
+
39
+ # Function Definitions
40
+
41
+ def generate_sierpinski_triangle(depth):
42
+ # Generate the vertices of the initial triangle
43
+ vertices = np.array([
44
+ [0, 0, 0],
45
+ [1, 0, 0],
46
+ [0.5, np.sqrt(3)/2, 0]
47
+ ])
48
+
49
+ # Function to recursively generate points
50
+ def recurse_triangle(v1, v2, v3, depth):
51
+ if depth == 0:
52
+ return [v1, v2, v3]
53
+ else:
54
+ # Calculate midpoints
55
+ m12 = (v1 + v2) / 2
56
+ m23 = (v2 + v3) / 2
57
+ m31 = (v3 + v1) / 2
58
+ # Recursively subdivide
59
+ return (recurse_triangle(v1, m12, m31, depth - 1) +
60
+ recurse_triangle(m12, v2, m23, depth - 1) +
61
+ recurse_triangle(m31, m23, v3, depth - 1))
62
+
63
+ points = recurse_triangle(vertices[0], vertices[1], vertices[2], depth)
64
+ pos = np.array(points)
65
+ # Create edges between points
66
+ edge_index = []
67
+ for i in range(0, len(pos), 3):
68
+ edge_index.extend([
69
+ [i, i+1],
70
+ [i+1, i+2],
71
+ [i+2, i]
72
+ ])
73
+ edge_index = np.array(edge_index).T
74
+ return pos, edge_index
75
+
76
+ def generate_spiral(turns, points_per_turn):
77
+ total_points = turns * points_per_turn
78
+ theta_max = 2 * np.pi * turns
79
+ theta = np.linspace(0, theta_max, total_points)
80
+ z = np.linspace(0, 1, total_points)
81
+ r = z # Spiral expanding in radius
82
+ x = r * np.cos(theta)
83
+ y = r * np.sin(theta)
84
+ pos = np.vstack((x, y, z)).T
85
+ # Edges connect sequential points
86
+ edge_index = np.array([np.arange(total_points - 1), np.arange(1, total_points)])
87
+ return pos, edge_index
88
+
89
+ def generate_plant(iterations, angle):
90
+ axiom = "F"
91
+ rules = {"F": "F[+F]F[-F]F"}
92
+ import math
93
+
94
+ def expand_axiom(axiom, rules, iterations):
95
+ for _ in range(iterations):
96
+ new_axiom = ""
97
+ for ch in axiom:
98
+ new_axiom += rules.get(ch, ch)
99
+ axiom = new_axiom
100
+ return axiom
101
+
102
+ final_axiom = expand_axiom(axiom, rules, iterations)
103
+
104
+ stack = []
105
+ pos_list = []
106
+ edge_list = []
107
+ current_pos = np.array([0, 0, 0])
108
+ pos_list.append(current_pos.copy())
109
+ idx = 0
110
+ direction = np.array([0, 1, 0])
111
+
112
+ for command in final_axiom:
113
+ if command == 'F':
114
+ next_pos = current_pos + direction
115
+ pos_list.append(next_pos.copy())
116
+ edge_list.append([idx, idx + 1])
117
+ current_pos = next_pos
118
+ idx += 1
119
+ elif command == '+':
120
+ theta = np.radians(angle)
121
+ rotation_matrix = rotation_matrix_3d(np.array([0, 0, 1]), theta)
122
+ direction = rotation_matrix @ direction
123
+ elif command == '-':
124
+ theta = np.radians(-angle)
125
+ rotation_matrix = rotation_matrix_3d(np.array([0, 0, 1]), theta)
126
+ direction = rotation_matrix @ direction
127
+ elif command == '[':
128
+ stack.append((current_pos.copy(), direction.copy(), idx))
129
+ elif command == ']':
130
+ current_pos, direction, idx = stack.pop()
131
+ pos_list.append(current_pos.copy())
132
+ idx += 1
133
+
134
+ pos = np.array(pos_list)
135
+ edge_index = np.array(edge_list).T
136
+ return pos, edge_index
137
+
138
+ def rotation_matrix_3d(axis, theta):
139
+ # Return the rotation matrix associated with rotation about the given axis by theta radians.
140
+ axis = axis / np.linalg.norm(axis)
141
+ a = np.cos(theta / 2)
142
+ b, c, d = -axis * np.sin(theta / 2)
143
+ return np.array([[a*a + b*b - c*c - d*d, 2*(b*c - a*d), 2*(b*d + a*c)],
144
+ [2*(b*c + a*d), a*a + c*c - b*b - d*d, 2*(c*d - a*b)],
145
+ [2*(b*d - a*c), 2*(c*d + a*b), a*a + d*d - b*b - c*c]])
146
+
147
+ def plot_graph_3d(pos, edge_index):
148
+ x, y, z = pos[:, 0], pos[:, 1], pos[:, 2]
149
+ edge_x = []
150
+ edge_y = []
151
+ edge_z = []
152
+
153
+ for i in range(edge_index.shape[1]):
154
+ src = edge_index[0, i]
155
+ dst = edge_index[1, i]
156
+ edge_x += [x[src], x[dst], None]
157
+ edge_y += [y[src], y[dst], None]
158
+ edge_z += [z[src], z[dst], None]
159
+
160
+ edge_trace = go.Scatter3d(
161
+ x=edge_x, y=edge_y, z=edge_z,
162
+ line=dict(width=2, color='gray'),
163
+ hoverinfo='none',
164
+ mode='lines')
165
+
166
+ node_trace = go.Scatter3d(
167
+ x=x, y=y, z=z,
168
+ mode='markers',
169
+ marker=dict(
170
+ size=4,
171
+ color='red',
172
+ ),
173
+ hoverinfo='none'
174
+ )
175
+
176
+ fig = go.Figure(data=[edge_trace, node_trace])
177
+ fig.update_layout(
178
+ scene=dict(
179
+ xaxis_title='X',
180
+ yaxis_title='Y',
181
+ zaxis_title='Z',
182
+ aspectmode='data'
183
+ ),
184
+ showlegend=False
185
+ )
186
+ return fig