ssaroya commited on
Commit
8083870
·
1 Parent(s): 66acbfe

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +6 -0
handler.py CHANGED
@@ -5,9 +5,11 @@ from typing import Dict, Any
5
  from gptq import GPTQ
6
  from utils import find_layers, DEV
7
  from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
 
8
 
9
  class EndpointHandler:
10
  def __init__(self,
 
11
  model_name="Wizard-Vicuna-13B-Uncensored-GPTQ",
12
  checkpoint_path="Wizard-Vicuna-13B-Uncensored-GPTQ/Wizard-Vicuna-13B-Uncensored-GPTQ-4bit-128g.compat.no-act-order.safetensors",
13
  wbits = 4,
@@ -16,6 +18,10 @@ class EndpointHandler:
16
  eval=True,
17
  warmup_autotune=True):
18
 
 
 
 
 
19
  self.model = self.load_quant(model_name, checkpoint_path, wbits, groupsize, fused_mlp, eval, warmup_autotune)
20
  self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
21
  self.model.to(DEV)
 
5
  from gptq import GPTQ
6
  from utils import find_layers, DEV
7
  from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
8
+ import os
9
 
10
  class EndpointHandler:
11
  def __init__(self,
12
+ path="",
13
  model_name="Wizard-Vicuna-13B-Uncensored-GPTQ",
14
  checkpoint_path="Wizard-Vicuna-13B-Uncensored-GPTQ/Wizard-Vicuna-13B-Uncensored-GPTQ-4bit-128g.compat.no-act-order.safetensors",
15
  wbits = 4,
 
18
  eval=True,
19
  warmup_autotune=True):
20
 
21
+ model_name = os.path.join(path, model_name)
22
+ checkpoint_path = os.path.join(path, checkpoint_path)
23
+
24
+
25
  self.model = self.load_quant(model_name, checkpoint_path, wbits, groupsize, fused_mlp, eval, warmup_autotune)
26
  self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
27
  self.model.to(DEV)