yunusajib's picture
Upload app.py
2606de1 verified
raw
history blame
4.41 kB
import gradio as gr
import torch
import os
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
from pyannote.audio import Pipeline as DiarizationPipeline
import whisper
import tempfile
import shutil
from pydub import AudioSegment
# Load whisper model
whisper_model = whisper.load_model("base") # Use "small" or "medium" if needed
# Load summarization pipeline
summarizer_tokenizer = AutoTokenizer.from_pretrained("t5-small")
summarizer_model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
summarizer = pipeline("summarization", model=summarizer_model, tokenizer=summarizer_tokenizer)
def convert_to_wav(input_path, output_path):
audio = AudioSegment.from_file(input_path)
audio.export(output_path, format="wav")
def transcribe_audio(audio_path):
result = whisper_model.transcribe(audio_path, fp16=torch.cuda.is_available())
return result['text']
def diarize_audio(audio_path, hf_token):
os.environ["HF_TOKEN"] = hf_token
diarization_pipeline = DiarizationPipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token)
diarization = diarization_pipeline(audio_path)
return diarization
def combine_diarized_transcript(diarization, full_text):
# Basic speaker labeling using diarization and full text
# Note: This is a simplified alignment using time chunks only
chunks = []
for turn, _, speaker in diarization.itertracks(yield_label=True):
start, end = turn.start, turn.end
chunks.append(f"{speaker}: [from {start:.1f}s to {end:.1f}s]")
# Combine for display/demo
return "\n".join(chunks) + "\n" + full_text
def summarize_text(text):
prefix = "summarize: " + text.strip()
inputs = summarizer_tokenizer.encode(prefix, return_tensors="pt", max_length=512, truncation=True)
summary_ids = summarizer_model.generate(inputs, max_length=100, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True)
return summarizer_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
def process_pipeline(audio_file, hf_token):
if not hf_token:
return "", "", "Error: HuggingFace token is required."
if not os.path.exists(audio_file) or os.path.getsize(audio_file) == 0:
return "", "", "Error: Uploaded file is missing or empty."
# Step 1: Convert to WAV if needed
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_wav:
try:
sound = AudioSegment.from_file(audio_file)
sound.export(tmp_wav.name, format="wav")
tmp_path = tmp_wav.name
except Exception as e:
return "", "", f"Audio conversion failed: {str(e)}"
# Step 2: Transcription (Whisper)
try:
transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-base",return_timestamps=True, device=0 if torch.cuda.is_available() else -1)
result = transcriber(tmp_path)
transcript = result["text"]
except Exception as e:
return "", "", f"Transcription failed: {str(e)}"
# Step 3: Summarization
try:
summarizer = pipeline("summarization", model="facebook/bart-large-cnn", device=0 if torch.cuda.is_available() else -1)
summary = summarizer(transcript, max_length=130, min_length=30, do_sample=False)[0]["summary_text"]
except Exception as e:
return transcript, "", f"Summarization failed: {str(e)}"
return tmp_path, transcript, summary
description = """
### 🩺 GP Consultation Summarizer (Demo App)
This app:
1. Transcribes short consultation audio using Whisper
2. Identifies who spoke when using PyAnnote speaker diarization
3. Combines both into a labeled transcript
4. Generates a short summary using T5-small
⚠️ **Note:** Best for short consultations (under 5–6 minutes).
⚠️ You must provide your own Hugging Face token (required for diarization).
"""
app = gr.Interface(
fn=process_pipeline,
inputs=[
gr.Audio(type="filepath", label="Upload Consultation Audio (.wav)"),
gr.Textbox(label="Your Hugging Face Token", type="password")
],
outputs=[
gr.Textbox(label="Raw Transcript"),
gr.Textbox(label="Labeled Transcript (with Speaker Info)"),
gr.Textbox(label="Summary")
],
title="GP Consultation Summarizer",
description=description,
allow_flagging="never"
)
if __name__ == "__main__":
app.launch(share=True)