ejschwartz commited on
Commit
d5e44b4
·
1 Parent(s): 803e1b0

Name constants

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -11,6 +11,10 @@ import huggingface_hub
11
 
12
  import prep_decompiled
13
 
 
 
 
 
14
  hf_key = os.environ["HF_TOKEN"]
15
  huggingface_hub.login(token=hf_key)
16
 
@@ -85,11 +89,11 @@ def infer(code):
85
  print(f"Prompt:\n{repr(var_prompt)}")
86
 
87
  var_input_ids = tokenizer.encode(var_prompt, return_tensors="pt").cuda()[
88
- :, : 8192 - 1024
89
  ]
90
  var_output = vardecoder_model.generate(
91
  input_ids=var_input_ids,
92
- max_new_tokens=1024,
93
  num_beams=4,
94
  num_return_sequences=1,
95
  do_sample=False,
@@ -112,12 +116,12 @@ def infer(code):
112
  field_output = "Failed to parse fields" if field_prompt_result is None else "No fields"
113
  else:
114
  field_input_ids = tokenizer.encode(field_prompt_result, return_tensors="pt").cuda()[
115
- :, : 8192 - 1024
116
  ]
117
 
118
  field_output = fielddecoder_model.generate(
119
  input_ids=field_input_ids,
120
- max_new_tokens=1024,
121
  num_beams=4,
122
  num_return_sequences=1,
123
  do_sample=False,
 
11
 
12
  import prep_decompiled
13
 
14
+ # Model configuration constants
15
+ MAX_CONTEXT_LENGTH = 8192
16
+ MAX_NEW_TOKENS = 1024
17
+
18
  hf_key = os.environ["HF_TOKEN"]
19
  huggingface_hub.login(token=hf_key)
20
 
 
89
  print(f"Prompt:\n{repr(var_prompt)}")
90
 
91
  var_input_ids = tokenizer.encode(var_prompt, return_tensors="pt").cuda()[
92
+ :, : MAX_CONTEXT_LENGTH - MAX_NEW_TOKENS
93
  ]
94
  var_output = vardecoder_model.generate(
95
  input_ids=var_input_ids,
96
+ max_new_tokens=MAX_NEW_TOKENS,
97
  num_beams=4,
98
  num_return_sequences=1,
99
  do_sample=False,
 
116
  field_output = "Failed to parse fields" if field_prompt_result is None else "No fields"
117
  else:
118
  field_input_ids = tokenizer.encode(field_prompt_result, return_tensors="pt").cuda()[
119
+ :, : MAX_CONTEXT_LENGTH - MAX_NEW_TOKENS
120
  ]
121
 
122
  field_output = fielddecoder_model.generate(
123
  input_ids=field_input_ids,
124
+ max_new_tokens=MAX_NEW_TOKENS,
125
  num_beams=4,
126
  num_return_sequences=1,
127
  do_sample=False,