MoHamdyy commited on
Commit
9d962cb
·
0 Parent(s):

Initial clean deployment

Browse files
Files changed (7) hide show
  1. .gitattributes +35 -0
  2. .gitignore +12 -0
  3. Dockerfile +34 -0
  4. README.md +13 -0
  5. app.py +500 -0
  6. requirements.txt +17 -0
  7. static/index.html +101 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Virtual Environment
2
+ .venv/
3
+ venv/
4
+ env/
5
+
6
+ # Python cache
7
+ __pycache__/
8
+ *.pyc
9
+
10
+ # IDE and editor folders
11
+ .vscode/
12
+ .idea/
Dockerfile ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python 3.11 slim image as a parent image
2
+ FROM python:3.11-slim
3
+
4
+ # --- Stage 1: Install system dependencies as root ---
5
+ # We need ffmpeg for pydub/torchaudio and git/git-lfs to download large models
6
+ RUN apt-get update && apt-get install -y --no-install-recommends \
7
+ ffmpeg \
8
+ libsndfile1 \
9
+ git \
10
+ git-lfs \
11
+ && rm -rf /var/lib/apt/lists/*
12
+
13
+ # --- Stage 2: Set up a non-root user for better security ---
14
+ # This is a best practice from the Hugging Face team
15
+ RUN useradd -m -u 1000 user
16
+ USER user
17
+ ENV PATH="/home/user/.local/bin:$PATH"
18
+ WORKDIR /home/user/app
19
+
20
+ # --- Stage 3: Install Python dependencies as the non-root user ---
21
+ # Copy requirements first to leverage Docker layer caching
22
+ COPY --chown=user ./requirements.txt requirements.txt
23
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
24
+
25
+ # --- Stage 4: Copy the rest of the application code ---
26
+ # This includes your app.py, models/ folder, and static/ folder
27
+ COPY --chown=user . .
28
+
29
+ # --- Stage 5: Run the application ---
30
+ # Expose the port the app runs on
31
+ EXPOSE 7860
32
+
33
+ # Command to run the application using uvicorn
34
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Arabic-English Speech Translator
3
+ emoji: ⚡
4
+ colorFrom: gray
5
+ colorTo: indigo
6
+ sdk: docker
7
+ pinned: false
8
+ license: mit
9
+ short_description: Neural Translation Stack
10
+ app_port: 7860
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import time
4
+ import random
5
+ import numpy as np
6
+ import pandas as pd
7
+ import math
8
+ import shutil
9
+ import base64
10
+
11
+ # Torch and Audio
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.optim as optim
15
+ from torch.utils.data import Dataset, DataLoader
16
+ import torch.nn.functional as F
17
+ import torchaudio
18
+ import librosa
19
+ import librosa.display
20
+
21
+ # Text and Audio Processing
22
+ from unidecode import unidecode
23
+ from inflect import engine
24
+ import pydub
25
+ import soundfile as sf
26
+
27
+ # Transformers
28
+ from transformers import (
29
+ WhisperProcessor, WhisperForConditionalGeneration,
30
+ MarianTokenizer, MarianMTModel,
31
+ )
32
+
33
+ # API Server
34
+ from fastapi import FastAPI, UploadFile, File
35
+ from fastapi.middleware.cors import CORSMiddleware
36
+ from fastapi.staticfiles import StaticFiles # <--- ADD THIS IMPORT
37
+
38
+
39
+
40
+ # Part 2: TTS Model Components (from your notebook)
41
+
42
+
43
+ # Hyperparameters
44
+ class Hyperparams:
45
+ seed = 42
46
+ # We won't use these dataset paths, but keep them for hp object integrity
47
+ csv_path = "path/to/metadata.csv"
48
+ wav_path = "path/to/wavs"
49
+ symbols = [
50
+ 'EOS', ' ', '!', ',', '-', '.', ';', '?', 'a', 'b', 'c', 'd', 'e', 'f',
51
+ 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's',
52
+ 't', 'u', 'v', 'w', 'x', 'y', 'z', 'à', 'â', 'è', 'é', 'ê', 'ü',
53
+ '’', '“', '”'
54
+ ]
55
+ sr = 22050
56
+ n_fft = 2048
57
+ n_stft = int((n_fft//2) + 1)
58
+ hop_length = int(n_fft/8.0)
59
+ win_length = int(n_fft/2.0)
60
+ mel_freq = 128
61
+ max_mel_time = 1024
62
+ power = 2.0
63
+ text_num_embeddings = 2*len(symbols)
64
+ embedding_size = 256
65
+ encoder_embedding_size = 512
66
+ dim_feedforward = 1024
67
+ postnet_embedding_size = 1024
68
+ encoder_kernel_size = 3
69
+ postnet_kernel_size = 5
70
+ ampl_multiplier = 10.0
71
+ ampl_amin = 1e-10
72
+ db_multiplier = 1.0
73
+ ampl_ref = 1.0
74
+ ampl_power = 1.0
75
+ max_db = 100
76
+ scale_db = 10
77
+
78
+ hp = Hyperparams()
79
+
80
+ # Text to Sequence
81
+ symbol_to_id = {s: i for i, s in enumerate(hp.symbols)}
82
+ def text_to_seq(text):
83
+ text = text.lower()
84
+ seq = []
85
+ for s in text:
86
+ _id = symbol_to_id.get(s, None)
87
+ if _id is not None:
88
+ seq.append(_id)
89
+ seq.append(symbol_to_id["EOS"])
90
+ return torch.IntTensor(seq)
91
+
92
+ # Audio Processing
93
+ spec_transform = torchaudio.transforms.Spectrogram(n_fft=hp.n_fft, win_length=hp.win_length, hop_length=hp.hop_length, power=hp.power)
94
+ mel_scale_transform = torchaudio.transforms.MelScale(n_mels=hp.mel_freq, sample_rate=hp.sr, n_stft=hp.n_stft)
95
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
96
+ mel_inverse_transform = torchaudio.transforms.InverseMelScale(n_mels=hp.mel_freq, sample_rate=hp.sr, n_stft=hp.n_stft).to(DEVICE)
97
+ griffnlim_transform = torchaudio.transforms.GriffinLim(n_fft=hp.n_fft, win_length=hp.win_length, hop_length=hp.hop_length).to(DEVICE)
98
+
99
+ def pow_to_db_mel_spec(mel_spec):
100
+ mel_spec = torchaudio.functional.amplitude_to_DB(mel_spec, multiplier=hp.ampl_multiplier, amin=hp.ampl_amin, db_multiplier=hp.db_multiplier, top_db=hp.max_db)
101
+ mel_spec = mel_spec/hp.scale_db
102
+ return mel_spec
103
+
104
+ def db_to_power_mel_spec(mel_spec):
105
+ mel_spec = mel_spec*hp.scale_db
106
+ mel_spec = torchaudio.functional.DB_to_amplitude(mel_spec, ref=hp.ampl_ref, power=hp.ampl_power)
107
+ return mel_spec
108
+
109
+ def inverse_mel_spec_to_wav(mel_spec):
110
+ power_mel_spec = db_to_power_mel_spec(mel_spec.to(DEVICE))
111
+ spectrogram = mel_inverse_transform(power_mel_spec)
112
+ pseudo_wav = griffnlim_transform(spectrogram)
113
+ return pseudo_wav
114
+
115
+ def mask_from_seq_lengths(sequence_lengths: torch.Tensor, max_length: int) -> torch.BoolTensor:
116
+ ones = sequence_lengths.new_ones(sequence_lengths.size(0), max_length)
117
+ range_tensor = ones.cumsum(dim=1)
118
+ return sequence_lengths.unsqueeze(1) >= range_tensor
119
+
120
+ # --- TransformerTTS Model Architecture (Copied from notebook)
121
+ class EncoderBlock(nn.Module):
122
+ def __init__(self):
123
+ super(EncoderBlock, self).__init__()
124
+ self.norm_1 = nn.LayerNorm(normalized_shape=hp.embedding_size)
125
+ self.attn = torch.nn.MultiheadAttention(embed_dim=hp.embedding_size, num_heads=4, dropout=0.1, batch_first=True)
126
+ self.dropout_1 = torch.nn.Dropout(0.1)
127
+ self.norm_2 = nn.LayerNorm(normalized_shape=hp.embedding_size)
128
+ self.linear_1 = nn.Linear(hp.embedding_size, hp.dim_feedforward)
129
+ self.dropout_2 = torch.nn.Dropout(0.1)
130
+ self.linear_2 = nn.Linear(hp.dim_feedforward, hp.embedding_size)
131
+ self.dropout_3 = torch.nn.Dropout(0.1)
132
+ def forward(self, x, attn_mask=None, key_padding_mask=None):
133
+ x_out = self.norm_1(x)
134
+ x_out, _ = self.attn(query=x_out, key=x_out, value=x_out, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
135
+ x_out = self.dropout_1(x_out)
136
+ x = x + x_out
137
+ x_out = self.norm_2(x)
138
+ x_out = self.linear_1(x_out)
139
+ x_out = F.relu(x_out)
140
+ x_out = self.dropout_2(x_out)
141
+ x_out = self.linear_2(x_out)
142
+ x_out = self.dropout_3(x_out)
143
+ x = x + x_out
144
+ return x
145
+
146
+ class DecoderBlock(nn.Module):
147
+ def __init__(self):
148
+ super(DecoderBlock, self).__init__()
149
+ self.norm_1 = nn.LayerNorm(normalized_shape=hp.embedding_size)
150
+ self.self_attn = torch.nn.MultiheadAttention(embed_dim=hp.embedding_size, num_heads=4, dropout=0.1, batch_first=True)
151
+ self.dropout_1 = torch.nn.Dropout(0.1)
152
+ self.norm_2 = nn.LayerNorm(normalized_shape=hp.embedding_size)
153
+ self.attn = torch.nn.MultiheadAttention(embed_dim=hp.embedding_size, num_heads=4, dropout=0.1, batch_first=True)
154
+ self.dropout_2 = torch.nn.Dropout(0.1)
155
+ self.norm_3 = nn.LayerNorm(normalized_shape=hp.embedding_size)
156
+ self.linear_1 = nn.Linear(hp.embedding_size, hp.dim_feedforward)
157
+ self.dropout_3 = torch.nn.Dropout(0.1)
158
+ self.linear_2 = nn.Linear(hp.dim_feedforward, hp.embedding_size)
159
+ self.dropout_4 = torch.nn.Dropout(0.1)
160
+ def forward(self, x, memory, x_attn_mask=None, x_key_padding_mask=None, memory_attn_mask=None, memory_key_padding_mask=None):
161
+ x_out, _ = self.self_attn(query=x, key=x, value=x, attn_mask=x_attn_mask, key_padding_mask=x_key_padding_mask)
162
+ x_out = self.dropout_1(x_out)
163
+ x = self.norm_1(x + x_out)
164
+ x_out, _ = self.attn(query=x, key=memory, value=memory, attn_mask=memory_attn_mask, key_padding_mask=memory_key_padding_mask)
165
+ x_out = self.dropout_2(x_out)
166
+ x = self.norm_2(x + x_out)
167
+ x_out = self.linear_1(x)
168
+ x_out = F.relu(x_out)
169
+ x_out = self.dropout_3(x_out)
170
+ x_out = self.linear_2(x_out)
171
+ x_out = self.dropout_4(x_out)
172
+ x = self.norm_3(x + x_out)
173
+ return x
174
+
175
+ class EncoderPreNet(nn.Module):
176
+ def __init__(self):
177
+ super(EncoderPreNet, self).__init__()
178
+ self.embedding = nn.Embedding(num_embeddings=hp.text_num_embeddings, embedding_dim=hp.encoder_embedding_size)
179
+ self.linear_1 = nn.Linear(hp.encoder_embedding_size, hp.encoder_embedding_size)
180
+ self.linear_2 = nn.Linear(hp.encoder_embedding_size, hp.embedding_size)
181
+ self.conv_1 = nn.Conv1d(hp.encoder_embedding_size, hp.encoder_embedding_size, kernel_size=hp.encoder_kernel_size, stride=1, padding=int((hp.encoder_kernel_size - 1) / 2), dilation=1)
182
+ self.bn_1 = nn.BatchNorm1d(hp.encoder_embedding_size)
183
+ self.dropout_1 = torch.nn.Dropout(0.5)
184
+ self.conv_2 = nn.Conv1d(hp.encoder_embedding_size, hp.encoder_embedding_size, kernel_size=hp.encoder_kernel_size, stride=1, padding=int((hp.encoder_kernel_size - 1) / 2), dilation=1)
185
+ self.bn_2 = nn.BatchNorm1d(hp.encoder_embedding_size)
186
+ self.dropout_2 = torch.nn.Dropout(0.5)
187
+ self.conv_3 = nn.Conv1d(hp.encoder_embedding_size, hp.encoder_embedding_size, kernel_size=hp.encoder_kernel_size, stride=1, padding=int((hp.encoder_kernel_size - 1) / 2), dilation=1)
188
+ self.bn_3 = nn.BatchNorm1d(hp.encoder_embedding_size)
189
+ self.dropout_3 = torch.nn.Dropout(0.5)
190
+ def forward(self, text):
191
+ x = self.embedding(text)
192
+ x = self.linear_1(x)
193
+ x = x.transpose(2, 1)
194
+ x = self.conv_1(x)
195
+ x = self.bn_1(x)
196
+ x = F.relu(x)
197
+ x = self.dropout_1(x)
198
+ x = self.conv_2(x)
199
+ x = self.bn_2(x)
200
+ x = F.relu(x)
201
+ x = self.dropout_2(x)
202
+ x = self.conv_3(x)
203
+ x = self.bn_3(x)
204
+ x = F.relu(x)
205
+ x = self.dropout_3(x)
206
+ x = x.transpose(1, 2)
207
+ x = self.linear_2(x)
208
+ return x
209
+
210
+ class PostNet(nn.Module):
211
+ def __init__(self):
212
+ super(PostNet, self).__init__()
213
+ self.conv_1 = nn.Conv1d(hp.mel_freq, hp.postnet_embedding_size, kernel_size=hp.postnet_kernel_size, stride=1, padding=int((hp.postnet_kernel_size - 1) / 2), dilation=1)
214
+ self.bn_1 = nn.BatchNorm1d(hp.postnet_embedding_size)
215
+ self.dropout_1 = torch.nn.Dropout(0.5)
216
+ self.conv_2 = nn.Conv1d(hp.postnet_embedding_size, hp.postnet_embedding_size, kernel_size=hp.postnet_kernel_size, stride=1, padding=int((hp.postnet_kernel_size - 1) / 2), dilation=1)
217
+ self.bn_2 = nn.BatchNorm1d(hp.postnet_embedding_size)
218
+ self.dropout_2 = torch.nn.Dropout(0.5)
219
+ self.conv_3 = nn.Conv1d(hp.postnet_embedding_size, hp.postnet_embedding_size, kernel_size=hp.postnet_kernel_size, stride=1, padding=int((hp.postnet_kernel_size - 1) / 2), dilation=1)
220
+ self.bn_3 = nn.BatchNorm1d(hp.postnet_embedding_size)
221
+ self.dropout_3 = torch.nn.Dropout(0.5)
222
+ self.conv_4 = nn.Conv1d(hp.postnet_embedding_size, hp.postnet_embedding_size, kernel_size=hp.postnet_kernel_size, stride=1, padding=int((hp.postnet_kernel_size - 1) / 2), dilation=1)
223
+ self.bn_4 = nn.BatchNorm1d(hp.postnet_embedding_size)
224
+ self.dropout_4 = torch.nn.Dropout(0.5)
225
+ self.conv_5 = nn.Conv1d(hp.postnet_embedding_size, hp.postnet_embedding_size, kernel_size=hp.postnet_kernel_size, stride=1, padding=int((hp.postnet_kernel_size - 1) / 2), dilation=1)
226
+ self.bn_5 = nn.BatchNorm1d(hp.postnet_embedding_size)
227
+ self.dropout_5 = torch.nn.Dropout(0.5)
228
+ self.conv_6 = nn.Conv1d(hp.postnet_embedding_size, hp.mel_freq, kernel_size=hp.postnet_kernel_size, stride=1, padding=int((hp.postnet_kernel_size - 1) / 2), dilation=1)
229
+ self.bn_6 = nn.BatchNorm1d(hp.mel_freq)
230
+ self.dropout_6 = torch.nn.Dropout(0.5)
231
+ def forward(self, x):
232
+ x = x.transpose(2, 1)
233
+ x = self.conv_1(x)
234
+ x = self.bn_1(x); x = torch.tanh(x); x = self.dropout_1(x)
235
+ x = self.conv_2(x)
236
+ x = self.bn_2(x); x = torch.tanh(x); x = self.dropout_2(x)
237
+ x = self.conv_3(x)
238
+ x = self.bn_3(x); x = torch.tanh(x); x = self.dropout_3(x)
239
+ x = self.conv_4(x)
240
+ x = self.bn_4(x); x = torch.tanh(x); x = self.dropout_4(x)
241
+ x = self.conv_5(x)
242
+ x = self.bn_5(x); x = torch.tanh(x); x = self.dropout_5(x)
243
+ x = self.conv_6(x)
244
+ x = self.bn_6(x); x = self.dropout_6(x)
245
+ x = x.transpose(1, 2)
246
+ return x
247
+
248
+ class DecoderPreNet(nn.Module):
249
+ def __init__(self):
250
+ super(DecoderPreNet, self).__init__()
251
+ self.linear_1 = nn.Linear(hp.mel_freq, hp.embedding_size)
252
+ self.linear_2 = nn.Linear(hp.embedding_size, hp.embedding_size)
253
+ def forward(self, x):
254
+ x = self.linear_1(x)
255
+ x = F.relu(x)
256
+ x = F.dropout(x, p=0.5, training=True)
257
+ x = self.linear_2(x)
258
+ x = F.relu(x)
259
+ x = F.dropout(x, p=0.5, training=True)
260
+ return x
261
+
262
+ class TransformerTTS(nn.Module):
263
+ def __init__(self, device=DEVICE):
264
+ super(TransformerTTS, self).__init__()
265
+ self.encoder_prenet = EncoderPreNet()
266
+ self.decoder_prenet = DecoderPreNet()
267
+ self.postnet = PostNet()
268
+ self.pos_encoding = nn.Embedding(num_embeddings=hp.max_mel_time, embedding_dim=hp.embedding_size)
269
+ self.encoder_block_1 = EncoderBlock()
270
+ self.encoder_block_2 = EncoderBlock()
271
+ self.encoder_block_3 = EncoderBlock()
272
+ self.decoder_block_1 = DecoderBlock()
273
+ self.decoder_block_2 = DecoderBlock()
274
+ self.decoder_block_3 = DecoderBlock()
275
+ self.linear_1 = nn.Linear(hp.embedding_size, hp.mel_freq)
276
+ self.linear_2 = nn.Linear(hp.embedding_size, 1)
277
+ self.norm_memory = nn.LayerNorm(normalized_shape=hp.embedding_size)
278
+ def forward(self, text, text_len, mel, mel_len):
279
+ N = text.shape[0]; S = text.shape[1]; TIME = mel.shape[1]
280
+ self.src_key_padding_mask = torch.zeros((N, S), device=text.device).masked_fill(~mask_from_seq_lengths(text_len, max_length=S), float("-inf"))
281
+ self.src_mask = torch.zeros((S, S), device=text.device).masked_fill(torch.triu(torch.full((S, S), True, dtype=torch.bool), diagonal=1).to(text.device), float("-inf"))
282
+ self.tgt_key_padding_mask = torch.zeros((N, TIME), device=mel.device).masked_fill(~mask_from_seq_lengths(mel_len, max_length=TIME), float("-inf"))
283
+ self.tgt_mask = torch.zeros((TIME, TIME), device=mel.device).masked_fill(torch.triu(torch.full((TIME, TIME), True, device=mel.device, dtype=torch.bool), diagonal=1), float("-inf"))
284
+ self.memory_mask = torch.zeros((TIME, S), device=mel.device).masked_fill(torch.triu(torch.full((TIME, S), True, device=mel.device, dtype=torch.bool), diagonal=1), float("-inf"))
285
+ text_x = self.encoder_prenet(text)
286
+ pos_codes = self.pos_encoding(torch.arange(hp.max_mel_time).to(mel.device))
287
+ S = text_x.shape[1]; text_x = text_x + pos_codes[:S]
288
+ text_x = self.encoder_block_1(text_x, attn_mask = self.src_mask, key_padding_mask = self.src_key_padding_mask)
289
+ text_x = self.encoder_block_2(text_x, attn_mask = self.src_mask, key_padding_mask = self.src_key_padding_mask)
290
+ text_x = self.encoder_block_3(text_x, attn_mask = self.src_mask, key_padding_mask = self.src_key_padding_mask)
291
+ text_x = self.norm_memory(text_x)
292
+ mel_x = self.decoder_prenet(mel); mel_x = mel_x + pos_codes[:TIME]
293
+ mel_x = self.decoder_block_1(x=mel_x, memory=text_x, x_attn_mask=self.tgt_mask, x_key_padding_mask=self.tgt_key_padding_mask, memory_attn_mask=self.memory_mask, memory_key_padding_mask=self.src_key_padding_mask)
294
+ mel_x = self.decoder_block_2(x=mel_x, memory=text_x, x_attn_mask=self.tgt_mask, x_key_padding_mask=self.tgt_key_padding_mask, memory_attn_mask=self.memory_mask, memory_key_padding_mask=self.src_key_padding_mask)
295
+ mel_x = self.decoder_block_3(x=mel_x, memory=text_x, x_attn_mask=self.tgt_mask, x_key_padding_mask=self.tgt_key_padding_mask, memory_attn_mask=self.memory_mask, memory_key_padding_mask=self.src_key_padding_mask)
296
+ mel_linear = self.linear_1(mel_x)
297
+ mel_postnet = self.postnet(mel_linear)
298
+ mel_postnet = mel_linear + mel_postnet
299
+ stop_token = self.linear_2(mel_x)
300
+ bool_mel_mask = self.tgt_key_padding_mask.ne(0).unsqueeze(-1).repeat(1, 1, hp.mel_freq)
301
+ mel_linear = mel_linear.masked_fill(bool_mel_mask, 0)
302
+ mel_postnet = mel_postnet.masked_fill(bool_mel_mask, 0)
303
+ stop_token = stop_token.masked_fill(bool_mel_mask[:, :, 0].unsqueeze(-1), 1e3).squeeze(2)
304
+ return mel_postnet, mel_linear, stop_token
305
+
306
+ @torch.no_grad()
307
+ def inference(self, text, max_length=800, stop_token_threshold=0.5, with_tqdm=True):
308
+ self.eval(); self.train(False)
309
+ text_lengths = torch.tensor(text.shape[1]).unsqueeze(0).to(DEVICE)
310
+ N = 1
311
+ SOS = torch.zeros((N, 1, hp.mel_freq), device=DEVICE)
312
+ mel_padded = SOS
313
+ mel_lengths = torch.tensor(1).unsqueeze(0).to(DEVICE)
314
+ stop_token_outputs = torch.FloatTensor([]).to(text.device)
315
+ iters = range(max_length)
316
+ for _ in iters:
317
+ mel_postnet, mel_linear, stop_token = self(text, text_lengths, mel_padded, mel_lengths)
318
+ mel_padded = torch.cat([mel_padded, mel_postnet[:, -1:, :]], dim=1)
319
+ if torch.sigmoid(stop_token[:, -1]) > stop_token_threshold:
320
+ break
321
+ else:
322
+ stop_token_outputs = torch.cat([stop_token_outputs, stop_token[:, -1:]], dim=1)
323
+ mel_lengths = torch.tensor(mel_padded.shape[1]).unsqueeze(0).to(DEVICE)
324
+ return mel_postnet, stop_token_outputs
325
+
326
+ # Part 3: Model Loading
327
+
328
+
329
+ # IMPORTANT: These paths assume you have placed the downloaded models
330
+ # into a 'models' subfolder in your project directory.
331
+ # ---------------------------------
332
+ # --- Part 3: Model Loading (from Hugging Face Hub)
333
+ # ---------------------------------
334
+
335
+ # IMPORTANT: Replace "your-username" with your Hugging Face username
336
+ # and the model names with the ones you created on the Hub.
337
+ TTS_MODEL_HUB_ID = "MoHamdyy/marian-ar-en-translation/transformer-tts-ljspeech"
338
+ ASR_HUB_ID = "MoHamdyy/whisper-stt-model/whisper-arabic-test"
339
+ MARIAN_HUB_ID = "your-username/marian-ar-en-translation"
340
+
341
+ # Helper function to download the TTS model file from the Hub
342
+ from huggingface_hub import hf_hub_download
343
+
344
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
345
+ print("Loading models from Hugging Face Hub to device:", DEVICE)
346
+
347
+ # Load TTS Model from Hub
348
+ try:
349
+ print("Loading TTS model...")
350
+ # Download the .pt file from its repo
351
+ tts_model_path = hf_hub_download(repo_id=TTS_MODEL_HUB_ID, filename="train_SimpleTransfromerTTS.pt")
352
+ state = torch.load(tts_model_path, map_location=DEVICE)
353
+ TTS_MODEL = TransformerTTS().to(DEVICE)
354
+ # Check for the correct key in the state dictionary
355
+ if "model" in state:
356
+ TTS_MODEL.load_state_dict(state["model"])
357
+ elif "state_dict" in state:
358
+ TTS_MODEL.load_state_dict(state["state_dict"])
359
+ else:
360
+ TTS_MODEL.load_state_dict(state) # Assume the whole file is the state_dict
361
+ TTS_MODEL.eval()
362
+ print("TTS model loaded successfully.")
363
+ except Exception as e:
364
+ print(f"Error loading TTS model: {e}")
365
+ TTS_MODEL = None
366
+
367
+ # Load STT (Whisper) Model from Hub
368
+ try:
369
+ print("Loading STT (Whisper) model...")
370
+ stt_processor = WhisperProcessor.from_pretrained(ASR_HUB_ID)
371
+ stt_model = WhisperForConditionalGeneration.from_pretrained(ASR_HUB_ID).to(DEVICE).eval()
372
+ print("STT model loaded successfully.")
373
+ except Exception as e:
374
+ print(f"Error loading STT model: {e}")
375
+ stt_processor = None
376
+ stt_model = None
377
+
378
+ # Load TTT (MarianMT) Model from Hub
379
+ try:
380
+ print("Loading TTT (MarianMT) model...")
381
+ mt_tokenizer = MarianTokenizer.from_pretrained(MARIAN_HUB_ID)
382
+ mt_model = MarianMTModel.from_pretrained(MARIAN_HUB_ID).to(DEVICE).eval()
383
+ print("TTT model loaded successfully.")
384
+ except Exception as e:
385
+ print(f"Error loading TTT model: {e}")
386
+ mt_tokenizer = None
387
+ mt_model = None
388
+
389
+
390
+
391
+ # Part 4: Full Pipeline Function
392
+
393
+
394
+ def full_speech_translation_pipeline(audio_input_path: str):
395
+ print(f"--- PIPELINE START: Processing {audio_input_path} ---")
396
+ if audio_input_path is None or not os.path.exists(audio_input_path):
397
+ msg = "Error: Audio file not provided or not found."
398
+ print(msg)
399
+ # Return empty/default values
400
+ return "Error: No file", "", (hp.sr, np.array([]).astype(np.float32))
401
+
402
+ # STT Stage
403
+ arabic_transcript = "STT Error: Processing failed."
404
+ try:
405
+ print("STT: Loading and resampling audio...")
406
+ wav, sr = torchaudio.load(audio_input_path)
407
+ if wav.size(0) > 1: wav = wav.mean(dim=0, keepdim=True)
408
+ target_sr_stt = stt_processor.feature_extractor.sampling_rate
409
+ if sr != target_sr_stt: wav = torchaudio.transforms.Resample(sr, target_sr_stt)(wav)
410
+ audio_array_stt = wav.squeeze().cpu().numpy()
411
+
412
+ print("STT: Extracting features and transcribing...")
413
+ inputs = stt_processor(audio_array_stt, sampling_rate=target_sr_stt, return_tensors="pt").input_features.to(DEVICE)
414
+ forced_ids = stt_processor.get_decoder_prompt_ids(language="arabic", task="transcribe")
415
+ with torch.no_grad():
416
+ generated_ids = stt_model.generate(inputs, forced_decoder_ids=forced_ids, max_length=448)
417
+ arabic_transcript = stt_processor.decode(generated_ids[0], skip_special_tokens=True).strip()
418
+ print(f"STT Output: {arabic_transcript}")
419
+ except Exception as e:
420
+ print(f"STT Error: {e}")
421
+
422
+ # TTT Stage
423
+ english_translation = "TTT Error: Processing failed."
424
+ if arabic_transcript and not arabic_transcript.startswith("STT Error"):
425
+ try:
426
+ print("TTT: Translating to English...")
427
+ batch = mt_tokenizer(arabic_transcript, return_tensors="pt", padding=True).to(DEVICE)
428
+ with torch.no_grad():
429
+ translated_ids = mt_model.generate(**batch, max_length=512)
430
+ english_translation = mt_tokenizer.batch_decode(translated_ids, skip_special_tokens=True)[0].strip()
431
+ print(f"TTT Output: {english_translation}")
432
+ except Exception as e:
433
+ print(f"TTT Error: {e}")
434
+ else:
435
+ english_translation = "(Skipped TTT due to STT failure)"
436
+ print(english_translation)
437
+
438
+ # TTS Stage
439
+ synthesized_audio_np = np.array([]).astype(np.float32)
440
+ if english_translation and not english_translation.startswith("TTT Error"):
441
+ try:
442
+ print("TTS: Synthesizing English speech...")
443
+ sequence = text_to_seq(english_translation).unsqueeze(0).to(DEVICE)
444
+ generated_mel, _ = TTS_MODEL.inference(sequence, max_length=hp.max_mel_time-20, stop_token_threshold=0.5, with_tqdm=False)
445
+
446
+ print(f"TTS: Generated mel shape: {generated_mel.shape if generated_mel is not None else 'None'}")
447
+ if generated_mel is not None and generated_mel.numel() > 0:
448
+ mel_for_vocoder = generated_mel.detach().squeeze(0).transpose(0, 1)
449
+ audio_tensor = inverse_mel_spec_to_wav(mel_for_vocoder)
450
+ synthesized_audio_np = audio_tensor.cpu().numpy()
451
+ print(f"TTS: Synthesized audio shape: {synthesized_audio_np.shape}")
452
+ except Exception as e:
453
+ print(f"TTS Error: {e}")
454
+
455
+ print(f"--- PIPELINE END ---")
456
+ return arabic_transcript, english_translation, (hp.sr, synthesized_audio_np)
457
+
458
+
459
+ # Part 5: FastAPI Application
460
+
461
+ app = FastAPI()
462
+
463
+ # Allow Cross-Origin Resource Sharing (CORS) for your frontend
464
+ app.add_middleware(
465
+ CORSMiddleware,
466
+ allow_origins=["*"], # Allows all origins
467
+ allow_credentials=True,
468
+ allow_methods=["*"], # Allows all methods
469
+ allow_headers=["*"], # Allows all headers
470
+ )
471
+
472
+ @app.post("/process-speech/")
473
+ async def create_upload_file(file: UploadFile = File(...)):
474
+ # Save the uploaded file temporarily
475
+ temp_path = f"/tmp/{file.filename}"
476
+ with open(temp_path, "wb") as buffer:
477
+ shutil.copyfileobj(file.file, buffer)
478
+
479
+ # Run the full pipeline
480
+ arabic, english, (sr, audio_np) = full_speech_translation_pipeline(temp_path)
481
+
482
+ # Prepare the audio to be sent back as base64
483
+ audio_base64 = ""
484
+ if audio_np.size > 0:
485
+ temp_wav_path = "/tmp/output.wav"
486
+ sf.write(temp_wav_path, audio_np, sr)
487
+ with open(temp_wav_path, "rb") as wav_file:
488
+ audio_bytes = wav_file.read()
489
+ audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
490
+
491
+ # Return all results in a single JSON response
492
+ return {
493
+ "arabic_transcript": arabic,
494
+ "english_translation": english,
495
+ "audio_data": {
496
+ "sample_rate": sr,
497
+ "base64": audio_base64
498
+ }
499
+ }
500
+ app.mount("/", StaticFiles(directory="static", html=True), name="static")
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers[torch]
2
+ torchaudio
3
+ safetensors
4
+ gradio
5
+ unidecode
6
+ inflect
7
+ pydub
8
+ accelerate
9
+ fastapi
10
+ uvicorn[standard]
11
+ python-multipart
12
+ soundfile
13
+ librosa
14
+ matplotlib
15
+ sentencepiece
16
+ sacremoses
17
+ pydub
static/index.html ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Speech Translator</title>
7
+ <style>
8
+ body { font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif; line-height: 1.6; padding: 20px; background-color: #f4f4f4; color: #333; }
9
+ .container { max-width: 700px; margin: auto; background: #fff; padding: 20px; border-radius: 8px; box-shadow: 0 0 10px rgba(0,0,0,0.1); }
10
+ h1 { text-align: center; color: #1a1a1a; }
11
+ .upload-section, .result-box { margin-bottom: 20px; }
12
+ .result-box { border: 1px solid #ddd; padding: 15px; border-radius: 5px; }
13
+ label, h3 { display: block; margin-bottom: 10px; font-weight: bold; }
14
+ input[type="file"] { display: block; margin-bottom: 10px; }
15
+ button { display: block; width: 100%; padding: 12px; background-color: #007bff; color: white; border: none; border-radius: 5px; cursor: pointer; font-size: 16px; }
16
+ button:hover { background-color: #0056b3; }
17
+ #status { text-align: center; font-style: italic; color: #555; margin-top: 15px; height: 20px; }
18
+ audio { width: 100%; margin-top: 10px; }
19
+ </style>
20
+ </head>
21
+ <body>
22
+ <div class="container">
23
+ <h1>Arabic to English Speech Translator</h1>
24
+ <div class="upload-section">
25
+ <label for="audioFileInput">Upload or Record Arabic Speech:</label>
26
+ <input type="file" id="audioFileInput" accept="audio/*">
27
+ <button id="submitButton">Submit</button>
28
+ <div id="status"></div>
29
+ </div>
30
+
31
+ <div class="result-box">
32
+ <h3>Arabic Transcript (STT):</h3>
33
+ <p id="arabicText">...</p>
34
+ </div>
35
+ <div class="result-box">
36
+ <h3>English Translation (TTT):</h3>
37
+ <p id="englishText">...</p>
38
+ </div>
39
+ <div class="result-box">
40
+ <h3>Synthesized English Speech (TTS):</h3>
41
+ <audio id="audioPlayer" controls></audio>
42
+ </div>
43
+ </div>
44
+
45
+ <script>
46
+ const fileInput = document.getElementById('audioFileInput');
47
+ const submitButton = document.getElementById('submitButton');
48
+ const statusDiv = document.getElementById('status');
49
+ const arabicText = document.getElementById('arabicText');
50
+ const englishText = document.getElementById('englishText');
51
+ const audioPlayer = document.getElementById('audioPlayer');
52
+
53
+ submitButton.addEventListener('click', async () => {
54
+ const file = fileInput.files[0];
55
+ if (!file) {
56
+ alert("Please select a file first.");
57
+ return;
58
+ }
59
+
60
+ statusDiv.textContent = "Uploading and processing... This may take a moment.";
61
+ submitButton.disabled = true;
62
+
63
+ const formData = new FormData();
64
+ formData.append('file', file); // The key 'file' must match the FastAPI endpoint parameter name
65
+
66
+ try {
67
+ // Replace with your deployed backend URL later
68
+ const apiUrl = '/process-speech/';
69
+ const response = await fetch(apiUrl, {
70
+ method: 'POST',
71
+ body: formData,
72
+ });
73
+
74
+ if (!response.ok) {
75
+ throw new Error(`HTTP error! status: ${response.status}`);
76
+ }
77
+
78
+ const result = await response.json();
79
+
80
+ arabicText.textContent = result.arabic_transcript;
81
+ englishText.textContent = result.english_translation;
82
+
83
+ if (result.audio_data.base64) {
84
+ const audioData = `data:audio/wav;base64,${result.audio_data.base64}`;
85
+ audioPlayer.src = audioData;
86
+ audioPlayer.style.display = 'block';
87
+ } else {
88
+ audioPlayer.style.display = 'none';
89
+ }
90
+
91
+ statusDiv.textContent = "Processing complete!";
92
+ } catch (error) {
93
+ console.error('Error:', error);
94
+ statusDiv.textContent = `An error occurred: ${error.message}`;
95
+ } finally {
96
+ submitButton.disabled = false;
97
+ }
98
+ });
99
+ </script>
100
+ </body>
101
+ </html>