|
from flask import Flask, request, jsonify |
|
import torch |
|
from transformers import ( |
|
UNet2DConditionModel, |
|
AutoTokenizer, |
|
CLIPTextModel, |
|
CLIPTextModelWithProjection, |
|
CLIPVisionModelWithProjection |
|
) |
|
from PIL import Image |
|
import base64 |
|
from io import BytesIO |
|
|
|
app = Flask(__name__) |
|
|
|
|
|
unet = None |
|
tokenizer_one = None |
|
tokenizer_two = None |
|
noise_scheduler = None |
|
text_encoder_one = None |
|
text_encoder_two = None |
|
image_encoder = None |
|
vae = None |
|
UNet_Encoder = None |
|
|
|
|
|
def load_models(): |
|
global unet, tokenizer_one, tokenizer_two, noise_scheduler |
|
global text_encoder_one, text_encoder_two, image_encoder, vae, UNet_Encoder |
|
|
|
if unet is None: |
|
|
|
unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-v1-4") |
|
|
|
if tokenizer_one is None: |
|
tokenizer_one = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14") |
|
|
|
if tokenizer_two is None: |
|
tokenizer_two = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14-336") |
|
|
|
if noise_scheduler is None: |
|
noise_scheduler = DDPMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4") |
|
|
|
if text_encoder_one is None: |
|
text_encoder_one = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") |
|
|
|
if text_encoder_two is None: |
|
text_encoder_two = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-large-patch14-336") |
|
|
|
if image_encoder is None: |
|
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") |
|
|
|
if vae is None: |
|
vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-v1-4") |
|
|
|
if UNet_Encoder is None: |
|
UNet_Encoder = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-v1-4") |
|
|
|
|
|
def decode_image(image_base64): |
|
image_data = base64.b64decode(image_base64) |
|
image = Image.open(BytesIO(image_data)).convert("RGB") |
|
return image |
|
|
|
|
|
def encode_image(image): |
|
buffered = BytesIO() |
|
image.save(buffered, format="PNG") |
|
return base64.b64encode(buffered.getvalue()).decode('utf-8') |
|
|
|
|
|
@app.route('/process_image', methods=['POST']) |
|
def process_image(): |
|
data = request.json |
|
|
|
|
|
load_models() |
|
|
|
|
|
image_base64 = data.get('image_base64') |
|
if not image_base64: |
|
return jsonify({"error": "No image provided"}), 400 |
|
|
|
image = decode_image(image_base64) |
|
|
|
|
|
processed_image = image |
|
|
|
|
|
processed_image_base64 = encode_image(processed_image) |
|
return jsonify({"processed_image": processed_image_base64}) |
|
|
|
if __name__ == '__main__': |
|
app.run(host='0.0.0.0', port=7860) |
|
|