swimmiing's picture
Add requirements.txt
8095871
raw
history blame
3.29 kB
import gradio as gr
import torch
import numpy as np
from modules.models import *
from util import get_prompt_template
from torchvision import transforms as vt
import torchaudio
from PIL import Image
def greet(image, audio):
device = torch.device('cpu')
# Get model
model_conf_file = f'./config/model/ACL_ViT16.yaml'
model = ACL(model_conf_file, device)
model.train(False)
model.load('./pretrain/Param_best.pth')
# Get placeholder text
prompt_template, text_pos_at_prompt, prompt_length = get_prompt_template()
# Input pre processing
sample_rate, audio = audio
audio = audio.astype(np.float32, order='C') / 32768.0
desired_sample_rate = 16000
set_length = 10
audio_file = torch.from_numpy(audio)
if desired_sample_rate != sample_rate:
audio_file = torchaudio.functional.resample(audio_file, sample_rate, desired_sample_rate)
if audio_file.shape[0] == 2:
audio_file = torch.concat([audio_file[0], audio_file[1]], dim=0) # Stereo -> mono (x2 duration)
audio_file.squeeze(0)
if audio_file.shape[0] > (desired_sample_rate * set_length):
audio_file = audio_file[:desired_sample_rate * set_length]
# zero padding
if audio_file.shape[0] < (desired_sample_rate * set_length):
pad_len = (desired_sample_rate * set_length) - audio_file.shape[0]
pad_val = torch.zeros(pad_len)
audio_file = torch.cat((audio_file, pad_val), dim=0)
audio_file = audio_file.unsqueeze(0)
image_transform = vt.Compose([
vt.Resize((352, 352), vt.InterpolationMode.BICUBIC),
vt.ToTensor(),
vt.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), # CLIP
])
image_file = image_transform(image).unsqueeze(0)
# Inference
placeholder_tokens = model.get_placeholder_token(prompt_template.replace('{}', ''))
audio_driven_embedding = model.encode_audio(audio_file.to(model.device), placeholder_tokens, text_pos_at_prompt,
prompt_length)
# Localization result
out_dict = model(image_file.to(model.device), audio_driven_embedding, 352)
seg = out_dict['heatmap'][0:1]
seg_image = ((1 - seg.squeeze().detach().cpu().numpy()) * 255).astype(np.uint8)
seg_image = Image.fromarray(seg_image)
heatmap_image = cv2.applyColorMap(np.array(seg_image), cv2.COLORMAP_JET)
overlaid_image = cv2.addWeighted(np.array(image), 0.5, heatmap_image, 0.5, 0)
return overlaid_image
title = "Zero-shot sound source localization with ACL"
description = "This is simple demo of our WACV'24 paper 'Can CLIP Help Sound Source Localization?'\nTo use it, simply upload an image and corresponding audio to mask (identify in the image) or use one of the examples below and click 'submit'. Results will show up in a few seconds."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2311.04066'>Can CLIP Help Sound Source Localization?</a> | <a href='https://github.com/swimmiing/ACL-SSL'>Official Github Repository</a></p>"
demo = gr.Interface(
fn=greet,
inputs=[gr.Image(type='pil'), gr.Audio()],
outputs=gr.Image(type="pil"),
title=title,
description=description,
article=article,
)
demo.launch(debug=True)