CodCodingCode commited on
Commit
4ce9b14
·
verified ·
1 Parent(s): 6e1da67

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +106 -0
handler.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import os
5
+
6
+ class EndpointHandler():
7
+ def __init__(self, path=""):
8
+ # Look for checkpoint-100 folder
9
+ checkpoint_path = None
10
+
11
+ if not path or path == "/repository":
12
+ base_path = "."
13
+ else:
14
+ base_path = path
15
+
16
+ # Check different possible locations
17
+ possible_paths = [
18
+ os.path.join(base_path, "checkpoint-100"),
19
+ os.path.join(".", "checkpoint-100"),
20
+ os.path.join("/repository", "checkpoint-100"),
21
+ "checkpoint-100"
22
+ ]
23
+
24
+ for check_path in possible_paths:
25
+ if os.path.exists(check_path) and os.path.isdir(check_path):
26
+ # Verify it contains model files
27
+ files = os.listdir(check_path)
28
+ if any(f in files for f in ['config.json', 'pytorch_model.bin', 'model.safetensors']):
29
+ checkpoint_path = check_path
30
+ break
31
+
32
+ if checkpoint_path is None:
33
+ print(f"Available files in base path: {os.listdir(base_path) if os.path.exists(base_path) else 'Path does not exist'}")
34
+ raise ValueError("Could not find checkpoint-100 folder with model files")
35
+
36
+ print(f"Loading model from: {checkpoint_path}")
37
+ print(f"Files in checkpoint: {os.listdir(checkpoint_path)}")
38
+
39
+ # Load model and tokenizer from checkpoint-100
40
+ self.tokenizer = AutoTokenizer.from_pretrained(checkpoint_path, trust_remote_code=True)
41
+ self.model = AutoModelForCausalLM.from_pretrained(
42
+ checkpoint_path,
43
+ device_map="auto",
44
+ torch_dtype=torch.bfloat16,
45
+ trust_remote_code=True,
46
+ )
47
+
48
+ # Set pad token if not exists
49
+ if self.tokenizer.pad_token is None:
50
+ self.tokenizer.pad_token = self.tokenizer.eos_token
51
+
52
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
53
+ """
54
+ data args:
55
+ inputs (:str): a string to be generated from
56
+ parameters (:dict): generation parameters
57
+ Return:
58
+ A :obj:`list` | `dict`: will be serialized and returned
59
+ """
60
+ # Get the input text
61
+ inputs = data.pop("inputs", data)
62
+ parameters = data.pop("parameters", {})
63
+
64
+ # Handle string input directly
65
+ if isinstance(inputs, str):
66
+ input_text = inputs
67
+ else:
68
+ input_text = str(inputs)
69
+
70
+ # Set default parameters
71
+ max_new_tokens = parameters.get("max_new_tokens", 1000)
72
+ temperature = parameters.get("temperature", 0.1)
73
+ do_sample = parameters.get("do_sample", True)
74
+ top_p = parameters.get("top_p", 0.9)
75
+ return_full_text = parameters.get("return_full_text", False)
76
+
77
+ # Tokenize the input
78
+ input_ids = self.tokenizer(
79
+ input_text,
80
+ return_tensors="pt",
81
+ padding=True,
82
+ truncation=True,
83
+ max_length=2048
84
+ ).to(self.model.device)
85
+
86
+ # Generate text
87
+ with torch.no_grad():
88
+ generated_ids = self.model.generate(
89
+ **input_ids,
90
+ max_new_tokens=max_new_tokens,
91
+ temperature=temperature,
92
+ do_sample=do_sample,
93
+ top_p=top_p,
94
+ pad_token_id=self.tokenizer.pad_token_id,
95
+ eos_token_id=self.tokenizer.eos_token_id,
96
+ )
97
+
98
+ # Decode the generated text
99
+ if return_full_text:
100
+ generated_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
101
+ else:
102
+ # Only return the newly generated part
103
+ new_tokens = generated_ids[0][input_ids["input_ids"].shape[1]:]
104
+ generated_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
105
+
106
+ return [{"generated_text": generated_text}]