Spaces:
Sleeping
Sleeping
File size: 5,527 Bytes
2cc6477 9e1bace 2cc6477 7ab20d4 9e1bace 2cc6477 d44ef31 9e1bace 7ab20d4 9e1bace |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import torch
from diffusers import StableDiffusionPipeline
import gradio as gr
# GPU ์ฌ์ฉ ๊ฐ๋ฅ ์ฌ๋ถ ํ์ธ
device = "cuda" if torch.cuda.is_available() else "cpu"
# ํ์ดํ๋ผ์ธ ๋ก๋ฉ
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)
# ์์ฑ ํจ์
def generate(prompt):
image = pipe(prompt).images[0]
return image
# Gradio ์ธํฐํ์ด์ค ์ ์
interface = gr.Interface(
fn=generate,
inputs=gr.Textbox(label="ํ๋กฌํํธ๋ฅผ ์
๋ ฅํ์ธ์", placeholder="์: a cute caricature of a cat in a hat"),
outputs=gr.Image(type="pil"),
title="Text to Image - Stable Diffusion",
description="Stable Diffusion์ ์ฌ์ฉํ ํ
์คํธ-์ด๋ฏธ์ง ์์ฑ๊ธฐ์
๋๋ค."
)
if __name__ == "__main__":
interface.launch()
# import os
# import torch
# import random
# import importlib
# from PIL import Image
# from huggingface_hub import snapshot_download
# import gradio as gr
# from transformers import AutoProcessor, AutoModelForCausalLM, CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor
# from diffusers import StableDiffusionPipeline, DiffusionPipeline, EulerDiscreteScheduler, UNet2DConditionModel
# # ํ๊ฒฝ ์ค์
# os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
# REVISION = "ceaf371f01ef66192264811b390bccad475a4f02"
# # ๋ก์ปฌ ๋ค์ด๋ก๋
# LOCAL_FLORENCE = snapshot_download("microsoft/Florence-2-base", revision=REVISION)
# LOCAL_TURBOX = snapshot_download("tensorart/stable-diffusion-3.5-large-TurboX")
# # ๋๋ฐ์ด์ค ๋ฐ dtype ์ค์
# device = "cuda" if torch.cuda.is_available() else "cpu"
# dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# # ๋ชจ๋ธ ๋ก๋ฉ (๋ถ๋ถ๋ณ ๋ก๋ฉ + dtype ์ ์ฉ)
# scheduler = EulerDiscreteScheduler.from_pretrained(
# LOCAL_TURBOX, subfolder="scheduler", torch_dtype=dtype
# )
# text_encoder = CLIPTextModel.from_pretrained(LOCAL_TURBOX, subfolder="text_encoder", torch_dtype=dtype)
# tokenizer = CLIPTokenizer.from_pretrained(LOCAL_TURBOX, subfolder="tokenizer")
# feature_extractor = CLIPFeatureExtractor.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="feature_extractor")
# unet = UNet2DConditionModel.from_pretrained(LOCAL_TURBOX, subfolder="unet", torch_dtype=dtype)
# florence_model = AutoModelForCausalLM.from_pretrained(
# LOCAL_FLORENCE, trust_remote_code=True, torch_dtype=dtype
# )
# florence_model.to("cpu").eval()
# florence_processor = AutoProcessor.from_pretrained(LOCAL_FLORENCE, trust_remote_code=True)
# # Stable Diffusion ํ์ดํ๋ผ์ธ
# pipe = DiffusionPipeline.from_pretrained(
# LOCAL_TURBOX,
# torch_dtype=dtype,
# trust_remote_code=True,
# safety_checker=None,
# feature_extractor=None
# )
# pipe = pipe.to(device)
# pipe.scheduler = scheduler
# pipe.enable_attention_slicing() # ๋ฉ๋ชจ๋ฆฌ ์ ์ฝ
# # ์์
# MAX_SEED = 2**31 - 1
# # ํ
์คํธ ์คํ์ผ๋ฌ
# def pseudo_translate_to_korean_style(en_prompt: str) -> str:
# return f"Cartoon styled {en_prompt} handsome or pretty people"
# # ํ๋กฌํํธ ์์ฑ
# def generate_prompt(image):
# if not isinstance(image, Image.Image):
# image = Image.fromarray(image)
# inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to("cpu")
# with torch.no_grad():
# generated_ids = florence_model.generate(
# input_ids=inputs["input_ids"],
# pixel_values=inputs["pixel_values"],
# max_new_tokens=256,
# num_beams=3
# )
# generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
# parsed_answer = florence_processor.post_process_generation(
# generated_text,
# task="<MORE_DETAILED_CAPTION>",
# image_size=(image.width, image.height)
# )
# prompt_en = parsed_answer["<MORE_DETAILED_CAPTION>"]
# cartoon_prompt = pseudo_translate_to_korean_style(prompt_en)
# return cartoon_prompt
# # ์ด๋ฏธ์ง ์์ฑ ํจ์
# def generate_image(prompt, seed=42, randomize_seed=False):
# if randomize_seed:
# seed = random.randint(0, MAX_SEED)
# generator = torch.Generator().manual_seed(seed)
# image = pipe(
# prompt=prompt,
# guidance_scale=1.5,
# num_inference_steps=6, # ์ต์ ํ๋ step ์
# width=512,
# height=512,
# generator=generator
# ).images[0]
# return image, seed
# # Gradio UI
# with gr.Blocks() as demo:
# gr.Markdown("# ๐ผ ์ด๋ฏธ์ง โ ์ค๋ช
์์ฑ โ ์นดํฐ ์ด๋ฏธ์ง ์๋ ์์ฑ๊ธฐ")
# gr.Markdown("**๐ ์ฌ์ฉ๋ฒ ์๋ด (ํ๊ตญ์ด)**\n"
# "- ์ด๋ฏธ์ง๋ฅผ ์
๋ก๋ํ๋ฉด AI๊ฐ ์ค๋ช
โ ์คํ์ผ ๋ณํ โ ์นดํฐ ์ด๋ฏธ์ง ์์ฑ๊น์ง ์๋์ผ๋ก ์ํํฉ๋๋ค.")
# with gr.Row():
# with gr.Column():
# input_img = gr.Image(label="๐จ ์๋ณธ ์ด๋ฏธ์ง ์
๋ก๋")
# run_button = gr.Button("โจ ์์ฑ ์์")
# with gr.Column():
# prompt_out = gr.Textbox(label="๐ ์คํ์ผ ์ ์ฉ๋ ํ๋กฌํํธ", lines=3, show_copy_button=True)
# output_img = gr.Image(label="๐ ์์ฑ๋ ์ด๋ฏธ์ง")
# def full_process(img):
# prompt = generate_prompt(img)
# image, seed = generate_image(prompt, randomize_seed=True)
# return prompt, image
# run_button.click(fn=full_process, inputs=[input_img], outputs=[prompt_out, output_img])
# demo.launch()
|