Update app.py
Browse files
app.py
CHANGED
@@ -1,9 +1,8 @@
|
|
1 |
-
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
2 |
-
import gradio as gr
|
3 |
|
4 |
-
#
|
5 |
-
model = GPT2LMHeadModel.from_pretrained("gpt2")
|
6 |
-
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
7 |
|
8 |
# 設置填充標記 ID
|
9 |
tokenizer.pad_token = tokenizer.eos_token
|
@@ -28,18 +27,4 @@ def generate_command(prompt, max_length=100):
|
|
28 |
end = len(generated_text)
|
29 |
command = generated_text[start:end].strip()
|
30 |
|
31 |
-
return command
|
32 |
-
|
33 |
-
def predict(input_text):
|
34 |
-
output = generate_command(input_text)
|
35 |
-
return output
|
36 |
-
|
37 |
-
iface = gr.Interface(
|
38 |
-
fn=predict,
|
39 |
-
inputs=gr.Textbox(lines=2, placeholder="請輸入你的指令生成提示..."),
|
40 |
-
outputs="text",
|
41 |
-
title="使用 GPT2 生成指令",
|
42 |
-
description="根據你的中文輸入提示生成 Bash 指令。"
|
43 |
-
)
|
44 |
-
|
45 |
-
iface.launch()
|
|
|
1 |
+
from transformers import GPT2LMHeadModel, GPT2Tokenizer, pipeline
|
|
|
2 |
|
3 |
+
# 使用緩存加速載入
|
4 |
+
model = GPT2LMHeadModel.from_pretrained("gpt2", cache_dir="./model_cache")
|
5 |
+
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", cache_dir="./model_cache")
|
6 |
|
7 |
# 設置填充標記 ID
|
8 |
tokenizer.pad_token = tokenizer.eos_token
|
|
|
27 |
end = len(generated_text)
|
28 |
command = generated_text[start:end].strip()
|
29 |
|
30 |
+
return command
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|