Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	Update tasks/Model_Loader.py
Browse files- tasks/Model_Loader.py +13 -9
 
    	
        tasks/Model_Loader.py
    CHANGED
    
    | 
         @@ -1,7 +1,7 @@ 
     | 
|
| 1 | 
         
             
            import torch
         
     | 
| 2 | 
         | 
| 3 | 
         
             
            class M5(torch.nn.Module):
         
     | 
| 4 | 
         
            -
                def __init__(self, num_classes= 
     | 
| 5 | 
         
             
                    super(M5, self).__init__()
         
     | 
| 6 | 
         
             
                    self.conv1 = torch.nn.Conv1d(in_channels=1, out_channels=32, kernel_size=80, stride=4)
         
     | 
| 7 | 
         
             
                    self.bn1 = torch.nn.BatchNorm1d(32)
         
     | 
| 
         @@ -26,13 +26,17 @@ class M5(torch.nn.Module): 
     | 
|
| 26 | 
         
             
                    x = self.fc1(x)
         
     | 
| 27 | 
         
             
                    return x
         
     | 
| 28 | 
         | 
| 29 | 
         
            -
             
     | 
| 30 | 
         
            -
             
     | 
| 31 | 
         
            -
             
     | 
| 32 | 
         
            -
             
     | 
| 33 | 
         
            -
             
     | 
| 34 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 35 | 
         | 
| 36 | 
         
             
            if __name__ == "__main__":
         
     | 
| 37 | 
         
            -
                model, device = load_model(" 
     | 
| 38 | 
         
            -
            print("✅ Model successfully loaded!")
         
     | 
| 
         | 
| 
         | 
|
| 1 | 
         
             
            import torch
         
     | 
| 2 | 
         | 
| 3 | 
         
             
            class M5(torch.nn.Module):
         
     | 
| 4 | 
         
            +
                def __init__(self, num_classes=2):  # Ensure it matches dataset labels (chainsaw/environment)
         
     | 
| 5 | 
         
             
                    super(M5, self).__init__()
         
     | 
| 6 | 
         
             
                    self.conv1 = torch.nn.Conv1d(in_channels=1, out_channels=32, kernel_size=80, stride=4)
         
     | 
| 7 | 
         
             
                    self.bn1 = torch.nn.BatchNorm1d(32)
         
     | 
| 
         | 
|
| 26 | 
         
             
                    x = self.fc1(x)
         
     | 
| 27 | 
         
             
                    return x
         
     | 
| 28 | 
         | 
| 29 | 
         
            +
            def load_model(model_path, num_classes=2):
         
     | 
| 30 | 
         
            +
                """
         
     | 
| 31 | 
         
            +
                Load trained M5 model.
         
     | 
| 32 | 
         
            +
                """
         
     | 
| 33 | 
         
            +
                device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         
     | 
| 34 | 
         
            +
                model = M5(num_classes=num_classes).to(device)
         
     | 
| 35 | 
         
            +
                model.load_state_dict(torch.load(model_path, map_location=device))
         
     | 
| 36 | 
         
            +
                model.eval()  # Set model to evaluation mode
         
     | 
| 37 | 
         
            +
                return model, device
         
     | 
| 38 | 
         | 
| 39 | 
         
             
            if __name__ == "__main__":
         
     | 
| 40 | 
         
            +
                model, device = load_model("quantized_teacher_m5_static.pth")
         
     | 
| 41 | 
         
            +
                print("✅ Model successfully loaded!")
         
     | 
| 42 | 
         
            +
             
     |