Spaces:
Sleeping
Sleeping
Commit
·
0e35ac3
1
Parent(s):
1c251e8
Update meteor.py
Browse files
meteor.py
CHANGED
@@ -13,7 +13,7 @@ sample_seed_prefix = b'sample'
|
|
13 |
sample_nonce_counter = b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
|
14 |
|
15 |
|
16 |
-
def encode_meteor(model, enc, message, context, finish_sent=False, device='
|
17 |
|
18 |
if randomize_key:
|
19 |
input_key = os.urandom(64)
|
@@ -142,7 +142,7 @@ def encode_meteor(model, enc, message, context, finish_sent=False, device='cuda'
|
|
142 |
|
143 |
return output[len(context):].tolist(), avg_NLL, avg_KL, words_per_bit
|
144 |
|
145 |
-
def decode_meteor(model, enc, text, context, device='
|
146 |
# inp is a list of token indices
|
147 |
# context is a list of token indices
|
148 |
inp = enc.encode(text)
|
|
|
13 |
sample_nonce_counter = b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
|
14 |
|
15 |
|
16 |
+
def encode_meteor(model, enc, message, context, finish_sent=False, device='cpu', temp=1.0, precision=16, topk=50000, is_sort=False, randomize_key=False, input_key=sample_key, input_nonce=sample_nonce_counter):
|
17 |
|
18 |
if randomize_key:
|
19 |
input_key = os.urandom(64)
|
|
|
142 |
|
143 |
return output[len(context):].tolist(), avg_NLL, avg_KL, words_per_bit
|
144 |
|
145 |
+
def decode_meteor(model, enc, text, context, device='cpu', temp=1.0, precision=16, topk=50000, is_sort=False, input_key=sample_key, input_nonce=sample_nonce_counter):
|
146 |
# inp is a list of token indices
|
147 |
# context is a list of token indices
|
148 |
inp = enc.encode(text)
|