3v324v23's picture
Зафиксирована рабочая версия TEN-Agent для HuggingFace Space
87337b1
#
#
# Agora Real Time Engagement
# Created by XinHui Li in 2024.
# Copyright (c) 2024 Agora IO. All rights reserved.
#
#
from dataclasses import dataclass
from typing import AsyncIterator, Tuple
from ten_ai_base.config import BaseConfig
from ten import (
AsyncTenEnv,
)
import copy
import websockets
import uuid
import json
import gzip
import asyncio
import threading
from datetime import datetime
MESSAGE_TYPES = {
11: "audio-only server response",
12: "frontend server response",
15: "error message from server",
}
MESSAGE_TYPE_SPECIFIC_FLAGS = {
0: "no sequence number",
1: "sequence number > 0",
2: "last message from server (seq < 0)",
3: "sequence number < 0",
}
MESSAGE_SERIALIZATION_METHODS = {0: "no serialization", 1: "JSON", 15: "custom type"}
MESSAGE_COMPRESSIONS = {0: "no compression", 1: "gzip", 15: "custom compression method"}
LATENCY_SAMPLE_INTERVAL_MS = 5
@dataclass
class TTSConfig(BaseConfig):
# Parameters, refer to: https://www.volcengine.com/docs/6561/79823.
appid: str = ""
token: str = ""
# Refer to: https://www.volcengine.com/docs/6561/1257544.
voice_type: str = "BV001_streaming"
sample_rate: int = 16000
api_url: str = "wss://openspeech.bytedance.com/api/v1/tts/ws_binary"
cluster: str = "volcano_tts"
class TTSClient:
def __init__(self, config: TTSConfig, ten_env: AsyncTenEnv) -> None:
self.config = config
self.websocket = None
self.ten_env = ten_env
# Refer to: https://www.volcengine.com/docs/6561/79823.
self.request_template = {
"app": {
"appid": self.config.appid,
"token": "access_token",
"cluster": self.config.cluster,
},
"user": {"uid": ""}, # Any non-empty string, used for tracing.
"audio": {
"rate": self.config.sample_rate,
"voice_type": self.config.voice_type,
"encoding": "pcm",
"speed_ratio": 1.0,
"volume_ratio": 1.0,
"pitch_ratio": 1.0,
},
"request": {
"reqid": "", # Must be unique for each request.
"text": "", # Text to be synthesized.
"text_type": "plain",
"operation": "submit",
},
}
# version: b0001 (4 bits)
# header size: b0001 (4 bits)
# message type: b0001 (Full client request) (4bits)
# message type specific flags: b0000 (none) (4bits)
# message serialization method: b0001 (JSON) (4 bits)
# message compression: b0001 (gzip) (4bits)
# reserved data: 0x00 (1 byte)
self.default_header = bytearray(b"\x11\x10\x11\x00")
self._cancel = threading.Event()
# Latency.
self._latest_record_time = None
def is_cancelled(self) -> bool:
return self._cancel.is_set()
async def cancel(self) -> None:
self._cancel.set()
async def connect(self) -> None:
header = {"Authorization": f"Bearer; {self.config.token}"}
self.websocket = await websockets.connect(
self.config.api_url,
extra_headers=header,
ping_interval=None,
close_timeout=1, # Fast close, as the `flush` cmd will close the connection.
)
self.ten_env.log_info("Websocket connection established.")
async def close(self) -> None:
if self.websocket is not None:
await self.websocket.close()
self.websocket = None
self.ten_env.log_info("Websocket connection closed.")
else:
self.ten_env.log_info("Websocket is not connected.")
async def reconnect(self) -> None:
await self.close()
await self.connect()
def parse_response(self, response: websockets.Data) -> Tuple[bytes, bool]:
protocol_version = response[0] >> 4
header_size = response[0] & 0x0F
message_type = response[1] >> 4
message_type_specific_flags = response[1] & 0x0F
serialization_method = response[2] >> 4
message_compression = response[2] & 0x0F
reserved = response[3]
header_extensions = response[4 : header_size * 4]
payload = response[header_size * 4 :]
self.ten_env.log_debug(
f"Protocol version: {protocol_version:#x} - version {protocol_version}"
)
self.ten_env.log_debug(
f"Header size: {header_size:#x} - {header_size * 4} bytes"
)
self.ten_env.log_debug(
f"Message type: {message_type:#x} - {MESSAGE_TYPES[message_type]}"
)
self.ten_env.log_debug(
f"Message type specific flags: {message_type_specific_flags:#x} - {MESSAGE_TYPE_SPECIFIC_FLAGS[message_type_specific_flags]}"
)
self.ten_env.log_debug(
f"Message serialization method: {serialization_method:#x} - {MESSAGE_SERIALIZATION_METHODS[serialization_method]}"
)
self.ten_env.log_debug(
f"Message compression: {message_compression:#x} - {MESSAGE_COMPRESSIONS[message_compression]}"
)
self.ten_env.log_debug(f"Reserved: {reserved:#04x}")
if header_size != 1:
self.ten_env.log_debug(f"Header extensions: {header_extensions}")
if message_type == 0xB: # audio-only server response
if message_type_specific_flags == 0: # no sequence number as ACK
self.ten_env.log_debug("Payload size: 0")
return None, False
else:
sequence_number = int.from_bytes(payload[:4], "big", signed=True)
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
payload = payload[8:]
self.ten_env.log_debug(f"Sequence number: {sequence_number}")
self.ten_env.log_debug(f"Payload size: {payload_size} bytes")
if sequence_number < 0:
return payload, True
else:
return payload, False
elif message_type == 0xF:
code = int.from_bytes(payload[:4], "big", signed=False)
msg_size = int.from_bytes(payload[4:8], "big", signed=False)
error_msg = payload[8:]
if message_compression == 1:
error_msg = gzip.decompress(error_msg)
error_msg = str(error_msg, "utf-8")
self.ten_env.log_error(f"Error message code: {code}")
self.ten_env.log_error(f"Error message size: {msg_size} bytes")
self.ten_env.log_error(f"Error message: {error_msg}")
return None, True
elif message_type == 0xC:
msg_size = int.from_bytes(payload[:4], "big", signed=False)
payload = payload[4:]
if message_compression == 1:
payload = gzip.decompress(payload)
self.ten_env.log_debug(f"Frontend message: {payload}")
else:
self.ten_env.log_error("undefined message type!")
return None, True
def record_latency(self, request_id: str, start: datetime) -> None:
end_time = datetime.now()
if self._latest_record_time:
sample_interval = datetime.now() - self._latest_record_time
if sample_interval.total_seconds() < LATENCY_SAMPLE_INTERVAL_MS:
return
self._latest_record_time = end_time
latency = int((end_time - start).total_seconds() * 1000)
self.ten_env.log_info(f"Request ({request_id}), ttfb {latency}ms.")
async def text_to_speech_stream(self, text: str) -> AsyncIterator[bytes]:
ws = self.websocket
if ws is None:
await self.connect()
ws = self.websocket
start_ms = datetime.now()
request_id = str(uuid.uuid4())
request = copy.deepcopy(self.request_template)
request["request"]["reqid"] = request_id
request["request"]["text"] = text
request["user"]["uid"] = str(uuid.uuid4())
request_bytes = str.encode(json.dumps(request))
request_bytes = gzip.compress(request_bytes)
full_request = bytearray(self.default_header)
# payload size(4 bytes)
full_request.extend((len(request_bytes)).to_bytes(4, "big"))
# payload
full_request.extend(request_bytes)
try:
await ws.send(full_request)
self.ten_env.log_info(f"Sent: {request}")
while True:
if self.is_cancelled():
self.ten_env.log_info(f"Request ({request_id}) has been cancelled.")
# Current connection should be closed, as the server will not drop the remain data.
await self.close()
self._cancel.clear()
break
resp = await ws.recv()
payload, done = self.parse_response(resp)
if payload:
yield payload
self.record_latency(request_id, start_ms)
if done:
self.ten_env.log_info(
f"Response is completed for request: {request_id}."
)
break
except websockets.exceptions.ConnectionClosedError as e:
self.ten_env.log_error(
f"Connection is closed with error: {e}, request: {request_id}."
)
await self.connect()
except asyncio.TimeoutError:
self.ten_env.log_error("Timeout waiting for response.")