Spaces:
Build error
Build error
import gradio as gr | |
import torch | |
import os | |
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM | |
from pyannote.audio import Pipeline as DiarizationPipeline | |
import whisper | |
import tempfile | |
import shutil | |
from pydub import AudioSegment | |
# Load whisper model | |
whisper_model = whisper.load_model("base") # Use "small" or "medium" if needed | |
# Load summarization pipeline | |
summarizer_tokenizer = AutoTokenizer.from_pretrained("t5-small") | |
summarizer_model = AutoModelForSeq2SeqLM.from_pretrained("t5-small") | |
summarizer = pipeline("summarization", model=summarizer_model, tokenizer=summarizer_tokenizer) | |
def convert_to_wav(input_path, output_path): | |
audio = AudioSegment.from_file(input_path) | |
audio.export(output_path, format="wav") | |
def transcribe_audio(audio_path): | |
result = whisper_model.transcribe(audio_path, fp16=torch.cuda.is_available()) | |
return result['text'] | |
def diarize_audio(audio_path, hf_token): | |
os.environ["HF_TOKEN"] = hf_token | |
diarization_pipeline = DiarizationPipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token) | |
diarization = diarization_pipeline(audio_path) | |
return diarization | |
def combine_diarized_transcript(diarization, full_text): | |
# Basic speaker labeling using diarization and full text | |
# Note: This is a simplified alignment using time chunks only | |
chunks = [] | |
for turn, _, speaker in diarization.itertracks(yield_label=True): | |
start, end = turn.start, turn.end | |
chunks.append(f"{speaker}: [from {start:.1f}s to {end:.1f}s]") | |
# Combine for display/demo | |
return "\n".join(chunks) + "\n" + full_text | |
def summarize_text(text): | |
prefix = "summarize: " + text.strip() | |
inputs = summarizer_tokenizer.encode(prefix, return_tensors="pt", max_length=512, truncation=True) | |
summary_ids = summarizer_model.generate(inputs, max_length=100, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True) | |
return summarizer_tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
def process_pipeline(audio_file, hf_token): | |
if not hf_token: | |
return "", "", "Error: HuggingFace token is required." | |
if not os.path.exists(audio_file) or os.path.getsize(audio_file) == 0: | |
return "", "", "Error: Uploaded file is missing or empty." | |
# Step 1: Convert to WAV if needed | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_wav: | |
try: | |
sound = AudioSegment.from_file(audio_file) | |
sound.export(tmp_wav.name, format="wav") | |
tmp_path = tmp_wav.name | |
except Exception as e: | |
return "", "", f"Audio conversion failed: {str(e)}" | |
# Step 2: Transcription (Whisper) | |
try: | |
transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-base",return_timestamps=True, device=0 if torch.cuda.is_available() else -1) | |
result = transcriber(tmp_path) | |
transcript = result["text"] | |
except Exception as e: | |
return "", "", f"Transcription failed: {str(e)}" | |
# Step 3: Summarization | |
try: | |
summarizer = pipeline("summarization", model="facebook/bart-large-cnn", device=0 if torch.cuda.is_available() else -1) | |
summary = summarizer(transcript, max_length=130, min_length=30, do_sample=False)[0]["summary_text"] | |
except Exception as e: | |
return transcript, "", f"Summarization failed: {str(e)}" | |
return tmp_path, transcript, summary | |
description = """ | |
### 🩺 GP Consultation Summarizer (Demo App) | |
This app: | |
1. Transcribes short consultation audio using Whisper | |
2. Identifies who spoke when using PyAnnote speaker diarization | |
3. Combines both into a labeled transcript | |
4. Generates a short summary using T5-small | |
⚠️ **Note:** Best for short consultations (under 5–6 minutes). | |
⚠️ You must provide your own Hugging Face token (required for diarization). | |
""" | |
app = gr.Interface( | |
fn=process_pipeline, | |
inputs=[ | |
gr.Audio(type="filepath", label="Upload Consultation Audio (.wav)"), | |
gr.Textbox(label="Your Hugging Face Token", type="password") | |
], | |
outputs=[ | |
gr.Textbox(label="Raw Transcript"), | |
gr.Textbox(label="Labeled Transcript (with Speaker Info)"), | |
gr.Textbox(label="Summary") | |
], | |
title="GP Consultation Summarizer", | |
description=description, | |
allow_flagging="never" | |
) | |
if __name__ == "__main__": | |
app.launch(share=True) | |