Spaces:
Sleeping
Sleeping
Commit
·
1c251e8
1
Parent(s):
d6a5fdd
Update huffman_baseline.py
Browse files- huffman_baseline.py +2 -2
huffman_baseline.py
CHANGED
@@ -4,7 +4,7 @@ import torch.nn.functional as F
|
|
4 |
from huffman import HuffmanCoding
|
5 |
from utils import kl, entropy, is_sent_finish, limit_past
|
6 |
|
7 |
-
def encode_huffman(model, enc, message, context, bits_per_word, finish_sent=False, device='
|
8 |
length = len(message)
|
9 |
|
10 |
context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
|
@@ -71,7 +71,7 @@ def encode_huffman(model, enc, message, context, bits_per_word, finish_sent=Fals
|
|
71 |
|
72 |
return output[len(context):].tolist(), avg_NLL, avg_KL, words_per_bit
|
73 |
|
74 |
-
def decode_huffman(model, enc, text, context, bits_per_word, device='
|
75 |
# inp is a list of token indices
|
76 |
# context is a list of token indices
|
77 |
inp = enc.encode(text)
|
|
|
4 |
from huffman import HuffmanCoding
|
5 |
from utils import kl, entropy, is_sent_finish, limit_past
|
6 |
|
7 |
+
def encode_huffman(model, enc, message, context, bits_per_word, finish_sent=False, device='cpu'):
|
8 |
length = len(message)
|
9 |
|
10 |
context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
|
|
|
71 |
|
72 |
return output[len(context):].tolist(), avg_NLL, avg_KL, words_per_bit
|
73 |
|
74 |
+
def decode_huffman(model, enc, text, context, bits_per_word, device='cpu'):
|
75 |
# inp is a list of token indices
|
76 |
# context is a list of token indices
|
77 |
inp = enc.encode(text)
|