Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	File size: 2,767 Bytes
			
			| a24b16a 5349660 a24b16a 5349660 a24b16a 5349660 a24b16a 5349660 a24b16a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 | import copy
import re
import torch
import util
class FineTunedModel(torch.nn.Module):
    def __init__(self,
                 model,
                 modules,
                 frozen_modules=[]
                 ):
        super().__init__()
        if isinstance(modules, str):
            modules = [modules]
        self.model = model
        self.ft_modules = {}
        self.orig_modules = {}
        util.freeze(self.model)
        for module_name, module in model.named_modules():
            for ft_module_regex in modules:
                match = re.search(ft_module_regex, module_name)
                
                if match is not None:
                    ft_module = copy.deepcopy(module)
                    
                    self.orig_modules[module_name] = module
                    self.ft_modules[module_name] = ft_module
                    util.unfreeze(ft_module)
                    print(f"=> Finetuning {module_name}")
       
                    for ft_module_name, module in ft_module.named_modules():
                        ft_module_name = f"{module_name}.{ft_module_name}"
                        for freeze_module_name in frozen_modules:
                            match = re.search(freeze_module_name, ft_module_name)
                            if match:
                                print(f"=> Freezing {ft_module_name}")
                                util.freeze(module)
        self.ft_modules_list = torch.nn.ModuleList(self.ft_modules.values())
        self.orig_modules_list = torch.nn.ModuleList(self.orig_modules.values())
    @classmethod
    def from_checkpoint(cls, model, checkpoint, frozen_modules=[]):
        if isinstance(checkpoint, str):
            checkpoint = torch.load(checkpoint)
        modules = [f"{key}$" for key in list(checkpoint.keys())]
        ftm = FineTunedModel(model, modules, frozen_modules=frozen_modules)
        ftm.load_state_dict(checkpoint)
        return ftm
        
    def __enter__(self):
        for key, ft_module in self.ft_modules.items():
            util.set_module(self.model, key, ft_module)
    def __exit__(self, exc_type, exc_value, tb):
        for key, module in self.orig_modules.items():
            util.set_module(self.model, key, module)
    def parameters(self):
        parameters = []
        for ft_module in self.ft_modules.values():
            parameters.extend(list(ft_module.parameters()))
        return parameters
    def state_dict(self):
        state_dict = {key: module.state_dict() for key, module in self.ft_modules.items()}
        return state_dict
    def load_state_dict(self, state_dict):
        for key, sd in state_dict.items():
            
            self.ft_modules[key].load_state_dict(sd) |