Spaces:
Build error
Build error
| /* | |
| * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | |
| * SPDX-License-Identifier: Apache-2.0. | |
| */ | |
| use std::default::Default; | |
| use std::error::Error; | |
| use std::fmt::{Debug, Display, Formatter}; | |
| use std::future::Future; | |
| use async_stream::stream; | |
| use aws_config::meta::region::RegionProviderChain; | |
| use aws_sdk_transcribestreaming::{config::Region, meta::PKG_VERSION}; | |
| use clap::Parser; | |
| use poem::{Endpoint, EndpointExt, get, handler, IntoResponse, listener::TcpListener, Route, Server}; | |
| use futures_util::{Sink, SinkExt, TryFutureExt, TryStreamExt}; | |
| use poem::endpoint::{StaticFileEndpoint, StaticFilesEndpoint}; | |
| use poem::web::websocket::{Message, WebSocket}; | |
| use futures_util::stream::StreamExt; | |
| use poem::web::{Data, Query}; | |
| use tokio::select; | |
| use tokio_stream::Stream; | |
| use serde::{Deserialize, Serialize}; | |
| use lesson::{LessonsManager}; | |
| use crate::lesson::Viseme; | |
| mod lesson; | |
| struct Opt { | |
| /// The AWS Region. | |
| region: Option<String>, | |
| // | |
| // /// The name of the audio file. | |
| // #[structopt(short, long)] | |
| // audio_file: String, | |
| // | |
| /// Whether to display additional information. | |
| verbose: bool, | |
| } | |
| enum ReplyEvent { | |
| Transcribed(String), | |
| Translated(String), | |
| Synthesized(Vec<u8>), | |
| } | |
| struct Context { | |
| lessons_manager: LessonsManager, | |
| } | |
| async fn main() -> Result<(), std::io::Error> { | |
| tracing_subscriber::fmt::init(); | |
| let Opt { | |
| region, | |
| verbose, | |
| } = Opt::parse(); | |
| let region_provider = RegionProviderChain::first_try(region.map(Region::new)) | |
| .or_default_provider() | |
| .or_else(Region::new("us-west-2")); | |
| println!(); | |
| if verbose { | |
| println!("Transcribe client version: {}", PKG_VERSION); | |
| println!( | |
| "Region: {}", | |
| region_provider.region().await.unwrap().as_ref() | |
| ); | |
| println!(); | |
| } | |
| let shared_config = aws_config::from_env().region(region_provider).load().await; | |
| let ctx = Context { | |
| lessons_manager: LessonsManager::new(&shared_config), | |
| }; | |
| let app = Route::new() | |
| .nest( | |
| "/", | |
| StaticFilesEndpoint::new("./static") | |
| .show_files_listing() | |
| .index_file("index.html"), | |
| ) | |
| .at("/ws/lesson-speaker", get(stream_speaker)) | |
| .at("/ws/lesson-listener", get(stream_listener)) | |
| .at("lesson-speaker", StaticFileEndpoint::new("./static/index.html")) | |
| .at("lesson-listener", StaticFileEndpoint::new("./static/index.html")) | |
| .data(ctx); | |
| let listener = TcpListener::bind("[::]:8080"); | |
| let server = Server::new(listener); | |
| server.run(app).await | |
| } | |
| pub struct LessonSpeakerQuery { | |
| id: u32, | |
| lang: String, | |
| } | |
| async fn stream_speaker(ctx: Data<&Context>, query: Query<LessonSpeakerQuery>, ws: WebSocket) -> impl IntoResponse { | |
| let lesson = ctx.lessons_manager.create_lesson(query.id, query.lang.clone().parse().expect("Not supported lang")).await; | |
| ws.on_upgrade(|mut socket| async move { | |
| let origin_tx = lesson.voice_channel(); | |
| let mut transcribe_rx = lesson.transcript_channel(); | |
| loop { | |
| select! { | |
| msg = socket.next() => { | |
| match msg.as_ref() { | |
| Some(Ok(Message::Binary(bin))) => { | |
| if origin_tx.send(bin.to_vec()).await.is_err() { | |
| println!("tx closed"); | |
| break; | |
| } | |
| }, | |
| Some(Ok(_)) => { | |
| println!("Other: {:?}", msg); | |
| }, | |
| Some(Err(e)) => { | |
| println!("Error: {:?}", e); | |
| }, | |
| None => { | |
| let _ = socket.close().await; | |
| println!("Other: {:?}", msg); | |
| break; | |
| } | |
| } | |
| }, | |
| output = transcribe_rx.recv() => { | |
| if let Ok(transcript) = output { | |
| println!("Transcribed: {}", transcript); | |
| let evt = LiveLessonTextEvent::Transcription { text: transcript.clone() }; | |
| let json = serde_json::to_string(&evt).expect("failed to serialize"); | |
| let _ = socket.send(Message::Text(json)).await.expect("failed to send"); | |
| } | |
| }, | |
| } | |
| } | |
| }) | |
| } | |
| pub struct LessonListenerQuery { | |
| id: u32, | |
| lang: String, | |
| voice: String, | |
| } | |
| enum LiveLessonTextEvent { | |
| Transcription { text: String }, | |
| Translation { text: String }, | |
| LipSync{ visemes: Vec<Viseme> }, | |
| } | |
| async fn stream_listener(ctx: Data<&Context>, query: Query<LessonListenerQuery>, ws: WebSocket) -> impl IntoResponse { | |
| let lesson_opt = ctx.lessons_manager.get_lesson(query.id).await; | |
| println!("{:?}", query); | |
| let voice_id = query.voice.parse().expect("Not supported voice"); | |
| ws.on_upgrade(|mut socket| async move { | |
| let Some(lesson) = lesson_opt else { | |
| let _ = socket.send(Message::Text("lesson not found".to_string())).await; | |
| return | |
| }; | |
| let mut transcript_rx = lesson.transcript_channel(); | |
| let mut lang_lesson = lesson.get_or_init(query.lang.clone()).await; | |
| let mut translate_rx = lang_lesson.translated_channel(); | |
| let mut voice_lesson = lang_lesson.get_or_init(voice_id).await; | |
| let mut voice_rx = voice_lesson.voice_channel(); | |
| let mut lip_sync_rx = voice_lesson.lip_sync_channel(); | |
| loop { | |
| select! { | |
| transcript = transcript_rx.recv() => { | |
| if let Ok(transcript) = transcript { | |
| let evt = LiveLessonTextEvent::Transcription { text: transcript }; | |
| let json = serde_json::to_string(&evt).expect("failed to serialize"); | |
| println!("Transcribed: {}", json); | |
| let _ = socket.send(Message::Text(json)).await; | |
| } | |
| }, | |
| translated = translate_rx.recv() => { | |
| if let Ok(translated) = translated { | |
| let evt = LiveLessonTextEvent::Translation { text: translated }; | |
| let json = serde_json::to_string(&evt).expect("failed to serialize"); | |
| println!("Translated: {}", json); | |
| let _ = socket.send(Message::Text(json)).await; | |
| } | |
| }, | |
| voice = voice_rx.recv() => { | |
| if let Ok(voice) = voice { | |
| println!("Synthesized: {:?}", voice.len()); | |
| let _ = socket.send(Message::Binary(voice)).await; | |
| } | |
| }, | |
| visemes = lip_sync_rx.recv() => { | |
| if let Ok(visemes) = visemes { | |
| let evt = LiveLessonTextEvent::LipSync { visemes }; | |
| let json = serde_json::to_string(&evt).expect("failed to serialize"); | |
| println!("Visemes: {:?}", json); | |
| let _ = socket.send(Message::Text(json)).await; | |
| } | |
| }, | |
| } | |
| } | |
| }) | |
| } | |
| enum StreamTranscriptionError { | |
| EstablishStreamError(Box<dyn Error + Send + Sync>), | |
| TranscriptResultStreamError(Box<dyn Error + Send + Sync>), | |
| Shutdown, | |
| Unknown | |
| } | |
| impl Display for StreamTranscriptionError { | |
| fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { | |
| match self { | |
| StreamTranscriptionError::EstablishStreamError(e) => write!(f, "EstablishStreamError: {}", e), | |
| StreamTranscriptionError::TranscriptResultStreamError(e) => write!(f, "TranscriptResultStreamError: {}", e), | |
| StreamTranscriptionError::Shutdown => write!(f, "Shutdown"), | |
| StreamTranscriptionError::Unknown => write!(f, "Unknown"), | |
| } | |
| } | |
| } | |
| impl Error for StreamTranscriptionError { | |
| fn source(&self) -> Option<&(dyn Error + 'static)> { | |
| match self { | |
| StreamTranscriptionError::EstablishStreamError(e) => Some(e.as_ref()), | |
| StreamTranscriptionError::TranscriptResultStreamError(e) => Some(e.as_ref()), | |
| StreamTranscriptionError::Shutdown => None, | |
| StreamTranscriptionError::Unknown => None, | |
| } | |
| } | |
| } | |