File size: 2,056 Bytes
491eded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64

import numpy as np
from modules.bbox_gen.utils.mesh import change_pcd_range


class BoundsTokenizerDiag:
    def __init__(self, bins, BOS_id, EOS_id, PAD_id):
        self.bins = bins
        self.BOS_id = BOS_id
        self.EOS_id = EOS_id
        self.PAD_id = PAD_id
    
    def encode(self, data_dict, coord_rg=(-1,1)):
        """
        Encode bounding boxes to token sequence

        Args:
            data_dict: dictionary containing bounding boxes
            coord_rg: range of coordinate values
        Returns:
            token sequence
        """
        bounds = data_dict["bounds"] # (s, 2, 3)

        all_vertices = bounds.reshape(-1, 6)

        all_vertices = change_pcd_range(all_vertices, from_rg=coord_rg, to_rg=(0.5/self.bins, 1-0.5/self.bins))
        quantized_vertices = (all_vertices * self.bins).astype(np.int32)
        
        tokens = []
        tokens.append(self.BOS_id)
        tokens.extend(quantized_vertices.flatten().tolist())
        tokens.append(self.EOS_id)
        tokens = np.array(tokens)

        return tokens
    
    def decode(self, tokens, coord_rg=(-1,1)):
        """
        Decode token sequence back to bounding boxes

        Args:
            tokens: token sequence
        Returns:
            bounding box array [N, 2, 3]
        """
        # Remove special tokens
        valid_tokens = []
        for t in tokens:
            if t != self.BOS_id and t != self.EOS_id and t != self.PAD_id:
                valid_tokens.append(t)
        
        # Ensure correct number of tokens (2 vertices per box, 3 coordinates per vertex)
        if len(valid_tokens) % (2 * 3) != 0:
            raise ValueError(f"Invalid token count: {len(valid_tokens)}")
        
        # Reshape to vertex coordinates
        points = np.array(valid_tokens).reshape(-1, 2, 3)
        
        # Convert quantized coordinates back to continuous values
        points = points / self.bins
        points = change_pcd_range(points, from_rg=(0.5/self.bins, 1-0.5/self.bins), to_rg=coord_rg)

        return points