Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from modules.models import * | |
| from util import get_prompt_template | |
| from torchvision import transforms as vt | |
| import torchaudio | |
| from PIL import Image | |
| import cv2 | |
| def greet(image, audio): | |
| device = torch.device('cpu') | |
| # Get model | |
| model_conf_file = f'./config/model/ACL_ViT16.yaml' | |
| model = ACL(model_conf_file, device) | |
| model.train(False) | |
| model.load('./pretrain/Param_best.pth') | |
| # Get placeholder text | |
| prompt_template, text_pos_at_prompt, prompt_length = get_prompt_template() | |
| # Input pre processing | |
| sample_rate, audio = audio | |
| audio = audio.astype(np.float32, order='C') / 32768.0 | |
| desired_sample_rate = 16000 | |
| set_length = 10 | |
| audio_file = torch.from_numpy(audio) | |
| if len(audio_file.shape) == 2: | |
| audio_file = torch.concat([audio_file[:, 0:1], audio_file[:, 1:2]], dim=0).T # Stereo -> mono (x2 duration) | |
| else: | |
| audio_file = audio_file.unsqueeze(0) | |
| if desired_sample_rate != sample_rate: | |
| audio_file = torchaudio.functional.resample(audio_file, sample_rate, desired_sample_rate) | |
| audio_file = audio_file.squeeze(0) | |
| if audio_file.shape[0] > (desired_sample_rate * set_length): | |
| audio_file = audio_file[:desired_sample_rate * set_length] | |
| # zero padding | |
| if audio_file.shape[0] < (desired_sample_rate * set_length): | |
| pad_len = (desired_sample_rate * set_length) - audio_file.shape[0] | |
| pad_val = torch.zeros(pad_len) | |
| audio_file = torch.cat((audio_file, pad_val), dim=0) | |
| audio_file = audio_file.unsqueeze(0) | |
| image_transform = vt.Compose([ | |
| vt.Resize((352, 352), vt.InterpolationMode.BICUBIC), | |
| vt.ToTensor(), | |
| vt.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), # CLIP | |
| ]) | |
| image_file = image_transform(image).unsqueeze(0) | |
| # Inference | |
| placeholder_tokens = model.get_placeholder_token(prompt_template.replace('{}', '')) | |
| audio_driven_embedding = model.encode_audio(audio_file.to(model.device), placeholder_tokens, text_pos_at_prompt, | |
| prompt_length) | |
| # Localization result | |
| out_dict = model(image_file.to(model.device), audio_driven_embedding, 352) | |
| seg = out_dict['heatmap'][0:1] | |
| seg_image = ((1 - seg.squeeze().detach().cpu().numpy()) * 255).astype(np.uint8) | |
| seg_image = Image.fromarray(seg_image) | |
| seg_image = seg_image.resize(image.size, Image.BICUBIC) | |
| heatmap_image = cv2.applyColorMap(np.array(seg_image), cv2.COLORMAP_JET) | |
| overlaid_image = cv2.addWeighted(np.array(image), 0.5, heatmap_image, 0.5, 0) | |
| return overlaid_image | |
| title = "Audio-Grounded Contrastive Learning" | |
| description = """<p> | |
| This is a simple demo of our WACV'24 paper 'Can CLIP Help Sound Source Localization?', zero-shot visual sound localization.<br><br> | |
| To use it simply upload an image and corresponding audio to mask (identify in the image), or use one of the examples below and click ‘submit’.<br><br> | |
| Results will show up in a few seconds. | |
| </p>""" | |
| article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2311.04066'>Can CLIP Help Sound Source Localization?</a> | <a href='https://github.com/swimmiing/ACL-SSL'>Offical Github repo</a></p>" | |
| demo = gr.Interface( | |
| fn=greet, | |
| inputs=[gr.Image(type='pil'), gr.Audio()], | |
| outputs=gr.Image(type="pil"), | |
| title=title, | |
| description=description, | |
| article=article | |
| ) | |
| demo.launch(debug=True) | |