Spaces:
Runtime error
Runtime error
| import sys | |
| # if 'google.colab' in sys.modules: | |
| # print('Running in Colab.') | |
| # !pip3 install transformers==4.15.0 timm==0.4.12 fairscale==0.4.4 | |
| # !git clone https://github.com/salesforce/BLIP | |
| # %cd BLIP | |
| import gradio as gr | |
| import torch | |
| import requests | |
| from torchvision import transforms | |
| from PIL import Image | |
| import requests | |
| import torch | |
| from torchvision import transforms | |
| from torchvision.transforms.functional import InterpolationMode | |
| #@title | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval() | |
| response = requests.get("https://git.io/JJkYN") | |
| labels = response.text.split("\n") | |
| def predict(inp): | |
| inp = transforms.ToTensor()(inp).unsqueeze(0) | |
| with torch.no_grad(): | |
| prediction = torch.nn.functional.softmax(model(inp)[0], dim=0) | |
| confidences = {labels[i]: float(prediction[i]) for i in range(1000)} | |
| return confidences | |
| demo = gr.Interface(fn=predict, | |
| inputs=gr.inputs.Image(type="pil"), | |
| outputs=gr.outputs.Label(num_top_classes=3) | |
| ) | |
| def load_demo_image(image_size,device,imageurl): | |
| img_url = imageurl | |
| raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB') | |
| w,h = raw_image.size | |
| display(raw_image.resize((w//5,h//5))) | |
| transform = transforms.Compose([ | |
| transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) | |
| ]) | |
| image = transform(raw_image).unsqueeze(0).to(device) | |
| return image | |
| from models.blip import blip_decoder | |
| def predict(imageurl): | |
| image_size = 384 | |
| image = load_demo_image(image_size=image_size, device=device,imageurl=imageurl) | |
| model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth' | |
| model = blip_decoder(pretrained=model_url, image_size=image_size, vit='base') | |
| model.eval() | |
| model = model.to(device) | |
| with torch.no_grad(): | |
| # beam search | |
| caption = model.generate(image, sample=False, num_beams=3, max_length=20, min_length=5) | |
| # nucleus sampling | |
| # caption = model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5) | |
| return('caption: '+caption[0]) | |
| demo = gr.Interface(fn=predict, | |
| inputs="text", | |
| outputs=gr.outputs.Label(num_top_classes=3) | |
| ) | |
| demo.launch() |