MudassirFayaz commited on
Commit
4431147
·
verified ·
1 Parent(s): 997e24d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -4
app.py CHANGED
@@ -1,19 +1,188 @@
1
- from datasets import load_dataset
 
 
 
 
 
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from peft import PeftModel
 
 
 
4
 
5
-
6
  base_model = AutoModelForCausalLM.from_pretrained(
7
  'meta-llama/Llama-2-7b-chat-hf',
8
  trust_remote_code=True,
9
  device_map="auto",
10
- torch_dtype=torch.float16, # optional if you have enough VRAM
11
  )
12
  tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-chat-hf')
13
 
 
14
  model = PeftModel.from_pretrained(base_model, 'FinGPT/fingpt-forecaster_dow30_llama2-7b_lora')
15
  model = model.eval()
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- if __name__ == "__app__":
19
  demo.queue(max_size=20).launch()
 
1
+ import os
2
+ from threading import Thread
3
+ from typing import Iterator, List, Tuple
4
+
5
+ import torch
6
+ from fastapi import FastAPI, HTTPException
7
+ from pydantic import BaseModel
8
  from transformers import AutoTokenizer, AutoModelForCausalLM
9
  from peft import PeftModel
10
+ import gradio as gr
11
+ from gradio import Blocks
12
+ from transformers import TextIteratorStreamer
13
 
14
+ # Load the base model and tokenizer
15
  base_model = AutoModelForCausalLM.from_pretrained(
16
  'meta-llama/Llama-2-7b-chat-hf',
17
  trust_remote_code=True,
18
  device_map="auto",
19
+ torch_dtype=torch.float16,
20
  )
21
  tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-chat-hf')
22
 
23
+ # Load the finetuned model
24
  model = PeftModel.from_pretrained(base_model, 'FinGPT/fingpt-forecaster_dow30_llama2-7b_lora')
25
  model = model.eval()
26
 
27
+ # Define constants
28
+ MAX_MAX_NEW_TOKENS = 2048
29
+ DEFAULT_MAX_NEW_TOKENS = 1024
30
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
31
+
32
+ # FastAPI setup
33
+ app = FastAPI()
34
+
35
+ class ChatRequest(BaseModel):
36
+ message: str
37
+ chat_history: List[Tuple[str, str]] = []
38
+ system_prompt: str = ""
39
+ max_new_tokens: int = 1024
40
+ temperature: float = 0.6
41
+ top_p: float = 0.9
42
+ top_k: int = 50
43
+ repetition_penalty: float = 1.2
44
+
45
+ @app.post("/chat/")
46
+ async def chat(request: ChatRequest):
47
+ try:
48
+ response = await generate_response(
49
+ request.message,
50
+ request.chat_history,
51
+ request.system_prompt,
52
+ request.max_new_tokens,
53
+ request.temperature,
54
+ request.top_p,
55
+ request.top_k,
56
+ request.repetition_penalty
57
+ )
58
+ return {"response": response}
59
+ except Exception as e:
60
+ raise HTTPException(status_code=500, detail=str(e))
61
+
62
+ async def generate_response(
63
+ message: str,
64
+ chat_history: List[Tuple[str, str]],
65
+ system_prompt: str,
66
+ max_new_tokens: int = 1024,
67
+ temperature: float = 0.6,
68
+ top_p: float = 0.9,
69
+ top_k: int = 50,
70
+ repetition_penalty: float = 1.2,
71
+ ) -> str:
72
+ conversation = []
73
+ if system_prompt:
74
+ conversation.append({"role": "system", "content": system_prompt})
75
+ for user, assistant in chat_history:
76
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
77
+ conversation.append({"role": "user", "content": message})
78
+
79
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
80
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
81
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
82
+
83
+ input_ids = input_ids.to(model.device)
84
+
85
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
86
+ generate_kwargs = {
87
+ "input_ids": input_ids,
88
+ "streamer": streamer,
89
+ "max_new_tokens": max_new_tokens,
90
+ "do_sample": True,
91
+ "top_p": top_p,
92
+ "top_k": top_k,
93
+ "temperature": temperature,
94
+ "num_beams": 1,
95
+ "repetition_penalty": repetition_penalty,
96
+ }
97
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
98
+ t.start()
99
+
100
+ outputs = []
101
+ for text in streamer:
102
+ outputs.append(text)
103
+ return "".join(outputs)
104
+
105
+ # Gradio setup
106
+ def generate(
107
+ message: str,
108
+ chat_history: List[Tuple[str, str]],
109
+ system_prompt: str,
110
+ max_new_tokens: int = 1024,
111
+ temperature: float = 0.6,
112
+ top_p: float = 0.9,
113
+ top_k: int = 50,
114
+ repetition_penalty: float = 1.2,
115
+ ) -> Iterator[str]:
116
+ return generate_response(
117
+ message,
118
+ chat_history,
119
+ system_prompt,
120
+ max_new_tokens,
121
+ temperature,
122
+ top_p,
123
+ top_k,
124
+ repetition_penalty
125
+ )
126
+
127
+ chat_interface = gr.ChatInterface(
128
+ fn=generate,
129
+ additional_inputs=[
130
+ gr.Textbox(label="System prompt", lines=6),
131
+ gr.Slider(
132
+ label="Max new tokens",
133
+ minimum=1,
134
+ maximum=MAX_MAX_NEW_TOKENS,
135
+ step=1,
136
+ value=DEFAULT_MAX_NEW_TOKENS,
137
+ ),
138
+ gr.Slider(
139
+ label="Temperature",
140
+ minimum=0.1,
141
+ maximum=4.0,
142
+ step=0.1,
143
+ value=0.6,
144
+ ),
145
+ gr.Slider(
146
+ label="Top-p (nucleus sampling)",
147
+ minimum=0.05,
148
+ maximum=1.0,
149
+ step=0.05,
150
+ value=0.9,
151
+ ),
152
+ gr.Slider(
153
+ label="Top-k",
154
+ minimum=1,
155
+ maximum=1000,
156
+ step=1,
157
+ value=50,
158
+ ),
159
+ gr.Slider(
160
+ label="Repetition penalty",
161
+ minimum=1.0,
162
+ maximum=2.0,
163
+ step=0.05,
164
+ value=1.2,
165
+ ),
166
+ ],
167
+ stop_btn=None,
168
+ examples=[
169
+ ["Hello there! How are you doing?"],
170
+ ["Can you explain briefly to me what is the Python programming language?"],
171
+ ["Explain the plot of Cinderella in a sentence."],
172
+ ["How many hours does it take a man to eat a Helicopter?"],
173
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
174
+ ],
175
+ )
176
+
177
+ with Blocks(css="style.css") as demo:
178
+ gr.Markdown("# Llama-2 7B Chat")
179
+ gr.Markdown("""
180
+ This Space demonstrates the Llama-2 7B Chat model by Meta, fine-tuned for chat instructions.
181
+ Feel free to chat with the model here or use the API to integrate it into your applications.
182
+ """)
183
+ chat_interface.render()
184
+ gr.Markdown("---")
185
+ gr.Markdown("This demo is governed by the original [license](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/LICENSE.txt).")
186
 
187
+ if __name__ == "__main__":
188
  demo.queue(max_size=20).launch()