freeCS-dot-org commited on
Commit
436bf67
·
verified ·
1 Parent(s): f4f5376

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -33
app.py CHANGED
@@ -33,7 +33,7 @@ class ConversationManager:
33
  def __init__(self):
34
  self.user_history = [] # For displaying to user (with markdown)
35
  self.model_history = [] # For feeding back to model (with original tags)
36
-
37
  def add_exchange(self, user_message, assistant_response, formatted_response):
38
  self.model_history.append((user_message, assistant_response))
39
  self.user_history.append((user_message, formatted_response))
@@ -42,13 +42,16 @@ class ConversationManager:
42
  print(f"User: {user_message}")
43
  print(f"Assistant (Original): {assistant_response}")
44
  print(f"Assistant (Formatted): {formatted_response}")
45
-
46
  def get_model_history(self):
47
  return self.model_history
48
-
49
  def get_user_history(self):
50
  return self.user_history
51
 
 
 
 
52
 
53
  device = "cuda" # for GPU usage or "cpu" for CPU usage
54
 
@@ -72,52 +75,49 @@ def format_response(response):
72
  @spaces.GPU()
73
  def stream_chat(
74
  message: str,
75
- history: list,
76
  system_prompt: str,
77
  temperature: float = 0.2,
78
  max_new_tokens: int = 4096,
79
  top_p: float = 1.0,
80
  top_k: int = 1,
81
  penalty: float = 1.1,
82
- conversation_manager: ConversationManager = None # Pass the manager as argument
83
  ):
84
- # Initialize the conversation manager for the first time only
85
- if conversation_manager is None:
86
- conversation_manager = ConversationManager()
87
 
88
  print(f'\nNew Chat Request:')
89
  print(f'Message: {message}')
90
- print(f'History from UI: {history}')
91
  print(f'System Prompt: {system_prompt}')
92
  print(f'Parameters: temp={temperature}, max_tokens={max_new_tokens}, top_p={top_p}, top_k={top_k}, penalty={penalty}')
93
-
94
  model_history = conversation_manager.get_model_history()
95
  print(f'Model History: {model_history}')
96
-
97
  conversation = []
98
  for prompt, answer in model_history:
99
  conversation.extend([
100
  {"role": "user", "content": prompt},
101
  {"role": "assistant", "content": answer},
102
  ])
103
-
104
  conversation.append({"role": "user", "content": message})
105
  print(f'\nFormatted Conversation for Model:')
106
  print(conversation)
107
-
108
  input_ids = tokenizer.apply_chat_template(
109
- conversation,
110
- add_generation_prompt=True,
111
  return_tensors="pt"
112
  ).to(model.device)
113
-
114
  streamer = TextIteratorStreamer(
115
- tokenizer,
116
- timeout=60.0,
117
- skip_prompt=True,
118
  skip_special_tokens=True
119
  )
120
-
121
  generate_kwargs = dict(
122
  input_ids=input_ids,
123
  max_new_tokens=max_new_tokens,
@@ -129,43 +129,59 @@ def stream_chat(
129
  eos_token_id=[end_of_sentence],
130
  streamer=streamer,
131
  )
132
-
133
  buffer = ""
134
  original_response = ""
135
-
136
  with torch.no_grad():
137
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
138
  thread.start()
139
-
140
  for new_text in streamer:
141
  buffer += new_text
142
  original_response += new_text
143
-
144
  formatted_buffer = format_response(buffer)
145
-
146
  if thread.is_alive() is False:
147
  print(f'\nGeneration Complete:')
148
  print(f'Original Response: {original_response}')
149
  print(f'Formatted Response: {formatted_buffer}')
150
-
151
  conversation_manager.add_exchange(
152
  message,
153
  original_response, # Original for model
154
- formatted_buffer # Formatted for user
155
  )
156
-
157
- yield formatted_buffer
 
 
 
 
 
 
 
158
 
159
  chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
160
- conversation_manager_session_state = gr.State(ConversationManager())
161
 
162
  with gr.Blocks(css=CSS, theme="soft") as demo:
163
  gr.HTML(TITLE)
164
  gr.DuplicateButton(
165
- value="Duplicate Space for private use",
166
  elem_classes="duplicate-button"
167
  )
168
- gr.ChatInterface(
 
 
 
 
 
 
 
 
 
 
169
  fn=stream_chat,
170
  chatbot=chatbot,
171
  fill_height=True,
@@ -175,6 +191,7 @@ with gr.Blocks(css=CSS, theme="soft") as demo:
175
  render=False
176
  ),
177
  additional_inputs=[
 
178
  gr.Textbox(
179
  value="",
180
  label="System Prompt",
@@ -220,7 +237,6 @@ with gr.Blocks(css=CSS, theme="soft") as demo:
220
  label="Repetition penalty",
221
  render=False,
222
  ),
223
- conversation_manager_session_state, # Add the state to the input
224
  ],
225
  examples=[
226
  ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
 
33
  def __init__(self):
34
  self.user_history = [] # For displaying to user (with markdown)
35
  self.model_history = [] # For feeding back to model (with original tags)
36
+
37
  def add_exchange(self, user_message, assistant_response, formatted_response):
38
  self.model_history.append((user_message, assistant_response))
39
  self.user_history.append((user_message, formatted_response))
 
42
  print(f"User: {user_message}")
43
  print(f"Assistant (Original): {assistant_response}")
44
  print(f"Assistant (Formatted): {formatted_response}")
45
+
46
  def get_model_history(self):
47
  return self.model_history
48
+
49
  def get_user_history(self):
50
  return self.user_history
51
 
52
+ def clear(self):
53
+ self.user_history = []
54
+ self.model_history = []
55
 
56
  device = "cuda" # for GPU usage or "cpu" for CPU usage
57
 
 
75
  @spaces.GPU()
76
  def stream_chat(
77
  message: str,
78
+ history_state: gr.State, # Access the internal history state
79
  system_prompt: str,
80
  temperature: float = 0.2,
81
  max_new_tokens: int = 4096,
82
  top_p: float = 1.0,
83
  top_k: int = 1,
84
  penalty: float = 1.1,
 
85
  ):
86
+ conversation_manager = history_state
 
 
87
 
88
  print(f'\nNew Chat Request:')
89
  print(f'Message: {message}')
90
+ print(f'History from UI: {conversation_manager.get_user_history()}')
91
  print(f'System Prompt: {system_prompt}')
92
  print(f'Parameters: temp={temperature}, max_tokens={max_new_tokens}, top_p={top_p}, top_k={top_k}, penalty={penalty}')
93
+
94
  model_history = conversation_manager.get_model_history()
95
  print(f'Model History: {model_history}')
96
+
97
  conversation = []
98
  for prompt, answer in model_history:
99
  conversation.extend([
100
  {"role": "user", "content": prompt},
101
  {"role": "assistant", "content": answer},
102
  ])
103
+
104
  conversation.append({"role": "user", "content": message})
105
  print(f'\nFormatted Conversation for Model:')
106
  print(conversation)
107
+
108
  input_ids = tokenizer.apply_chat_template(
109
+ conversation,
110
+ add_generation_prompt=True,
111
  return_tensors="pt"
112
  ).to(model.device)
113
+
114
  streamer = TextIteratorStreamer(
115
+ tokenizer,
116
+ timeout=60.0,
117
+ skip_prompt=True,
118
  skip_special_tokens=True
119
  )
120
+
121
  generate_kwargs = dict(
122
  input_ids=input_ids,
123
  max_new_tokens=max_new_tokens,
 
129
  eos_token_id=[end_of_sentence],
130
  streamer=streamer,
131
  )
132
+
133
  buffer = ""
134
  original_response = ""
135
+
136
  with torch.no_grad():
137
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
138
  thread.start()
139
+
140
  for new_text in streamer:
141
  buffer += new_text
142
  original_response += new_text
143
+
144
  formatted_buffer = format_response(buffer)
145
+
146
  if thread.is_alive() is False:
147
  print(f'\nGeneration Complete:')
148
  print(f'Original Response: {original_response}')
149
  print(f'Formatted Response: {formatted_buffer}')
150
+
151
  conversation_manager.add_exchange(
152
  message,
153
  original_response, # Original for model
154
+ formatted_buffer # Formatted for user
155
  )
156
+
157
+ yield formatted_buffer, conversation_manager
158
+
159
+ def clear_chat(history_state: gr.State):
160
+ history_state.clear()
161
+ return None, history_state
162
+
163
+ # Initialize the conversation manager outside of the function
164
+ conversation_manager = ConversationManager()
165
 
166
  chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
 
167
 
168
  with gr.Blocks(css=CSS, theme="soft") as demo:
169
  gr.HTML(TITLE)
170
  gr.DuplicateButton(
171
+ value="Duplicate Space for private use",
172
  elem_classes="duplicate-button"
173
  )
174
+
175
+ # Pass the initial state to the ChatInterface
176
+ history_state = gr.State(conversation_manager)
177
+
178
+ clear_inputs_button = gr.ClearButton(
179
+ value="Clear Chat",
180
+ components=[chatbot],
181
+ )
182
+ clear_inputs_button.click(fn=clear_chat, inputs=[history_state], outputs=[chatbot, history_state])
183
+
184
+ chat_interface = gr.ChatInterface(
185
  fn=stream_chat,
186
  chatbot=chatbot,
187
  fill_height=True,
 
191
  render=False
192
  ),
193
  additional_inputs=[
194
+ history_state, # Pass the state to the ChatInterface
195
  gr.Textbox(
196
  value="",
197
  label="System Prompt",
 
237
  label="Repetition penalty",
238
  render=False,
239
  ),
 
240
  ],
241
  examples=[
242
  ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],