ejschwartz commited on
Commit
4c9f7ae
·
1 Parent(s): 5672f53

Do not use safetensors

Browse files
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -38,11 +38,12 @@ print("Loaded vardecoder model successfully.")
38
  logger.info("Loading fielddecoder model...")
39
 
40
  fielddecoder_model = None
41
- #fielddecoder_model = AutoModelForCausalLM.from_pretrained(
42
- # "ejschwartz/resym-fielddecoder",
43
- # torch_dtype=torch.bfloat16,
44
- #)
45
- #logger.info("Successfully loaded fielddecoder model")
 
46
 
47
  make_gradio_client = lambda: Client("https://ejschwartz-resym-field-helper.hf.space/")
48
 
@@ -104,7 +105,7 @@ def infer(code):
104
 
105
  print(f"Prompt:\n{repr(var_prompt)}")
106
 
107
- var_input_ids = tokenizer.encode(var_prompt, return_tensors="pt").cuda()[
108
  :, : MAX_CONTEXT_LENGTH - MAX_NEW_TOKENS
109
  ]
110
  var_output = vardecoder_model.generate(
@@ -131,7 +132,7 @@ def infer(code):
131
  if len(fields) == 0:
132
  field_output = "Failed to parse fields" if field_prompt_result is None else "No fields"
133
  else:
134
- field_input_ids = tokenizer.encode(field_prompt_result, return_tensors="pt").cuda()[
135
  :, : MAX_CONTEXT_LENGTH - MAX_NEW_TOKENS
136
  ]
137
 
 
38
  logger.info("Loading fielddecoder model...")
39
 
40
  fielddecoder_model = None
41
+ fielddecoder_model = AutoModelForCausalLM.from_pretrained(
42
+ "ejschwartz/resym-fielddecoder",
43
+ torch_dtype=torch.bfloat16,
44
+ use_safetensors=False
45
+ )
46
+ logger.info("Successfully loaded fielddecoder model")
47
 
48
  make_gradio_client = lambda: Client("https://ejschwartz-resym-field-helper.hf.space/")
49
 
 
105
 
106
  print(f"Prompt:\n{repr(var_prompt)}")
107
 
108
+ var_input_ids = tokenizer.encode(var_prompt, return_tensors="pt").to(vardecoder_model.device)[
109
  :, : MAX_CONTEXT_LENGTH - MAX_NEW_TOKENS
110
  ]
111
  var_output = vardecoder_model.generate(
 
132
  if len(fields) == 0:
133
  field_output = "Failed to parse fields" if field_prompt_result is None else "No fields"
134
  else:
135
+ field_input_ids = tokenizer.encode(field_prompt_result, return_tensors="pt").to(fielddecoder_model.device)[
136
  :, : MAX_CONTEXT_LENGTH - MAX_NEW_TOKENS
137
  ]
138