Spaces:
Runtime error
Runtime error
Commit
·
0d05e34
1
Parent(s):
93457f4
Update model/openllama.py
Browse files- model/openllama.py +7 -6
model/openllama.py
CHANGED
|
@@ -170,15 +170,16 @@ class OpenLLAMAPEFTModel(nn.Module):
|
|
| 170 |
print (f'Initializing visual encoder from {imagebind_ckpt_path} ...')
|
| 171 |
|
| 172 |
self.visual_encoder, self.visual_hidden_size = imagebind_model.imagebind_huge(args)
|
| 173 |
-
imagebind_ckpt = torch.load(imagebind_ckpt_path, map_location=torch.device('cpu'))
|
| 174 |
-
self.visual_encoder.load_state_dict(imagebind_ckpt, strict=True)
|
| 175 |
self.visual_encoder.to(self.device)
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
self.iter = 0
|
| 178 |
|
| 179 |
-
self.image_decoder = LinearLayer(1280, 1024, 4)
|
| 180 |
|
| 181 |
-
self.prompt_learner = PromptLearner(1, 4096)
|
| 182 |
|
| 183 |
self.loss_focal = FocalLoss()
|
| 184 |
self.loss_dice = BinaryDiceLoss()
|
|
@@ -202,11 +203,11 @@ class OpenLLAMAPEFTModel(nn.Module):
|
|
| 202 |
target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']
|
| 203 |
)
|
| 204 |
|
| 205 |
-
self.llama_model = LlamaForCausalLM.from_pretrained(vicuna_ckpt_path,
|
| 206 |
self.llama_model = get_peft_model(self.llama_model, peft_config)
|
| 207 |
self.llama_model.print_trainable_parameters()
|
| 208 |
|
| 209 |
-
self.llama_tokenizer = LlamaTokenizer.from_pretrained(vicuna_ckpt_path, use_fast=False, torch_dtype=torch.
|
| 210 |
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
|
| 211 |
self.llama_tokenizer.padding_side = "right"
|
| 212 |
print ('Language decoder initialized.')
|
|
|
|
| 170 |
print (f'Initializing visual encoder from {imagebind_ckpt_path} ...')
|
| 171 |
|
| 172 |
self.visual_encoder, self.visual_hidden_size = imagebind_model.imagebind_huge(args)
|
|
|
|
|
|
|
| 173 |
self.visual_encoder.to(self.device)
|
| 174 |
+
imagebind_ckpt = torch.load(imagebind_ckpt_path, map_location=self.device)
|
| 175 |
+
self.visual_encoder.load_state_dict(imagebind_ckpt, strict=True)
|
| 176 |
+
|
| 177 |
|
| 178 |
self.iter = 0
|
| 179 |
|
| 180 |
+
self.image_decoder = LinearLayer(1280, 1024, 4).to(self.device)
|
| 181 |
|
| 182 |
+
self.prompt_learner = PromptLearner(1, 4096).to(self.device)
|
| 183 |
|
| 184 |
self.loss_focal = FocalLoss()
|
| 185 |
self.loss_dice = BinaryDiceLoss()
|
|
|
|
| 203 |
target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']
|
| 204 |
)
|
| 205 |
|
| 206 |
+
self.llama_model = LlamaForCausalLM.from_pretrained(vicuna_ckpt_path, torch_dtype=torch.float16, device_map='auto', offload_folder="offload", offload_state_dict = True)
|
| 207 |
self.llama_model = get_peft_model(self.llama_model, peft_config)
|
| 208 |
self.llama_model.print_trainable_parameters()
|
| 209 |
|
| 210 |
+
self.llama_tokenizer = LlamaTokenizer.from_pretrained(vicuna_ckpt_path, use_fast=False, torch_dtype=torch.float16, device_map='auto', offload_folder="offload", offload_state_dict = True)
|
| 211 |
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
|
| 212 |
self.llama_tokenizer.padding_side = "right"
|
| 213 |
print ('Language decoder initialized.')
|