Ilvir commited on
Commit
ccbc7d2
·
1 Parent(s): ed24903

Delete pages/gpt.py

Browse files
Files changed (1) hide show
  1. pages/gpt.py +0 -35
pages/gpt.py DELETED
@@ -1,35 +0,0 @@
1
- from transformers import GPT2LMHeadModel, GPT2Tokenizer
2
- import torch
3
-
4
- model = GPT2LMHeadModel.from_pretrained(
5
- 'sberbank-ai/rugpt3small_based_on_gpt2',
6
- output_attentions = False,
7
- output_hidden_states = False,
8
- )
9
- # Вешаем сохраненные веса на нашу модель
10
- model.load_state_dict(torch.load('models/model.pt'), map_location=torch.device('cpu'))
11
-
12
-
13
- def generate_text(model, tokenizer, prompt, length, num_samples, temperature):
14
- input_ids = tokenizer.encode(prompt, return_tensors='pt')
15
- output_sequences = model.generate(
16
- input_ids=input_ids,
17
- max_length=length,
18
- num_return_sequences=num_samples,
19
- temperature=temperature
20
- )
21
-
22
- generated_texts = []
23
- for output_sequence in output_sequences:
24
- generated_text = tokenizer.decode(output_sequence, clean_up_tokenization_spaces=True)
25
- generated_texts.append(generated_text)
26
-
27
- return generated_texts
28
-
29
-
30
- if st.button('Сгенерировать текст'):
31
- generated_texts = generate_text(model, tokenizer, prompt, length, num_samples, temperature)
32
- for i, text in enumerate(generated_texts):
33
- st.write(f'Текст {i+1}:')
34
- st.write(text)
35
-