mk3d / app.py
yongyeol's picture
Update app.py
cc80057 verified
# ────────────────────────────────────────────────────────────────────────────
# app.py – Text ➜ 2D (FLUX-mini Kontext-dev) ➜ 3D (Hunyuan3D-2)
# • Fits into ≈16 GB system RAM: lightweight models + lazy loading + offload
# • 2025-07-07: fixed repo names, added HF token + trust_remote_code, cleaned logs
# ────────────────────────────────────────────────────────────────────────────
import os
import tempfile
from typing import List
import gradio as gr
import torch
from PIL import Image
from huggingface_hub import login
# ─────────────────────── Auth ───────────────────────
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
raise RuntimeError(
"HF_TOKEN이 설정되지 않았습니다. Space Settings → Secrets에서 "
"HF_TOKEN=<your_read_token> 을 등록한 뒤 재시작하세요."
)
login(token=HF_TOKEN, add_to_git_credential=False)
# ─────────────────────── Device & dtype ───────────────────────
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
# ─────────────────────── Lazy loaders ───────────────────────
from diffusers import FluxKontextPipeline, FluxPipeline
# Global caches
kontext_pipe: FluxKontextPipeline | None = None
_text2img_pipe: FluxPipeline | None = None
shape_pipe = None
paint_pipe = None
# Repository names (공개 버전)
MINI_KONTEXT_REPO = "black-forest-labs/FLUX.1-Kontext-dev" # 이미지 편집/확장용
MINI_T2I_REPO = "black-forest-labs/FLUX.1-schnell" # 텍스트→이미지(4-step distilled)
HUNYUAN_REPO = "tencent/Hunyuan3D-2" # 3D shape & paint
DEVICE_MAP_STRATEGY = "balanced" # "auto"(offload) 미지원, so use "balanced"
# ──────────────────────────── Loaders ────────────────────────────
def load_kontext() -> FluxKontextPipeline:
"""Lazy-load FLUX.1-Kontext-dev (image-to-image editing)."""
global kontext_pipe
if kontext_pipe is None:
print("[+] Loading FLUX.1-Kontext-dev … (balanced offload)")
kontext_pipe = FluxKontextPipeline.from_pretrained(
MINI_KONTEXT_REPO,
torch_dtype=DTYPE,
device_map=DEVICE_MAP_STRATEGY,
low_cpu_mem_usage=True,
token=HF_TOKEN,
trust_remote_code=True,
)
kontext_pipe.set_progress_bar_config(disable=True)
return kontext_pipe
def load_text2img() -> FluxPipeline:
"""Lazy-load FLUX.1-schnell (text-to-image)."""
global _text2img_pipe
if _text2img_pipe is None:
print("[+] Loading FLUX.1-schnell (text→image)…")
_text2img_pipe = FluxPipeline.from_pretrained(
MINI_T2I_REPO,
torch_dtype=DTYPE,
device_map=DEVICE_MAP_STRATEGY,
low_cpu_mem_usage=True,
token=HF_TOKEN,
trust_remote_code=True,
)
_text2img_pipe.set_progress_bar_config(disable=True)
return _text2img_pipe
def load_hunyuan():
"""Lazy-load Hunyuan3D-2 shape & texture pipelines."""
global shape_pipe, paint_pipe
if shape_pipe is None or paint_pipe is None:
print("[+] Loading Hunyuan3D-2 (shape & texture)…")
from hy3dgen.shapegen import Hunyuan3DDiTFlowMatchingPipeline
from hy3dgen.texgen import Hunyuan3DPaintPipeline
shape_pipe = Hunyuan3DDiTFlowMatchingPipeline.from_pretrained(
HUNYUAN_REPO,
torch_dtype=DTYPE,
device_map=DEVICE_MAP_STRATEGY,
low_cpu_mem_usage=True,
token=HF_TOKEN,
trust_remote_code=True,
)
shape_pipe.set_progress_bar_config(disable=True)
paint_pipe = Hunyuan3DPaintPipeline.from_pretrained(
HUNYUAN_REPO,
torch_dtype=DTYPE,
device_map=DEVICE_MAP_STRATEGY,
low_cpu_mem_usage=True,
token=HF_TOKEN,
trust_remote_code=True,
)
paint_pipe.set_progress_bar_config(disable=True)
return shape_pipe, paint_pipe
# ───────────────────────────── Helpers ─────────────────────────────
def generate_single_2d(prompt: str, image: Image.Image | None, guidance_scale: float) -> Image.Image:
"""Generate a single 2D image (txt2img or img2img)."""
if image is None:
t2i = load_text2img()
return t2i(prompt=prompt, guidance_scale=guidance_scale).images[0]
kontext = load_kontext()
return kontext(image=image, prompt=prompt, guidance_scale=guidance_scale).images[0]
def generate_multiview(prompt: str, base_image: Image.Image, guidance_scale: float) -> List[Image.Image]:
"""Generate 4-view images for better 3D reconstruction."""
kontext = load_kontext()
return [
base_image,
kontext(image=base_image, prompt=f"{prompt}, left side view", guidance_scale=guidance_scale).images[0],
kontext(image=base_image, prompt=f"{prompt}, right side view", guidance_scale=guidance_scale).images[0],
kontext(image=base_image, prompt=f"{prompt}, back view", guidance_scale=guidance_scale).images[0],
]
def build_3d_mesh(prompt: str, images: List[Image.Image]) -> str:
"""Create GLB mesh from single or multi-view images."""
shape, paint = load_hunyuan()
source = images if len(images) > 1 else images[0]
mesh = shape(image=source, prompt=prompt)[0]
mesh = paint(mesh, image=source) # texture painting
tmpdir = tempfile.mkdtemp()
out_path = os.path.join(tmpdir, "mesh.glb")
mesh.export(out_path)
return out_path
# ──────────────────────────────── UI ────────────────────────────────
CSS = """footer {visibility:hidden;}"""
def workflow(prompt: str, input_image: Image.Image | None, multiview: bool, guidance_scale: float):
if not prompt:
raise gr.Error("프롬프트(설명)를 입력하세요 📌")
base_img = generate_single_2d(prompt, input_image, guidance_scale)
images = generate_multiview(prompt, base_img, guidance_scale) if multiview else [base_img]
model_path = build_3d_mesh(prompt, images)
return images, model_path, model_path
def build_ui():
with gr.Blocks(css=CSS, title="Text ➜ 2D ➜ 3D (mini)") as demo:
gr.Markdown("# 🌀 텍스트 → 2D → 3D 생성기 (경량 버전)")
gr.Markdown("Kontext-dev + Hunyuan3D-2. 16 GB RAM에서도 동작합니다.")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="프롬프트 / 설명", placeholder="예: 파란 모자를 쓴 귀여운 로봇")
input_image = gr.Image(label="(선택) 편집할 참조 이미지", type="pil")
multiview = gr.Checkbox(label="멀티뷰(좌/우/후면 포함)", value=True)
guidance = gr.Slider(0.5, 7.5, 2.5, step=0.1, label="Guidance Scale")
run_btn = gr.Button("🚀 생성하기", variant="primary")
with gr.Column():
gallery = gr.Gallery(label="🎨 2D 결과", columns=2, height="auto")
model3d = gr.Model3D(label="🧱 3D 미리보기", clear_color=[1, 1, 1, 0])
download = gr.File(label="⬇️ GLB 다운로드")
run_btn.click(
fn=workflow,
inputs=[prompt, input_image, multiview, guidance],
outputs=[gallery, model3d, download],
api_name="generate",
scroll_to_output=True,
show_progress="full",
)
return demo
if __name__ == "__main__":
build_ui().queue(max_size=3).launch()