AkashDataScience commited on
Commit
324c700
·
1 Parent(s): 749f119

Updated encoding

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -23,7 +23,7 @@ encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list
23
  decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a
24
 
25
  def inference(input_text, max_new_tokens=500):
26
- context = torch.tensor(encode(input_text), dtype=torch.long, device=device)
27
 
28
  output_text = decode(model.generate(context, max_new_tokens=max_new_tokens)[0].tolist())
29
 
 
23
  decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a
24
 
25
  def inference(input_text, max_new_tokens=500):
26
+ context = torch.tensor(encode(input_text), dtype=torch.long, device=device).view(1, -1)
27
 
28
  output_text = decode(model.generate(context, max_new_tokens=max_new_tokens)[0].tolist())
29