tadmztxi / myapp.py
Geek7's picture
Update myapp.py
46ed025 verified
raw
history blame
2.13 kB
import os
import io
import random
import torch
from flask import Flask, jsonify, request, send_file
from flask_cors import CORS
from diffusers import DiffusionPipeline
import numpy as np
# Initialize the Flask app
myapp = Flask(__name__)
CORS(myapp) # Enable CORS if needed
# Load the model
device = "cpu"
repo = "prompthero/openjourney-v4"
pipe = DiffusionPipeline.from_pretrained(repo).to(device)
MAX_SEED = np.iinfo(np.int32).max
@myapp.route('/')
def home():
return "Welcome to the Image Generation API!" # Basic home response
@myapp.route('/generate_image', methods=['POST'])
def generate_image():
data = request.json
# Get inputs from request JSON
prompt = data.get('prompt', '')
negative_prompt = data.get('negative_prompt', None)
seed = data.get('seed', 0)
randomize_seed = data.get('randomize_seed', True)
# Get width and height and ensure they are divisible by 8
width = data.get('width', 1024)
height = data.get('height', 1024)
# Round width and height to the nearest multiple of 8
width = (width // 8) * 8
height = (height // 8) * 8
guidance_scale = data.get('guidance_scale', 5.0)
num_inference_steps = data.get('num_inference_steps', 28)
# Randomize seed if requested
if randomize_seed:
seed = random.randint(0, MAX_SEED)
# Generate the image
generator = torch.Generator().manual_seed(seed)
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator
).images[0]
# Save the image to a byte array
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='PNG')
img_byte_arr.seek(0) # Move the pointer to the start of the byte array
# Return the image as a response
return send_file(img_byte_arr, mimetype='image/png')
# Add this block to make sure your app runs when called
if __name__ == "__main__":
myapp.run(host='0.0.0.0', port=7860) # Run the Flask app on port 7860