Spaces:
Sleeping
Sleeping
Commit
·
c9281d9
1
Parent(s):
467a29f
Updated device
Browse files
app.py
CHANGED
@@ -20,7 +20,7 @@ enc = tiktoken.get_encoding('gpt2')
|
|
20 |
def inference(input_text, num_return_sequences, max_length):
|
21 |
input_tokens = torch.tensor(enc.encode(input_text), dtype=torch.long)
|
22 |
input_tokens = input_tokens.unsqueeze(0).repeat(num_return_sequences, 1)
|
23 |
-
x = input_tokens.to(
|
24 |
|
25 |
while x.size(1) < max_length:
|
26 |
# forward the model to get the logits
|
|
|
20 |
def inference(input_text, num_return_sequences, max_length):
|
21 |
input_tokens = torch.tensor(enc.encode(input_text), dtype=torch.long)
|
22 |
input_tokens = input_tokens.unsqueeze(0).repeat(num_return_sequences, 1)
|
23 |
+
x = input_tokens.to(device)
|
24 |
|
25 |
while x.size(1) < max_length:
|
26 |
# forward the model to get the logits
|