Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from utils import CONFIG | |
| from networks import m2ms, ops | |
| import sys | |
| sys.path.insert(0, './segment-anything') | |
| from segment_anything import sam_model_registry | |
| class sam_m2m(nn.Module): | |
| def __init__(self, m2m): | |
| super(sam_m2m, self).__init__() | |
| if m2m not in m2ms.__all__: | |
| raise NotImplementedError("Unknown M2M {}".format(m2m)) | |
| self.m2m = m2ms.__dict__[m2m](nc=256) | |
| self.seg_model = sam_model_registry['vit_b'](checkpoint=None) | |
| self.seg_model.eval() | |
| def forward(self, image, guidance): | |
| self.seg_model.eval() | |
| with torch.no_grad(): | |
| feas, masks = self.seg_model.forward_m2m(image, guidance, multimask_output=True) | |
| pred = self.m2m(feas, image, masks) | |
| return pred | |
| def forward_inference(self, image_dict): | |
| self.seg_model.eval() | |
| with torch.no_grad(): | |
| feas, masks, post_masks = self.seg_model.forward_m2m_inference(image_dict, multimask_output=True) | |
| pred = self.m2m(feas, image_dict["image"], masks) | |
| return feas, pred, post_masks | |
| def get_generator_m2m(seg, m2m): | |
| if seg == 'sam': | |
| generator = sam_m2m(m2m=m2m) | |
| return generator |