kovacsvi commited on
Commit
654bf8b
·
1 Parent(s): ca62943

jit tracing fix

Browse files
Files changed (1) hide show
  1. utils.py +29 -32
utils.py CHANGED
@@ -69,40 +69,37 @@ def download_hf_models():
69
  token=HF_TOKEN,
70
  device_map="auto"
71
  )
72
- tokenizer = AutoTokenizer.from_pretrained(
73
- model_id,
74
- token=HF_TOKEN
75
- )
76
-
77
- model.eval()
78
-
79
- # Dummy input for tracing
80
- dummy_input = tokenizer(
81
- "Hello, world!",
82
- return_tensors="pt",
83
- padding=True,
84
- truncation=True,
85
- max_length=256
86
- )
87
-
88
- # JIT trace
89
- traced_model = torch.jit.trace(
90
- model,
91
- (dummy_input["input_ids"], dummy_input["attention_mask"])
92
- )
93
-
94
- # Save traced model
95
  safe_model_name = model_id.replace("/", "_")
96
  traced_model_path = os.path.join(JIT_DIR, f"{safe_model_name}.pt")
97
- traced_model.save(traced_model_path)
98
- print(f"✔️ Saved JIT model to: {traced_model_path}")
99
-
100
- for tokenizer_id in tokenizers:
101
- print(f"Downloading tokenizer: {tokenizer_id}")
102
- AutoTokenizer.from_pretrained(
103
- tokenizer_id,
104
- token=HF_TOKEN
105
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  def df_h():
108
  result = subprocess.run(["df", "-H"], capture_output=True, text=True)
 
69
  token=HF_TOKEN,
70
  device_map="auto"
71
  )
72
+ tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-large")
73
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  safe_model_name = model_id.replace("/", "_")
75
  traced_model_path = os.path.join(JIT_DIR, f"{safe_model_name}.pt")
76
+
77
+ if os.path.exists(traced_model_path):
78
+ print(f"⏩ Skipping JIT — already exists: {traced_model_path}")
79
+ else:
80
+ print(f"⚙️ Tracing and saving: {traced_model_path}")
81
+
82
+ model.eval()
83
+
84
+ # Dummy input for tracing
85
+ dummy_input = tokenizer(
86
+ "Hello, world!",
87
+ return_tensors="pt",
88
+ padding=True,
89
+ truncation=True,
90
+ max_length=256
91
+ )
92
+
93
+ # JIT trace
94
+ traced_model = torch.jit.trace(
95
+ model,
96
+ (dummy_input["input_ids"], dummy_input["attention_mask"]),
97
+ strict=False
98
+ )
99
+
100
+ # Save traced model
101
+ traced_model.save(traced_model_path)
102
+ print(f"✔️ Saved JIT model to: {traced_model_path}")
103
 
104
  def df_h():
105
  result = subprocess.run(["df", "-H"], capture_output=True, text=True)