Spaces:
				
			
			
	
			
			
		Build error
		
	
	
	
			
			
	
	
	
	
		
		
		Build error
		
	| """ | |
| A two-view sparse feature matching pipeline on triplets. | |
| If a triplet is found, runs the extractor on three images and | |
| then runs matcher/filter/solver for all three pairs. | |
| Losses and metrics get accumulated accordingly. | |
| If no triplet is found, this falls back to two_view_pipeline.py | |
| """ | |
| import torch | |
| from ..utils.misc import get_twoview, stack_twoviews, unstack_twoviews | |
| from .two_view_pipeline import TwoViewPipeline | |
| def has_triplet(data): | |
| # we already check for image0 and image1 in required_keys | |
| return "view2" in data.keys() | |
| class TripletPipeline(TwoViewPipeline): | |
| default_conf = {"batch_triplets": True, **TwoViewPipeline.default_conf} | |
| def _forward(self, data): | |
| if not has_triplet(data): | |
| return super()._forward(data) | |
| # the two-view outputs are stored in | |
| # pred['0to1'],pred['0to2'], pred['1to2'] | |
| assert not self.conf.run_gt_in_forward | |
| pred0 = self.extract_view(data, "0") | |
| pred1 = self.extract_view(data, "1") | |
| pred2 = self.extract_view(data, "2") | |
| pred = {} | |
| pred = { | |
| **{k + "0": v for k, v in pred0.items()}, | |
| **{k + "1": v for k, v in pred1.items()}, | |
| **{k + "2": v for k, v in pred2.items()}, | |
| } | |
| def predict_twoview(pred, data): | |
| # forward pass | |
| if self.conf.matcher.name: | |
| pred = {**pred, **self.matcher({**data, **pred})} | |
| if self.conf.filter.name: | |
| pred = {**pred, **self.filter({**m_data, **pred})} | |
| if self.conf.solver.name: | |
| pred = {**pred, **self.solver({**m_data, **pred})} | |
| return pred | |
| if self.conf.batch_triplets: | |
| B = data["image1"].shape[0] | |
| # stack on batch dimension | |
| m_data = stack_twoviews(data) | |
| m_pred = stack_twoviews(pred) | |
| # forward pass | |
| m_pred = predict_twoview(m_pred, m_data) | |
| # unstack | |
| pred = {**pred, **unstack_twoviews(m_pred, B)} | |
| else: | |
| for idx in ["0to1", "0to2", "1to2"]: | |
| m_data = get_twoview(data, idx) | |
| m_pred = get_twoview(pred, idx) | |
| pred[idx] = predict_twoview(m_pred, m_data) | |
| return pred | |
| def loss(self, pred, data): | |
| if not has_triplet(data): | |
| return super().loss(pred, data) | |
| if self.conf.batch_triplets: | |
| m_data = stack_twoviews(data) | |
| m_pred = stack_twoviews(pred) | |
| losses, metrics = super().loss(m_pred, m_data) | |
| else: | |
| losses = {} | |
| metrics = {} | |
| for idx in ["0to1", "0to2", "1to2"]: | |
| data_i = get_twoview(data, idx) | |
| pred_i = pred[idx] | |
| losses_i, metrics_i = super().loss(pred_i, data_i) | |
| for k, v in losses_i.items(): | |
| if k in losses.keys(): | |
| losses[k] = losses[k] + v | |
| else: | |
| losses[k] = v | |
| for k, v in metrics_i.items(): | |
| if k in metrics.keys(): | |
| metrics[k] = torch.cat([metrics[k], v], 0) | |
| else: | |
| metrics[k] = v | |
| return losses, metrics | |
 
			
