soiz1 commited on
Commit
726e80a
·
verified ·
1 Parent(s): 6af69d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -15
app.py CHANGED
@@ -4,10 +4,8 @@ import random, re
4
 
5
  app = Flask(__name__)
6
 
7
- # モデルキャッシュ用の辞書
8
  model_cache = {}
9
 
10
- # テキスト候補読み込み
11
  with open("ideas.txt", "r") as f:
12
  lines = f.readlines()
13
 
@@ -16,25 +14,25 @@ def get_pipeline(model_name):
16
  model_cache[model_name] = pipeline('text-generation', model=model_name, tokenizer='gpt2')
17
  return model_cache[model_name]
18
 
19
- def generate_prompts(starting_text, model_name, num_prompts=1):
20
  response_list = []
21
  gpt2_pipe = get_pipeline(model_name)
22
 
23
  for _ in range(num_prompts):
24
- for count in range(4): # 最大4回試行
25
  seed = random.randint(100, 1000000)
26
  set_seed(seed)
27
 
28
- # 入力テキストが空ならランダムに選ぶ
29
  if starting_text == "":
30
  starting_text = lines[random.randrange(0, len(lines))].strip().lower().capitalize()
31
  starting_text = re.sub(r"[,:\-–.!;?_]", '', starting_text)
32
 
33
- # テキスト生成
34
- response = gpt2_pipe(starting_text, max_length=random.randint(60, 90), num_return_sequences=1)
 
 
35
  generated_text = response[0]['generated_text'].strip()
36
 
37
- # テキストをチェック・クリーン
38
  if generated_text != starting_text and len(generated_text) > (len(starting_text) + 4):
39
  cleaned_text = re.sub(r'[^ ]+\.[^ ]+', '', generated_text)
40
  cleaned_text = cleaned_text.replace("<", "").replace(">", "")
@@ -43,7 +41,6 @@ def generate_prompts(starting_text, model_name, num_prompts=1):
43
 
44
  return response_list[:num_prompts]
45
 
46
- # APIエンドポイント
47
  @app.route('/', methods=['GET'])
48
  def generate_api():
49
  starting_text = request.args.get('text', default="", type=str)
@@ -51,13 +48,21 @@ def generate_api():
51
  model_param = request.args.get('model', default="sd", type=str).lower()
52
 
53
  # モデル選択
54
- if model_param == "dall":
55
- model_name = "Gustavosta/MagicPrompt-Dalle"
56
- else:
57
- model_name = "Gustavosta/MagicPrompt-Stable-Diffusion"
58
 
59
- # プロンプト生成
60
- results = generate_prompts(starting_text, model_name, num_prompts=num_prompts)
 
 
 
 
 
 
 
 
 
 
 
61
  return jsonify(results)
62
 
63
  if __name__ == '__main__':
 
4
 
5
  app = Flask(__name__)
6
 
 
7
  model_cache = {}
8
 
 
9
  with open("ideas.txt", "r") as f:
10
  lines = f.readlines()
11
 
 
14
  model_cache[model_name] = pipeline('text-generation', model=model_name, tokenizer='gpt2')
15
  return model_cache[model_name]
16
 
17
+ def generate_prompts(starting_text, model_name, num_prompts=1, generation_args=None):
18
  response_list = []
19
  gpt2_pipe = get_pipeline(model_name)
20
 
21
  for _ in range(num_prompts):
22
+ for count in range(4):
23
  seed = random.randint(100, 1000000)
24
  set_seed(seed)
25
 
 
26
  if starting_text == "":
27
  starting_text = lines[random.randrange(0, len(lines))].strip().lower().capitalize()
28
  starting_text = re.sub(r"[,:\-–.!;?_]", '', starting_text)
29
 
30
+ response = gpt2_pipe(
31
+ starting_text,
32
+ **generation_args # 各種パラメーターを一括で渡す
33
+ )
34
  generated_text = response[0]['generated_text'].strip()
35
 
 
36
  if generated_text != starting_text and len(generated_text) > (len(starting_text) + 4):
37
  cleaned_text = re.sub(r'[^ ]+\.[^ ]+', '', generated_text)
38
  cleaned_text = cleaned_text.replace("<", "").replace(">", "")
 
41
 
42
  return response_list[:num_prompts]
43
 
 
44
  @app.route('/', methods=['GET'])
45
  def generate_api():
46
  starting_text = request.args.get('text', default="", type=str)
 
48
  model_param = request.args.get('model', default="sd", type=str).lower()
49
 
50
  # モデル選択
51
+ model_name = "Gustavosta/MagicPrompt-Dalle" if model_param == "dall" else "Gustavosta/MagicPrompt-Stable-Diffusion"
 
 
 
52
 
53
+ # URL パラメータから生成設定を取得
54
+ generation_args = {
55
+ "max_length": request.args.get('max_length', default=random.randint(60, 90), type=int),
56
+ "min_length": request.args.get('min_length', default=0, type=int),
57
+ "temperature": request.args.get('temperature', default=1.0, type=float),
58
+ "top_k": request.args.get('top_k', default=50, type=int),
59
+ "top_p": request.args.get('top_p', default=0.95, type=float),
60
+ "repetition_penalty": request.args.get('repetition_penalty', default=1.0, type=float),
61
+ "do_sample": request.args.get('do_sample', default=True, type=lambda v: v.lower() in ['true', '1', 'yes']),
62
+ "num_return_sequences": 1
63
+ }
64
+
65
+ results = generate_prompts(starting_text, model_name, num_prompts=num_prompts, generation_args=generation_args)
66
  return jsonify(results)
67
 
68
  if __name__ == '__main__':