Spaces:
Runtime error
Runtime error
Commit
·
1b250fd
1
Parent(s):
aa620b9
Update model/openllama.py
Browse files- model/openllama.py +1 -1
model/openllama.py
CHANGED
|
@@ -199,7 +199,7 @@ class OpenLLAMAPEFTModel(nn.Module):
|
|
| 199 |
target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']
|
| 200 |
)
|
| 201 |
|
| 202 |
-
self.llama_model = LlamaForCausalLM.from_pretrained(vicuna_ckpt_path, low_cpu_mem_usage=True)
|
| 203 |
self.llama_model = get_peft_model(self.llama_model, peft_config)
|
| 204 |
self.llama_model.print_trainable_parameters()
|
| 205 |
|
|
|
|
| 199 |
target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']
|
| 200 |
)
|
| 201 |
|
| 202 |
+
self.llama_model = LlamaForCausalLM.from_pretrained(vicuna_ckpt_path, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16)
|
| 203 |
self.llama_model = get_peft_model(self.llama_model, peft_config)
|
| 204 |
self.llama_model.print_trainable_parameters()
|
| 205 |
|