vincentamato commited on
Commit
8549414
·
1 Parent(s): 6142e6b

precache clip

Browse files
Files changed (1) hide show
  1. aria/image_encoder.py +34 -9
aria/image_encoder.py CHANGED
@@ -3,6 +3,7 @@ import torch.nn as nn
3
  from transformers import CLIPProcessor, CLIPModel
4
  from PIL import Image
5
  from typing import Tuple, Union
 
6
 
7
  class ImageEncoder(nn.Module):
8
  def __init__(self, clip_model_name: str = "openai/clip-vit-large-patch14-336"):
@@ -30,15 +31,39 @@ class ImageEncoder(nn.Module):
30
  print(f"Initializing ImageEncoder with {self.clip_model_name}...")
31
  print("Loading CLIP model from local cache (network disabled)...")
32
 
33
- # Load CLIP model and processor strictly from the local Hugging-Face cache
34
- self.clip_model = CLIPModel.from_pretrained(
35
- self.clip_model_name,
36
- local_files_only=True # fail fast if cache is missing
37
- )
38
- self.processor = CLIPProcessor.from_pretrained(
39
- self.clip_model_name,
40
- local_files_only=True
41
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  print("CLIP model loaded successfully")
44
 
 
3
  from transformers import CLIPProcessor, CLIPModel
4
  from PIL import Image
5
  from typing import Tuple, Union
6
+ import os
7
 
8
  class ImageEncoder(nn.Module):
9
  def __init__(self, clip_model_name: str = "openai/clip-vit-large-patch14-336"):
 
31
  print(f"Initializing ImageEncoder with {self.clip_model_name}...")
32
  print("Loading CLIP model from local cache (network disabled)...")
33
 
34
+ # Prefer loading strictly from the local Hugging Face cache that `app.py` populates.
35
+ # If the files are genuinely missing (e.g. first run without network), we fall back
36
+ # to an online download so the user still gets a working application.
37
+
38
+ # Determine the cache directory from env – this is set in `app.py`.
39
+ hf_cache_dir = os.environ.get("HF_HUB_CACHE", None)
40
+
41
+ try:
42
+ self.clip_model = CLIPModel.from_pretrained(
43
+ self.clip_model_name,
44
+ cache_dir=hf_cache_dir,
45
+ local_files_only=True, # use cache only on the first attempt
46
+ )
47
+ self.processor = CLIPProcessor.from_pretrained(
48
+ self.clip_model_name,
49
+ cache_dir=hf_cache_dir,
50
+ local_files_only=True,
51
+ )
52
+ print("CLIP model loaded successfully from local cache")
53
+ except (OSError, EnvironmentError) as cache_err:
54
+ print(
55
+ "Local cache for CLIP model not found – attempting a one-time online download..."
56
+ )
57
+ # Note: this will still respect HF_HUB_CACHE so the files are cached for future runs.
58
+ self.clip_model = CLIPModel.from_pretrained(
59
+ self.clip_model_name,
60
+ cache_dir=hf_cache_dir,
61
+ )
62
+ self.processor = CLIPProcessor.from_pretrained(
63
+ self.clip_model_name,
64
+ cache_dir=hf_cache_dir,
65
+ )
66
+ print("CLIP model downloaded and cached successfully")
67
 
68
  print("CLIP model loaded successfully")
69