File size: 970 Bytes
28e30fe 007c626 aa665ef 007c626 aa665ef 8b0b1ab aa665ef c150fa5 8b0b1ab aa665ef 4938516 aa665ef c150fa5 8b0b1ab aa665ef 4938516 8b0b1ab aa665ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
from transformers import pipeline
import gradio as gr
generator = pipeline("text-generation", model="ckiplab/gpt2-base-chinese",
tokenizer="ckiplab/gpt2-base-chinese")
def chat_fn(message, history):
history = history or []
input_text = "\n".join(history + [f"你: {message}", "AI:"])
output = generator(input_text, max_new_tokens=80, pad_token_id=0)[0]["generated_text"]
response = output.split("AI:")[-1].strip().split("你:")[0].strip()
history.append(f"你: {message}")
history.append(f"AI: {response}")
messages = [(history[i], history[i+1]) for i in range(0, len(history)-1, 2)]
return messages, history
with gr.Blocks() as demo:
chatbot = gr.Chatbot(label="中文聊天機器人", type="tuples")
state = gr.State([])
textbox = gr.Textbox(placeholder="請輸入訊息")
textbox.submit(chat_fn, [textbox, state], [chatbot, state])
textbox.submit(lambda: "", None, textbox)
demo.launch()
|