DesiredName commited on
Commit
1cadecd
·
verified ·
1 Parent(s): f15cd12

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -22
app.py CHANGED
@@ -1,18 +1,65 @@
1
  from fastapi import FastAPI
2
  import uvicorn
3
- from transformers import AutoModel, AutoTokenizer
4
 
5
- model_name = "Tap-M/Luna-AI-Llama2-Uncensored"
6
 
7
- model = AutoModel.from_pretrained(
8
- model_name, # Example model
9
- offload_folder="./offload", # Temporary directory
10
- trust_remote_code=True # Required for some models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  )
12
 
13
- # load tokenizer
14
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
15
- tokenizer.pad_token = tokenizer.eos_token
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  app = FastAPI()
18
 
@@ -22,19 +69,7 @@ def greet_json():
22
 
23
  @app.get("/message")
24
  async def message(input: str):
25
- prompt = "USER:" + input + "\nASSISTANT:"
26
-
27
- inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
28
-
29
- output = model.generate(
30
- input_ids=inputs["input_ids"],
31
- attention_mask=inputs["attention_mask"],
32
- max_new_tokens=100,
33
- )
34
-
35
- response = tokenizer.decode(output[0], skip_special_tokens=True)
36
-
37
- return response
38
 
39
  if __name__ == "__main__":
40
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  from fastapi import FastAPI
2
  import uvicorn
 
3
 
4
+ #model_name = "Tap-M/Luna-AI-Llama2-Uncensored"
5
 
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextStreamer
7
+ import torch
8
+
9
+ # Configuration for 4-bit quantization
10
+ bnb_config = BitsAndBytesConfig(
11
+ load_in_4bit=True,
12
+ bnb_4bit_quant_type="nf4", # Optimized 4-bit precision
13
+ bnb_4bit_compute_dtype=torch.float16, # Faster computations
14
+ bnb_4bit_use_double_quant=True # Extra memory savings
15
+ )
16
+
17
+ # Load model and tokenizer
18
+ model_name = "meta-llama/Llama-2-7b-chat-hf" # or "13b-chat-hf"
19
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
20
+ model = AutoModelForCausalLM.from_pretrained(
21
+ model_name,
22
+ quantization_config=bnb_config,
23
+ device_map="auto", # Auto-distribute across GPU/CPU
24
+ torch_dtype=torch.float16,
25
+ trust_remote_code=True # Required for Llama 2
26
  )
27
 
28
+ # Set chat template (critical for chat models)
29
+ tokenizer.chat_template = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content + ' ' + eos_token }}{% endif %}{% endfor %}"
30
+
31
+ def llama2_chat(prompt, system_prompt="You are a helpful assistant."):
32
+ # Format as Llama 2 chat
33
+ messages = [
34
+ {"role": "system", "content": system_prompt},
35
+ {"role": "user", "content": prompt}
36
+ ]
37
+
38
+ # Tokenize with chat template
39
+ inputs = tokenizer.apply_chat_template(
40
+ messages,
41
+ return_tensors="pt"
42
+ ).to(model.device)
43
+
44
+ # Stream output tokens
45
+ streamer = TextStreamer(tokenizer, skip_prompt=True)
46
+
47
+ # Generate response
48
+ outputs = model.generate(
49
+ inputs,
50
+ max_new_tokens=1000,
51
+ temperature=0.7,
52
+ streamer=streamer
53
+ )
54
+
55
+ # Decode full output
56
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
57
+
58
+
59
+
60
+
61
+
62
+
63
 
64
  app = FastAPI()
65
 
 
69
 
70
  @app.get("/message")
71
  async def message(input: str):
72
+ return llama2_chat(input)
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  if __name__ == "__main__":
75
  uvicorn.run(app, host="0.0.0.0", port=7860)