Spaces:
Runtime error
Runtime error
Commit
·
4b9278e
1
Parent(s):
c2ca0ca
Update model/openllama.py
Browse files- model/openllama.py +7 -5
model/openllama.py
CHANGED
|
@@ -165,6 +165,8 @@ class OpenLLAMAPEFTModel(nn.Module):
|
|
| 165 |
max_tgt_len = args['max_tgt_len']
|
| 166 |
stage = args['stage']
|
| 167 |
|
|
|
|
|
|
|
| 168 |
print (f'Initializing visual encoder from {imagebind_ckpt_path} ...')
|
| 169 |
|
| 170 |
self.visual_encoder, self.visual_hidden_size = imagebind_model.imagebind_huge(args)
|
|
@@ -173,12 +175,12 @@ class OpenLLAMAPEFTModel(nn.Module):
|
|
| 173 |
|
| 174 |
self.iter = 0
|
| 175 |
|
| 176 |
-
self.image_decoder = LinearLayer(1280, 1024, 4)
|
| 177 |
|
| 178 |
-
self.prompt_learner = PromptLearner(1, 4096)
|
| 179 |
|
| 180 |
-
self.loss_focal = FocalLoss()
|
| 181 |
-
self.loss_dice = BinaryDiceLoss()
|
| 182 |
|
| 183 |
|
| 184 |
# free vision encoder
|
|
@@ -213,7 +215,7 @@ class OpenLLAMAPEFTModel(nn.Module):
|
|
| 213 |
)
|
| 214 |
|
| 215 |
self.max_tgt_len = max_tgt_len
|
| 216 |
-
|
| 217 |
|
| 218 |
|
| 219 |
def rot90_img(self,x,k):
|
|
|
|
| 165 |
max_tgt_len = args['max_tgt_len']
|
| 166 |
stage = args['stage']
|
| 167 |
|
| 168 |
+
self.device = torch.cuda.current_device()
|
| 169 |
+
|
| 170 |
print (f'Initializing visual encoder from {imagebind_ckpt_path} ...')
|
| 171 |
|
| 172 |
self.visual_encoder, self.visual_hidden_size = imagebind_model.imagebind_huge(args)
|
|
|
|
| 175 |
|
| 176 |
self.iter = 0
|
| 177 |
|
| 178 |
+
self.image_decoder = LinearLayer(1280, 1024, 4).to(self.device)
|
| 179 |
|
| 180 |
+
self.prompt_learner = PromptLearner(1, 4096).to(self.device)
|
| 181 |
|
| 182 |
+
self.loss_focal = FocalLoss().to(self.device)
|
| 183 |
+
self.loss_dice = BinaryDiceLoss().to(self.device)
|
| 184 |
|
| 185 |
|
| 186 |
# free vision encoder
|
|
|
|
| 215 |
)
|
| 216 |
|
| 217 |
self.max_tgt_len = max_tgt_len
|
| 218 |
+
|
| 219 |
|
| 220 |
|
| 221 |
def rot90_img(self,x,k):
|