vikigitonga11 commited on
Commit
6baacdd
·
verified ·
1 Parent(s): 1990293

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -20
app.py CHANGED
@@ -1,17 +1,16 @@
1
  import gradio as gr
2
  import re
3
  import torch
 
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
5
 
6
- # Load T5 paraphrase model (faster than PEGASUS)
7
  model_name = "Vamsi/T5_Paraphrase_Paws"
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.float16) # Use fp16 for speed
10
 
11
- # Move model to CPU (remove if using GPU)
12
- model.to("cpu")
13
 
14
- # Initialize paraphrase pipeline with optimized settings
15
  paraphrase_pipeline = pipeline(
16
  "text2text-generation",
17
  model=model,
@@ -23,31 +22,38 @@ def split_sentences(text):
23
  """Split text into sentences using regex (faster than nltk)."""
24
  return re.split(r'(?<=[.!?])\s+', text.strip())
25
 
26
- def paraphrase_text(text):
27
- """Paraphrases input text while maintaining sentence structure."""
28
  if not text.strip():
29
  return "⚠️ Please enter some text to paraphrase."
30
 
31
  sentences = split_sentences(text)
32
 
33
- # Apply T5 paraphrasing to each sentence
34
- paraphrased_results = paraphrase_pipeline(
35
  [f"paraphrase: {sentence} </s>" for sentence in sentences if sentence],
36
- max_length=50, do_sample=True, batch_size=8, num_return_sequences=1 # Faster settings
 
 
 
 
 
 
 
37
  )
38
 
 
39
  paraphrased_sentences = [result['generated_text'] for result in paraphrased_results]
40
  return " ".join(paraphrased_sentences)
41
 
42
- # Define Gradio Interface
43
- demo = gr.Interface(
44
- fn=paraphrase_text,
45
- inputs=gr.Textbox(label="Enter text", placeholder="Type your text to paraphrase...", lines=10),
46
- outputs=gr.Textbox(label="Paraphrased Text", lines=10),
47
- title="🚀 Fast & Clean T5 Paraphraser",
48
- description="Enter text and let AI generate a paraphrased version using an optimized T5 model!",
49
- theme="huggingface"
50
- )
51
 
52
  if __name__ == "__main__":
53
- demo.launch()
 
1
  import gradio as gr
2
  import re
3
  import torch
4
+ import asyncio
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
 
7
+ # Load T5 paraphrase model
8
  model_name = "Vamsi/T5_Paraphrase_Paws"
9
  tokenizer = AutoTokenizer.from_pretrained(model_name)
10
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.float16)
11
 
 
 
12
 
13
+ # Initialize paraphrase pipeline
14
  paraphrase_pipeline = pipeline(
15
  "text2text-generation",
16
  model=model,
 
22
  """Split text into sentences using regex (faster than nltk)."""
23
  return re.split(r'(?<=[.!?])\s+', text.strip())
24
 
25
+ async def paraphrase_text(text):
26
+ """Paraphrases input text while maintaining sentence structure asynchronously."""
27
  if not text.strip():
28
  return "⚠️ Please enter some text to paraphrase."
29
 
30
  sentences = split_sentences(text)
31
 
32
+ # Apply T5 paraphrasing with optimized settings
33
+ paraphrased_results = await asyncio.to_thread(paraphrase_pipeline,
34
  [f"paraphrase: {sentence} </s>" for sentence in sentences if sentence],
35
+ max_length=80,
36
+ do_sample=True,
37
+ temperature=0.7,
38
+ top_p=0.85,
39
+ top_k=50,
40
+ repetition_penalty=1.2,
41
+ num_return_sequences=1,
42
+ batch_size=8
43
  )
44
 
45
+ # Extract and join paraphrased sentences
46
  paraphrased_sentences = [result['generated_text'] for result in paraphrased_results]
47
  return " ".join(paraphrased_sentences)
48
 
49
+ # Define Gradio Interface (Disable queueing)
50
+ with gr.Blocks() as demo:
51
+ gr.Markdown("# 🚀 Fast & Parallel T5 Paraphraser")
52
+ input_box = gr.Textbox(label="Enter text", placeholder="Type your text to paraphrase...", lines=10)
53
+ output_box = gr.Textbox(label="Paraphrased Text", lines=10)
54
+ button = gr.Button("Paraphrase")
55
+
56
+ button.click(paraphrase_text, inputs=input_box, outputs=output_box)
 
57
 
58
  if __name__ == "__main__":
59
+ demo.launch(share=True, queue=False) # Disable queueing