ssaroya commited on
Commit
66acbfe
·
1 Parent(s): adcfac4

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +83 -0
handler.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import transformers
3
+ import quant
4
+ from typing import Dict, Any
5
+ from gptq import GPTQ
6
+ from utils import find_layers, DEV
7
+ from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
8
+
9
+ class EndpointHandler:
10
+ def __init__(self,
11
+ model_name="Wizard-Vicuna-13B-Uncensored-GPTQ",
12
+ checkpoint_path="Wizard-Vicuna-13B-Uncensored-GPTQ/Wizard-Vicuna-13B-Uncensored-GPTQ-4bit-128g.compat.no-act-order.safetensors",
13
+ wbits = 4,
14
+ groupsize=128,
15
+ fused_mlp=True,
16
+ eval=True,
17
+ warmup_autotune=True):
18
+
19
+ self.model = self.load_quant(model_name, checkpoint_path, wbits, groupsize, fused_mlp, eval, warmup_autotune)
20
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
21
+ self.model.to(DEV)
22
+
23
+ def load_quant(self, model, checkpoint, wbits, groupsize, fused_mlp, eval, warmup_autotune):
24
+ config = LlamaConfig.from_pretrained(model)
25
+
26
+ def noop(*args, **kwargs):
27
+ pass
28
+
29
+ torch.nn.init.kaiming_uniform_ = noop
30
+ torch.nn.init.uniform_ = noop
31
+ torch.nn.init.normal_ = noop
32
+
33
+ torch.set_default_dtype(torch.half)
34
+ transformers.modeling_utils._init_weights = False
35
+ model = LlamaForCausalLM(config)
36
+ torch.set_default_dtype(torch.float)
37
+ if eval:
38
+ model = model.eval()
39
+ layers = find_layers(model)
40
+ for name in ['lm_head']:
41
+ if name in layers:
42
+ del layers[name]
43
+ quant.make_quant_linear(model, layers, wbits, groupsize)
44
+
45
+ del layers
46
+
47
+ print('Loading model ...')
48
+ if checkpoint.endswith('.safetensors'):
49
+ from safetensors.torch import load_file as safe_load
50
+ model.load_state_dict(safe_load(checkpoint), strict=False)
51
+ else:
52
+ model.load_state_dict(torch.load(checkpoint), strict=False)
53
+
54
+ if eval:
55
+ quant.make_quant_attn(model)
56
+ quant.make_quant_norm(model)
57
+ if fused_mlp:
58
+ quant.make_fused_mlp(model)
59
+ if warmup_autotune:
60
+ quant.autotune_warmup_linear(model, transpose=not (eval))
61
+ if eval and fused_mlp:
62
+ quant.autotune_warmup_fused(model)
63
+ model.seqlen = 2048
64
+ print('Done.')
65
+
66
+ return model
67
+
68
+ def __call__(self, data: Any) -> Dict[str, str]:
69
+ input_text = data.pop("inputs", data)
70
+ input_ids = self.tokenizer.encode(input_text, return_tensors="pt").to(DEV)
71
+
72
+ with torch.no_grad():
73
+ generated_ids = self.model.generate(
74
+ input_ids,
75
+ do_sample=True,
76
+ min_length=50,
77
+ max_length=200,
78
+ top_p=0.95,
79
+ temperature=0.8,
80
+ )
81
+ generated_text = self.tokenizer.decode([el.item() for el in generated_ids[0]])
82
+
83
+ return {'generated_text': generated_text}