Spaces:
Running
Running
kovacsvi
commited on
Commit
·
654bf8b
1
Parent(s):
ca62943
jit tracing fix
Browse files
utils.py
CHANGED
@@ -69,40 +69,37 @@ def download_hf_models():
|
|
69 |
token=HF_TOKEN,
|
70 |
device_map="auto"
|
71 |
)
|
72 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
73 |
-
|
74 |
-
token=HF_TOKEN
|
75 |
-
)
|
76 |
-
|
77 |
-
model.eval()
|
78 |
-
|
79 |
-
# Dummy input for tracing
|
80 |
-
dummy_input = tokenizer(
|
81 |
-
"Hello, world!",
|
82 |
-
return_tensors="pt",
|
83 |
-
padding=True,
|
84 |
-
truncation=True,
|
85 |
-
max_length=256
|
86 |
-
)
|
87 |
-
|
88 |
-
# JIT trace
|
89 |
-
traced_model = torch.jit.trace(
|
90 |
-
model,
|
91 |
-
(dummy_input["input_ids"], dummy_input["attention_mask"])
|
92 |
-
)
|
93 |
-
|
94 |
-
# Save traced model
|
95 |
safe_model_name = model_id.replace("/", "_")
|
96 |
traced_model_path = os.path.join(JIT_DIR, f"{safe_model_name}.pt")
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
def df_h():
|
108 |
result = subprocess.run(["df", "-H"], capture_output=True, text=True)
|
|
|
69 |
token=HF_TOKEN,
|
70 |
device_map="auto"
|
71 |
)
|
72 |
+
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-large")
|
73 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
safe_model_name = model_id.replace("/", "_")
|
75 |
traced_model_path = os.path.join(JIT_DIR, f"{safe_model_name}.pt")
|
76 |
+
|
77 |
+
if os.path.exists(traced_model_path):
|
78 |
+
print(f"⏩ Skipping JIT — already exists: {traced_model_path}")
|
79 |
+
else:
|
80 |
+
print(f"⚙️ Tracing and saving: {traced_model_path}")
|
81 |
+
|
82 |
+
model.eval()
|
83 |
+
|
84 |
+
# Dummy input for tracing
|
85 |
+
dummy_input = tokenizer(
|
86 |
+
"Hello, world!",
|
87 |
+
return_tensors="pt",
|
88 |
+
padding=True,
|
89 |
+
truncation=True,
|
90 |
+
max_length=256
|
91 |
+
)
|
92 |
+
|
93 |
+
# JIT trace
|
94 |
+
traced_model = torch.jit.trace(
|
95 |
+
model,
|
96 |
+
(dummy_input["input_ids"], dummy_input["attention_mask"]),
|
97 |
+
strict=False
|
98 |
+
)
|
99 |
+
|
100 |
+
# Save traced model
|
101 |
+
traced_model.save(traced_model_path)
|
102 |
+
print(f"✔️ Saved JIT model to: {traced_model_path}")
|
103 |
|
104 |
def df_h():
|
105 |
result = subprocess.run(["df", "-H"], capture_output=True, text=True)
|