File size: 3,504 Bytes
4c954ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import copy
import math
import numpy as np
import torch
import json

class WireframeGraph:
    def __init__(self, 
                vertices: torch.Tensor, 
                v_confidences: torch.Tensor,
                edges: torch.Tensor, 
                edge_weights: torch.Tensor, 
                frame_width: int, 
                frame_height: int):
        self.vertices = vertices
        self.v_confidences = v_confidences
        self.edges = edges
        self.weights = edge_weights
        self.frame_width = frame_width
        self.frame_height = frame_height

    @classmethod
    def xyxy2indices(cls,junctions, lines):
        # junctions: (N,2)
        # lines: (M,4)
        # return: (M,2)
        dist1 = torch.norm(junctions[None,:,:]-lines[:,None,:2],dim=-1)
        dist2 = torch.norm(junctions[None,:,:]-lines[:,None,2:],dim=-1)
        idx1 = torch.argmin(dist1,dim=-1)
        idx2 = torch.argmin(dist2,dim=-1)
        return torch.stack((idx1,idx2),dim=-1)
    @classmethod
    def load_json(cls, fname):
        with open(fname,'r') as f:
            data = json.load(f)

         
        vertices = torch.tensor(data['vertices'])
        v_confidences = torch.tensor(data['vertices-score'])
        edges = torch.tensor(data['edges'])
        edge_weights = torch.tensor(data['edges-weights'])
        height = data['height']
        width = data['width']

        return WireframeGraph(vertices,v_confidences,edges,edge_weights,width,height)

    @property
    def is_empty(self):
        for key, val in self.__dict__.items():
            if val is None:
                return True
        return False

    @property
    def num_vertices(self):
        if self.is_empty:
            return 0
        return self.vertices.shape[0]
    
    @property
    def num_edges(self):
        if self.is_empty:
            return 0
        return self.edges.shape[0]


    def line_segments(self, threshold = 0.05, device=None, to_np=False):
        is_valid = self.weights>threshold
        p1 = self.vertices[self.edges[is_valid,0]]
        p2 = self.vertices[self.edges[is_valid,1]]
        ps = self.weights[is_valid]

        lines = torch.cat((p1,p2,ps[:,None]),dim=-1)
        if device is not None:
            lines = lines.to(device)
        if to_np:
            lines = lines.cpu().numpy()

        return lines
       # if device != self.device:
        
    def rescale(self, image_width, image_height):
        scale_x = float(image_width)/float(self.frame_width)
        scale_y = float(image_height)/float(self.frame_height)

        self.vertices[:,0] *= scale_x
        self.vertices[:,1] *= scale_y
        self.frame_width = image_width
        self.frame_height = image_height

    def jsonize(self):
        return {
            'vertices': self.vertices.cpu().tolist(),
            'vertices-score': self.v_confidences.cpu().tolist(),
            'edges': self.edges.cpu().tolist(),
            'edges-weights': self.weights.cpu().tolist(),
            'height': self.frame_height,
            'width': self.frame_width,
        }
    def __repr__(self) -> str:
        return "WireframeGraph\n"+\
               "Vertices: {}\n".format(self.num_vertices)+\
               "Edges: {}\n".format(self.num_edges,) + \
               "Frame size (HxW): {}x{}".format(self.frame_height,self.frame_width)

#graph = WireframeGraph()
if __name__ == "__main__":
    graph = WireframeGraph.load_json('NeuS/public_data/bmvs_clock/hawp/000.json')
    print(graph)