storresbusquets commited on
Commit
0905a09
·
1 Parent(s): 79513d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -5
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  import whisper
3
  from pytube import YouTube
4
- from transformers import pipeline, T5Tokenizer, T5ForConditionalGeneration
5
 
6
  class GradioInference():
7
  def __init__(self):
@@ -10,8 +10,11 @@ class GradioInference():
10
  self.current_size = "base"
11
  self.loaded_model = whisper.load_model(self.current_size)
12
  self.yt = None
13
- self.summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")
14
-
 
 
 
15
  # Initialize VoiceLabT5 model and tokenizer
16
  self.keyword_model = T5ForConditionalGeneration.from_pretrained("Voicelab/vlt5-base-keywords")
17
  self.keyword_tokenizer = T5Tokenizer.from_pretrained("Voicelab/vlt5-base-keywords")
@@ -33,8 +36,15 @@ class GradioInference():
33
 
34
  results = self.loaded_model.transcribe(path, language=lang)
35
 
 
 
 
 
 
 
 
36
  # Perform summarization on the transcription
37
- transcription_summary = self.summarizer(results["text"], max_length=130, min_length=30, do_sample=False)
38
 
39
  # Extract keywords using VoiceLabT5
40
  task_prefix = "Keywords: "
@@ -46,7 +56,7 @@ class GradioInference():
46
 
47
  label = self.classifier(results["text"])[0]["label"]
48
 
49
- return results["text"], transcription_summary[0]["summary_text"], keywords, label
50
 
51
  def populate_metadata(self, link):
52
  self.yt = YouTube(link)
 
1
  import gradio as gr
2
  import whisper
3
  from pytube import YouTube
4
+ from transformers import pipeline, T5Tokenizer, T5ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM
5
 
6
  class GradioInference():
7
  def __init__(self):
 
10
  self.current_size = "base"
11
  self.loaded_model = whisper.load_model(self.current_size)
12
  self.yt = None
13
+ # self.summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")
14
+
15
+ self.tokenizer_model = AutoTokenizer.from_pretrained("google/pegasus-large")
16
+ self.summarizer_model = AutoModelForSeq2SeqLM.from_pretrained("google/pegasus-large")
17
+
18
  # Initialize VoiceLabT5 model and tokenizer
19
  self.keyword_model = T5ForConditionalGeneration.from_pretrained("Voicelab/vlt5-base-keywords")
20
  self.keyword_tokenizer = T5Tokenizer.from_pretrained("Voicelab/vlt5-base-keywords")
 
36
 
37
  results = self.loaded_model.transcribe(path, language=lang)
38
 
39
+ inputs = tokenizer(results["text"], max_length=1024, truncation=True, return_tensors="pt")
40
+
41
+ summary_ids = self.keyword_model.generate(inputs["input_ids"])
42
+ summary = self.keyword_tokenizer.batch_decode(summary_ids,
43
+ skip_special_tokens=True,
44
+ clean_up_tokenization_spaces=False)
45
+
46
  # Perform summarization on the transcription
47
+ # transcription_summary = self.summarizer(results["text"], max_length=130, min_length=30, do_sample=False)
48
 
49
  # Extract keywords using VoiceLabT5
50
  task_prefix = "Keywords: "
 
56
 
57
  label = self.classifier(results["text"])[0]["label"]
58
 
59
+ return results["text"], summary[0], keywords, label
60
 
61
  def populate_metadata(self, link):
62
  self.yt = YouTube(link)