File size: 970 Bytes
28e30fe
007c626
 
aa665ef
 
007c626
aa665ef
8b0b1ab
aa665ef
 
 
 
 
c150fa5
 
8b0b1ab
aa665ef
4938516
aa665ef
c150fa5
8b0b1ab
aa665ef
4938516
8b0b1ab
aa665ef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from transformers import pipeline
import gradio as gr

generator = pipeline("text-generation", model="ckiplab/gpt2-base-chinese", 
                     tokenizer="ckiplab/gpt2-base-chinese")

def chat_fn(message, history):
    history = history or []
    input_text = "\n".join(history + [f"你: {message}", "AI:"])
    output = generator(input_text, max_new_tokens=80, pad_token_id=0)[0]["generated_text"]
    response = output.split("AI:")[-1].strip().split("你:")[0].strip()
    history.append(f"你: {message}")
    history.append(f"AI: {response}")
    messages = [(history[i], history[i+1]) for i in range(0, len(history)-1, 2)]
    return messages, history

with gr.Blocks() as demo:
    chatbot = gr.Chatbot(label="中文聊天機器人", type="tuples")
    state = gr.State([])
    textbox = gr.Textbox(placeholder="請輸入訊息")

    textbox.submit(chat_fn, [textbox, state], [chatbot, state])
    textbox.submit(lambda: "", None, textbox)

demo.launch()