cella110n commited on
Commit
fe88ff7
·
verified ·
1 Parent(s): 941802a

Upload 3 files

Browse files
Files changed (2) hide show
  1. app.py +140 -197
  2. requirements.txt +1 -1
app.py CHANGED
@@ -13,6 +13,7 @@ from dataclasses import dataclass
13
  from typing import List, Dict, Optional, Tuple
14
  import time
15
  import spaces # Required for @spaces.GPU
 
16
 
17
  import torch # Keep torch for device check in Tagger
18
  import timm # Restore timm
@@ -33,23 +34,48 @@ class LabelData:
33
  meta: list[np.int64]
34
  quality: list[np.int64]
35
 
36
- # Keep helpers needed for initialization
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  def load_tag_mapping(mapping_path):
 
38
  with open(mapping_path, 'r', encoding='utf-8') as f: tag_mapping_data = json.load(f)
 
39
  if isinstance(tag_mapping_data, dict) and "idx_to_tag" in tag_mapping_data:
40
  idx_to_tag = {int(k): v for k, v in tag_mapping_data["idx_to_tag"].items()}
41
  tag_to_category = tag_mapping_data["tag_to_category"]
42
  elif isinstance(tag_mapping_data, dict):
43
- tag_mapping_data = {int(k): v for k, v in tag_mapping_data.items()}
44
- idx_to_tag = {idx: data['tag'] for idx, data in tag_mapping_data.items()}
45
- tag_to_category = {data['tag']: data['category'] for data in tag_mapping_data.values()}
46
- else: raise ValueError("Unsupported tag mapping format")
 
 
 
 
 
 
47
  names = [None] * (max(idx_to_tag.keys()) + 1)
48
  rating, general, artist, character, copyright, meta, quality = [], [], [], [], [], [], []
49
  for idx, tag in idx_to_tag.items():
50
  if idx >= len(names): names.extend([None] * (idx - len(names) + 1))
51
  names[idx] = tag
52
- category = tag_to_category.get(tag, 'Unknown')
53
  idx_int = int(idx)
54
  if category == 'Rating': rating.append(idx_int)
55
  elif category == 'General': general.append(idx_int)
@@ -58,215 +84,132 @@ def load_tag_mapping(mapping_path):
58
  elif category == 'Copyright': copyright.append(idx_int)
59
  elif category == 'Meta': meta.append(idx_int)
60
  elif category == 'Quality': quality.append(idx_int)
61
- return LabelData(names=names, rating=np.array(rating), general=np.array(general), artist=np.array(artist),
62
- character=np.array(character), copyright=np.array(copyright), meta=np.array(meta), quality=np.array(quality)), tag_to_category
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  # --- Constants ---
65
  REPO_ID = "cella110n/cl_tagger"
66
- SAFETENSORS_FILENAME = "lora_model_0426/checkpoint_epoch_4.safetensors"
67
- METADATA_FILENAME = "lora_model_0426/checkpoint_epoch_4_metadata.json"
 
68
  TAG_MAPPING_FILENAME = "lora_model_0426/tag_mapping.json"
69
  CACHE_DIR = "./model_cache"
70
- BASE_MODEL_NAME = 'eva02_large_patch14_448.mim_m38m_ft_in1k' # Restore base model name
71
-
72
- # --- Tagger Class ---
73
- class Tagger:
74
- def __init__(self):
75
- print("Initializing Tagger...")
76
- self.safetensors_path = None
77
- self.metadata_path = None
78
- self.tag_mapping_path = None
79
- self.labels_data = None
80
- self.tag_to_category = None
81
- self.model = None # Model will be loaded later
82
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
83
- self._initialize_paths_and_labels()
84
- print("Tagger Initialized.") # Add confirmation
85
-
86
- def _download_files(self):
87
- # Check if paths are already set and files exist (useful for restarts)
88
- local_safetensors = os.path.join(CACHE_DIR, 'models--cella110n--cl_tagger', 'snapshots', '21e237f0ae461b8d9ebf7472ae8de003e5effe5b', SAFETENSORS_FILENAME)
89
- local_tag_mapping = os.path.join(CACHE_DIR, 'models--cella110n--cl_tagger', 'snapshots', '21e237f0ae461b8d9ebf7472ae8de003e5effe5b', TAG_MAPPING_FILENAME)
90
- local_metadata = os.path.join(CACHE_DIR, 'models--cella110n--cl_tagger', 'snapshots', '21e237f0ae461b8d9ebf7472ae8de003e5effe5b', METADATA_FILENAME)
91
-
92
- needs_download = False
93
- if not (self.safetensors_path and os.path.exists(self.safetensors_path)):
94
- if os.path.exists(local_safetensors):
95
- self.safetensors_path = local_safetensors
96
- print(f"Found existing safetensors: {self.safetensors_path}")
97
- else:
98
- needs_download = True
99
- if not (self.tag_mapping_path and os.path.exists(self.tag_mapping_path)):
100
- if os.path.exists(local_tag_mapping):
101
- self.tag_mapping_path = local_tag_mapping
102
- print(f"Found existing tag mapping: {self.tag_mapping_path}")
103
- else:
104
- needs_download = True
105
- # Metadata is optional, check similarly
106
- if not (self.metadata_path and os.path.exists(self.metadata_path)):
107
- if os.path.exists(local_metadata):
108
- self.metadata_path = local_metadata
109
- print(f"Found existing metadata: {self.metadata_path}")
110
- # Don't trigger download just for metadata if others exist
111
-
112
- if not needs_download and self.safetensors_path and self.tag_mapping_path:
113
- print("Required files already exist or paths set.")
114
- return
115
-
116
- print("Downloading model files...")
117
- hf_token = os.environ.get("HF_TOKEN")
118
- try:
119
- # Only download if not found locally
120
- if not self.safetensors_path:
121
- self.safetensors_path = hf_hub_download(repo_id=REPO_ID, filename=SAFETENSORS_FILENAME, cache_dir=CACHE_DIR, token=hf_token, force_download=False) # Use force_download=False
122
- if not self.tag_mapping_path:
123
- self.tag_mapping_path = hf_hub_download(repo_id=REPO_ID, filename=TAG_MAPPING_FILENAME, cache_dir=CACHE_DIR, token=hf_token, force_download=False)
124
- print(f"Safetensors: {self.safetensors_path}")
125
- print(f"Tag mapping: {self.tag_mapping_path}")
126
- try:
127
- # Only download if not found locally
128
- if not self.metadata_path:
129
- self.metadata_path = hf_hub_download(repo_id=REPO_ID, filename=METADATA_FILENAME, cache_dir=CACHE_DIR, token=hf_token, force_download=False)
130
- print(f"Metadata: {self.metadata_path}")
131
- except Exception as e_meta:
132
- # Handle case where metadata genuinely doesn't exist or download fails
133
- print(f"Metadata ({METADATA_FILENAME}) not found/download failed. Error: {e_meta}")
134
- self.metadata_path = None
135
-
136
- except Exception as e:
137
- print(f"Error downloading files: {e}")
138
- if "401 Client Error" in str(e) or "Repository not found" in str(e): raise gr.Error(f"Could not download files from {REPO_ID}. Check HF_TOKEN or repository status.")
139
- else: raise gr.Error(f"Error downloading files: {e}")
140
-
141
- def _initialize_paths_and_labels(self):
142
- # Call download first (it now checks existence)
143
- self._download_files()
144
- # Only load labels if not already loaded
145
- if self.labels_data is None:
146
- print("Loading labels...")
147
- if self.tag_mapping_path and os.path.exists(self.tag_mapping_path):
148
- try:
149
- self.labels_data, self.tag_to_category = load_tag_mapping(self.tag_mapping_path)
150
- print(f"Labels loaded. Count: {len(self.labels_data.names)}")
151
- except Exception as e: raise gr.Error(f"Error loading tag mapping: {e}")
152
- else:
153
- # This should ideally not happen if download worked
154
- raise gr.Error(f"Tag mapping file not found at expected path: {self.tag_mapping_path}")
155
- else:
156
- print("Labels already loaded.")
157
-
158
- # Restore model loading function
159
- def _load_model_on_gpu(self):
160
- # Only load if not already loaded on the correct device
161
- if self.model is not None and next(self.model.parameters()).device == self.device:
162
- print("Model already loaded on the correct device.")
163
- return True # Indicate success
164
-
165
- print("Loading PyTorch model for GPU worker...")
166
- if not self.safetensors_path or not self.labels_data:
167
- print("Error: Model paths or labels not initialized before loading.")
168
- return False # Indicate failure
169
- try:
170
- num_classes = len(self.labels_data.names)
171
- if num_classes <= 0: raise ValueError(f"Invalid num_classes: {num_classes}")
172
- print(f"Creating base model: {BASE_MODEL_NAME} with {num_classes} classes")
173
- # Load model structure (without pretrained weights initially if possible, or handle mismatch)
174
- # Using pretrained=True might download weights we immediately overwrite
175
- model = timm.create_model(BASE_MODEL_NAME, pretrained=True, num_classes=num_classes)
176
-
177
- print(f"Loading state dict from: {self.safetensors_path}")
178
- if not os.path.exists(self.safetensors_path): raise FileNotFoundError(f"File not found: {self.safetensors_path}")
179
- state_dict = safe_load_file(self.safetensors_path)
180
-
181
- # --- Key Adaptation Logic (Important!) ---
182
- # Assuming direct match based on previous code structure
183
- adapted_state_dict = state_dict
184
- # Example if keys were prefixed with 'base_model.':
185
- # adapted_state_dict = {k.replace('base_model.', ''): v for k, v in state_dict.items()}
186
- # -----------------------------------------
187
-
188
- print("Loading state dict into model...")
189
- missing_keys, unexpected_keys = model.load_state_dict(adapted_state_dict, strict=False)
190
- # Only print if there are actually missing/unexpected keys
191
- if missing_keys: print(f"State dict loaded. Missing keys: {missing_keys}")
192
- if unexpected_keys: print(f"State dict loaded. Unexpected keys: {unexpected_keys}")
193
- if any(k.startswith('head.') for k in missing_keys): print("Warning: Head weights seem missing/mismatched!")
194
-
195
- print(f"Moving model to device: {self.device}")
196
- model.to(self.device)
197
- model.eval()
198
- self.model = model # Store loaded model
199
- print("Model loaded successfully on GPU worker.")
200
- return True # Indicate success
201
- except Exception as e:
202
- print(f"(Worker) Error loading PyTorch model: {e}")
203
- import traceback; print(traceback.format_exc())
204
- # raise gr.Error(f"Error loading PyTorch model: {e}") # Don't raise here, return status
205
- return False # Indicate failure
206
-
207
- # Restore predict_on_gpu, but modify it to ONLY test model loading
208
- @spaces.GPU()
209
- def predict_on_gpu(self, image_input, gen_threshold, char_threshold, output_mode):
210
- print("--- predict_on_gpu function started (GPU worker - TESTING MODEL LOAD) ---")
211
-
212
- # Attempt to load the model
213
- load_success = self._load_model_on_gpu()
214
-
215
- if load_success:
216
- message = "Model loading successful on GPU worker."
217
- print(message)
218
- # Optional: Check model device again after loading
219
- if self.model is not None:
220
- print(f"Model device after load: {next(self.model.parameters()).device}")
221
- else:
222
- print("Model object is None even after successful load reported?")
223
- else:
224
- message = "Error: Model could not be loaded on GPU worker. Check logs."
225
- print(message)
226
-
227
- # Return only the status message for this test, and None for the image output
228
- return message, None
229
-
230
- # --- Original prediction logic (commented out for this test) ---
231
- # if self.model is None:
232
- # return "Error: Model could not be loaded on GPU worker.", None
233
- # if image_input is None: return "Please upload an image.", None
234
- # ... (image loading, preprocessing, inference, postprocessing) ...
235
-
236
- # Instantiate the tagger class (this will download files/load labels)
237
- tagger = Tagger()
238
 
239
- # --- Gradio Interface Definition (Minimal) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  with gr.Blocks() as demo:
241
  gr.Markdown("""
242
- # Tagger Initialization + Model Load Test
243
- Instantiates Tagger, then click the button below to attempt loading the model via `@spaces.GPU`.
244
- Check logs for Tagger initialization and model loading messages.
245
  """)
246
  with gr.Column():
247
- # Keep using the same button name for simplicity for now
248
- test_button = gr.Button("Test Model Load on GPU")
249
  output_text = gr.Textbox(label="Output")
250
- # Add dummy components to match the signature of the real predict_on_gpu eventually
251
- # These won't be used by the button click directly but might be needed if we switch fn later
252
- dummy_image = gr.Image(visible=False) # Hidden image input
253
- dummy_gen_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.55, visible=False)
254
- dummy_char_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.60, visible=False)
255
- dummy_radio = gr.Radio(choices=["Tags Only", "Tags + Visualization"], value="Tags + Visualization", visible=False)
256
- dummy_vis_output = gr.Image(visible=False) # Hidden image output
257
 
258
  test_button.click(
259
- fn=tagger.predict_on_gpu,
260
- # Provide dummy inputs matching the function signature
261
- # We only care about the first output (text) for this test
262
- inputs=[dummy_image, dummy_gen_slider, dummy_char_slider, dummy_radio],
263
- outputs=[output_text, dummy_vis_output] # Map outputs
264
  )
265
 
266
  # --- Main Block ---
267
  if __name__ == "__main__":
268
  if not os.environ.get("HF_TOKEN"): print("Warning: HF_TOKEN environment variable not set.")
269
- # Tagger instance is created above
 
 
270
  demo.launch(share=True)
271
 
272
  # --- Commented out original UI and helpers/constants not needed for init/simple test ---
 
13
  from typing import List, Dict, Optional, Tuple
14
  import time
15
  import spaces # Required for @spaces.GPU
16
+ import onnxruntime as ort # Use ONNX Runtime
17
 
18
  import torch # Keep torch for device check in Tagger
19
  import timm # Restore timm
 
34
  meta: list[np.int64]
35
  quality: list[np.int64]
36
 
37
+ def pil_ensure_rgb(image: Image.Image) -> Image.Image:
38
+ if image.mode not in ["RGB", "RGBA"]:
39
+ image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
40
+ if image.mode == "RGBA":
41
+ background = Image.new("RGB", image.size, (255, 255, 255))
42
+ background.paste(image, mask=image.split()[3])
43
+ image = background
44
+ return image
45
+
46
+ def pil_pad_square(image: Image.Image) -> Image.Image:
47
+ width, height = image.size
48
+ if width == height: return image
49
+ new_size = max(width, height)
50
+ new_image = Image.new(image.mode, (new_size, new_size), (255, 255, 255)) # Use image.mode
51
+ paste_position = ((new_size - width) // 2, (new_size - height) // 2)
52
+ new_image.paste(image, paste_position)
53
+ return new_image
54
+
55
  def load_tag_mapping(mapping_path):
56
+ # Use the implementation from the original app.py as it was confirmed working
57
  with open(mapping_path, 'r', encoding='utf-8') as f: tag_mapping_data = json.load(f)
58
+ # Check format compatibility (can be dict of dicts or dict with idx_to_tag/tag_to_category)
59
  if isinstance(tag_mapping_data, dict) and "idx_to_tag" in tag_mapping_data:
60
  idx_to_tag = {int(k): v for k, v in tag_mapping_data["idx_to_tag"].items()}
61
  tag_to_category = tag_mapping_data["tag_to_category"]
62
  elif isinstance(tag_mapping_data, dict):
63
+ # Assuming the dict-of-dicts format from previous tests
64
+ try:
65
+ tag_mapping_data_int_keys = {int(k): v for k, v in tag_mapping_data.items()}
66
+ idx_to_tag = {idx: data['tag'] for idx, data in tag_mapping_data_int_keys.items()}
67
+ tag_to_category = {data['tag']: data['category'] for data in tag_mapping_data_int_keys.values()}
68
+ except (KeyError, ValueError) as e:
69
+ raise ValueError(f"Unsupported tag mapping format (dict): {e}. Expected int keys with 'tag' and 'category'.")
70
+ else:
71
+ raise ValueError("Unsupported tag mapping format: Expected a dictionary.")
72
+
73
  names = [None] * (max(idx_to_tag.keys()) + 1)
74
  rating, general, artist, character, copyright, meta, quality = [], [], [], [], [], [], []
75
  for idx, tag in idx_to_tag.items():
76
  if idx >= len(names): names.extend([None] * (idx - len(names) + 1))
77
  names[idx] = tag
78
+ category = tag_to_category.get(tag, 'Unknown') # Handle missing category mapping gracefully
79
  idx_int = int(idx)
80
  if category == 'Rating': rating.append(idx_int)
81
  elif category == 'General': general.append(idx_int)
 
84
  elif category == 'Copyright': copyright.append(idx_int)
85
  elif category == 'Meta': meta.append(idx_int)
86
  elif category == 'Quality': quality.append(idx_int)
87
+
88
+ return LabelData(names=names, rating=np.array(rating, dtype=np.int64), general=np.array(general, dtype=np.int64), artist=np.array(artist, dtype=np.int64),
89
+ character=np.array(character, dtype=np.int64), copyright=np.array(copyright, dtype=np.int64), meta=np.array(meta, dtype=np.int64), quality=np.array(quality, dtype=np.int64)), idx_to_tag, tag_to_category
90
+
91
+ def preprocess_image(image: Image.Image, target_size=(448, 448)):
92
+ # Adapted from onnx_predict.py's version
93
+ image = pil_ensure_rgb(image)
94
+ image = pil_pad_square(image)
95
+ image_resized = image.resize(target_size, Image.BICUBIC)
96
+ img_array = np.array(image_resized, dtype=np.float32) / 255.0
97
+ img_array = img_array.transpose(2, 0, 1) # HWC -> CHW
98
+ # Assuming model expects RGB based on original code, no BGR conversion here
99
+ # img_array = img_array[::-1, :, :] # BGR conversion if needed
100
+ mean = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(3, 1, 1)
101
+ std = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(3, 1, 1)
102
+ img_array = (img_array - mean) / std
103
+ img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
104
+ return image, img_array
105
 
106
  # --- Constants ---
107
  REPO_ID = "cella110n/cl_tagger"
108
+ # Use the specified ONNX model filename
109
+ ONNX_FILENAME = "cl_eva02_tagger_v1_250426/model.onnx"
110
+ # Keep the previously used tag mapping filename
111
  TAG_MAPPING_FILENAME = "lora_model_0426/tag_mapping.json"
112
  CACHE_DIR = "./model_cache"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
+ # --- Global variables for paths (initialized at startup) ---
115
+ g_onnx_model_path = None
116
+ g_tag_mapping_path = None
117
+ g_labels_data = None
118
+ g_idx_to_tag = None
119
+ g_tag_to_category = None
120
+
121
+ # --- Initialization Function ---
122
+ def initialize_onnx_paths():
123
+ global g_onnx_model_path, g_tag_mapping_path, g_labels_data, g_idx_to_tag, g_tag_to_category
124
+ print("Initializing ONNX paths and labels...")
125
+ hf_token = os.environ.get("HF_TOKEN")
126
+ try:
127
+ print(f"Attempting to download ONNX model: {ONNX_FILENAME}")
128
+ g_onnx_model_path = hf_hub_download(repo_id=REPO_ID, filename=ONNX_FILENAME, cache_dir=CACHE_DIR, token=hf_token, force_download=False)
129
+ print(f"ONNX model path: {g_onnx_model_path}")
130
+
131
+ print(f"Attempting to download Tag mapping: {TAG_MAPPING_FILENAME}")
132
+ g_tag_mapping_path = hf_hub_download(repo_id=REPO_ID, filename=TAG_MAPPING_FILENAME, cache_dir=CACHE_DIR, token=hf_token, force_download=False)
133
+ print(f"Tag mapping path: {g_tag_mapping_path}")
134
+
135
+ print("Loading labels from mapping...")
136
+ g_labels_data, g_idx_to_tag, g_tag_to_category = load_tag_mapping(g_tag_mapping_path)
137
+ print(f"Labels loaded. Count: {len(g_labels_data.names)}")
138
+
139
+ except Exception as e:
140
+ print(f"Error during initialization: {e}")
141
+ import traceback; traceback.print_exc()
142
+ # Raise Gradio error to make it visible in the UI
143
+ raise gr.Error(f"Initialization failed: {e}. Check logs and HF_TOKEN.")
144
+
145
+ # --- ONNX Loading Test Function ---
146
+ @spaces.GPU()
147
+ def test_onnx_load():
148
+ print("--- test_onnx_load function started (GPU worker) ---")
149
+ if g_onnx_model_path is None:
150
+ message = "Error: ONNX model path not initialized. Check startup logs."
151
+ print(message)
152
+ return message
153
+
154
+ if not os.path.exists(g_onnx_model_path):
155
+ message = f"Error: ONNX file not found at {g_onnx_model_path}. Check download."
156
+ print(message)
157
+ return message
158
+
159
+ try:
160
+ print(f"Attempting to load ONNX session from: {g_onnx_model_path}")
161
+ # Determine providers (GPU if available)
162
+ available_providers = ort.get_available_providers()
163
+ print(f"Available ORT providers: {available_providers}")
164
+ providers = []
165
+ # Prioritize GPU providers
166
+ if 'CUDAExecutionProvider' in available_providers:
167
+ print("CUDAExecutionProvider found.")
168
+ providers.append('CUDAExecutionProvider')
169
+ elif 'DmlExecutionProvider' in available_providers: # For Windows with DirectML
170
+ print("DmlExecutionProvider found.")
171
+ providers.append('DmlExecutionProvider')
172
+ # Always include CPU as fallback
173
+ providers.append('CPUExecutionProvider')
174
+
175
+ print(f"Attempting to load session with providers: {providers}")
176
+ session = ort.InferenceSession(g_onnx_model_path, providers=providers)
177
+ active_provider = session.get_providers()[0]
178
+ message = f"ONNX session loaded successfully on GPU worker using provider: {active_provider}"
179
+ print(message)
180
+ # Clean up session immediately after test?
181
+ # del session # Optional, depends if we want to keep it loaded
182
+
183
+ except Exception as e:
184
+ message = f"Error loading ONNX session: {e}"
185
+ print(message)
186
+ import traceback; traceback.print_exc()
187
+
188
+ return message
189
+
190
+ # --- Gradio Interface Definition (Minimal for ONNX Load Test) ---
191
  with gr.Blocks() as demo:
192
  gr.Markdown("""
193
+ # ONNX Model Load Test
194
+ Downloads ONNX model and tag mapping, then attempts to load the ONNX session on the GPU worker when the button is clicked.
195
+ Check logs for download and loading messages.
196
  """)
197
  with gr.Column():
198
+ test_button = gr.Button("Test ONNX Load on GPU")
 
199
  output_text = gr.Textbox(label="Output")
 
 
 
 
 
 
 
200
 
201
  test_button.click(
202
+ fn=test_onnx_load,
203
+ inputs=[],
204
+ outputs=[output_text]
 
 
205
  )
206
 
207
  # --- Main Block ---
208
  if __name__ == "__main__":
209
  if not os.environ.get("HF_TOKEN"): print("Warning: HF_TOKEN environment variable not set.")
210
+ # Initialize paths and labels at startup
211
+ initialize_onnx_paths()
212
+ # Launch Gradio app
213
  demo.launch(share=True)
214
 
215
  # --- Commented out original UI and helpers/constants not needed for init/simple test ---
requirements.txt CHANGED
@@ -2,7 +2,7 @@
2
  torch
3
  torchvision
4
  torchaudio
5
- # onnxruntime-gpu==1.19.0 # Removed ONNX Runtime
6
  safetensors
7
  transformers
8
  timm # Needed for EVA02 base model
 
2
  torch
3
  torchvision
4
  torchaudio
5
+ onnxruntime-gpu==1.19.0 # Removed ONNX Runtime
6
  safetensors
7
  transformers
8
  timm # Needed for EVA02 base model