Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -20,7 +20,7 @@ class ModelWrapper:
|
|
| 20 |
super().__init__()
|
| 21 |
torch.set_grad_enabled(False)
|
| 22 |
|
| 23 |
-
self.DTYPE =
|
| 24 |
self.device = accelerator.device
|
| 25 |
|
| 26 |
self.tokenizer_one = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, use_fast=False)
|
|
@@ -49,7 +49,7 @@ class ModelWrapper:
|
|
| 49 |
|
| 50 |
def create_generator(self, model_id, checkpoint_path):
|
| 51 |
generator = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to(self.DTYPE)
|
| 52 |
-
state_dict = torch.load(checkpoint_path
|
| 53 |
generator.load_state_dict(state_dict, strict=True)
|
| 54 |
generator.requires_grad_(False)
|
| 55 |
return generator
|
|
|
|
| 20 |
super().__init__()
|
| 21 |
torch.set_grad_enabled(False)
|
| 22 |
|
| 23 |
+
self.DTYPE = torch.float16
|
| 24 |
self.device = accelerator.device
|
| 25 |
|
| 26 |
self.tokenizer_one = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, use_fast=False)
|
|
|
|
| 49 |
|
| 50 |
def create_generator(self, model_id, checkpoint_path):
|
| 51 |
generator = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to(self.DTYPE)
|
| 52 |
+
state_dict = torch.load(checkpoint_path)
|
| 53 |
generator.load_state_dict(state_dict, strict=True)
|
| 54 |
generator.requires_grad_(False)
|
| 55 |
return generator
|