jetclustering / src /plotting /plot_coordinates.py
gregorkrzmanc's picture
.
e75a247
raw
history blame
1.24 kB
import plotly.express as px
import pandas as pd
import numpy as np
import os
def plot_coordinates(
coords,
pt, # the size of the points
tidx, # truth idx
outdir=None,
filename=None
):
data = {
"X": coords[:, 0].view(-1, 1).detach().cpu().numpy(),
"Y": coords[:, 1].view(-1, 1).detach().cpu().numpy(),
"Z": coords[:, 2].view(-1, 1).detach().cpu().numpy(),
"tIdx": tidx.view(-1, 1).detach().cpu().numpy(),
"pt": pt.view(-1, 1).detach().cpu().numpy(),
}
print([(k, data[k].shape) for k in data])
df = pd.DataFrame(
np.concatenate([data[k] for k in sorted(data.keys())], axis=1),
columns=[k for k in sorted(data.keys())],
)
df["orig_tIdx"] = df["tIdx"]
fig = px.scatter_3d(
df,
x="X",
y="Y",
z="Z",
color="tIdx",
size="pt",
# hover_data=hover_data,
template="plotly_dark",
color_continuous_scale=px.colors.sequential.Rainbow,
# make it opaque a bit
opacity=0.5,
)
fig.update_traces(marker=dict(line=dict(width=0)))
if filename is None or outdir is None:
return fig
fig.write_html(
os.path.join(outdir, filename)
)