File size: 2,905 Bytes
7d6f77f
 
 
 
518485c
833aa0a
da9ee26
833aa0a
7d6f77f
 
47d0a4a
7d6f77f
 
724876e
1bc822d
724876e
7d6f77f
724876e
7d6f77f
 
 
 
 
a16dba0
d2e6254
bb72c45
7d6f77f
b86439f
7d6f77f
 
 
 
 
 
d2e6254
1bc822d
7d6f77f
 
 
 
 
833aa0a
 
7d6f77f
da9ee26
d2e6254
da9ee26
 
7d6f77f
833aa0a
 
7d6f77f
 
833aa0a
7d6f77f
 
 
ded5ed0
da9ee26
833aa0a
 
 
 
 
 
4bd4566
 
 
 
 
 
1bc822d
 
 
d2e6254
4bd4566
ded5ed0
da9ee26
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import transformers
import torch
import tokenizers
import streamlit as st
import re

from PIL import Image
from huggingface_hub import hf_hub_download


@st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, tokenizers.AddedToken: lambda _: None, re.Pattern: lambda _: None}, allow_output_mutation=True, suppress_st_warning=True)
def get_model(model_name, model_path):
    tokenizer = transformers.GPT2Tokenizer.from_pretrained(model_name)
    tokenizer.add_special_tokens({
        'eos_token': '[EOS]'
    })
    model = transformers.GPT2LMHeadModel.from_pretrained(model_name)
    model.resize_token_embeddings(len(tokenizer))
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    model.eval()
    return model, tokenizer


#@st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, tokenizers.AddedToken: lambda _: None, re.Pattern: lambda _: None}, allow_output_mutation=True, suppress_st_warning=True)
def predict(text, model, tokenizer, n_beams=5, temperature=2.5, top_p=0.8, length_of_generated=300):
    text += '\n'
    input_ids = tokenizer.encode(text, return_tensors="pt")
    length_of_prompt = len(input_ids[0])
    with torch.no_grad():
        out = model.generate(input_ids,
                             do_sample=True,
                             num_beams=n_beams,
                             temperature=temperature,
                             top_p=top_p,
                             max_length=length_of_prompt + length_of_generated,
                             eos_token_id=tokenizer.eos_token_id
                             )

    return list(map(tokenizer.decode, out))[0]


medium_model, medium_tokenizer = get_model('sberbank-ai/rugpt3medium_based_on_gpt2', 'korzh-medium_best_eval_loss.bin')
large_model, large_tokenizer = get_model('sberbank-ai/rugpt3large_based_on_gpt2', 'korzh-large_best_eval_loss.bin')

# st.title("NeuroKorzh")

image = Image.open('korzh.jpg')
st.image(image, caption='NeuroKorzh')

option = st.selectbox('Model to be used', ('medium', 'large'))

st.markdown("\n")

text = st.text_area(label='Starting point for text generation', value='Что делать, Макс?', height=200)
button = st.button('Go')

if button:
    #try:
    with st.spinner("Generation in progress"):
        if option == 'medium':
            result = predict(text, medium_model, medium_tokenizer)
        elif option == 'large':
            result = predict(text, large_model, large_tokenizer)
        else:
            raise st.error('Error in selectbox')
    
    #st.subheader('Max Korzh:')
    #lines = result.split('\n')
    #for line in lines:
    #    st.write(line)
    
    #lines = result.replace('\n', '\n\n')
    #st.write(lines)
    
    st.text_area(label='', value=result, height=1200)
    
    #except Exception:
    #    st.error("Ooooops, something went wrong. Try again please and report to me, tg: @vladyur")