File size: 2,804 Bytes
06a81df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import streamlit as st
import requests
import base64
from PIL import Image
from io import BytesIO

# ───────────────────────────────
# CONFIG
# ───────────────────────────────
st.set_page_config(page_title="동화 μ‚½ν™” 생성기", layout="wide")
API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-Kontext-dev"
headers = {"Authorization": f"Bearer {st.secrets['HF_TOKEN']}"}

# ───────────────────────────────
# UTILS
# ───────────────────────────────
def query_flux(prompt, image_bytes=None):
    inputs = {"prompt": prompt}
    if image_bytes:
        inputs["image"] = base64.b64encode(image_bytes).decode("utf-8")

    response = requests.post(API_URL, headers=headers, json=inputs)
    if response.status_code == 200:
        image_data = base64.b64decode(response.json()["image"])
        return Image.open(BytesIO(image_data))
    else:
        st.error(f"Error: {response.status_code} - {response.text}")
        return None

# ───────────────────────────────
# UI
# ───────────────────────────────
st.title("πŸ“š 동화 μ‚½ν™” 생성기 (FLUX.1 Kontext)")
st.markdown("각 μž₯면의 μ„€λͺ…을 μž…λ ₯ν•˜λ©΄, 이야기 흐름에 λ§žλŠ” μ‚½ν™”λ₯Ό μžλ™μœΌλ‘œ μƒμ„±ν•©λ‹ˆλ‹€.")

with st.expander("λ“±μž₯인물 및 μŠ€νƒ€μΌ 정보 (선택)"):
    character_prompt = st.text_area("λ“±μž₯인물 μ„€λͺ… (ex. λΉ¨κ°„ 망토λ₯Ό μ“΄ μ†Œλ…€, νšŒμƒ‰ λŠ‘λŒ€ λ“±)", height=100)
    reference_image = st.file_uploader("μ°Έμ‘° 이미지 μ—…λ‘œλ“œ (선택)", type=["jpg", "png"])

st.markdown("### πŸ“ 동화 λ‚΄μš©μ„ 5개 μž₯면으둜 λ‚˜λˆ„μ–΄ μž…λ ₯ν•˜μ„Έμš”")

scene_prompts = []
for i in range(5):
    text = st.text_area(f"μž₯λ©΄ {i+1} μ„€λͺ…", key=f"scene_{i}", height=80)
    scene_prompts.append(text)

if st.button("🎨 μ‚½ν™” μƒμ„±ν•˜κΈ°"):
    if not any(scene_prompts):
        st.warning("μ΅œμ†Œ ν•˜λ‚˜ μ΄μƒμ˜ μž₯λ©΄ μ„€λͺ…이 ν•„μš”ν•©λ‹ˆλ‹€.")
    else:
        ref_bytes = reference_image.read() if reference_image else None

        with st.spinner("이미지λ₯Ό 생성 μ€‘μž…λ‹ˆλ‹€..."):
            cols = st.columns(5)
            for i, prompt in enumerate(scene_prompts):
                if not prompt.strip():
                    continue

                full_prompt = f"{character_prompt}. {prompt}" if character_prompt else prompt
                img = query_flux(full_prompt, ref_bytes)
                if img:
                    cols[i].image(img, caption=f"μž₯λ©΄ {i+1}")