AkashDataScience commited on
Commit
c9281d9
·
1 Parent(s): 467a29f

Updated device

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -20,7 +20,7 @@ enc = tiktoken.get_encoding('gpt2')
20
  def inference(input_text, num_return_sequences, max_length):
21
  input_tokens = torch.tensor(enc.encode(input_text), dtype=torch.long)
22
  input_tokens = input_tokens.unsqueeze(0).repeat(num_return_sequences, 1)
23
- x = input_tokens.to('cuda')
24
 
25
  while x.size(1) < max_length:
26
  # forward the model to get the logits
 
20
  def inference(input_text, num_return_sequences, max_length):
21
  input_tokens = torch.tensor(enc.encode(input_text), dtype=torch.long)
22
  input_tokens = input_tokens.unsqueeze(0).repeat(num_return_sequences, 1)
23
+ x = input_tokens.to(device)
24
 
25
  while x.size(1) < max_length:
26
  # forward the model to get the logits