Spaces:
Build error
Build error
Commit
·
ead7a82
1
Parent(s):
15a96ee
Update app.py
Browse files
app.py
CHANGED
|
@@ -20,6 +20,7 @@ import torchvision.transforms as transforms
|
|
| 20 |
import av
|
| 21 |
import subprocess
|
| 22 |
import librosa
|
|
|
|
| 23 |
|
| 24 |
args = {"model": "./ckpts/checkpoint.pth", "llama_type": "7B", "llama_dir": "./ckpts/LLaMA-2",
|
| 25 |
"mert_path": "m-a-p/MERT-v1-330M", "vit_path": "google/vit-base-patch16-224", "vivit_path": "google/vivit-b-16x2-kinetics400",
|
|
@@ -33,7 +34,7 @@ class dotdict(dict):
|
|
| 33 |
|
| 34 |
args = dotdict(args)
|
| 35 |
|
| 36 |
-
generated_audio_files =
|
| 37 |
|
| 38 |
llama_type = args.llama_type
|
| 39 |
llama_ckpt_dir = os.path.join(args.llama_dir, llama_type)
|
|
@@ -117,7 +118,7 @@ def parse_text(text, image_path, video_path, audio_path):
|
|
| 117 |
return text, outputs
|
| 118 |
|
| 119 |
|
| 120 |
-
def save_audio_to_local(audio, sec):
|
| 121 |
global generated_audio_files
|
| 122 |
if not os.path.exists('temp'):
|
| 123 |
os.mkdir('temp')
|
|
@@ -126,11 +127,11 @@ def save_audio_to_local(audio, sec):
|
|
| 126 |
scipy.io.wavfile.write(filename, rate=16000, data=audio[0])
|
| 127 |
else:
|
| 128 |
scipy.io.wavfile.write(filename, rate=model.generation_model.config.audio_encoder.sampling_rate, data=audio)
|
| 129 |
-
generated_audio_files.append(filename)
|
| 130 |
return filename
|
| 131 |
|
| 132 |
|
| 133 |
-
def parse_reponse(model_outputs, audio_length_in_s):
|
| 134 |
response = ''
|
| 135 |
text_outputs = []
|
| 136 |
for output_i, p in enumerate(model_outputs):
|
|
@@ -146,7 +147,7 @@ def parse_reponse(model_outputs, audio_length_in_s):
|
|
| 146 |
response += '<br>'
|
| 147 |
_temp_output += m.replace(' '.join([f'[AUD{i}]' for i in range(8)]), '')
|
| 148 |
else:
|
| 149 |
-
filename = save_audio_to_local(m, audio_length_in_s)
|
| 150 |
print(filename)
|
| 151 |
_temp_output = f'<Audio>{filename}</Audio> ' + _temp_output
|
| 152 |
response += f'<audio controls playsinline><source src="./file={filename}" type="audio/wav"></audio>'
|
|
@@ -161,15 +162,15 @@ def reset_user_input():
|
|
| 161 |
return gr.update(value='')
|
| 162 |
|
| 163 |
|
| 164 |
-
def reset_dialog():
|
| 165 |
global generated_audio_files
|
| 166 |
-
generated_audio_files = []
|
| 167 |
return [], []
|
| 168 |
|
| 169 |
|
| 170 |
-
def reset_state():
|
| 171 |
global generated_audio_files
|
| 172 |
-
generated_audio_files = []
|
| 173 |
return None, None, None, None, [], [], []
|
| 174 |
|
| 175 |
|
|
@@ -218,6 +219,7 @@ def get_audio_length(filename):
|
|
| 218 |
|
| 219 |
|
| 220 |
def predict(
|
|
|
|
| 221 |
prompt_input,
|
| 222 |
image_path,
|
| 223 |
audio_path,
|
|
@@ -247,28 +249,30 @@ def predict(
|
|
| 247 |
indices = sample_frame_indices(clip_len=32, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
|
| 248 |
video = read_video_pyav(container=container, indices=indices)
|
| 249 |
|
| 250 |
-
if len(generated_audio_files) != 0:
|
| 251 |
-
audio_length_in_s = get_audio_length(generated_audio_files[-1])
|
| 252 |
sample_rate = 24000
|
| 253 |
-
waveform, sr = torchaudio.load(generated_audio_files[-1])
|
| 254 |
if sample_rate != sr:
|
| 255 |
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=sample_rate)
|
| 256 |
audio = torch.mean(waveform, 0)
|
| 257 |
audio_length_in_s = int(len(audio)//sample_rate)
|
| 258 |
print(f"Audio Length: {audio_length_in_s}")
|
|
|
|
|
|
|
| 259 |
if video_path is not None:
|
| 260 |
audio_length_in_s = get_video_length(video_path)
|
| 261 |
print(f"Video Length: {audio_length_in_s}")
|
| 262 |
if audio_path is not None:
|
| 263 |
audio_length_in_s = get_audio_length(audio_path)
|
| 264 |
-
generated_audio_files.append(audio_path)
|
| 265 |
print(f"Audio Length: {audio_length_in_s}")
|
| 266 |
|
| 267 |
print(image, video, audio)
|
| 268 |
response = model.generate(prompts, audio, image, video, 200, temperature, top_p,
|
| 269 |
audio_length_in_s=audio_length_in_s)
|
| 270 |
print(response)
|
| 271 |
-
response_chat, response_outputs = parse_reponse(response, audio_length_in_s)
|
| 272 |
print('text_outputs: ', response_outputs)
|
| 273 |
user_chat, user_outputs = parse_text(prompt_input, image_path, video_path, audio_path)
|
| 274 |
chatbot.append((user_chat, response_chat))
|
|
@@ -319,9 +323,11 @@ with gr.Blocks() as demo:
|
|
| 319 |
|
| 320 |
history = gr.State([])
|
| 321 |
modality_cache = gr.State([])
|
|
|
|
| 322 |
|
| 323 |
submitBtn.click(
|
| 324 |
predict, [
|
|
|
|
| 325 |
user_input,
|
| 326 |
image_path,
|
| 327 |
audio_path,
|
|
@@ -343,8 +349,8 @@ with gr.Blocks() as demo:
|
|
| 343 |
show_progress=True
|
| 344 |
)
|
| 345 |
|
| 346 |
-
submitBtn.click(reset_user_input, [], [user_input])
|
| 347 |
-
emptyBtn.click(reset_state, outputs=[
|
| 348 |
image_path,
|
| 349 |
audio_path,
|
| 350 |
video_path,
|
|
|
|
| 20 |
import av
|
| 21 |
import subprocess
|
| 22 |
import librosa
|
| 23 |
+
import uuid
|
| 24 |
|
| 25 |
args = {"model": "./ckpts/checkpoint.pth", "llama_type": "7B", "llama_dir": "./ckpts/LLaMA-2",
|
| 26 |
"mert_path": "m-a-p/MERT-v1-330M", "vit_path": "google/vit-base-patch16-224", "vivit_path": "google/vivit-b-16x2-kinetics400",
|
|
|
|
| 34 |
|
| 35 |
args = dotdict(args)
|
| 36 |
|
| 37 |
+
generated_audio_files = {}
|
| 38 |
|
| 39 |
llama_type = args.llama_type
|
| 40 |
llama_ckpt_dir = os.path.join(args.llama_dir, llama_type)
|
|
|
|
| 118 |
return text, outputs
|
| 119 |
|
| 120 |
|
| 121 |
+
def save_audio_to_local(uid, audio, sec):
|
| 122 |
global generated_audio_files
|
| 123 |
if not os.path.exists('temp'):
|
| 124 |
os.mkdir('temp')
|
|
|
|
| 127 |
scipy.io.wavfile.write(filename, rate=16000, data=audio[0])
|
| 128 |
else:
|
| 129 |
scipy.io.wavfile.write(filename, rate=model.generation_model.config.audio_encoder.sampling_rate, data=audio)
|
| 130 |
+
generated_audio_files[uid].append(filename)
|
| 131 |
return filename
|
| 132 |
|
| 133 |
|
| 134 |
+
def parse_reponse(uid, model_outputs, audio_length_in_s):
|
| 135 |
response = ''
|
| 136 |
text_outputs = []
|
| 137 |
for output_i, p in enumerate(model_outputs):
|
|
|
|
| 147 |
response += '<br>'
|
| 148 |
_temp_output += m.replace(' '.join([f'[AUD{i}]' for i in range(8)]), '')
|
| 149 |
else:
|
| 150 |
+
filename = save_audio_to_local(uid, m, audio_length_in_s)
|
| 151 |
print(filename)
|
| 152 |
_temp_output = f'<Audio>{filename}</Audio> ' + _temp_output
|
| 153 |
response += f'<audio controls playsinline><source src="./file={filename}" type="audio/wav"></audio>'
|
|
|
|
| 162 |
return gr.update(value='')
|
| 163 |
|
| 164 |
|
| 165 |
+
def reset_dialog(uid):
|
| 166 |
global generated_audio_files
|
| 167 |
+
generated_audio_files[uid] = []
|
| 168 |
return [], []
|
| 169 |
|
| 170 |
|
| 171 |
+
def reset_state(uid):
|
| 172 |
global generated_audio_files
|
| 173 |
+
generated_audio_files[uid] = []
|
| 174 |
return None, None, None, None, [], [], []
|
| 175 |
|
| 176 |
|
|
|
|
| 219 |
|
| 220 |
|
| 221 |
def predict(
|
| 222 |
+
uid,
|
| 223 |
prompt_input,
|
| 224 |
image_path,
|
| 225 |
audio_path,
|
|
|
|
| 249 |
indices = sample_frame_indices(clip_len=32, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
|
| 250 |
video = read_video_pyav(container=container, indices=indices)
|
| 251 |
|
| 252 |
+
if uid in generated_audio_files and len(generated_audio_files[uid]) != 0:
|
| 253 |
+
audio_length_in_s = get_audio_length(generated_audio_files[uid][-1])
|
| 254 |
sample_rate = 24000
|
| 255 |
+
waveform, sr = torchaudio.load(generated_audio_files[uid][-1])
|
| 256 |
if sample_rate != sr:
|
| 257 |
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=sample_rate)
|
| 258 |
audio = torch.mean(waveform, 0)
|
| 259 |
audio_length_in_s = int(len(audio)//sample_rate)
|
| 260 |
print(f"Audio Length: {audio_length_in_s}")
|
| 261 |
+
else:
|
| 262 |
+
generated_audio_files[uid] = []
|
| 263 |
if video_path is not None:
|
| 264 |
audio_length_in_s = get_video_length(video_path)
|
| 265 |
print(f"Video Length: {audio_length_in_s}")
|
| 266 |
if audio_path is not None:
|
| 267 |
audio_length_in_s = get_audio_length(audio_path)
|
| 268 |
+
generated_audio_files[uid].append(audio_path)
|
| 269 |
print(f"Audio Length: {audio_length_in_s}")
|
| 270 |
|
| 271 |
print(image, video, audio)
|
| 272 |
response = model.generate(prompts, audio, image, video, 200, temperature, top_p,
|
| 273 |
audio_length_in_s=audio_length_in_s)
|
| 274 |
print(response)
|
| 275 |
+
response_chat, response_outputs = parse_reponse(uid, response, audio_length_in_s)
|
| 276 |
print('text_outputs: ', response_outputs)
|
| 277 |
user_chat, user_outputs = parse_text(prompt_input, image_path, video_path, audio_path)
|
| 278 |
chatbot.append((user_chat, response_chat))
|
|
|
|
| 323 |
|
| 324 |
history = gr.State([])
|
| 325 |
modality_cache = gr.State([])
|
| 326 |
+
uid = gr.State(uuid.uuid4())
|
| 327 |
|
| 328 |
submitBtn.click(
|
| 329 |
predict, [
|
| 330 |
+
uid,
|
| 331 |
user_input,
|
| 332 |
image_path,
|
| 333 |
audio_path,
|
|
|
|
| 349 |
show_progress=True
|
| 350 |
)
|
| 351 |
|
| 352 |
+
submitBtn.click(reset_user_input, [uid], [user_input])
|
| 353 |
+
emptyBtn.click(reset_state, [uid], outputs=[
|
| 354 |
image_path,
|
| 355 |
audio_path,
|
| 356 |
video_path,
|