soiz1 commited on
Commit
6af69d0
·
verified ·
1 Parent(s): a04f8cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -16
app.py CHANGED
@@ -4,47 +4,61 @@ import random, re
4
 
5
  app = Flask(__name__)
6
 
7
- # Initialize the GPT-2 pipeline
8
- gpt2_pipe = pipeline('text-generation', model='Gustavosta/MagicPrompt-Stable-Diffusion', tokenizer='gpt2')
 
 
9
  with open("ideas.txt", "r") as f:
10
  lines = f.readlines()
11
 
12
- def generate_prompts(starting_text, num_prompts=1):
 
 
 
 
 
13
  response_list = []
14
-
 
15
  for _ in range(num_prompts):
16
- for count in range(4): # Attempt up to 4 times to generate valid response
17
  seed = random.randint(100, 1000000)
18
  set_seed(seed)
19
 
20
- # Choose a random line from the file if the input text is empty
21
  if starting_text == "":
22
  starting_text = lines[random.randrange(0, len(lines))].strip().lower().capitalize()
23
  starting_text = re.sub(r"[,:\-–.!;?_]", '', starting_text)
24
 
25
- # Generate text
26
  response = gpt2_pipe(starting_text, max_length=random.randint(60, 90), num_return_sequences=1)
27
  generated_text = response[0]['generated_text'].strip()
28
 
29
- # Clean and check the generated response
30
  if generated_text != starting_text and len(generated_text) > (len(starting_text) + 4):
31
- cleaned_text = re.sub(r'[^ ]+\.[^ ]+', '', generated_text) # Remove strings like 'abc.xyz'
32
  cleaned_text = cleaned_text.replace("<", "").replace(">", "")
33
  response_list.append(cleaned_text)
34
- break # Stop trying further once a valid prompt is added
35
 
36
  return response_list[:num_prompts]
37
 
38
- # Define the API endpoint
39
  @app.route('/', methods=['GET'])
40
  def generate_api():
41
  starting_text = request.args.get('text', default="", type=str)
42
- num_prompts = request.args.get('n', default=1, type=int) # Get the number of prompts to return, default is 1
43
-
44
- # Generate the prompts
45
- results = generate_prompts(starting_text, num_prompts=num_prompts)
 
 
 
 
 
 
 
46
  return jsonify(results)
47
 
48
  if __name__ == '__main__':
49
- # Run the Flask app on port 7860
50
  app.run(host='0.0.0.0', port=7860)
 
4
 
5
  app = Flask(__name__)
6
 
7
+ # モデルキャッシュ用の辞書
8
+ model_cache = {}
9
+
10
+ # テキスト候補読み込み
11
  with open("ideas.txt", "r") as f:
12
  lines = f.readlines()
13
 
14
+ def get_pipeline(model_name):
15
+ if model_name not in model_cache:
16
+ model_cache[model_name] = pipeline('text-generation', model=model_name, tokenizer='gpt2')
17
+ return model_cache[model_name]
18
+
19
+ def generate_prompts(starting_text, model_name, num_prompts=1):
20
  response_list = []
21
+ gpt2_pipe = get_pipeline(model_name)
22
+
23
  for _ in range(num_prompts):
24
+ for count in range(4): # 最大4回試行
25
  seed = random.randint(100, 1000000)
26
  set_seed(seed)
27
 
28
+ # 入力テキストが空ならランダムに選ぶ
29
  if starting_text == "":
30
  starting_text = lines[random.randrange(0, len(lines))].strip().lower().capitalize()
31
  starting_text = re.sub(r"[,:\-–.!;?_]", '', starting_text)
32
 
33
+ # テキスト生成
34
  response = gpt2_pipe(starting_text, max_length=random.randint(60, 90), num_return_sequences=1)
35
  generated_text = response[0]['generated_text'].strip()
36
 
37
+ # テキストをチェック・クリーン
38
  if generated_text != starting_text and len(generated_text) > (len(starting_text) + 4):
39
+ cleaned_text = re.sub(r'[^ ]+\.[^ ]+', '', generated_text)
40
  cleaned_text = cleaned_text.replace("<", "").replace(">", "")
41
  response_list.append(cleaned_text)
42
+ break
43
 
44
  return response_list[:num_prompts]
45
 
46
+ # APIエンドポイント
47
  @app.route('/', methods=['GET'])
48
  def generate_api():
49
  starting_text = request.args.get('text', default="", type=str)
50
+ num_prompts = request.args.get('n', default=1, type=int)
51
+ model_param = request.args.get('model', default="sd", type=str).lower()
52
+
53
+ # モデル選択
54
+ if model_param == "dall":
55
+ model_name = "Gustavosta/MagicPrompt-Dalle"
56
+ else:
57
+ model_name = "Gustavosta/MagicPrompt-Stable-Diffusion"
58
+
59
+ # プロンプト生成
60
+ results = generate_prompts(starting_text, model_name, num_prompts=num_prompts)
61
  return jsonify(results)
62
 
63
  if __name__ == '__main__':
 
64
  app.run(host='0.0.0.0', port=7860)