Nick021402 commited on
Commit
7baec98
Β·
verified Β·
1 Parent(s): d914104

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -19
app.py CHANGED
@@ -1,24 +1,33 @@
1
  import gradio as gr
2
- from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
3
  import torch
4
- import librosa
 
 
 
 
 
5
 
6
- # Load pretrained model and processor
7
  processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
8
- model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
9
 
10
- device = "cuda" if torch.cuda.is_available() else "cpu"
11
- model.to(device)
 
12
 
13
- # Transcription function
14
- def transcribe(audio_path):
15
- if audio_path is None:
16
  return "Please upload or record an audio file."
 
 
 
 
 
17
 
18
- # Load audio file and resample to 16kHz mono
19
- audio_np, sample_rate = librosa.load(audio_path, sr=16000)
20
 
21
- # Process and transcribe
22
  input_values = processor(audio_np, sampling_rate=16000, return_tensors="pt").input_values.to(device)
23
  with torch.no_grad():
24
  logits = model(input_values).logits
@@ -26,16 +35,33 @@ def transcribe(audio_path):
26
  transcription = processor.decode(predicted_ids[0])
27
  return transcription.lower()
28
 
29
- # Gradio interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  with gr.Blocks(theme=gr.themes.Soft()) as app:
31
- gr.Markdown("# Voice2PersonaAI")
32
- gr.Markdown("Upload or record your voice, and this app will transcribe what you say.")
33
 
34
  with gr.Row():
35
- audio_input = gr.Audio(label="🎀 Record or Upload Your Voice", type="filepath")
36
- output_text = gr.Textbox(label="πŸ“ Transcribed Text")
 
 
 
37
 
38
- transcribe_button = gr.Button("Transcribe")
39
- transcribe_button.click(fn=transcribe, inputs=audio_input, outputs=output_text)
40
 
41
  app.launch()
 
1
  import gradio as gr
 
2
  import torch
3
+ import numpy as np
4
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
+
7
+ # Set device
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
10
+ # Load Wav2Vec2 model and processor for speech recognition
11
  processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
12
+ model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to(device)
13
 
14
+ # Load FLAN-T5 model for personality generation
15
+ gen_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
16
+ gen_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base").to(device)
17
 
18
+ # Function to transcribe audio to text
19
+ def transcribe(audio):
20
+ if audio is None:
21
  return "Please upload or record an audio file."
22
+
23
+ if isinstance(audio, tuple):
24
+ audio_np = audio[1]
25
+ else:
26
+ audio_np = audio
27
 
28
+ if isinstance(audio_np, np.ndarray) and audio_np.ndim > 1:
29
+ audio_np = np.mean(audio_np, axis=1)
30
 
 
31
  input_values = processor(audio_np, sampling_rate=16000, return_tensors="pt").input_values.to(device)
32
  with torch.no_grad():
33
  logits = model(input_values).logits
 
35
  transcription = processor.decode(predicted_ids[0])
36
  return transcription.lower()
37
 
38
+ # Function to generate personality from transcription
39
+ def generate_persona_from_text(transcription):
40
+ prompt = f"Describe the speaker's personality and role as if they are a fictional character, based on this message:\n\"{transcription}\""
41
+ inputs = gen_tokenizer(prompt, return_tensors="pt").to(device)
42
+ output_ids = gen_model.generate(**inputs, max_length=100)
43
+ return gen_tokenizer.decode(output_ids[0], skip_special_tokens=True)
44
+
45
+ # Complete function for Gradio
46
+ def analyze_speaker(audio):
47
+ transcription = transcribe(audio)
48
+ if "please upload" in transcription:
49
+ return transcription, ""
50
+ persona = generate_persona_from_text(transcription)
51
+ return transcription, persona
52
+
53
+ # Gradio Interface
54
  with gr.Blocks(theme=gr.themes.Soft()) as app:
55
+ gr.Markdown("# Voice2Persona AI")
56
+ gr.Markdown("Upload or record your voice. We'll transcribe it and guess your fictional personality.")
57
 
58
  with gr.Row():
59
+ audio_input = gr.Audio(source="microphone", type="numpy", label="🎀 Your Voice")
60
+ transcribed_text = gr.Textbox(label="πŸ“ Transcription")
61
+ persona_output = gr.Textbox(label="🧠 Persona Analysis")
62
+
63
+ analyze_button = gr.Button("Analyze")
64
 
65
+ analyze_button.click(fn=analyze_speaker, inputs=audio_input, outputs=[transcribed_text, persona_output])
 
66
 
67
  app.launch()