qfuxa's picture
refacto 0
b9f09f7
raw
history blame
4.01 kB
import asyncio
import logging
from time import time
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, field
from timed_objects import ASRToken
logger = logging.getLogger(__name__)
class SharedState:
"""
Thread-safe state manager for streaming transcription and diarization.
Handles coordination between audio processing, transcription, and diarization.
"""
def __init__(self):
self.tokens: List[ASRToken] = []
self.buffer_transcription: str = ""
self.buffer_diarization: str = ""
self.full_transcription: str = ""
self.end_buffer: float = 0
self.end_attributed_speaker: float = 0
self.lock = asyncio.Lock()
self.beg_loop: float = time()
self.sep: str = " " # Default separator
self.last_response_content: str = "" # To track changes in response
async def update_transcription(self, new_tokens: List[ASRToken], buffer: str,
end_buffer: float, full_transcription: str, sep: str) -> None:
"""Update the state with new transcription data."""
async with self.lock:
self.tokens.extend(new_tokens)
self.buffer_transcription = buffer
self.end_buffer = end_buffer
self.full_transcription = full_transcription
self.sep = sep
async def update_diarization(self, end_attributed_speaker: float, buffer_diarization: str = "") -> None:
"""Update the state with new diarization data."""
async with self.lock:
self.end_attributed_speaker = end_attributed_speaker
if buffer_diarization:
self.buffer_diarization = buffer_diarization
async def add_dummy_token(self) -> None:
"""Add a dummy token to keep the state updated even without transcription."""
async with self.lock:
current_time = time() - self.beg_loop
dummy_token = ASRToken(
start=current_time,
end=current_time + 1,
text=".",
speaker=-1,
is_dummy=True
)
self.tokens.append(dummy_token)
async def get_current_state(self) -> Dict[str, Any]:
"""Get the current state with calculated timing information."""
async with self.lock:
current_time = time()
remaining_time_transcription = 0
remaining_time_diarization = 0
# Calculate remaining time for transcription buffer
if self.end_buffer > 0:
remaining_time_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 2))
# Calculate remaining time for diarization
if self.tokens:
latest_end = max(self.end_buffer, self.tokens[-1].end if self.tokens else 0)
remaining_time_diarization = max(0, round(latest_end - self.end_attributed_speaker, 2))
return {
"tokens": self.tokens.copy(),
"buffer_transcription": self.buffer_transcription,
"buffer_diarization": self.buffer_diarization,
"end_buffer": self.end_buffer,
"end_attributed_speaker": self.end_attributed_speaker,
"sep": self.sep,
"remaining_time_transcription": remaining_time_transcription,
"remaining_time_diarization": remaining_time_diarization
}
async def reset(self) -> None:
"""Reset the state to initial values."""
async with self.lock:
self.tokens = []
self.buffer_transcription = ""
self.buffer_diarization = ""
self.end_buffer = 0
self.end_attributed_speaker = 0
self.full_transcription = ""
self.beg_loop = time()
self.last_response_content = ""