DesiredName commited on
Commit
8a5a310
·
verified ·
1 Parent(s): 1f2a2bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -14
app.py CHANGED
@@ -1,23 +1,16 @@
1
  from fastapi import FastAPI
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
3
  import uvicorn
4
 
5
- bnb_config = BitsAndBytesConfig(
6
- load_in_4bit=True, # Enable 4-bit quantization
7
- bnb_4bit_quant_type="nf4", # Use normalized float 4
8
- bnb_4bit_compute_dtype="float16", # Faster computations
9
- bnb_4bit_use_double_quant=True # Extra compression
10
- )
11
 
12
  model = AutoModelForCausalLM.from_pretrained(
13
- "TheBloke/Wizard-Vicuna-13B-Uncensored-SuperHOT-8K-GPTQ",
14
- quantization_config=bnb_config,
15
- device_map="auto", # Auto-distribute across CPU/GPU
16
- trust_remote_code=True # Required for Qwen!
17
  )
18
 
19
  tokenizer = AutoTokenizer.from_pretrained(
20
- "TheBloke/Wizard-Vicuna-13B-Uncensored-SuperHOT-8K-GPTQ",
21
  trust_remote_code=True
22
  )
23
 
@@ -30,8 +23,18 @@ def greet_json():
30
  @app.get("/message")
31
  async def message(input: str):
32
  inputs = tokenizer(input, return_tensors="pt", padding=True, truncation=True)
33
- output = model.generate(**inputs, max_length=50, temperature=0.3)
34
- return tokenizer.decode(output[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
35
 
36
  if __name__ == "__main__":
37
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  from fastapi import FastAPI
2
+ from transformers import AutoModel, AutoTokenizer
3
  import uvicorn
4
 
5
+ model_name = "TheBloke/Wizard-Vicuna-13B-Uncensored-GGUF"
 
 
 
 
 
6
 
7
  model = AutoModelForCausalLM.from_pretrained(
8
+ model_name,
9
+ trust_remote_code=True
 
 
10
  )
11
 
12
  tokenizer = AutoTokenizer.from_pretrained(
13
+ model_name,
14
  trust_remote_code=True
15
  )
16
 
 
23
  @app.get("/message")
24
  async def message(input: str):
25
  inputs = tokenizer(input, return_tensors="pt", padding=True, truncation=True)
26
+
27
+ output = model.generate(
28
+ input_ids=inputs["input_ids"],
29
+ attention_mask=inputs["attention_mask"], # Pass attention_mask!
30
+ max_new_tokens=100,
31
+ temperature=0.0, # Disables randomness
32
+ do_sample=False # Greedy decoding
33
+ )
34
+
35
+ response = tokenizer.decode(output[0], skip_special_tokens=True)
36
+
37
+ return response
38
 
39
  if __name__ == "__main__":
40
  uvicorn.run(app, host="0.0.0.0", port=7860)