Spaces:
Build error
Build error
Commit
·
33e8867
1
Parent(s):
46cc5f0
Update model/openllama.py
Browse files- model/openllama.py +4 -4
model/openllama.py
CHANGED
|
@@ -215,17 +215,17 @@ class OpenLLAMAPEFTModel(nn.Module):
|
|
| 215 |
# # self.llama_model = load_checkpoint_and_dispatch(self.llama_model, vicuna_ckpt_path, device_map=device_map, offload_folder="offload", offload_state_dict = True)
|
| 216 |
# # self.llama_model.to(torch.float16)
|
| 217 |
# # try:
|
| 218 |
-
self.llama_model = AutoModelForCausalLM.from_pretrained(vicuna_ckpt_path, torch_dtype=torch.float16, device_map='auto', load_in_8bit=True
|
| 219 |
# # except:
|
| 220 |
# pass
|
| 221 |
# finally:
|
| 222 |
# print(self.llama_model.hf_device_map)
|
| 223 |
self.llama_model = get_peft_model(self.llama_model, peft_config)
|
| 224 |
-
delta_ckpt = torch.load(args['delta_ckpt_path'], map_location=torch.device('cpu'))
|
| 225 |
-
self.llama_model.load_state_dict(delta_ckpt, strict=False)
|
| 226 |
self.llama_model.print_trainable_parameters()
|
| 227 |
|
| 228 |
-
self.llama_tokenizer = LlamaTokenizer.from_pretrained(vicuna_ckpt_path, use_fast=False, torch_dtype=torch.float16
|
| 229 |
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
|
| 230 |
self.llama_tokenizer.padding_side = "right"
|
| 231 |
print ('Language decoder initialized.')
|
|
|
|
| 215 |
# # self.llama_model = load_checkpoint_and_dispatch(self.llama_model, vicuna_ckpt_path, device_map=device_map, offload_folder="offload", offload_state_dict = True)
|
| 216 |
# # self.llama_model.to(torch.float16)
|
| 217 |
# # try:
|
| 218 |
+
self.llama_model = AutoModelForCausalLM.from_pretrained(vicuna_ckpt_path, torch_dtype=torch.float16, device_map='auto', load_in_8bit=True)
|
| 219 |
# # except:
|
| 220 |
# pass
|
| 221 |
# finally:
|
| 222 |
# print(self.llama_model.hf_device_map)
|
| 223 |
self.llama_model = get_peft_model(self.llama_model, peft_config)
|
| 224 |
+
# delta_ckpt = torch.load(args['delta_ckpt_path'], map_location=torch.device('cpu'))
|
| 225 |
+
# self.llama_model.load_state_dict(delta_ckpt, strict=False)
|
| 226 |
self.llama_model.print_trainable_parameters()
|
| 227 |
|
| 228 |
+
self.llama_tokenizer = LlamaTokenizer.from_pretrained(vicuna_ckpt_path, use_fast=False, torch_dtype=torch.float16)
|
| 229 |
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
|
| 230 |
self.llama_tokenizer.padding_side = "right"
|
| 231 |
print ('Language decoder initialized.')
|