File size: 1,240 Bytes
e75a247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import plotly.express as px
import pandas as pd
import numpy as np
import os

def plot_coordinates(
    coords,
    pt, # the size of the points
    tidx, # truth idx
    outdir=None,
    filename=None
):
    data = {
        "X": coords[:, 0].view(-1, 1).detach().cpu().numpy(),
        "Y": coords[:, 1].view(-1, 1).detach().cpu().numpy(),
        "Z": coords[:, 2].view(-1, 1).detach().cpu().numpy(),
        "tIdx": tidx.view(-1, 1).detach().cpu().numpy(),
        "pt": pt.view(-1, 1).detach().cpu().numpy(),
    }
    print([(k, data[k].shape) for k in data])
    df = pd.DataFrame(
        np.concatenate([data[k] for k in sorted(data.keys())], axis=1),
        columns=[k for k in sorted(data.keys())],
    )
    df["orig_tIdx"] = df["tIdx"]
    fig = px.scatter_3d(
        df,
        x="X",
        y="Y",
        z="Z",
        color="tIdx",
        size="pt",
        # hover_data=hover_data,
        template="plotly_dark",
        color_continuous_scale=px.colors.sequential.Rainbow,
        # make it opaque a bit
        opacity=0.5,
    )
    fig.update_traces(marker=dict(line=dict(width=0)))
    if filename is None or outdir is None:
        return fig
    fig.write_html(
        os.path.join(outdir, filename)
    )