sumuks HF Staff commited on
Commit
0440349
·
verified ·
1 Parent(s): bae72e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -11
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import gradio as gr
2
  import spaces
3
  import torch
4
- from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
 
5
 
6
  model_id = "textcleanlm/textclean-4B"
7
  model = None
@@ -37,19 +38,32 @@ def load_model():
37
  def clean_text(text):
38
  model, tokenizer = load_model()
39
 
40
- inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
41
  inputs = {k: v.cuda() for k, v in inputs.items()}
42
 
43
- with torch.no_grad():
44
- outputs = model.generate(
45
- **inputs,
46
- max_length=512,
47
- num_beams=4,
48
- early_stopping=True
49
- )
50
 
51
- cleaned_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
52
- return cleaned_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  iface = gr.Interface(
55
  fn=clean_text,
 
1
  import gradio as gr
2
  import spaces
3
  import torch
4
+ from threading import Thread
5
+ from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, TextIteratorStreamer
6
 
7
  model_id = "textcleanlm/textclean-4B"
8
  model = None
 
38
  def clean_text(text):
39
  model, tokenizer = load_model()
40
 
41
+ inputs = tokenizer(text, return_tensors="pt", max_length=4096, truncation=True)
42
  inputs = {k: v.cuda() for k, v in inputs.items()}
43
 
44
+ # Enable streaming
45
+ streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
 
 
 
 
 
46
 
47
+ generation_kwargs = dict(
48
+ **inputs,
49
+ max_length=4096,
50
+ num_beams=1, # Set to 1 for streaming
51
+ do_sample=True,
52
+ temperature=1.0,
53
+ streamer=streamer,
54
+ )
55
+
56
+ # Run generation in a separate thread
57
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
58
+ thread.start()
59
+
60
+ # Yield text as it's generated
61
+ generated_text = ""
62
+ for new_text in streamer:
63
+ generated_text += new_text
64
+ yield generated_text
65
+
66
+ thread.join()
67
 
68
  iface = gr.Interface(
69
  fn=clean_text,