AbstractPhil commited on
Commit
9dc2118
·
1 Parent(s): f7e1fb5
Files changed (1) hide show
  1. app.py +14 -1
app.py CHANGED
@@ -32,7 +32,7 @@ except ImportError:
32
  print("⚠ Triton not configured for MX - run install.sh")
33
 
34
  # ===== MAIN IMPORTS =====
35
- import os, gc, json, torch, warnings, traceback
36
  import subprocess, sys
37
  from dataclasses import dataclass
38
  from typing import List, Dict, Optional, Any, Union
@@ -42,6 +42,9 @@ import spaces # required for ZeroGPU
42
  from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
43
  import numpy as np
44
 
 
 
 
45
  # Suppress warnings
46
  warnings.filterwarnings("ignore", message=".*microscaling.*")
47
  warnings.filterwarnings("ignore", message=".*mx.*")
@@ -169,6 +172,8 @@ def detect_mx_format(model) -> bool:
169
 
170
  def load_base_model(device_map: Optional[str] = "auto") -> AutoModelForCausalLM:
171
  """Load the base model with proper MX format handling."""
 
 
172
  print(f"\n{'='*50}")
173
  print(f"Loading model: {MODEL_ID}")
174
  print(f"MX Format Available: {_HAS_TRITON_KERNELS}")
@@ -198,6 +203,8 @@ def load_base_model(device_map: Optional[str] = "auto") -> AutoModelForCausalLM:
198
  else:
199
  print("⚠ No triton_kernels - falling back to bf16 (dequantized)")
200
  print(" This will likely cause LoRA compatibility issues!")
 
 
201
  load_kwargs["torch_dtype"] = torch.bfloat16
202
 
203
  # Explicitly disable MX
@@ -205,6 +212,7 @@ def load_base_model(device_map: Optional[str] = "auto") -> AutoModelForCausalLM:
205
  os.environ["FORCE_MX_QUANTIZATION"] = "0"
206
  else:
207
  # Non-GPT-OSS models
 
208
  load_kwargs["torch_dtype"] = torch.bfloat16
209
 
210
  try:
@@ -240,6 +248,7 @@ def load_base_model(device_map: Optional[str] = "auto") -> AutoModelForCausalLM:
240
 
241
  # Try to load without MX as fallback
242
  print("Attempting to load model without MX format...")
 
243
  load_kwargs["torch_dtype"] = torch.bfloat16
244
  os.environ["FORCE_MX_QUANTIZATION"] = "0"
245
  model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **load_kwargs)
@@ -425,6 +434,8 @@ def generate_on_gpu(
425
  seed: Optional[int]
426
  ) -> Dict[str, str]:
427
  """Run generation on GPU."""
 
 
428
  try:
429
  # Set seed if provided
430
  if seed is not None:
@@ -441,6 +452,7 @@ def generate_on_gpu(
441
  model.eval()
442
 
443
  # Prepare inputs
 
444
  device = next(model.parameters()).device
445
 
446
  if HARMONY_AVAILABLE and isinstance(prompt, list):
@@ -492,6 +504,7 @@ def generate_on_gpu(
492
 
493
  finally:
494
  # Cleanup
 
495
  if 'model' in locals():
496
  del model
497
  gc.collect()
 
32
  print("⚠ Triton not configured for MX - run install.sh")
33
 
34
  # ===== MAIN IMPORTS =====
35
+ import os, gc, json, warnings, traceback
36
  import subprocess, sys
37
  from dataclasses import dataclass
38
  from typing import List, Dict, Optional, Any, Union
 
42
  from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
43
  import numpy as np
44
 
45
+ # IMPORTANT: Don't import torch at module level for ZeroGPU
46
+ # It will be imported inside GPU-decorated functions
47
+
48
  # Suppress warnings
49
  warnings.filterwarnings("ignore", message=".*microscaling.*")
50
  warnings.filterwarnings("ignore", message=".*mx.*")
 
172
 
173
  def load_base_model(device_map: Optional[str] = "auto") -> AutoModelForCausalLM:
174
  """Load the base model with proper MX format handling."""
175
+ import torch # Import torch here for ZeroGPU compatibility
176
+
177
  print(f"\n{'='*50}")
178
  print(f"Loading model: {MODEL_ID}")
179
  print(f"MX Format Available: {_HAS_TRITON_KERNELS}")
 
203
  else:
204
  print("⚠ No triton_kernels - falling back to bf16 (dequantized)")
205
  print(" This will likely cause LoRA compatibility issues!")
206
+ # Load the model - torch imported inside function
207
+ import torch
208
  load_kwargs["torch_dtype"] = torch.bfloat16
209
 
210
  # Explicitly disable MX
 
212
  os.environ["FORCE_MX_QUANTIZATION"] = "0"
213
  else:
214
  # Non-GPT-OSS models
215
+ import torch
216
  load_kwargs["torch_dtype"] = torch.bfloat16
217
 
218
  try:
 
248
 
249
  # Try to load without MX as fallback
250
  print("Attempting to load model without MX format...")
251
+ import torch
252
  load_kwargs["torch_dtype"] = torch.bfloat16
253
  os.environ["FORCE_MX_QUANTIZATION"] = "0"
254
  model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **load_kwargs)
 
434
  seed: Optional[int]
435
  ) -> Dict[str, str]:
436
  """Run generation on GPU."""
437
+ import torch # Import torch inside GPU function for ZeroGPU
438
+
439
  try:
440
  # Set seed if provided
441
  if seed is not None:
 
452
  model.eval()
453
 
454
  # Prepare inputs
455
+ import torch # Make sure torch is available
456
  device = next(model.parameters()).device
457
 
458
  if HARMONY_AVAILABLE and isinstance(prompt, list):
 
504
 
505
  finally:
506
  # Cleanup
507
+ import torch
508
  if 'model' in locals():
509
  del model
510
  gc.collect()