File size: 3,656 Bytes
160cee9
412c852
aa756f5
160cee9
ab662d2
922cd73
60ce8fc
c906256
922cd73
412c852
5f6cbd7
ab662d2
5f6cbd7
793e132
 
00c9bf5
 
160cee9
 
 
 
 
 
 
 
 
 
 
 
191d30d
26832b7
 
793e132
 
26832b7
412c852
2e8cc61
 
 
 
 
 
 
793e132
2e8cc61
 
 
160cee9
 
 
 
 
 
 
 
 
 
2e8cc61
 
00c9bf5
2e8cc61
 
 
 
00c9bf5
2e8cc61
 
fb40cda
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from transformers import pipeline, M2M100ForConditionalGeneration, M2M100Tokenizer
import gradio as gr
import re
import torch
from pyctcdecode import BeamSearchDecoderCTC
import torch
import librosa
import time


lmID = "aware-ai/german-lowercase-4gram-kenlm"
decoder = BeamSearchDecoderCTC.load_from_hf_hub(lmID)
p = pipeline("automatic-speech-recognition", model="aware-ai/robust-wav2vec2-base-german-lowercase", decoder=decoder)
ttp = pipeline("text2text-generation", model="aware-ai/marian-german-grammar")

#model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_1.2B")
#tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_1.2B")

def translate(src, tgt, text):
    src = src.split(" ")[-1][1:-1]
    tgt = tgt.split(" ")[-1][1:-1]

    # translate
    tokenizer.src_lang = src
    encoded_src = tokenizer(text, return_tensors="pt")
    generated_tokens = model.generate(**encoded_src, forced_bos_token_id=tokenizer.get_lang_id(tgt), use_cache=True)
    result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
    return result

def transcribe(audio):
    transcribed = p(audio, chunk_length_s=20, stride_length_s=(0, 0))["text"]
    
    punctuated = ttp(transcribed, max_length = 512)[0]["generated_text"]
    
    return transcribed, punctuated

def get_asr_interface():
    return gr.Interface(
        fn=transcribe, 
        inputs=[
            gr.inputs.Audio(source="microphone", type="filepath")
        ],
        outputs=[
            "textbox",
            "textbox"
        ])
        

def get_translate_interface():
    langs = """Afrikaans (af), Amharic (am), Arabic (ar), Asturian (ast), Azerbaijani (az), Bashkir (ba), Belarusian (be), Bulgarian (bg), Bengali (bn), Breton (br), Bosnian (bs), Catalan; Valencian (ca), Cebuano (ceb), Czech (cs), Welsh (cy), Danish (da), German (de), Greeek (el), English (en), Spanish (es), Estonian (et), Persian (fa), Fulah (ff), Finnish (fi), French (fr), Western Frisian (fy), Irish (ga), Gaelic; Scottish Gaelic (gd), Galician (gl), Gujarati (gu), Hausa (ha), Hebrew (he), Hindi (hi), Croatian (hr), Haitian; Haitian Creole (ht), Hungarian (hu), Armenian (hy), Indonesian (id), Igbo (ig), Iloko (ilo), Icelandic (is), Italian (it), Japanese (ja), Javanese (jv), Georgian (ka), Kazakh (kk), Central Khmer (km), Kannada (kn), 
    Korean (ko), Luxembourgish; Letzeburgesch (lb), Ganda (lg), Lingala (ln), Lao (lo), Lithuanian (lt), Latvian (lv), Malagasy (mg), Macedonian (mk), Malayalam (ml), Mongolian (mn), Marathi (mr), Malay (ms), Burmese (my), Nepali (ne), Dutch; Flemish (nl), Norwegian (no), Northern Sotho (ns), Occitan (post 1500) (oc), Oriya (or), Panjabi; Punjabi (pa), Polish (pl), Pushto; Pashto (ps), Portuguese (pt), Romanian; Moldavian; Moldovan (ro), Russian (ru), Sindhi (sd), Sinhala; Sinhalese (si), Slovak (sk), 
    Slovenian (sl), Somali (so), Albanian (sq), Serbian (sr), Swati (ss), Sundanese (su), Swedish (sv), Swahili (sw), Tamil (ta), Thai (th), Tagalog (tl), Tswana (tn), 
    Turkish (tr), Ukrainian (uk), Urdu (ur), Uzbek (uz), Vietnamese (vi), Wolof (wo), Xhosa (xh), Yiddish (yi), Yoruba (yo), Chinese (zh), Zulu (zu)"""
    lang_list = [lang.strip() for lang in langs.split(',')]
    return gr.Interface(translate, inputs=[gr.inputs.Dropdown(lang_list, label="Source Language"), gr.inputs.Dropdown(lang_list, label="Target Language"), 'text'], outputs=gr.outputs.Textbox(), title="Translate Between 100 languages")
        

interfaces = [
    get_asr_interface(),
    #get_translate_interface(),
]

names = [
    "ASR",
    #"translate",
]

gr.TabbedInterface(interfaces, names).launch(server_name = "0.0.0.0")