iimran commited on
Commit
6d7bc3e
·
verified ·
1 Parent(s): b9bfbd5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -31
app.py CHANGED
@@ -1,54 +1,44 @@
1
  import gradio as gr
2
- import torch
3
  from transformers import BartTokenizer, BartForConditionalGeneration
4
 
5
- # Load the model and tokenizer from Hugging Face hub
6
  model_name = "iimran/SAM-TheSummariserV2"
7
  tokenizer = BartTokenizer.from_pretrained(model_name)
8
  model = BartForConditionalGeneration.from_pretrained(model_name)
9
- model.eval() # Set the model to evaluation mode
10
 
11
- # Function to summarize the input text
12
  def summarize(input_text):
13
- # Tokenize the input text with truncation (adjust max_length as needed)
14
- inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=1024)
15
-
16
- # Create global attention mask: assign global attention to the first token (required by LED)
17
- global_attention_mask = torch.zeros(inputs["input_ids"].shape, dtype=torch.long)
18
- global_attention_mask[:, 0] = 1
19
-
20
- # Generate the summary using beam search (you can adjust parameters as needed)
21
  summary_ids = model.generate(
22
  inputs["input_ids"],
23
- attention_mask=inputs["attention_mask"],
24
- global_attention_mask=global_attention_mask,
25
- max_length=512,
26
- num_beams=4,
27
- early_stopping=True,
28
  )
29
-
30
- # Decode the generated ids to a summary string, skipping special tokens
31
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
32
  return summary
33
 
34
- # Create a Gradio interface with a title, description, submit button, and larger input text area
35
  iface = gr.Interface(
36
- fn=summarize, # Function that handles the summarization
37
  inputs=gr.Textbox(
38
  label="Enter Text to Summarize",
39
- lines=10, # Make the input area larger by increasing the number of lines
40
- placeholder="Paste or type the text you want to summarize here...",
41
  ),
42
  outputs=gr.Textbox(
43
- label="Summary",
44
- lines=5, # Adjust output area size (number of lines)
45
  placeholder="Summary will appear here..."
46
  ),
47
- live=False, # Disable live updates, use the submit button instead
48
- allow_flagging="never", # Optionally disable flagging
49
- title="SAM - The Summariser", # Title of the page
50
- description="SAM is a model which will help summarize large knowledge base articles into small summaries.", # Description of the model
51
  )
52
 
53
- # Launch the interface
54
- iface.launch()
 
1
  import gradio as gr
 
2
  from transformers import BartTokenizer, BartForConditionalGeneration
3
 
4
+ # Load model and tokenizer from Hugging Face hub using the provided model name
5
  model_name = "iimran/SAM-TheSummariserV2"
6
  tokenizer = BartTokenizer.from_pretrained(model_name)
7
  model = BartForConditionalGeneration.from_pretrained(model_name)
 
8
 
9
+ # Define the summarization function
10
  def summarize(input_text):
11
+ # Tokenize the input text with truncation
12
+ inputs = tokenizer(input_text, max_length=1024, truncation=True, return_tensors="pt")
13
+
14
+ # Generate the summary using beam search
 
 
 
 
15
  summary_ids = model.generate(
16
  inputs["input_ids"],
17
+ num_beams=4, # Use beam search with 4 beams for quality summaries
18
+ max_length=128, # Set maximum length for the generated summary
19
+ early_stopping=True # Enable early stopping if all beams finish
 
 
20
  )
21
+
22
+ # Decode the generated summary tokens to a string
23
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
24
  return summary
25
 
26
+ # Create a Gradio interface
27
  iface = gr.Interface(
28
+ fn=summarize,
29
  inputs=gr.Textbox(
30
  label="Enter Text to Summarize",
31
+ lines=10,
32
+ placeholder="Paste or type the text you want to summarize here..."
33
  ),
34
  outputs=gr.Textbox(
35
+ label="Summary",
36
+ lines=5,
37
  placeholder="Summary will appear here..."
38
  ),
39
+ title="SAM - The Summariser",
40
+ description="SAM is a model that summarizes large texts into concise summaries."
 
 
41
  )
42
 
43
+ # Launch the Gradio interface
44
+ iface.launch()