File size: 7,913 Bytes
bdd9100
b3935fd
47058ca
bdd9100
ebaaf9b
bdd9100
40cde13
ebaaf9b
 
bdd9100
 
92ce07c
ebaaf9b
7380009
40cde13
e37aac1
 
bdd9100
40cde13
7380009
bdd9100
40cde13
bdd9100
40cde13
8e3c59e
bdd9100
 
40cde13
b3935fd
bdd9100
b3935fd
 
bdd9100
 
 
 
b3935fd
47058ca
bdd9100
 
 
 
40cde13
bdd9100
 
 
 
40cde13
47058ca
bdd9100
 
47058ca
bdd9100
40cde13
47058ca
bdd9100
40cde13
bdd9100
47058ca
bdd9100
 
 
 
40cde13
bdd9100
40cde13
bdd9100
 
47058ca
40cde13
bdd9100
47058ca
bdd9100
40cde13
bdd9100
47058ca
91062af
 
 
 
47058ca
ebaaf9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf31b20
ebaaf9b
 
5ac76de
cf31b20
4c42c49
 
ebaaf9b
 
 
 
 
 
 
 
 
372483f
ebaaf9b
 
4c42c49
ebaaf9b
 
 
 
c72d2a4
ebaaf9b
92ce07c
ebaaf9b
 
 
 
 
 
 
4c42c49
 
c72d2a4
4c42c49
ebaaf9b
 
f1bf1b3
ebaaf9b
 
 
 
 
 
 
 
 
 
92ce07c
ebaaf9b
 
4c42c49
 
 
 
 
ebaaf9b
a3f668c
edf4250
4c42c49
 
 
 
edf4250
4c42c49
edf4250
35b4964
edf4250
 
 
 
ebaaf9b
 
a94388a
ebaaf9b
 
 
 
 
a94388a
ebaaf9b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import base64
import faster_whisper
import tempfile
import torch
import time
import requests
import logging
from fastapi import FastAPI, HTTPException, WebSocket,WebSocketDisconnect
import websockets
from pydantic import BaseModel
from typing import Optional
import sys
import asyncio

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s',handlers=[logging.StreamHandler(sys.stdout)], force=True)
#logging.getLogger("asyncio").setLevel(logging.DEBUG)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
logging.info(f'Device selected: {device}')

model_name = 'ivrit-ai/faster-whisper-v2-d4'
logging.info(f'Loading model: {model_name}')
model = faster_whisper.WhisperModel(model_name, device=device)
logging.info('Model loaded successfully')

# Maximum data size: 200MB
MAX_PAYLOAD_SIZE = 200 * 1024 * 1024
logging.info(f'Max payload size set to: {MAX_PAYLOAD_SIZE} bytes')

app = FastAPI()


class InputData(BaseModel):
    type: str
    data: Optional[str] = None  # Used for blob input
    url: Optional[str] = None  # Used for url input


def download_file(url, max_size_bytes, output_filename, api_key=None):
    """
    Download a file from a given URL with size limit and optional API key.
    """
    logging.debug(f'Starting file download from URL: {url}')
    try:
        headers = {}
        if api_key:
            headers['Authorization'] = f'Bearer {api_key}'
            logging.debug('API key provided, added to headers')

        response = requests.get(url, stream=True, headers=headers)
        response.raise_for_status()

        file_size = int(response.headers.get('Content-Length', 0))
        logging.info(f'File size: {file_size} bytes')

        if file_size > max_size_bytes:
            logging.error(f'File size exceeds limit: {file_size} > {max_size_bytes}')
            return False

        downloaded_size = 0
        with open(output_filename, 'wb') as file:
            for chunk in response.iter_content(chunk_size=8192):
                downloaded_size += len(chunk)
                logging.debug(f'Downloaded {downloaded_size} bytes')
                if downloaded_size > max_size_bytes:
                    logging.error('Downloaded size exceeds maximum allowed payload size')
                    return False
                file.write(chunk)

        logging.info(f'File downloaded successfully: {output_filename}')
        return True

    except requests.RequestException as e:
        logging.error(f"Error downloading file: {e}")
        return False

@app.get("/")
async def read_root():
    return {"message": "This is the Ivrit AI Streaming service."}




def transcribe_core_ws(audio_file, last_transcribed_time):
    """
    Transcribe the audio file and return only the segments that have not been processed yet.

    :param audio_file: Path to the growing audio file.
    :param last_transcribed_time: The last time (in seconds) that was transcribed.
    :return: Newly transcribed segments and the updated last transcribed time.
    """
    logging.info(f"Starting transcription for file: {audio_file} from {last_transcribed_time} seconds.")

    ret = {'new_segments': []}
    new_last_transcribed_time = last_transcribed_time

    try:
        # Transcribe the entire audio file
        logging.debug(f"Initiating model transcription for file: {audio_file}")
        segs, _ = model.transcribe(audio_file, language='he', word_timestamps=True)
        logging.info('Transcription completed successfully.')
    except Exception as e:
        logging.error(f"Error during transcription: {e}")
        raise e

    # Track the new segments and update the last transcribed time
    for s in segs:
        logging.info(f"Processing segment with start time: {s.start} and end time: {s.end}")

        # Only process segments that start after the last transcribed time
        if s.start >= last_transcribed_time:
            logging.info(f"New segment found starting at {s.start} seconds.")
            words = [{'start': w.start, 'end': w.end, 'word': w.word, 'probability': w.probability} for w in s.words]

            seg = {
                'id': s.id, 'seek': s.seek, 'start': s.start, 'end': s.end, 'text': s.text,
                'avg_logprob': s.avg_logprob, 'compression_ratio': s.compression_ratio,
                'no_speech_prob': s.no_speech_prob, 'words': words
            }
            logging.info(f'Adding new transcription segment: {seg}')
            ret['new_segments'].append(seg)

            # Update the last transcribed time to the end of the current segment
            new_last_transcribed_time = s.end
            logging.debug(f"Updated last transcribed time to: {new_last_transcribed_time} seconds")

    #logging.info(f"Returning {len(ret['new_segments'])} new segments and updated last transcribed time.")
    return ret, new_last_transcribed_time


import tempfile


@app.websocket("/wtranscribe")
async def websocket_transcribe(websocket: WebSocket):
    logging.info("New WebSocket connection request received.")
    await websocket.accept()
    logging.info("WebSocket connection established successfully.")

    try:
        processed_segments = []  # Keeps track of the segments already transcribed
        accumulated_audio_size = 0  # Track how much audio data has been buffered
        accumulated_audio_time = 0  # Track the total audio duration accumulated
        last_transcribed_time = 0.0
        #min_transcription_time = 5.0  # Minimum duration of audio in seconds before transcription starts

        # A temporary file to store the growing audio data
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio_file:
            logging.info(f"Temporary audio file created at {temp_audio_file.name}")

            while True:
                try:
                    # Receive the next chunk of audio data
                    audio_chunk = await websocket.receive_bytes()
                    if not audio_chunk:
                        logging.warning("Received empty audio chunk, skipping processing.")
                        continue

                    # Write audio chunk to file and accumulate size and time
                    temp_audio_file.write(audio_chunk)
                    temp_audio_file.flush()
                    accumulated_audio_size += len(audio_chunk)

                    # Estimate the duration of the chunk based on its size (e.g., 16kHz audio)
                    chunk_duration = len(audio_chunk) / (16000 * 2)  # Assuming 16kHz mono WAV (2 bytes per sample)
                    accumulated_audio_time += chunk_duration

                    partial_result, last_transcribed_time = transcribe_core_ws(temp_audio_file.name, last_transcribed_time)
                    accumulated_audio_time = 0  # Reset the accumulated audio time
                    processed_segments.extend(partial_result['new_segments'])

                    # Reset the accumulated audio size after transcription
                    accumulated_audio_size = 0

                    # Send the transcription result back to the client with both new and all processed segments
                    response = {
                        "new_segments": partial_result['new_segments'],
                        "processed_segments": processed_segments
                    }
                    logging.info(f"Sending {len(partial_result['new_segments'])} new segments to the client.")
                    await websocket.send_json(response)

                except WebSocketDisconnect:
                    logging.info("WebSocket connection closed by the client.")
                    break

    except Exception as e:
        logging.error(f"Unexpected error during WebSocket transcription: {e}")
        await websocket.send_json({"error": str(e)})

    finally:
        logging.info("Cleaning up and closing WebSocket connection.")