Upload 2 files
Browse files- image_feature.py +4 -13
image_feature.py
CHANGED
@@ -55,9 +55,10 @@ DEVICE = torch.device('cpu')
|
|
55 |
# model = AutoModel.from_pretrained("google/vit-base-patch16-224-in21k").to(DEVICE)
|
56 |
# processor = AutoImageProcessor.from_pretrained("chanhua/autotrain-izefx-v3qh0")
|
57 |
# model = AutoModel.from_pretrained("chanhua/autotrain-izefx-v3qh0").to(DEVICE)
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
61 |
|
62 |
|
63 |
# tensor([0.6061], device='cuda:0', grad_fn=<SumBackward1>)
|
@@ -102,16 +103,6 @@ def infer4(url1, url2):
|
|
102 |
# 无论是否发生异常,都会执行此代码块
|
103 |
print("这是finally块")
|
104 |
|
105 |
-
# 推理
|
106 |
-
def infer3(url):
|
107 |
-
# image_real = Image.open(requests.get(img_urls[0], stream=True).raw).convert("RGB")
|
108 |
-
# image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
|
109 |
-
image = Image.open(url).convert('RGB')
|
110 |
-
|
111 |
-
inputs = processor(image, return_tensors="pt").to(DEVICE)
|
112 |
-
outputs = model(**inputs)
|
113 |
-
return outputs.pooler_output
|
114 |
-
|
115 |
# 推理
|
116 |
def infer2(url):
|
117 |
# image_real = Image.open(requests.get(img_urls[0], stream=True).raw).convert("RGB")
|
|
|
55 |
# model = AutoModel.from_pretrained("google/vit-base-patch16-224-in21k").to(DEVICE)
|
56 |
# processor = AutoImageProcessor.from_pretrained("chanhua/autotrain-izefx-v3qh0")
|
57 |
# model = AutoModel.from_pretrained("chanhua/autotrain-izefx-v3qh0").to(DEVICE)
|
58 |
+
processor = ViTImageProcessor.from_pretrained('google/vit-large-patch16-224-in21k')
|
59 |
+
model = ViTModel.from_pretrained('google/vit-large-patch16-224-in21k')
|
60 |
+
# processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
|
61 |
+
# model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
|
62 |
|
63 |
|
64 |
# tensor([0.6061], device='cuda:0', grad_fn=<SumBackward1>)
|
|
|
103 |
# 无论是否发生异常,都会执行此代码块
|
104 |
print("这是finally块")
|
105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
# 推理
|
107 |
def infer2(url):
|
108 |
# image_real = Image.open(requests.get(img_urls[0], stream=True).raw).convert("RGB")
|