Spaces:
Sleeping
Sleeping
Commit
·
0b8b900
1
Parent(s):
1320a73
Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,106 @@
|
|
1 |
-
import streamlit as st
|
2 |
|
3 |
-
x = st.slider('Select a value')
|
4 |
-
st.write(x, 'squared is', x * x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import streamlit as st
|
2 |
|
3 |
+
# x = st.slider('Select a value')
|
4 |
+
# st.write(x, 'squared is', x * x)
|
5 |
+
|
6 |
+
import sys
|
7 |
+
import os
|
8 |
+
import torch
|
9 |
+
import transformers
|
10 |
+
import json
|
11 |
+
|
12 |
+
assert (
|
13 |
+
"LlamaTokenizer" in transformers._import_structure["models.llama"]
|
14 |
+
), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
|
15 |
+
from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
|
16 |
+
|
17 |
+
if torch.cuda.is_available():
|
18 |
+
device = "cuda"
|
19 |
+
else:
|
20 |
+
device = "cpu"
|
21 |
+
|
22 |
+
try:
|
23 |
+
if torch.backends.mps.is_available():
|
24 |
+
device = "mps"
|
25 |
+
except:
|
26 |
+
pass
|
27 |
+
|
28 |
+
|
29 |
+
base_model = "/home/v-shanggu/MyCode_server_1503/Llama-X/RL_evol_llama2_chat_7B/checkpoint-40" # "/path/to/WizardLM13B",
|
30 |
+
|
31 |
+
tokenizer = LlamaTokenizer.from_pretrained(base_model)
|
32 |
+
load_8bit = False
|
33 |
+
if device == "cuda":
|
34 |
+
model = LlamaForCausalLM.from_pretrained(
|
35 |
+
base_model,
|
36 |
+
load_in_8bit=load_8bit,
|
37 |
+
torch_dtype=torch.float16,
|
38 |
+
device_map="auto",
|
39 |
+
)
|
40 |
+
elif device == "mps":
|
41 |
+
model = LlamaForCausalLM.from_pretrained(
|
42 |
+
base_model,
|
43 |
+
device_map={"": device},
|
44 |
+
torch_dtype=torch.float16,
|
45 |
+
)
|
46 |
+
|
47 |
+
model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
|
48 |
+
model.config.bos_token_id = 1
|
49 |
+
model.config.eos_token_id = 2
|
50 |
+
|
51 |
+
if not load_8bit:
|
52 |
+
model.half() # seems to fix bugs for some users.
|
53 |
+
if torch.__version__ >= "2" and sys.platform != "win32":
|
54 |
+
model = torch.compile(model)
|
55 |
+
|
56 |
+
class Call_model():
|
57 |
+
model.eval()
|
58 |
+
def evaluate(self, instruction):
|
59 |
+
final_output = self.inference(instruction+"\n\n### Response:")
|
60 |
+
return final_output
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
def inference(self,
|
65 |
+
batch_data,
|
66 |
+
input=None,
|
67 |
+
temperature=1,
|
68 |
+
top_p=0.95,
|
69 |
+
top_k=40,
|
70 |
+
num_beams=1,
|
71 |
+
max_new_tokens=4096,
|
72 |
+
**kwargs,
|
73 |
+
):
|
74 |
+
|
75 |
+
|
76 |
+
prompts = f"""A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {batch_data} ASSISTANT:"""
|
77 |
+
|
78 |
+
inputs = tokenizer(prompts, return_tensors="pt")
|
79 |
+
input_ids = inputs["input_ids"].to(device)
|
80 |
+
generation_config = GenerationConfig(
|
81 |
+
temperature=temperature,
|
82 |
+
top_p=top_p,
|
83 |
+
top_k=top_k,
|
84 |
+
num_beams=num_beams,
|
85 |
+
**kwargs,
|
86 |
+
)
|
87 |
+
with torch.no_grad():
|
88 |
+
generation_output = model.generate(
|
89 |
+
input_ids=input_ids,
|
90 |
+
generation_config=generation_config,
|
91 |
+
return_dict_in_generate=True,
|
92 |
+
output_scores=True,
|
93 |
+
max_new_tokens=max_new_tokens,
|
94 |
+
)
|
95 |
+
s = generation_output.sequences
|
96 |
+
output = tokenizer.batch_decode(s, skip_special_tokens=True)
|
97 |
+
output = output[0].split("ASSISTANT:")[1].strip()
|
98 |
+
|
99 |
+
return output
|
100 |
+
|
101 |
+
if __name__ == "__main__":
|
102 |
+
prompt = input("Please input:")
|
103 |
+
prompt = str(prompt)
|
104 |
+
model_evaluate = Call_model()
|
105 |
+
prompt_state = model_evaluate.evaluate(prompt)
|
106 |
+
print(prompt_state)
|