Spaces:
Running
on
Zero
Running
on
Zero
AbstractPhil
commited on
Commit
·
9dc2118
1
Parent(s):
f7e1fb5
yes
Browse files
app.py
CHANGED
@@ -32,7 +32,7 @@ except ImportError:
|
|
32 |
print("⚠ Triton not configured for MX - run install.sh")
|
33 |
|
34 |
# ===== MAIN IMPORTS =====
|
35 |
-
import os, gc, json,
|
36 |
import subprocess, sys
|
37 |
from dataclasses import dataclass
|
38 |
from typing import List, Dict, Optional, Any, Union
|
@@ -42,6 +42,9 @@ import spaces # required for ZeroGPU
|
|
42 |
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
43 |
import numpy as np
|
44 |
|
|
|
|
|
|
|
45 |
# Suppress warnings
|
46 |
warnings.filterwarnings("ignore", message=".*microscaling.*")
|
47 |
warnings.filterwarnings("ignore", message=".*mx.*")
|
@@ -169,6 +172,8 @@ def detect_mx_format(model) -> bool:
|
|
169 |
|
170 |
def load_base_model(device_map: Optional[str] = "auto") -> AutoModelForCausalLM:
|
171 |
"""Load the base model with proper MX format handling."""
|
|
|
|
|
172 |
print(f"\n{'='*50}")
|
173 |
print(f"Loading model: {MODEL_ID}")
|
174 |
print(f"MX Format Available: {_HAS_TRITON_KERNELS}")
|
@@ -198,6 +203,8 @@ def load_base_model(device_map: Optional[str] = "auto") -> AutoModelForCausalLM:
|
|
198 |
else:
|
199 |
print("⚠ No triton_kernels - falling back to bf16 (dequantized)")
|
200 |
print(" This will likely cause LoRA compatibility issues!")
|
|
|
|
|
201 |
load_kwargs["torch_dtype"] = torch.bfloat16
|
202 |
|
203 |
# Explicitly disable MX
|
@@ -205,6 +212,7 @@ def load_base_model(device_map: Optional[str] = "auto") -> AutoModelForCausalLM:
|
|
205 |
os.environ["FORCE_MX_QUANTIZATION"] = "0"
|
206 |
else:
|
207 |
# Non-GPT-OSS models
|
|
|
208 |
load_kwargs["torch_dtype"] = torch.bfloat16
|
209 |
|
210 |
try:
|
@@ -240,6 +248,7 @@ def load_base_model(device_map: Optional[str] = "auto") -> AutoModelForCausalLM:
|
|
240 |
|
241 |
# Try to load without MX as fallback
|
242 |
print("Attempting to load model without MX format...")
|
|
|
243 |
load_kwargs["torch_dtype"] = torch.bfloat16
|
244 |
os.environ["FORCE_MX_QUANTIZATION"] = "0"
|
245 |
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **load_kwargs)
|
@@ -425,6 +434,8 @@ def generate_on_gpu(
|
|
425 |
seed: Optional[int]
|
426 |
) -> Dict[str, str]:
|
427 |
"""Run generation on GPU."""
|
|
|
|
|
428 |
try:
|
429 |
# Set seed if provided
|
430 |
if seed is not None:
|
@@ -441,6 +452,7 @@ def generate_on_gpu(
|
|
441 |
model.eval()
|
442 |
|
443 |
# Prepare inputs
|
|
|
444 |
device = next(model.parameters()).device
|
445 |
|
446 |
if HARMONY_AVAILABLE and isinstance(prompt, list):
|
@@ -492,6 +504,7 @@ def generate_on_gpu(
|
|
492 |
|
493 |
finally:
|
494 |
# Cleanup
|
|
|
495 |
if 'model' in locals():
|
496 |
del model
|
497 |
gc.collect()
|
|
|
32 |
print("⚠ Triton not configured for MX - run install.sh")
|
33 |
|
34 |
# ===== MAIN IMPORTS =====
|
35 |
+
import os, gc, json, warnings, traceback
|
36 |
import subprocess, sys
|
37 |
from dataclasses import dataclass
|
38 |
from typing import List, Dict, Optional, Any, Union
|
|
|
42 |
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
43 |
import numpy as np
|
44 |
|
45 |
+
# IMPORTANT: Don't import torch at module level for ZeroGPU
|
46 |
+
# It will be imported inside GPU-decorated functions
|
47 |
+
|
48 |
# Suppress warnings
|
49 |
warnings.filterwarnings("ignore", message=".*microscaling.*")
|
50 |
warnings.filterwarnings("ignore", message=".*mx.*")
|
|
|
172 |
|
173 |
def load_base_model(device_map: Optional[str] = "auto") -> AutoModelForCausalLM:
|
174 |
"""Load the base model with proper MX format handling."""
|
175 |
+
import torch # Import torch here for ZeroGPU compatibility
|
176 |
+
|
177 |
print(f"\n{'='*50}")
|
178 |
print(f"Loading model: {MODEL_ID}")
|
179 |
print(f"MX Format Available: {_HAS_TRITON_KERNELS}")
|
|
|
203 |
else:
|
204 |
print("⚠ No triton_kernels - falling back to bf16 (dequantized)")
|
205 |
print(" This will likely cause LoRA compatibility issues!")
|
206 |
+
# Load the model - torch imported inside function
|
207 |
+
import torch
|
208 |
load_kwargs["torch_dtype"] = torch.bfloat16
|
209 |
|
210 |
# Explicitly disable MX
|
|
|
212 |
os.environ["FORCE_MX_QUANTIZATION"] = "0"
|
213 |
else:
|
214 |
# Non-GPT-OSS models
|
215 |
+
import torch
|
216 |
load_kwargs["torch_dtype"] = torch.bfloat16
|
217 |
|
218 |
try:
|
|
|
248 |
|
249 |
# Try to load without MX as fallback
|
250 |
print("Attempting to load model without MX format...")
|
251 |
+
import torch
|
252 |
load_kwargs["torch_dtype"] = torch.bfloat16
|
253 |
os.environ["FORCE_MX_QUANTIZATION"] = "0"
|
254 |
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **load_kwargs)
|
|
|
434 |
seed: Optional[int]
|
435 |
) -> Dict[str, str]:
|
436 |
"""Run generation on GPU."""
|
437 |
+
import torch # Import torch inside GPU function for ZeroGPU
|
438 |
+
|
439 |
try:
|
440 |
# Set seed if provided
|
441 |
if seed is not None:
|
|
|
452 |
model.eval()
|
453 |
|
454 |
# Prepare inputs
|
455 |
+
import torch # Make sure torch is available
|
456 |
device = next(model.parameters()).device
|
457 |
|
458 |
if HARMONY_AVAILABLE and isinstance(prompt, list):
|
|
|
504 |
|
505 |
finally:
|
506 |
# Cleanup
|
507 |
+
import torch
|
508 |
if 'model' in locals():
|
509 |
del model
|
510 |
gc.collect()
|