Spaces:
Configuration error
Configuration error
Samuel Stevens
commited on
Commit
·
4cc1f09
1
Parent(s):
09ef17b
add app.py
Browse files- app.py +73 -0
- requirements.txt +4 -0
app.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torchvision import transforms
|
| 5 |
+
|
| 6 |
+
from open_clip import create_model, get_tokenizer
|
| 7 |
+
from open_clip.training.imagenet_zeroshot_data import openai_imagenet_template
|
| 8 |
+
|
| 9 |
+
model_str = "ViT-B-16"
|
| 10 |
+
pretrained = "/fs/ess/PAS2136/foundation_model/model/10m/2023_09_22-21_14_04-model_ViT-B-16-lr_0.0001-b_4096-j_8-p_amp/checkpoints/epoch_99.pt"
|
| 11 |
+
|
| 12 |
+
preprocess_img = transforms.Compose(
|
| 13 |
+
[
|
| 14 |
+
transforms.ToTensor(),
|
| 15 |
+
transforms.Normalize(
|
| 16 |
+
mean=(0.48145466, 0.4578275, 0.40821073),
|
| 17 |
+
std=(0.26862954, 0.26130258, 0.27577711),
|
| 18 |
+
),
|
| 19 |
+
]
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@torch.no_grad()
|
| 24 |
+
def get_txt_features(classnames, templates):
|
| 25 |
+
all_features = []
|
| 26 |
+
for classname in classnames:
|
| 27 |
+
txts = [template(classname) for template in templates]
|
| 28 |
+
txts = tokenizer(txts)
|
| 29 |
+
txt_features = model.encode_text(txts)
|
| 30 |
+
txt_features = F.normalize(txt_features, dim=-1).mean(dim=0)
|
| 31 |
+
txt_features /= txt_features.norm()
|
| 32 |
+
all_features.append(txt_features)
|
| 33 |
+
all_features = torch.stack(all_features, dim=1)
|
| 34 |
+
return all_features
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@torch.no_grad()
|
| 38 |
+
def predict(img, cls_str: str) -> dict[str, float]:
|
| 39 |
+
classes = [cls.strip() for cls in cls_str.split("\n") if cls.strip()]
|
| 40 |
+
txt_features = get_txt_features(classes, openai_imagenet_template)
|
| 41 |
+
|
| 42 |
+
img = preprocess_img(img)
|
| 43 |
+
|
| 44 |
+
img_features = model.encode_image(img.unsqueeze(0))
|
| 45 |
+
img_features = F.normalize(img_features, dim=-1)
|
| 46 |
+
logits = (img_features @ txt_features).squeeze()
|
| 47 |
+
probs = F.softmax(logits, dim=0).tolist()
|
| 48 |
+
return {cls: prob for cls, prob in zip(classes, probs)}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
if __name__ == "__main__":
|
| 52 |
+
print("Starting.")
|
| 53 |
+
model = create_model(model_str, pretrained, output_dict=True)
|
| 54 |
+
print("Created model.")
|
| 55 |
+
|
| 56 |
+
model = torch.compile(model)
|
| 57 |
+
print("Compiled model.")
|
| 58 |
+
|
| 59 |
+
tokenizer = get_tokenizer(model_str)
|
| 60 |
+
|
| 61 |
+
demo = gr.Interface(
|
| 62 |
+
fn=predict,
|
| 63 |
+
inputs=[
|
| 64 |
+
gr.Image(shape=(224, 224)),
|
| 65 |
+
gr.Textbox(
|
| 66 |
+
placeholder="dog\ncat\n...", lines=3, label="Classes", show_label=True
|
| 67 |
+
),
|
| 68 |
+
],
|
| 69 |
+
outputs=gr.Label(num_top_classes=20, label="Predictions", show_label=True),
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
demo.launch()
|
| 73 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
open_clip_torch
|
| 2 |
+
torchvision
|
| 3 |
+
torch
|
| 4 |
+
gradio
|