File size: 3,448 Bytes
2417517
7380009
 
6465ad1
7380009
6465ad1
7380009
 
 
 
 
 
40671f0
 
7380009
 
 
40671f0
7380009
 
 
 
40671f0
7380009
40671f0
7380009
 
 
 
 
40671f0
7380009
 
b7e558a
6465ad1
 
2417517
6465ad1
 
 
 
 
7380009
6465ad1
 
7380009
6465ad1
7380009
6465ad1
 
 
 
 
8e3c59e
6465ad1
 
7380009
40671f0
7380009
40671f0
7380009
8e3c59e
6465ad1
7380009
40671f0
8e3c59e
7380009
 
 
 
40671f0
7380009
40671f0
7380009
 
 
 
 
40671f0
7380009
40671f0
7380009
8e3c59e
40671f0
8e3c59e
860f8a3
 
 
8e3c59e
9bd82d6
40671f0
 
7380009
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import soundfile as sf
from fastapi import FastAPI, File, UploadFile, Form
import uvicorn
import requests
import os
from datetime import datetime

# Initialize FastAPI app
app = FastAPI()

# Print initialization of the application
print("FastAPI application started.")

# Load the Whisper model and processor
model_name = "openai/whisper-base"
print(f"Loading Whisper model: {model_name}")

try:
    processor = WhisperProcessor.from_pretrained(model_name)
    model = WhisperForConditionalGeneration.from_pretrained(model_name)
    print(f"Model {model_name} successfully loaded.")
except Exception as e:
    print(f"Error loading the model: {e}")
    raise e

# Move model to the appropriate device (GPU if available)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"Model is using device: {device}")


@app.post("/transcribe/")
def transcribe_audio_url(audio_url: str = Form(...)):
    # Download the audio file from the provided URL
    try:
        response = requests.get(audio_url)
        if response.status_code != 200:
            return {"error": f"Failed to download audio from URL. Status code: {response.status_code}"}
        print(f"Successfully downloaded audio from URL: {audio_url}")
        audio_data = io.BytesIO(response.content)  # Store audio data in memory
    except Exception as e:
        print(f"Error downloading the audio file: {e}")
        return {"error": f"Error downloading the audio file: {e}"}

    # Process the audio
    try:
        audio_input, _ = sf.read(audio_data)  # Read the audio from the in-memory BytesIO
        print(f"Audio file from URL successfully read.")
    except Exception as e:
        print(f"Error reading the audio file: {e}")
        return {"error": f"Error reading the audio file: {e}"}

    # Preprocess the audio for Whisper
    try:
        inputs = processor(audio_input, return_tensors="pt", sampling_rate=16000)
        print(f"Audio file preprocessed for transcription.")
    except Exception as e:
        print(f"Error processing the audio file: {e}")
        return {"error": f"Error processing the audio file: {e}"}

    # Move inputs to the appropriate device
    inputs = {key: value.to(device) for key, value in inputs.items()}
    print("Inputs moved to the appropriate device.")

    # Generate the transcription
    try:
        with torch.no_grad():
            predicted_ids = model.generate(inputs["input_features"])
        print("Transcription successfully generated.")
    except Exception as e:
        print(f"Error during transcription generation: {e}")
        return {"error": f"Error during transcription generation: {e}"}

    # Decode the transcription
    try:
        transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
        print("Transcription successfully decoded.")
    except Exception as e:
        print(f"Error decoding the transcription: {e}")
        return {"error": f"Error decoding the transcription: {e}"}

    return {"transcription": transcription}

@app.get("/")
def read_root():
    return {"message": "Welcome to the Whisper transcription API"}

if __name__ == "__main__":
    # Print when starting the FastAPI server
    print("Starting FastAPI server with Uvicorn...")

    # Run the FastAPI app on the default port (7860)
    uvicorn.run(app, host="0.0.0.0", port=7860)