Spaces:
Sleeping
Sleeping
""" | |
This file contains the code to plot a 3d tree | |
""" | |
import numpy as np | |
import plotly.graph_objects as go | |
from scipy.interpolate import griddata | |
def gen_three_D_plot(detectability_val, distortion_val, euclidean_val): | |
""" | |
Generates a 3D surface plot showing the relationship between detectability, distortion, | |
and Euclidean distance, with a focus on highlighting the "sweet spot" based on a composite score. | |
The function takes three sets of values: detectability, distortion, and Euclidean distance, | |
normalizes them to a [0, 1] range, and computes a composite score that combines these three metrics. | |
The "sweet spot" is the point where the composite score is maximized. This sweet spot is plotted | |
as a red marker on the 3D surface plot. | |
The function then uses a grid interpolation method (`griddata`) to generate a smooth surface | |
for the Euclidean distance over the detectability and distortion values. The result is a surface plot | |
where the contours represent different Euclidean distances. | |
Args: | |
detectability_val (list or array): A list or array of detectability scores. | |
distortion_val (list or array): A list or array of distortion scores. | |
euclidean_val (list or array): A list or array of Euclidean distances. | |
Returns: | |
plotly.graph_objects.Figure: A Plotly figure object representing the 3D surface plot, | |
with contour lines and a marker for the sweet spot. | |
Raises: | |
ValueError: If `griddata` fails to generate a valid interpolation, which could happen if the | |
input data does not allow for a proper interpolation. | |
Example: | |
# Example of usage: | |
detectability_vals = [0.1, 0.3, 0.5, 0.7, 0.9] | |
distortion_vals = [0.2, 0.4, 0.6, 0.8, 1.0] | |
euclidean_vals = [0.5, 0.3, 0.2, 0.4, 0.6] | |
fig = gen_three_D_plot(detectability_vals, distortion_vals, euclidean_vals) | |
fig.show() # Displays the plot in a web browser | |
Notes: | |
- The composite score is calculated as: | |
`composite_score = norm_detectability - (norm_distortion + norm_euclidean)`, | |
where the goal is to maximize detectability and minimize distortion and Euclidean distance. | |
- The `griddata` function uses linear interpolation to create a smooth surface for the plot. | |
- The function uses the "Plasma" colorscale for the surface plot, which provides a perceptually uniform color scheme. | |
""" | |
detectability = np.array(detectability_val) | |
distortion = np.array(distortion_val) | |
euclidean = np.array(euclidean_val) | |
# Normalize the values to range [0, 1] | |
norm_detectability = (detectability - min(detectability)) / (max(detectability) - min(detectability)) | |
norm_distortion = (distortion - min(distortion)) / (max(distortion) - min(distortion)) | |
norm_euclidean = (euclidean - min(euclidean)) / (max(euclidean) - min(euclidean)) | |
# Composite score: maximize detectability, minimize distortion and Euclidean distance | |
composite_score = norm_detectability - (norm_distortion + norm_euclidean) | |
# Find the index of the maximum score (sweet spot) | |
sweet_spot_index = np.argmax(composite_score) | |
# Sweet spot values | |
sweet_spot_detectability = detectability[sweet_spot_index] | |
sweet_spot_distortion = distortion[sweet_spot_index] | |
sweet_spot_euclidean = euclidean[sweet_spot_index] | |
# Create a meshgrid from the data | |
x_grid, y_grid = np.meshgrid(np.linspace(min(detectability), max(detectability), 30), | |
np.linspace(min(distortion), max(distortion), 30)) | |
# Interpolate z values (Euclidean distances) to fit the grid using 'nearest' method | |
z_grid = griddata((detectability, distortion), euclidean, (x_grid, y_grid), method='nearest') | |
if z_grid is None: | |
raise ValueError("griddata could not generate a valid interpolation. Check your input data.") | |
# Create the 3D contour plot with the Plasma color scale | |
fig = go.Figure(data=go.Surface( | |
z=z_grid, | |
x=x_grid, | |
y=y_grid, | |
contours={ | |
"z": {"show": True, "start": min(euclidean), "end": max(euclidean), "size": 0.1, "usecolormap": True} | |
}, | |
colorscale='Plasma' | |
)) | |
# Add a marker for the sweet spot | |
fig.add_trace(go.Scatter3d( | |
x=[sweet_spot_detectability], | |
y=[sweet_spot_distortion], | |
z=[sweet_spot_euclidean], | |
mode='markers+text', | |
marker=dict(size=10, color='red', symbol='circle'), | |
text=["Sweet Spot"], | |
textposition="top center" | |
)) | |
# Set axis labels | |
fig.update_layout( | |
scene=dict( | |
xaxis_title='Detectability Score', | |
yaxis_title='Distortion Score', | |
zaxis_title='Euclidean Distance' | |
), | |
margin=dict(l=0, r=0, b=0, t=0) | |
) | |
return fig | |
if __name__ == "__main__": | |
# Example input data | |
detectability_vals = [0.1, 0.3, 0.5, 0.7, 0.9] | |
distortion_vals = [0.2, 0.4, 0.6, 0.8, 1.0] | |
euclidean_vals = [0.5, 0.3, 0.2, 0.4, 0.6] | |
# Call the function with example data | |
fig = gen_three_D_plot(detectability_vals, distortion_vals, euclidean_vals) | |
# Show the plot | |
fig.show() |