Shangding-Gu commited on
Commit
0b8b900
·
1 Parent(s): 1320a73

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -3
app.py CHANGED
@@ -1,4 +1,106 @@
1
- import streamlit as st
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import streamlit as st
2
 
3
+ # x = st.slider('Select a value')
4
+ # st.write(x, 'squared is', x * x)
5
+
6
+ import sys
7
+ import os
8
+ import torch
9
+ import transformers
10
+ import json
11
+
12
+ assert (
13
+ "LlamaTokenizer" in transformers._import_structure["models.llama"]
14
+ ), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
15
+ from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
16
+
17
+ if torch.cuda.is_available():
18
+ device = "cuda"
19
+ else:
20
+ device = "cpu"
21
+
22
+ try:
23
+ if torch.backends.mps.is_available():
24
+ device = "mps"
25
+ except:
26
+ pass
27
+
28
+
29
+ base_model = "/home/v-shanggu/MyCode_server_1503/Llama-X/RL_evol_llama2_chat_7B/checkpoint-40" # "/path/to/WizardLM13B",
30
+
31
+ tokenizer = LlamaTokenizer.from_pretrained(base_model)
32
+ load_8bit = False
33
+ if device == "cuda":
34
+ model = LlamaForCausalLM.from_pretrained(
35
+ base_model,
36
+ load_in_8bit=load_8bit,
37
+ torch_dtype=torch.float16,
38
+ device_map="auto",
39
+ )
40
+ elif device == "mps":
41
+ model = LlamaForCausalLM.from_pretrained(
42
+ base_model,
43
+ device_map={"": device},
44
+ torch_dtype=torch.float16,
45
+ )
46
+
47
+ model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
48
+ model.config.bos_token_id = 1
49
+ model.config.eos_token_id = 2
50
+
51
+ if not load_8bit:
52
+ model.half() # seems to fix bugs for some users.
53
+ if torch.__version__ >= "2" and sys.platform != "win32":
54
+ model = torch.compile(model)
55
+
56
+ class Call_model():
57
+ model.eval()
58
+ def evaluate(self, instruction):
59
+ final_output = self.inference(instruction+"\n\n### Response:")
60
+ return final_output
61
+
62
+
63
+
64
+ def inference(self,
65
+ batch_data,
66
+ input=None,
67
+ temperature=1,
68
+ top_p=0.95,
69
+ top_k=40,
70
+ num_beams=1,
71
+ max_new_tokens=4096,
72
+ **kwargs,
73
+ ):
74
+
75
+
76
+ prompts = f"""A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {batch_data} ASSISTANT:"""
77
+
78
+ inputs = tokenizer(prompts, return_tensors="pt")
79
+ input_ids = inputs["input_ids"].to(device)
80
+ generation_config = GenerationConfig(
81
+ temperature=temperature,
82
+ top_p=top_p,
83
+ top_k=top_k,
84
+ num_beams=num_beams,
85
+ **kwargs,
86
+ )
87
+ with torch.no_grad():
88
+ generation_output = model.generate(
89
+ input_ids=input_ids,
90
+ generation_config=generation_config,
91
+ return_dict_in_generate=True,
92
+ output_scores=True,
93
+ max_new_tokens=max_new_tokens,
94
+ )
95
+ s = generation_output.sequences
96
+ output = tokenizer.batch_decode(s, skip_special_tokens=True)
97
+ output = output[0].split("ASSISTANT:")[1].strip()
98
+
99
+ return output
100
+
101
+ if __name__ == "__main__":
102
+ prompt = input("Please input:")
103
+ prompt = str(prompt)
104
+ model_evaluate = Call_model()
105
+ prompt_state = model_evaluate.evaluate(prompt)
106
+ print(prompt_state)