Spaces:
Running
Running
import gradio as gr | |
import torch | |
import numpy as np | |
from modules.models import * | |
from util import get_prompt_template | |
from torchvision import transforms as vt | |
import torchaudio | |
from PIL import Image | |
def greet(image, audio): | |
device = torch.device('cpu') | |
# Get model | |
model_conf_file = f'./config/model/ACL_ViT16.yaml' | |
model = ACL(model_conf_file, device) | |
model.train(False) | |
model.load('./pretrain/Param_best.pth') | |
# Get placeholder text | |
prompt_template, text_pos_at_prompt, prompt_length = get_prompt_template() | |
# Input pre processing | |
sample_rate, audio = audio | |
audio = audio.astype(np.float32, order='C') / 32768.0 | |
desired_sample_rate = 16000 | |
set_length = 10 | |
audio_file = torch.from_numpy(audio) | |
if desired_sample_rate != sample_rate: | |
audio_file = torchaudio.functional.resample(audio_file, sample_rate, desired_sample_rate) | |
if audio_file.shape[0] == 2: | |
audio_file = torch.concat([audio_file[0], audio_file[1]], dim=0) # Stereo -> mono (x2 duration) | |
audio_file.squeeze(0) | |
if audio_file.shape[0] > (desired_sample_rate * set_length): | |
audio_file = audio_file[:desired_sample_rate * set_length] | |
# zero padding | |
if audio_file.shape[0] < (desired_sample_rate * set_length): | |
pad_len = (desired_sample_rate * set_length) - audio_file.shape[0] | |
pad_val = torch.zeros(pad_len) | |
audio_file = torch.cat((audio_file, pad_val), dim=0) | |
audio_file = audio_file.unsqueeze(0) | |
image_transform = vt.Compose([ | |
vt.Resize((352, 352), vt.InterpolationMode.BICUBIC), | |
vt.ToTensor(), | |
vt.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), # CLIP | |
]) | |
image_file = image_transform(image).unsqueeze(0) | |
# Inference | |
placeholder_tokens = model.get_placeholder_token(prompt_template.replace('{}', '')) | |
audio_driven_embedding = model.encode_audio(audio_file.to(model.device), placeholder_tokens, text_pos_at_prompt, | |
prompt_length) | |
# Localization result | |
out_dict = model(image_file.to(model.device), audio_driven_embedding, 352) | |
seg = out_dict['heatmap'][0:1] | |
seg_image = ((1 - seg.squeeze().detach().cpu().numpy()) * 255).astype(np.uint8) | |
seg_image = Image.fromarray(seg_image) | |
heatmap_image = cv2.applyColorMap(np.array(seg_image), cv2.COLORMAP_JET) | |
overlaid_image = cv2.addWeighted(np.array(image), 0.5, heatmap_image, 0.5, 0) | |
return overlaid_image | |
title = "Zero-shot sound source localization with ACL" | |
description = "This is simple demo of our WACV'24 paper 'Can CLIP Help Sound Source Localization?'\nTo use it, simply upload an image and corresponding audio to mask (identify in the image) or use one of the examples below and click 'submit'. Results will show up in a few seconds." | |
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2311.04066'>Can CLIP Help Sound Source Localization?</a> | <a href='https://github.com/swimmiing/ACL-SSL'>Official Github Repository</a></p>" | |
demo = gr.Interface( | |
fn=greet, | |
inputs=[gr.Image(type='pil'), gr.Audio()], | |
outputs=gr.Image(type="pil"), | |
title=title, | |
description=description, | |
article=article, | |
) | |
demo.launch(debug=True) | |