Update pages/01_🦷 Segment.py
Browse files- pages/01_🦷 Segment.py +42 -7
pages/01_🦷 Segment.py
CHANGED
|
@@ -1,27 +1,25 @@
|
|
|
|
|
| 1 |
import shutil
|
| 2 |
-
import os
|
| 3 |
|
|
|
|
| 4 |
import numpy as np
|
| 5 |
from sklearn import neighbors
|
| 6 |
from scipy.spatial import distance_matrix
|
| 7 |
from pygco import cut_from_graph
|
|
|
|
| 8 |
import open3d as o3d
|
| 9 |
import matplotlib.pyplot as plt
|
| 10 |
import matplotlib.colors as mcolors
|
| 11 |
from stqdm import stqdm
|
| 12 |
import json
|
| 13 |
-
|
| 14 |
-
import pyvista as pv
|
| 15 |
from stpyvista import stpyvista
|
| 16 |
-
|
| 17 |
import torch
|
| 18 |
import torch.nn as nn
|
| 19 |
-
import torch.nn.functional as F
|
| 20 |
from torch.autograd import Variable
|
| 21 |
-
|
| 22 |
import streamlit as st
|
|
|
|
| 23 |
|
| 24 |
-
from streamlit import session_state as session
|
| 25 |
from PIL import Image
|
| 26 |
|
| 27 |
class TeethApp:
|
|
@@ -896,6 +894,43 @@ class Segment(TeethApp):
|
|
| 896 |
if segment:
|
| 897 |
segmentation_main("ZOUIF2W4_upper.obj")
|
| 898 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 899 |
|
| 900 |
|
| 901 |
elif inputs == "Upload Scan":
|
|
|
|
| 1 |
+
from streamlit import session_state as session
|
| 2 |
import shutil
|
|
|
|
| 3 |
|
| 4 |
+
import os
|
| 5 |
import numpy as np
|
| 6 |
from sklearn import neighbors
|
| 7 |
from scipy.spatial import distance_matrix
|
| 8 |
from pygco import cut_from_graph
|
| 9 |
+
import streamlit_ext as ste
|
| 10 |
import open3d as o3d
|
| 11 |
import matplotlib.pyplot as plt
|
| 12 |
import matplotlib.colors as mcolors
|
| 13 |
from stqdm import stqdm
|
| 14 |
import json
|
|
|
|
|
|
|
| 15 |
from stpyvista import stpyvista
|
|
|
|
| 16 |
import torch
|
| 17 |
import torch.nn as nn
|
|
|
|
| 18 |
from torch.autograd import Variable
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
import streamlit as st
|
| 21 |
+
import pyvista as pv
|
| 22 |
|
|
|
|
| 23 |
from PIL import Image
|
| 24 |
|
| 25 |
class TeethApp:
|
|
|
|
| 894 |
if segment:
|
| 895 |
segmentation_main("ZOUIF2W4_upper.obj")
|
| 896 |
|
| 897 |
+
# Load the JSON file
|
| 898 |
+
with open('ZOUIF2W4_upper.json', 'r') as file:
|
| 899 |
+
labels_data = json.load(file)
|
| 900 |
+
|
| 901 |
+
# Assuming labels_data['labels'] is a list of labels
|
| 902 |
+
labels = labels_data['labels']
|
| 903 |
+
|
| 904 |
+
# Make sure the number of labels matches the number of vertices or faces
|
| 905 |
+
assert len(labels) == mesh.n_points or len(labels) == mesh.n_cells
|
| 906 |
+
|
| 907 |
+
# If labels correspond to vertices
|
| 908 |
+
if len(labels) == mesh.n_points:
|
| 909 |
+
mesh.point_data['Labels'] = labels
|
| 910 |
+
# If labels correspond to faces
|
| 911 |
+
elif len(labels) == mesh.n_cells:
|
| 912 |
+
mesh.cell_data['Labels'] = labels
|
| 913 |
+
|
| 914 |
+
# Create a pyvista plotter
|
| 915 |
+
plotter = pv.Plotter()
|
| 916 |
+
|
| 917 |
+
cmap = plt.cm.get_cmap('jet', 27) # Using a colormap with sufficient distinct colors
|
| 918 |
+
|
| 919 |
+
colors = cmap(np.linspace(0, 1, 27)) # Generate colors
|
| 920 |
+
|
| 921 |
+
# Convert colors to a format acceptable by PyVista
|
| 922 |
+
colormap = mcolors.ListedColormap(colors)
|
| 923 |
+
|
| 924 |
+
# Add the mesh to the plotter with labels as a scalar field
|
| 925 |
+
#plotter.add_mesh(mesh, scalars='Labels', show_scalar_bar=True, cmap='jet')
|
| 926 |
+
plotter.add_mesh(mesh, scalars='Labels', show_scalar_bar=True, cmap=colormap, clim=[0, 27])
|
| 927 |
+
|
| 928 |
+
# Show the plot
|
| 929 |
+
#plotter.show()
|
| 930 |
+
## Send to streamlit
|
| 931 |
+
with st.expander("Ground Truth - scroll for zoom", expanded=False):
|
| 932 |
+
stpyvista(plotter)
|
| 933 |
+
|
| 934 |
|
| 935 |
|
| 936 |
elif inputs == "Upload Scan":
|