File size: 2,589 Bytes
7d6f77f
 
 
 
518485c
da9ee26
7d6f77f
 
47d0a4a
7d6f77f
 
724876e
1bc822d
724876e
7d6f77f
724876e
7d6f77f
 
 
 
 
a16dba0
1bc822d
bb72c45
7d6f77f
b86439f
7d6f77f
 
 
 
 
 
b86439f
1bc822d
7d6f77f
 
 
 
 
e1458fb
7d6f77f
da9ee26
 
 
 
 
7d6f77f
 
 
b86439f
7d6f77f
 
 
ded5ed0
da9ee26
 
4bd4566
 
 
 
 
 
1bc822d
 
 
 
4bd4566
ded5ed0
da9ee26
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import transformers
import torch
import tokenizers
import streamlit as st
import re
from PIL import Image


@st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, tokenizers.AddedToken: lambda _: None, re.Pattern: lambda _: None}, allow_output_mutation=True, suppress_st_warning=True)
def get_model(model_name, model_path):
    tokenizer = transformers.GPT2Tokenizer.from_pretrained(model_name)
    tokenizer.add_special_tokens({
        'eos_token': '[EOS]'
    })
    model = transformers.GPT2LMHeadModel.from_pretrained(model_name)
    model.resize_token_embeddings(len(tokenizer))
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    model.eval()
    return model, tokenizer


#@st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, tokenizers.AddedToken: lambda _: None, re.Pattern: lambda _: None}, allow_output_mutation=True, suppress_st_warning=True)
def predict(text, model, tokenizer, n_beams=5, temperature=2.5, top_p=0.8, max_length=300):
    text += '\n'
    input_ids = tokenizer.encode(text, return_tensors="pt")
    length_of_prompt = len(input_ids[0])
    with torch.no_grad():
        out = model.generate(input_ids,
                             do_sample=True,
                             num_beams=n_beams,
                             temperature=temperature,
                             top_p=top_p,
                             max_length=max_length + length_of_prompt,
                             eos_token_id=tokenizer.eos_token_id
                             )

    return list(map(tokenizer.decode, out))[0]


model, tokenizer = get_model('sberbank-ai/rugpt3medium_based_on_gpt2', 'korzh-medium_best_eval_loss.bin')

# st.title("NeuroKorzh")
# st.markdown("<img width=400px src='https://the-flow.ru/uploads/images/resize/830x0/adaptiveResize/05/06/06/42/25/8c7405840cd7.jpg'>",
#              unsafe_allow_html=True)
image = Image.open('korzh.jpg')
st.image(image, caption='NeuroKorzh')

st.markdown("\n")

text = st.text_input(label='Starting point for text generation', value='Что делать, Макс?')
button = st.button('Go')

if button:
    #try:
    with st.spinner("Generation in progress"):
        result = predict(text, model, tokenizer)
    
    #st.subheader('Max Korzh:')
    #lines = result.split('\n')
    #for line in lines:
    #    st.write(line)
    
    #lines = result.replace('\n', '\n\n')
    #st.write(lines)
    
    st.text_area(label='', value=result)
    
    #except Exception:
    #    st.error("Ooooops, something went wrong. Try again please and report to me, tg: @vladyur")