Xinyoumeng233hu commited on
Commit
1c251e8
·
1 Parent(s): d6a5fdd

Update huffman_baseline.py

Browse files
Files changed (1) hide show
  1. huffman_baseline.py +2 -2
huffman_baseline.py CHANGED
@@ -4,7 +4,7 @@ import torch.nn.functional as F
4
  from huffman import HuffmanCoding
5
  from utils import kl, entropy, is_sent_finish, limit_past
6
 
7
- def encode_huffman(model, enc, message, context, bits_per_word, finish_sent=False, device='cuda'):
8
  length = len(message)
9
 
10
  context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
@@ -71,7 +71,7 @@ def encode_huffman(model, enc, message, context, bits_per_word, finish_sent=Fals
71
 
72
  return output[len(context):].tolist(), avg_NLL, avg_KL, words_per_bit
73
 
74
- def decode_huffman(model, enc, text, context, bits_per_word, device='cuda'):
75
  # inp is a list of token indices
76
  # context is a list of token indices
77
  inp = enc.encode(text)
 
4
  from huffman import HuffmanCoding
5
  from utils import kl, entropy, is_sent_finish, limit_past
6
 
7
+ def encode_huffman(model, enc, message, context, bits_per_word, finish_sent=False, device='cpu'):
8
  length = len(message)
9
 
10
  context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
 
71
 
72
  return output[len(context):].tolist(), avg_NLL, avg_KL, words_per_bit
73
 
74
+ def decode_huffman(model, enc, text, context, bits_per_word, device='cpu'):
75
  # inp is a list of token indices
76
  # context is a list of token indices
77
  inp = enc.encode(text)