File size: 3,871 Bytes
ecfed43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import json
import random
import requests
from mtranslate import translate
import streamlit as st
MODEL_URL = "https://api-inference.huggingface.co/models/flax-community/papuGaPT2"
PROMPT_LIST = {
    "Najsmaczniejszy owoc to...": ["Najsmaczniejszy owoc to "],
    "Cześć, mam na imię...": ["Cześć, mam na imię "],
    "Największym polskim poetą był...": ["Największym polskim poetą był "],
}
def query(payload, model_url):
    data = json.dumps(payload)
    print("model url:", model_url)
    response = requests.request(
        "POST", model_url, headers={}, data=data
    )
    return json.loads(response.content.decode("utf-8"))
def process(
    text: str, model_name: str, max_len: int, temp: float, top_k: int, top_p: float
):
    payload = {
        "inputs": text,
        "parameters": {
            "max_new_tokens": max_len,
            "top_k": top_k,
            "top_p": top_p,
            "temperature": temp,
            "repetition_penalty": 2.0,
        },
        "options": {
            "use_cache": True,
        },
    }
    return query(payload, model_name)
# Page
st.set_page_config(page_title="papuGaPT2 (Polish GPT-2) Demo")
st.title("papuGaPT2 (Polish GPT-2")
# Sidebar
st.sidebar.subheader("Configurable parameters")
max_len = st.sidebar.number_input(
    "Maximum length",
    value=100,
    help="The maximum length of the sequence to be generated.",
)
temp = st.sidebar.slider(
    "Temperature",
    value=1.0,
    min_value=0.1,
    max_value=100.0,
    help="The value used to module the next token probabilities.",
)
top_k = st.sidebar.number_input(
    "Top k",
    value=10,
    help="The number of highest probability vocabulary tokens to keep for top-k-filtering.",
)
top_p = st.sidebar.number_input(
    "Top p",
    value=0.95,
    help=" If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.",
)
do_sample = st.sidebar.selectbox(
    "Sampling?",
    (True, False),
    help="Whether or not to use sampling; use greedy decoding otherwise.",
)
# Body
st.markdown(
    """
    papuGaPT2 (Polish GPT-2) model trained from scratch on OSCAR dataset.
    
    The models were trained with Jax and Flax using TPUs as part of the [Flax/Jax Community Week](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104) organised by HuggingFace.
    """
)
model_name = MODEL_name
ALL_PROMPTS = list(PROMPT_LIST.keys()) + ["Custom"]
prompt = st.selectbox("Prompt", ALL_PROMPTS, index=len(ALL_PROMPTS) - 1)
if prompt == "Custom":
    prompt_box = "Enter your text here"
else:
    prompt_box = random.choice(PROMPT_LIST[prompt])
text = st.text_area("Enter text", prompt_box)
if st.button("Run"):
    with st.spinner(text="Getting results..."):
        st.subheader("Result")
        print(f"maxlen:{max_len}, temp:{temp}, top_k:{top_k}, top_p:{top_p}")
        result = process(
            text=text,
            model_name=model_name,
            max_len=int(max_len),
            temp=temp,
            top_k=int(top_k),
            top_p=float(top_p),
        )
        print("result:", result)
        if "error" in result:
            if type(result["error"]) is str:
                st.write(f'{result["error"]}.', end=" ")
                if "estimated_time" in result:
                    st.write(
                        f'Please try again in about {result["estimated_time"]:.0f} seconds.'
                    )
            else:
                if type(result["error"]) is list:
                    for error in result["error"]:
                        st.write(f"{error}")
        else:
            result = result[0]["generated_text"]
            st.write(result.replace("\n", "  \n"))
            st.text("English translation")
            st.write(translate(result, "en", "es").replace("\n", "  \n"))