File size: 2,207 Bytes
5a0eb94
46db6b6
 
5a0eb94
 
46db6b6
 
 
 
 
 
 
 
5a0eb94
46db6b6
5a0eb94
 
46db6b6
5a0eb94
46db6b6
 
 
 
 
5a0eb94
46db6b6
 
 
 
 
 
5a0eb94
46db6b6
 
 
 
 
5a0eb94
 
 
46db6b6
 
 
5a0eb94
46db6b6
 
 
 
 
 
 
5a0eb94
 
46db6b6
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from flask import Flask, request, send_file
from flask_cors import CORS
import torch
from diffusers import DiffusionPipeline
import io
import numpy as np
import random

# Initialize the Flask app
myapp = Flask(__name__)
CORS(myapp)  # Enable CORS if needed

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the DiffusionPipeline for "prompthero/openjourney-v4"
pipe = DiffusionPipeline.from_pretrained("prompthero/openjourney-v4").to(device)

# Define max values for seed and image size
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1344

@app.route('/')
def home():
    return "Welcome to the OpenJourney Image Generation API!"

@app.route('/generate_image', methods=['POST'])
def generate_image():
    data = request.json
    
    # Get inputs from request JSON
    prompt = data.get('prompt', 'Astronaut in a jungle, cold color palette, muted colors, detailed, 8k')
    negative_prompt = data.get('negative_prompt', None)
    seed = data.get('seed', 0)
    randomize_seed = data.get('randomize_seed', True)
    width = data.get('width', 1024)
    height = data.get('height', 1024)
    guidance_scale = data.get('guidance_scale', 7.5)  # Default to a higher guidance scale for better results
    num_inference_steps = data.get('num_inference_steps', 50)  # Default number of steps

    # Randomize seed if requested
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)

    # Generate the image
    generator = torch.Generator().manual_seed(seed)
    image = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        width=width,
        height=height,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        generator=generator
    ).images[0]
    
    # Save the image to a byte array
    img_byte_arr = io.BytesIO()
    image.save(img_byte_arr, format='PNG')
    img_byte_arr.seek(0)  # Move the pointer to the start of the byte array
    
    # Return the image as a response
    return send_file(img_byte_arr, mimetype='image/png')

# Add this block to make sure your app runs when called
if __name__ == "__main__":
    myapp.run(host='0.0.0.0', port=7860)  # Run the Flask app on port 7860