yukimama commited on
Commit
ad4fa8c
·
verified ·
1 Parent(s): 096a14b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -20
app.py CHANGED
@@ -1,9 +1,8 @@
1
- from transformers import GPT2LMHeadModel, GPT2Tokenizer
2
- import gradio as gr
3
 
4
- # Load the pre-trained GPT2 model and tokenizer
5
- model = GPT2LMHeadModel.from_pretrained("gpt2")
6
- tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
7
 
8
  # 設置填充標記 ID
9
  tokenizer.pad_token = tokenizer.eos_token
@@ -28,18 +27,4 @@ def generate_command(prompt, max_length=100):
28
  end = len(generated_text)
29
  command = generated_text[start:end].strip()
30
 
31
- return command
32
-
33
- def predict(input_text):
34
- output = generate_command(input_text)
35
- return output
36
-
37
- iface = gr.Interface(
38
- fn=predict,
39
- inputs=gr.Textbox(lines=2, placeholder="請輸入你的指令生成提示..."),
40
- outputs="text",
41
- title="使用 GPT2 生成指令",
42
- description="根據你的中文輸入提示生成 Bash 指令。"
43
- )
44
-
45
- iface.launch()
 
1
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer, pipeline
 
2
 
3
+ # 使用緩存加速載入
4
+ model = GPT2LMHeadModel.from_pretrained("gpt2", cache_dir="./model_cache")
5
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2", cache_dir="./model_cache")
6
 
7
  # 設置填充標記 ID
8
  tokenizer.pad_token = tokenizer.eos_token
 
27
  end = len(generated_text)
28
  command = generated_text[start:end].strip()
29
 
30
+ return command