Update app.py
Browse files
app.py
CHANGED
@@ -9,40 +9,41 @@ gpt2_pipe = pipeline('text-generation', model='Gustavosta/MagicPrompt-Stable-Dif
|
|
9 |
with open("ideas.txt", "r") as f:
|
10 |
lines = f.readlines()
|
11 |
|
12 |
-
def
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
if count == 4:
|
38 |
-
return response_end
|
39 |
|
40 |
# Define the API endpoint
|
41 |
@app.route('/', methods=['GET'])
|
42 |
def generate_api():
|
43 |
starting_text = request.args.get('text', default="", type=str)
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
46 |
|
47 |
if __name__ == '__main__':
|
48 |
# Run the Flask app on port 7860
|
|
|
9 |
with open("ideas.txt", "r") as f:
|
10 |
lines = f.readlines()
|
11 |
|
12 |
+
def generate_prompts(starting_text, num_prompts=1):
|
13 |
+
response_list = []
|
14 |
+
|
15 |
+
for _ in range(num_prompts):
|
16 |
+
for count in range(4): # Attempt up to 4 times to generate valid response
|
17 |
+
seed = random.randint(100, 1000000)
|
18 |
+
set_seed(seed)
|
19 |
+
|
20 |
+
# Choose a random line from the file if the input text is empty
|
21 |
+
if starting_text == "":
|
22 |
+
starting_text = lines[random.randrange(0, len(lines))].strip().lower().capitalize()
|
23 |
+
starting_text = re.sub(r"[,:\-–.!;?_]", '', starting_text)
|
24 |
+
|
25 |
+
# Generate text
|
26 |
+
response = gpt2_pipe(starting_text, max_length=random.randint(60, 90), num_return_sequences=1)
|
27 |
+
generated_text = response[0]['generated_text'].strip()
|
28 |
+
|
29 |
+
# Clean and check the generated response
|
30 |
+
if generated_text != starting_text and len(generated_text) > (len(starting_text) + 4):
|
31 |
+
cleaned_text = re.sub(r'[^ ]+\.[^ ]+', '', generated_text) # Remove strings like 'abc.xyz'
|
32 |
+
cleaned_text = cleaned_text.replace("<", "").replace(">", "")
|
33 |
+
response_list.append(cleaned_text)
|
34 |
+
break # Stop trying further once a valid prompt is added
|
35 |
+
|
36 |
+
return response_list[:num_prompts]
|
|
|
|
|
37 |
|
38 |
# Define the API endpoint
|
39 |
@app.route('/', methods=['GET'])
|
40 |
def generate_api():
|
41 |
starting_text = request.args.get('text', default="", type=str)
|
42 |
+
num_prompts = request.args.get('n', default=1, type=int) # Get the number of prompts to return, default is 1
|
43 |
+
|
44 |
+
# Generate the prompts
|
45 |
+
results = generate_prompts(starting_text, num_prompts=num_prompts)
|
46 |
+
return jsonify(results)
|
47 |
|
48 |
if __name__ == '__main__':
|
49 |
# Run the Flask app on port 7860
|