File size: 3,642 Bytes
7bc1fb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15f3d13
7bc1fb2
 
 
 
 
 
 
 
aa2f0b6
7bc1fb2
 
 
 
 
 
 
 
 
 
 
 
e5f4437
7bc1fb2
 
 
 
 
b94e276
 
7bc1fb2
 
 
 
 
 
 
15f3d13
 
 
 
 
7bc1fb2
 
 
 
 
 
 
 
 
 
 
9aa8502
 
7bc1fb2
 
 
 
 
 
 
 
 
 
 
a5cd868
7bc1fb2
 
fbf6b2e
575630d
fbf6b2e
76f9de4
fbf6b2e
 
 
7bc1fb2
 
 
 
 
 
 
 
8287bc9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import json
import requests
import streamlit as st
import random
headers = {}
MODELS = {
    "GPT-2 Base": {
        "url": "https://api-inference.huggingface.co/models/BigSalmon/InformalToFormalLincoln14"
    }
}
def query(payload, model_name):
    data = json.dumps(payload)
    print("model url:", MODELS[model_name]["url"])
    response = requests.request(
        "POST", MODELS[model_name]["url"], headers=headers, data=data)
    return json.loads(response.content.decode("utf-8"))
def process(text: str,
            model_name: str,
            max_len: int,
            temp: float,
            top_k: int,
            num_return_sequences:int,
            top_p: float):
    payload = {
        "inputs": text,
        "parameters": {
            "max_new_tokens": max_len,
            "top_k": top_k,
            "top_p": top_p,
            "temperature": temp,
            "num_return_sequences": num_return_sequences,
            "repetition_penalty": 2.0,
        },
        "options": {
            "use_cache": True,
        }
    }
    return query(payload, model_name)
st.set_page_config(page_title="Thai GPT2 Demo")
st.title("🐘 Thai GPT2")
st.sidebar.subheader("Configurable parameters")
max_len = st.sidebar.text_input(
    "Maximum length",
    value=5,
    help="The maximum length of the sequence to be generated."
)
temp = st.sidebar.slider(
    "Temperature",
    value=1.0,
    min_value=0.7,
    max_value=1.5,
    help="The value used to module the next token probabilities."
)
top_k = st.sidebar.text_input(
    "Top k",
    value=50,
    help="The number of highest probability vocabulary tokens to keep for top-k-filtering."
)
num_return_sequences = st.sidebar.text_input(
    "Returns",
    value=5,
    help="Number to return."
)
top_p = st.sidebar.text_input(
    "Top p",
    value=0.95,
    help=" If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation."
)
do_sample = st.sidebar.selectbox(
    'Sampling?', (True, False), help="Whether or not to use sampling; use greedy decoding otherwise.")
st.markdown(
    """Thai GPT-2 demo. Part of the [Huggingface JAX/Flax event](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/)."""
)
model_name = st.selectbox('Model', (['GPT-2 Base']))
hello = ['Custom', 'Yellow']
prompt = st.selectbox('Prompt', options = hello)
if prompt == "Custom":
    prompt_box = "Enter your text here"
text = st.text_area("Enter text", prompt_box)
if st.button("Run"):
    with st.spinner(text="Getting results..."):
        st.subheader("Result")
        print(f"maxlen:{max_len}, temp:{temp}, top_k:{top_k}, top_p:{top_p}")
        result = process(text=text,
                         model_name=model_name,
                         max_len=int(max_len),
                         temp=temp,
                         num_return_sequences = int(num_return_sequences),
                         top_k=int(top_k),
                         top_p=float(top_p))
        result2 = []                 
        text = str(text)
        for i in result:
            i = str(i)
            i = i.replace(text, "")
            result2.append(i)               
        st.write(result2)
        if "error" in result:
            if type(result["error"]) is str:
                st.write(f'{result["error"]}. Please try it again in about {result["estimated_time"]:.0f} seconds')
            else:
                if type(result["error"]) is list:
                    for error in result["error"]:
                        st.write(f'{error}')
        else:
            print("hey")