Spaces:
Build error
Build error
prepare to integrate whisper
Browse files- .gitignore +1 -0
- Cargo.lock +20 -0
- Cargo.toml +1 -0
- config.yaml +22 -0
- src/config.rs +76 -0
- src/main.rs +21 -4
- src/whisper.rs +29 -0
.gitignore
CHANGED
@@ -1,2 +1,3 @@
|
|
1 |
/target
|
2 |
.idea/
|
|
|
|
1 |
/target
|
2 |
.idea/
|
3 |
+
models
|
Cargo.lock
CHANGED
@@ -1422,6 +1422,7 @@ dependencies = [
|
|
1422 |
"poem",
|
1423 |
"serde",
|
1424 |
"serde_json",
|
|
|
1425 |
"tokio",
|
1426 |
"tokio-stream",
|
1427 |
"tracing-subscriber",
|
@@ -1740,6 +1741,19 @@ dependencies = [
|
|
1740 |
"serde",
|
1741 |
]
|
1742 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1743 |
[[package]]
|
1744 |
name = "sha1"
|
1745 |
version = "0.10.6"
|
@@ -2184,6 +2198,12 @@ dependencies = [
|
|
2184 |
"tinyvec",
|
2185 |
]
|
2186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
2187 |
[[package]]
|
2188 |
name = "untrusted"
|
2189 |
version = "0.7.1"
|
|
|
1422 |
"poem",
|
1423 |
"serde",
|
1424 |
"serde_json",
|
1425 |
+
"serde_yaml",
|
1426 |
"tokio",
|
1427 |
"tokio-stream",
|
1428 |
"tracing-subscriber",
|
|
|
1741 |
"serde",
|
1742 |
]
|
1743 |
|
1744 |
+
[[package]]
|
1745 |
+
name = "serde_yaml"
|
1746 |
+
version = "0.9.25"
|
1747 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
1748 |
+
checksum = "1a49e178e4452f45cb61d0cd8cebc1b0fafd3e41929e996cef79aa3aca91f574"
|
1749 |
+
dependencies = [
|
1750 |
+
"indexmap 2.0.2",
|
1751 |
+
"itoa",
|
1752 |
+
"ryu",
|
1753 |
+
"serde",
|
1754 |
+
"unsafe-libyaml",
|
1755 |
+
]
|
1756 |
+
|
1757 |
[[package]]
|
1758 |
name = "sha1"
|
1759 |
version = "0.10.6"
|
|
|
2198 |
"tinyvec",
|
2199 |
]
|
2200 |
|
2201 |
+
[[package]]
|
2202 |
+
name = "unsafe-libyaml"
|
2203 |
+
version = "0.2.9"
|
2204 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
2205 |
+
checksum = "f28467d3e1d3c6586d8f25fa243f544f5800fec42d97032474e17222c2b75cfa"
|
2206 |
+
|
2207 |
[[package]]
|
2208 |
name = "untrusted"
|
2209 |
version = "0.7.1"
|
Cargo.toml
CHANGED
@@ -16,6 +16,7 @@ tracing-subscriber = "0.3.17"
|
|
16 |
futures-util = "0.3.28"
|
17 |
serde = { version = "1.0.189", features = ["derive"] }
|
18 |
serde_json = { version = "1.0.107", features = [] }
|
|
|
19 |
whisper-rs = { version = "0.8.0" , features = ["coreml"] }
|
20 |
|
21 |
[dependencies.poem]
|
|
|
16 |
futures-util = "0.3.28"
|
17 |
serde = { version = "1.0.189", features = ["derive"] }
|
18 |
serde_json = { version = "1.0.107", features = [] }
|
19 |
+
serde_yaml = "0.9.25"
|
20 |
whisper-rs = { version = "0.8.0" , features = ["coreml"] }
|
21 |
|
22 |
[dependencies.poem]
|
config.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
server:
|
2 |
+
port: 8080
|
3 |
+
host: ::1
|
4 |
+
whisper:
|
5 |
+
n_threads: 4
|
6 |
+
step_ms: 500
|
7 |
+
length_ms: 5000
|
8 |
+
keep_ms: 200
|
9 |
+
capture_id: -1
|
10 |
+
max_tokens: 32
|
11 |
+
audio_ctx: 0
|
12 |
+
vad_thold: 0.6
|
13 |
+
freq_thold: 100.0
|
14 |
+
speed_up: false
|
15 |
+
translate: false
|
16 |
+
no_fallback: false
|
17 |
+
print_special: false
|
18 |
+
no_context: true
|
19 |
+
no_timestamps: false
|
20 |
+
tinydiarize: false
|
21 |
+
language: "en"
|
22 |
+
model: "models/ggml-base.bin"
|
src/config.rs
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use std::ffi::c_int;
|
2 |
+
use std::fs;
|
3 |
+
use std::net::IpAddr;
|
4 |
+
use serde::{Deserialize};
|
5 |
+
use whisper_rs::FullParams;
|
6 |
+
|
7 |
+
#[derive(Debug, Deserialize)]
|
8 |
+
pub(crate) struct WhisperParams {
|
9 |
+
pub(crate) n_threads: Option<usize>,
|
10 |
+
pub(crate) step_ms: i32,
|
11 |
+
pub(crate) length_ms: i32,
|
12 |
+
pub(crate) keep_ms: i32,
|
13 |
+
pub(crate) capture_id: i32,
|
14 |
+
pub(crate) max_tokens: i32,
|
15 |
+
pub(crate) audio_ctx: i32,
|
16 |
+
pub(crate) vad_thold: f32,
|
17 |
+
pub(crate) freq_thold: f32,
|
18 |
+
pub(crate) speed_up: bool,
|
19 |
+
pub(crate) translate: bool,
|
20 |
+
pub(crate) no_fallback: bool,
|
21 |
+
pub(crate) print_special: bool,
|
22 |
+
pub(crate) no_context: bool,
|
23 |
+
pub(crate) no_timestamps: bool,
|
24 |
+
pub(crate) tinydiarize: bool,
|
25 |
+
pub(crate) language: Option<String>,
|
26 |
+
pub(crate) model: String,
|
27 |
+
}
|
28 |
+
|
29 |
+
const NONE: [c_int;0] = [];
|
30 |
+
|
31 |
+
impl WhisperParams {
|
32 |
+
pub(crate) fn to_full_params<'a, 'b>(&'a self) -> FullParams<'a, 'b> {
|
33 |
+
let mut param = FullParams::new(Default::default());
|
34 |
+
param.set_print_progress(false);
|
35 |
+
param.set_print_special(self.print_special);
|
36 |
+
param.set_print_realtime(false);
|
37 |
+
param.set_print_timestamps(!self.no_timestamps);
|
38 |
+
param.set_translate(self.translate);
|
39 |
+
param.set_single_segment(true);
|
40 |
+
param.set_max_tokens(self.max_tokens);
|
41 |
+
let lang = self.language.as_ref().map(|s| s.as_str());
|
42 |
+
param.set_language(lang);
|
43 |
+
let num_cpus = std::thread::available_parallelism().map(|c| c.get()).unwrap_or(4);
|
44 |
+
param.set_n_threads(self.n_threads.unwrap_or(num_cpus) as c_int);
|
45 |
+
param.set_audio_ctx(self.audio_ctx);
|
46 |
+
param.set_speed_up(self.speed_up);
|
47 |
+
// param.set_tdrz_enable(self.tinydiarize);
|
48 |
+
if self.no_fallback {
|
49 |
+
param.set_temperature_inc(0.0);
|
50 |
+
}
|
51 |
+
if self.no_context {
|
52 |
+
param.set_tokens(&NONE);
|
53 |
+
}
|
54 |
+
|
55 |
+
param
|
56 |
+
}
|
57 |
+
}
|
58 |
+
|
59 |
+
#[derive(Debug, Deserialize)]
|
60 |
+
pub(crate) struct Server {
|
61 |
+
pub(crate) port: u16,
|
62 |
+
pub(crate) host: IpAddr,
|
63 |
+
}
|
64 |
+
|
65 |
+
#[derive(Debug, Deserialize)]
|
66 |
+
pub(crate) struct Config {
|
67 |
+
pub(crate) whisper: WhisperParams,
|
68 |
+
pub(crate) server: Server,
|
69 |
+
}
|
70 |
+
|
71 |
+
#[tokio::test]
|
72 |
+
async fn load() {
|
73 |
+
let config_str = fs::read_to_string("config.yaml").expect("failed to read config file");
|
74 |
+
let params: Config = serde_yaml::from_str(config_str.as_str()).expect("failed to parse config file");
|
75 |
+
println!("{:?}", params);
|
76 |
+
}
|
src/main.rs
CHANGED
@@ -17,13 +17,17 @@ use poem::web::websocket::{Message, WebSocket};
|
|
17 |
use futures_util::stream::StreamExt;
|
18 |
use poem::web::{Data, Query};
|
19 |
|
20 |
-
use tokio::select;
|
21 |
use serde::{Deserialize, Serialize};
|
22 |
use whisper_rs::WhisperContext;
|
23 |
use lesson::{LessonsManager};
|
|
|
24 |
use crate::lesson::Viseme;
|
|
|
25 |
|
26 |
mod lesson;
|
|
|
|
|
27 |
|
28 |
|
29 |
#[derive(Debug, Parser)]
|
@@ -46,12 +50,25 @@ struct Context {
|
|
46 |
lessons_manager: LessonsManager,
|
47 |
}
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
#[tokio::main]
|
50 |
async fn main() -> Result<(), std::io::Error> {
|
51 |
tracing_subscriber::fmt::init();
|
52 |
-
|
53 |
-
let
|
54 |
-
|
55 |
|
56 |
let Opt {
|
57 |
region,
|
|
|
17 |
use futures_util::stream::StreamExt;
|
18 |
use poem::web::{Data, Query};
|
19 |
|
20 |
+
use tokio::{fs, select};
|
21 |
use serde::{Deserialize, Serialize};
|
22 |
use whisper_rs::WhisperContext;
|
23 |
use lesson::{LessonsManager};
|
24 |
+
use crate::config::Config;
|
25 |
use crate::lesson::Viseme;
|
26 |
+
use crate::whisper::run_whisper;
|
27 |
|
28 |
mod lesson;
|
29 |
+
mod config;
|
30 |
+
mod whisper;
|
31 |
|
32 |
|
33 |
#[derive(Debug, Parser)]
|
|
|
50 |
lessons_manager: LessonsManager,
|
51 |
}
|
52 |
|
53 |
+
#[derive(Debug)]
|
54 |
+
enum Error {
|
55 |
+
IoError(std::io::Error),
|
56 |
+
ConfigError(serde_yaml::Error),
|
57 |
+
}
|
58 |
+
|
59 |
+
async fn load_config() -> Result<Config, Error> {
|
60 |
+
let config_str = fs::read_to_string("config.yaml").await.map_err(|e| Error::IoError(e))?;
|
61 |
+
let config: Config = serde_yaml::from_str(config_str.as_str())
|
62 |
+
.map_err(|e| Error::ConfigError(e))?;
|
63 |
+
return Ok(config)
|
64 |
+
}
|
65 |
+
|
66 |
#[tokio::main]
|
67 |
async fn main() -> Result<(), std::io::Error> {
|
68 |
tracing_subscriber::fmt::init();
|
69 |
+
|
70 |
+
let config = load_config().await.expect("failed to load config");
|
71 |
+
run_whisper(&config).await;
|
72 |
|
73 |
let Opt {
|
74 |
region,
|
src/whisper.rs
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
use whisper_rs::WhisperContext;
|
2 |
+
use crate::config::Config;
|
3 |
+
|
4 |
+
pub(crate) async fn run_whisper(config: &Config) {
|
5 |
+
let ctx = WhisperContext::new(&*config.whisper.model).expect("failed to load whisper context");
|
6 |
+
let mut _state = ctx.create_state().expect("failed to create state");
|
7 |
+
let params = (&config.whisper).to_full_params();
|
8 |
+
_state.full(params, &[]).expect("TODO: panic message");
|
9 |
+
}
|
10 |
+
|
11 |
+
async fn pcm_i16_to_f32(input: &Vec<u8>) -> Vec<f32> {
|
12 |
+
let pcm_i16 = input
|
13 |
+
.chunks_exact(2)
|
14 |
+
.map(|chunk| {
|
15 |
+
let mut buf = [0u8; 2];
|
16 |
+
buf.copy_from_slice(chunk);
|
17 |
+
i16::from_le_bytes(buf)
|
18 |
+
})
|
19 |
+
.collect::<Vec<i16>>();
|
20 |
+
let pcm_f32 = pcm_i16
|
21 |
+
.iter()
|
22 |
+
.map(|i| *i as f32 / i16::MAX as f32)
|
23 |
+
.collect::<Vec<f32>>();
|
24 |
+
pcm_f32
|
25 |
+
}
|
26 |
+
|
27 |
+
struct WhisperHandler {
|
28 |
+
|
29 |
+
}
|