soiz commited on
Commit
7198e43
·
verified ·
1 Parent(s): 29baf33

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -29
app.py CHANGED
@@ -9,40 +9,41 @@ gpt2_pipe = pipeline('text-generation', model='Gustavosta/MagicPrompt-Stable-Dif
9
  with open("ideas.txt", "r") as f:
10
  lines = f.readlines()
11
 
12
- def generate_prompt(starting_text):
13
- for count in range(4):
14
- seed = random.randint(100, 1000000)
15
- set_seed(seed)
16
-
17
- # Choose a random line from the file if the input text is empty
18
- if starting_text == "":
19
- starting_text = lines[random.randrange(0, len(lines))].strip().lower().capitalize()
20
- starting_text = re.sub(r"[,:\-–.!;?_]", '', starting_text)
21
-
22
- # Generate text
23
- response = gpt2_pipe(starting_text, max_length=random.randint(60, 90), num_return_sequences=4)
24
- response_list = []
25
- for x in response:
26
- resp = x['generated_text'].strip()
27
- if resp != starting_text and len(resp) > (len(starting_text) + 4) and not resp.endswith((":", "-", "—")):
28
- response_list.append(resp + '\n')
29
-
30
- # Clean the generated text
31
- response_end = "\n".join(response_list)
32
- response_end = re.sub(r'[^ ]+\.[^ ]+', '', response_end) # Removes strings like 'abc.xyz'
33
- response_end = response_end.replace("<", "").replace(">", "")
34
-
35
- if response_end:
36
- return response_end
37
- if count == 4:
38
- return response_end
39
 
40
  # Define the API endpoint
41
  @app.route('/', methods=['GET'])
42
  def generate_api():
43
  starting_text = request.args.get('text', default="", type=str)
44
- result = generate_prompt(starting_text)
45
- return jsonify({"generated_text": result})
 
 
 
46
 
47
  if __name__ == '__main__':
48
  # Run the Flask app on port 7860
 
9
  with open("ideas.txt", "r") as f:
10
  lines = f.readlines()
11
 
12
+ def generate_prompts(starting_text, num_prompts=1):
13
+ response_list = []
14
+
15
+ for _ in range(num_prompts):
16
+ for count in range(4): # Attempt up to 4 times to generate valid response
17
+ seed = random.randint(100, 1000000)
18
+ set_seed(seed)
19
+
20
+ # Choose a random line from the file if the input text is empty
21
+ if starting_text == "":
22
+ starting_text = lines[random.randrange(0, len(lines))].strip().lower().capitalize()
23
+ starting_text = re.sub(r"[,:\-–.!;?_]", '', starting_text)
24
+
25
+ # Generate text
26
+ response = gpt2_pipe(starting_text, max_length=random.randint(60, 90), num_return_sequences=1)
27
+ generated_text = response[0]['generated_text'].strip()
28
+
29
+ # Clean and check the generated response
30
+ if generated_text != starting_text and len(generated_text) > (len(starting_text) + 4):
31
+ cleaned_text = re.sub(r'[^ ]+\.[^ ]+', '', generated_text) # Remove strings like 'abc.xyz'
32
+ cleaned_text = cleaned_text.replace("<", "").replace(">", "")
33
+ response_list.append(cleaned_text)
34
+ break # Stop trying further once a valid prompt is added
35
+
36
+ return response_list[:num_prompts]
 
 
37
 
38
  # Define the API endpoint
39
  @app.route('/', methods=['GET'])
40
  def generate_api():
41
  starting_text = request.args.get('text', default="", type=str)
42
+ num_prompts = request.args.get('n', default=1, type=int) # Get the number of prompts to return, default is 1
43
+
44
+ # Generate the prompts
45
+ results = generate_prompts(starting_text, num_prompts=num_prompts)
46
+ return jsonify(results)
47
 
48
  if __name__ == '__main__':
49
  # Run the Flask app on port 7860