cat2vec / README.md
aeth0r's picture
Update README.md
d73f262 verified
metadata
license: gpl-3.0
base_model:
  - microsoft/resnet-50
library_name: transformers

cat2vec

The cat2vec model is a search model for cats.

It was trained using the Labeled Cats In The Wild dataset and a triplet loss.

Usage

from transformers import AutoImageProcessor, ResNetModel
import torch
from datasets import load_dataset

dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][:2]

processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
model = ResNetModel.from_pretrained("microsoft/resnet-50")

inputs = processor(image, return_tensors="pt")

with torch.no_grad():
    features = model(**inputs)

print(features)