Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
8549414
1
Parent(s):
6142e6b
precache clip
Browse files- aria/image_encoder.py +34 -9
aria/image_encoder.py
CHANGED
@@ -3,6 +3,7 @@ import torch.nn as nn
|
|
3 |
from transformers import CLIPProcessor, CLIPModel
|
4 |
from PIL import Image
|
5 |
from typing import Tuple, Union
|
|
|
6 |
|
7 |
class ImageEncoder(nn.Module):
|
8 |
def __init__(self, clip_model_name: str = "openai/clip-vit-large-patch14-336"):
|
@@ -30,15 +31,39 @@ class ImageEncoder(nn.Module):
|
|
30 |
print(f"Initializing ImageEncoder with {self.clip_model_name}...")
|
31 |
print("Loading CLIP model from local cache (network disabled)...")
|
32 |
|
33 |
-
#
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
print("CLIP model loaded successfully")
|
44 |
|
|
|
3 |
from transformers import CLIPProcessor, CLIPModel
|
4 |
from PIL import Image
|
5 |
from typing import Tuple, Union
|
6 |
+
import os
|
7 |
|
8 |
class ImageEncoder(nn.Module):
|
9 |
def __init__(self, clip_model_name: str = "openai/clip-vit-large-patch14-336"):
|
|
|
31 |
print(f"Initializing ImageEncoder with {self.clip_model_name}...")
|
32 |
print("Loading CLIP model from local cache (network disabled)...")
|
33 |
|
34 |
+
# Prefer loading strictly from the local Hugging Face cache that `app.py` populates.
|
35 |
+
# If the files are genuinely missing (e.g. first run without network), we fall back
|
36 |
+
# to an online download so the user still gets a working application.
|
37 |
+
|
38 |
+
# Determine the cache directory from env – this is set in `app.py`.
|
39 |
+
hf_cache_dir = os.environ.get("HF_HUB_CACHE", None)
|
40 |
+
|
41 |
+
try:
|
42 |
+
self.clip_model = CLIPModel.from_pretrained(
|
43 |
+
self.clip_model_name,
|
44 |
+
cache_dir=hf_cache_dir,
|
45 |
+
local_files_only=True, # use cache only on the first attempt
|
46 |
+
)
|
47 |
+
self.processor = CLIPProcessor.from_pretrained(
|
48 |
+
self.clip_model_name,
|
49 |
+
cache_dir=hf_cache_dir,
|
50 |
+
local_files_only=True,
|
51 |
+
)
|
52 |
+
print("CLIP model loaded successfully from local cache")
|
53 |
+
except (OSError, EnvironmentError) as cache_err:
|
54 |
+
print(
|
55 |
+
"Local cache for CLIP model not found – attempting a one-time online download..."
|
56 |
+
)
|
57 |
+
# Note: this will still respect HF_HUB_CACHE so the files are cached for future runs.
|
58 |
+
self.clip_model = CLIPModel.from_pretrained(
|
59 |
+
self.clip_model_name,
|
60 |
+
cache_dir=hf_cache_dir,
|
61 |
+
)
|
62 |
+
self.processor = CLIPProcessor.from_pretrained(
|
63 |
+
self.clip_model_name,
|
64 |
+
cache_dir=hf_cache_dir,
|
65 |
+
)
|
66 |
+
print("CLIP model downloaded and cached successfully")
|
67 |
|
68 |
print("CLIP model loaded successfully")
|
69 |
|