tadmztxi / myapp.py
Geek7's picture
Update myapp.py
5a0eb94 verified
raw
history blame
2.21 kB
from flask import Flask, request, send_file
from flask_cors import CORS
import torch
from diffusers import DiffusionPipeline
import io
import numpy as np
import random
# Initialize the Flask app
myapp = Flask(__name__)
CORS(myapp) # Enable CORS if needed
# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the DiffusionPipeline for "prompthero/openjourney-v4"
pipe = DiffusionPipeline.from_pretrained("prompthero/openjourney-v4").to(device)
# Define max values for seed and image size
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1344
@app.route('/')
def home():
return "Welcome to the OpenJourney Image Generation API!"
@app.route('/generate_image', methods=['POST'])
def generate_image():
data = request.json
# Get inputs from request JSON
prompt = data.get('prompt', 'Astronaut in a jungle, cold color palette, muted colors, detailed, 8k')
negative_prompt = data.get('negative_prompt', None)
seed = data.get('seed', 0)
randomize_seed = data.get('randomize_seed', True)
width = data.get('width', 1024)
height = data.get('height', 1024)
guidance_scale = data.get('guidance_scale', 7.5) # Default to a higher guidance scale for better results
num_inference_steps = data.get('num_inference_steps', 50) # Default number of steps
# Randomize seed if requested
if randomize_seed:
seed = random.randint(0, MAX_SEED)
# Generate the image
generator = torch.Generator().manual_seed(seed)
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator
).images[0]
# Save the image to a byte array
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='PNG')
img_byte_arr.seek(0) # Move the pointer to the start of the byte array
# Return the image as a response
return send_file(img_byte_arr, mimetype='image/png')
# Add this block to make sure your app runs when called
if __name__ == "__main__":
myapp.run(host='0.0.0.0', port=7860) # Run the Flask app on port 7860