Ilvir commited on
Commit
96cc24c
·
1 Parent(s): 9f33ac3

Update gpt.py

Browse files
Files changed (1) hide show
  1. gpt.py +40 -20
gpt.py CHANGED
@@ -1,7 +1,7 @@
1
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
2
  import streamlit as st
3
  import torch
4
-
5
 
6
  tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
7
  model = GPT2LMHeadModel.from_pretrained(
@@ -16,32 +16,52 @@ model.load_state_dict(torch.load('modelgpt.pt', map_location=torch.device('cpu')
16
 
17
 
18
  prompt = st.text_input('Введите текст prompt:')
19
- length = st.slider('Длина генерируемой последовательности:', 10, 1000, 15)
20
  num_samples = st.slider('Число генераций:', 1, 10, 1)
21
  temperature = st.slider('Температура:', 1.0, 10.0, 2.0)
 
 
22
 
23
 
 
 
 
 
 
 
 
 
24
 
 
 
 
 
25
 
26
- def generate_text(model, tokenizer, prompt, length, num_samples, temperature):
27
- input_ids = tokenizer.encode(prompt, return_tensors='pt')
28
- output_sequences = model.generate(
29
- input_ids=input_ids,
30
- max_length=length,
31
- num_return_sequences=num_samples,
32
- temperature=temperature
33
- )
34
-
35
- generated_texts = []
36
- for output_sequence in output_sequences:
37
- generated_text = tokenizer.decode(output_sequence, clean_up_tokenization_spaces=True)
38
- generated_texts.append(generated_text)
39
-
40
- return generated_texts
41
 
42
 
43
  if st.button('Сгенерировать текст'):
44
- generated_texts = generate_text(model, tokenizer, prompt, length, num_samples, temperature)
45
- for i, text in enumerate(generated_texts):
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  st.write(f'Текст {i+1}:')
47
- st.write(text)
 
 
 
 
 
 
 
 
1
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
2
  import streamlit as st
3
  import torch
4
+ import textwrap
5
 
6
  tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
7
  model = GPT2LMHeadModel.from_pretrained(
 
16
 
17
 
18
  prompt = st.text_input('Введите текст prompt:')
19
+ length = st.slider('Длина генерируемой последовательности:', 8, 256, 15)
20
  num_samples = st.slider('Число генераций:', 1, 10, 1)
21
  temperature = st.slider('Температура:', 1.0, 10.0, 2.0)
22
+ top_k = st.slider('Количество наиболее вероятных слов генерации:', 10, 200, 50)
23
+ top_k = st.slider('Минимальная суммарная вероятность топовых слов:', 0.4, 1.0, 0.9)
24
 
25
 
26
+ # def generate_text(model, tokenizer, prompt, length, num_samples, temperature):
27
+ # input_ids = tokenizer.encode(prompt, return_tensors='pt')
28
+ # output_sequences = model.generate(
29
+ # input_ids=input_ids,
30
+ # max_length=length,
31
+ # num_return_sequences=num_samples,
32
+ # temperature=temperature
33
+ # )
34
 
35
+ # generated_texts = []
36
+ # for output_sequence in output_sequences:
37
+ # generated_text = tokenizer.decode(output_sequence, clean_up_tokenization_spaces=True)
38
+ # generated_texts.append(generated_text)
39
 
40
+ # return generated_texts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
 
43
  if st.button('Сгенерировать текст'):
44
+
45
+ with torch.inference_mode():
46
+ prompt = tokenizer.encode(prompt, return_tensors='pt').to(device)
47
+ out = model.generate(
48
+ input_ids=prompt,
49
+ max_length=length,
50
+ num_beams=num_samples,
51
+ do_sample=True,
52
+ temperature=temperature,
53
+ top_k=top_k,
54
+ top_p=top_p,
55
+ no_repeat_ngram_size=3,
56
+ num_return_sequences=3,
57
+ ).cpu().numpy()
58
+ for i, out_ in enumerate(out):
59
  st.write(f'Текст {i+1}:')
60
+ st.write(textwrap.fill(tokenizer.decode(out_), 100))
61
+
62
+
63
+
64
+ # generated_texts = generate_text(model, tokenizer, prompt, length, num_samples, temperature)
65
+ # for i, text in enumerate(generated_texts):
66
+ # st.write(f'Текст {i+1}:')
67
+ # st.write(text)