xiezhe22 commited on
Commit
ff86f2b
Β·
1 Parent(s): b6fa5ca

Add text streamer

Browse files
Files changed (1) hide show
  1. app.py +35 -26
app.py CHANGED
@@ -4,11 +4,12 @@ import pandas as pd
4
  import numpy as np
5
  import torch
6
  import subprocess
 
7
  from transformers import (
8
  AutoModelForCausalLM,
9
  AutoTokenizer,
10
  AutoProcessor,
11
- TextStreamer
12
  )
13
 
14
  # ─── MODEL SETUP ────────────────────────────────────────────────────────────────
@@ -99,9 +100,10 @@ def preview_csv(csv_file):
99
 
100
  # Create plot with first column as default
101
  first_column = column_choices[0]
 
102
  plot = gr.LinePlot(
103
  df_with_index,
104
- # x="index",
105
  y=first_column,
106
  title=f"Time Series: {first_column}"
107
  )
@@ -127,10 +129,11 @@ def update_plot(csv_file, selected_column):
127
  return gr.LinePlot(value=pd.DataFrame())
128
 
129
  df_with_index = df.copy()
 
130
 
131
  plot = gr.LinePlot(
132
  df_with_index,
133
- # x="index",
134
  y=selected_column,
135
  title=f"Time Series: {selected_column}"
136
  )
@@ -190,30 +193,36 @@ def infer_chatts_stream(prompt: str, csv_file):
190
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
191
 
192
  # Generate with streaming
193
- # streamer = TextStreamer(tokenizer)
194
- generated_text = ""
195
- with torch.no_grad():
196
- outputs = model.generate(
197
- **inputs,
198
- max_new_tokens=512,
199
- do_sample=True,
200
- temperature=0.7,
201
- pad_token_id=tokenizer.eos_token_id
202
- )
203
-
204
- # Decode the generated text
205
- full_generated = tokenizer.decode(
206
- outputs[0][inputs["input_ids"].shape[-1]:],
207
- skip_special_tokens=True
208
- )
209
-
210
- # Simulate streaming by yielding character by character
211
- for i, char in enumerate(full_generated):
212
- generated_text += char
213
- if i % 5 == 0: # Update every 5 characters for smoother streaming
214
- yield generated_text
 
 
 
 
 
 
215
 
216
- yield generated_text
217
 
218
  except Exception as e:
219
  yield f"Error during inference: {str(e)}"
 
4
  import numpy as np
5
  import torch
6
  import subprocess
7
+ from threading import Thread
8
  from transformers import (
9
  AutoModelForCausalLM,
10
  AutoTokenizer,
11
  AutoProcessor,
12
+ TextIteratorStreamer
13
  )
14
 
15
  # ─── MODEL SETUP ────────────────────────────────────────────────────────────────
 
100
 
101
  # Create plot with first column as default
102
  first_column = column_choices[0]
103
+ df_with_index["_internal_idx"] = np.arange(len(df[first_column].values))
104
  plot = gr.LinePlot(
105
  df_with_index,
106
+ x="_internal_idx",
107
  y=first_column,
108
  title=f"Time Series: {first_column}"
109
  )
 
129
  return gr.LinePlot(value=pd.DataFrame())
130
 
131
  df_with_index = df.copy()
132
+ df_with_index["_internal_idx"] = np.arange(len(df[selected_column].values))
133
 
134
  plot = gr.LinePlot(
135
  df_with_index,
136
+ x="_internal_idx",
137
  y=selected_column,
138
  title=f"Time Series: {selected_column}"
139
  )
 
193
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
194
 
195
  # Generate with streaming
196
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
197
+ inputs.update({
198
+ "max_new_tokens": 512,
199
+ "streamer": streamer,
200
+ "temperature": 0.3
201
+ })
202
+ thread = threading.Thread(
203
+ target=model.generate,
204
+ kwargs=inputs
205
+ )
206
+ thread.start()
207
+
208
+ model_output = ""
209
+ for new_text in streamer:
210
+ model_output += new_text
211
+ yield model_output
212
+
213
+ # # Decode the generated text
214
+ # full_generated = tokenizer.decode(
215
+ # outputs[0][inputs["input_ids"].shape[-1]:],
216
+ # skip_special_tokens=True
217
+ # )
218
+
219
+ # # Simulate streaming by yielding character by character
220
+ # for i, char in enumerate(full_generated):
221
+ # generated_text += char
222
+ # if i % 5 == 0: # Update every 5 characters for smoother streaming
223
+ # yield generated_text
224
 
225
+ # yield generated_text
226
 
227
  except Exception as e:
228
  yield f"Error during inference: {str(e)}"