Update app.py
Browse files
app.py
CHANGED
@@ -4,47 +4,61 @@ import random, re
|
|
4 |
|
5 |
app = Flask(__name__)
|
6 |
|
7 |
-
#
|
8 |
-
|
|
|
|
|
9 |
with open("ideas.txt", "r") as f:
|
10 |
lines = f.readlines()
|
11 |
|
12 |
-
def
|
|
|
|
|
|
|
|
|
|
|
13 |
response_list = []
|
14 |
-
|
|
|
15 |
for _ in range(num_prompts):
|
16 |
-
for count in range(4): #
|
17 |
seed = random.randint(100, 1000000)
|
18 |
set_seed(seed)
|
19 |
|
20 |
-
#
|
21 |
if starting_text == "":
|
22 |
starting_text = lines[random.randrange(0, len(lines))].strip().lower().capitalize()
|
23 |
starting_text = re.sub(r"[,:\-–.!;?_]", '', starting_text)
|
24 |
|
25 |
-
#
|
26 |
response = gpt2_pipe(starting_text, max_length=random.randint(60, 90), num_return_sequences=1)
|
27 |
generated_text = response[0]['generated_text'].strip()
|
28 |
|
29 |
-
#
|
30 |
if generated_text != starting_text and len(generated_text) > (len(starting_text) + 4):
|
31 |
-
cleaned_text = re.sub(r'[^ ]+\.[^ ]+', '', generated_text)
|
32 |
cleaned_text = cleaned_text.replace("<", "").replace(">", "")
|
33 |
response_list.append(cleaned_text)
|
34 |
-
break
|
35 |
|
36 |
return response_list[:num_prompts]
|
37 |
|
38 |
-
#
|
39 |
@app.route('/', methods=['GET'])
|
40 |
def generate_api():
|
41 |
starting_text = request.args.get('text', default="", type=str)
|
42 |
-
num_prompts = request.args.get('n', default=1, type=int)
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
return jsonify(results)
|
47 |
|
48 |
if __name__ == '__main__':
|
49 |
-
# Run the Flask app on port 7860
|
50 |
app.run(host='0.0.0.0', port=7860)
|
|
|
4 |
|
5 |
app = Flask(__name__)
|
6 |
|
7 |
+
# モデルキャッシュ用の辞書
|
8 |
+
model_cache = {}
|
9 |
+
|
10 |
+
# テキスト候補読み込み
|
11 |
with open("ideas.txt", "r") as f:
|
12 |
lines = f.readlines()
|
13 |
|
14 |
+
def get_pipeline(model_name):
|
15 |
+
if model_name not in model_cache:
|
16 |
+
model_cache[model_name] = pipeline('text-generation', model=model_name, tokenizer='gpt2')
|
17 |
+
return model_cache[model_name]
|
18 |
+
|
19 |
+
def generate_prompts(starting_text, model_name, num_prompts=1):
|
20 |
response_list = []
|
21 |
+
gpt2_pipe = get_pipeline(model_name)
|
22 |
+
|
23 |
for _ in range(num_prompts):
|
24 |
+
for count in range(4): # 最大4回試行
|
25 |
seed = random.randint(100, 1000000)
|
26 |
set_seed(seed)
|
27 |
|
28 |
+
# 入力テキストが空ならランダムに選ぶ
|
29 |
if starting_text == "":
|
30 |
starting_text = lines[random.randrange(0, len(lines))].strip().lower().capitalize()
|
31 |
starting_text = re.sub(r"[,:\-–.!;?_]", '', starting_text)
|
32 |
|
33 |
+
# テキスト生成
|
34 |
response = gpt2_pipe(starting_text, max_length=random.randint(60, 90), num_return_sequences=1)
|
35 |
generated_text = response[0]['generated_text'].strip()
|
36 |
|
37 |
+
# テキストをチェック・クリーン
|
38 |
if generated_text != starting_text and len(generated_text) > (len(starting_text) + 4):
|
39 |
+
cleaned_text = re.sub(r'[^ ]+\.[^ ]+', '', generated_text)
|
40 |
cleaned_text = cleaned_text.replace("<", "").replace(">", "")
|
41 |
response_list.append(cleaned_text)
|
42 |
+
break
|
43 |
|
44 |
return response_list[:num_prompts]
|
45 |
|
46 |
+
# APIエンドポイント
|
47 |
@app.route('/', methods=['GET'])
|
48 |
def generate_api():
|
49 |
starting_text = request.args.get('text', default="", type=str)
|
50 |
+
num_prompts = request.args.get('n', default=1, type=int)
|
51 |
+
model_param = request.args.get('model', default="sd", type=str).lower()
|
52 |
+
|
53 |
+
# モデル選択
|
54 |
+
if model_param == "dall":
|
55 |
+
model_name = "Gustavosta/MagicPrompt-Dalle"
|
56 |
+
else:
|
57 |
+
model_name = "Gustavosta/MagicPrompt-Stable-Diffusion"
|
58 |
+
|
59 |
+
# プロンプト生成
|
60 |
+
results = generate_prompts(starting_text, model_name, num_prompts=num_prompts)
|
61 |
return jsonify(results)
|
62 |
|
63 |
if __name__ == '__main__':
|
|
|
64 |
app.run(host='0.0.0.0', port=7860)
|