Ilvir commited on
Commit
e4b5631
·
1 Parent(s): 6dd30c8

Update gpt.py

Browse files
Files changed (1) hide show
  1. gpt.py +20 -22
gpt.py CHANGED
@@ -15,10 +15,6 @@ model.load_state_dict(torch.load('modelgpt.pt', map_location=torch.device('cpu')
15
 
16
  col1, col2 = st.columns([7, 7])
17
 
18
- with col2:
19
-
20
- prompt = st.text_input('Введите текст prompt:')
21
-
22
  with col1:
23
 
24
  length = st.slider('Длина генерируемой последовательности:', 8, 256, 15)
@@ -27,23 +23,25 @@ with col1:
27
  top_k = st.slider('Количество наиболее вероятных слов генерации:', 10, 200, 50)
28
  top_p = st.slider('Минимальная суммарная вероятность топовых слов:', 0.4, 1.0, 0.9)
29
 
 
30
 
 
31
 
32
- if st.button('Сгенерировать текст'):
33
-
34
- with torch.inference_mode():
35
- prompt = tokenizer.encode(prompt, return_tensors='pt')
36
- out = model.generate(
37
- input_ids=prompt,
38
- max_length=length,
39
- num_beams=5,
40
- do_sample=True,
41
- temperature=temperature,
42
- top_k=top_k,
43
- top_p=top_p,
44
- no_repeat_ngram_size=3,
45
- num_return_sequences=num_samples,
46
- ).cpu().numpy()
47
- for i, out_ in enumerate(out):
48
- st.write(f'Текст {i+1}:')
49
- st.write(textwrap.fill(tokenizer.decode(out_), 100))
 
15
 
16
  col1, col2 = st.columns([7, 7])
17
 
 
 
 
 
18
  with col1:
19
 
20
  length = st.slider('Длина генерируемой последовательности:', 8, 256, 15)
 
23
  top_k = st.slider('Количество наиболее вероятных слов генерации:', 10, 200, 50)
24
  top_p = st.slider('Минимальная суммарная вероятность топовых слов:', 0.4, 1.0, 0.9)
25
 
26
+ with col2:
27
 
28
+ prompt = st.text_input('Введите текст prompt:')
29
 
30
+ if st.button('Сгенерировать текст'):
31
+
32
+ with torch.inference_mode():
33
+ prompt = tokenizer.encode(prompt, return_tensors='pt')
34
+ out = model.generate(
35
+ input_ids=prompt,
36
+ max_length=length,
37
+ num_beams=5,
38
+ do_sample=True,
39
+ temperature=temperature,
40
+ top_k=top_k,
41
+ top_p=top_p,
42
+ no_repeat_ngram_size=3,
43
+ num_return_sequences=num_samples,
44
+ ).cpu().numpy()
45
+ for i, out_ in enumerate(out):
46
+ st.write(f'Текст {i+1}:')
47
+ st.write(textwrap.fill(tokenizer.decode(out_), 100))