File size: 4,229 Bytes
fdbb2cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247a8eb
fdbb2cb
 
 
247a8eb
 
 
fdbb2cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42889e6
fdbb2cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e16f56
fdbb2cb
 
9e16f56
fdbb2cb
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# app.py
import streamlit as st
from cerebras.cloud.sdk import Cerebras
import openai
import os
from dotenv import load_dotenv

# --- Assuming config.py and utils.py exist ---
import config
import utils

# --- BASE_PROMPT のインポート ---
try:
    from prompt import BASE_PROMPT
except ImportError:
    st.error(
        "Error: 'prompt.py' not found or 'BASE_PROMPT' is not defined within it.")
    st.stop()

# --- Import column rendering functions ---
from chat_column import render_chat_column

# --- 環境変数読み込み ---
load_dotenv()

# --- Streamlit ページ設定 ---
st.set_page_config(page_icon="🤖", layout="wide",
                   page_title="Veo3 JSON Creator")

# --- UI 表示 ---
utils.display_icon("🤖")
st.title("Veo3 JSON Creator")
st.subheader("Generate json for Veo3",
             divider="blue", anchor=False)

# --- APIキーの処理 ---
# (API Key logic remains the same)
api_key_from_env = os.getenv("CEREBRAS_API_KEY")
show_api_key_input = not bool(api_key_from_env)
cerebras_api_key = None

# --- サイドバーの設定 ---
# (Sidebar logic remains the same)
with st.sidebar:
    st.title("Settings")
    # Cerebras Key Input
    if show_api_key_input:
        st.markdown("### :red[Enter your Cerebras API Key below]")
        api_key_input = st.text_input(
            "Cerebras API Key:", type="password", key="cerebras_api_key_input_field")
        if api_key_input:
            cerebras_api_key = api_key_input
    else:
        cerebras_api_key = api_key_from_env
        st.success("✓ Cerebras API Key loaded from environment")

    # Model selection
    model_option = st.selectbox(
        "Choose a LLM model:",
        options=list(config.MODELS.keys()),
        format_func=lambda x: config.MODELS[x]["name"],
        key="model_select"
    )
    # Max tokens slider
    max_tokens_range = config.MODELS[model_option]["tokens"]
    default_tokens = min(2048, max_tokens_range)
    max_tokens = st.slider(
        "Max Tokens (LLM):",
        min_value=512,
        max_value=max_tokens_range,
        value=default_tokens,
        step=512,
        help="Max tokens for the LLM's text prompt response."
    )


# --- メインアプリケーションロジック ---
# Re-check Cerebras API key
if not cerebras_api_key and show_api_key_input and 'cerebras_api_key_input_field' in st.session_state and st.session_state.cerebras_api_key_input_field:
    cerebras_api_key = st.session_state.cerebras_api_key_input_field

if not cerebras_api_key:
    st.error("Cerebras API Key is required. Please enter it in the sidebar or set the CEREBRAS_API_KEY environment variable.", icon="🚨")
    st.stop()

# APIクライアント初期化
# (Client initialization remains the same)
llm_client = None
image_client = None
try:
    llm_client = Cerebras(api_key=cerebras_api_key)
except Exception as e:
    st.error(f"Failed to initialize API client(s): {str(e)}", icon="🚨")
    st.stop()


# --- Session State Initialization ---
# Initialize state variables if they don't exist
if "messages" not in st.session_state:
    st.session_state.messages = []
if "current_image_prompt_text" not in st.session_state:
    st.session_state.current_image_prompt_text = ""
# --- MODIFICATION START ---
# Replace single image state with a list to store multiple images and their prompts
if "generated_images_list" not in st.session_state:
    st.session_state.generated_images_list = []  # Initialize as empty list
# Remove old state variable if it exists (optional cleanup)
if "latest_generated_image" in st.session_state:
    del st.session_state["latest_generated_image"]
# --- MODIFICATION END ---
if "selected_model" not in st.session_state:
    st.session_state.selected_model = None


# --- Track selected model, but do not clear chat or image state on model change ---
if st.session_state.selected_model != model_option:
    st.session_state.selected_model = model_option
    # Optionally rerun to update UI, but do not clear messages or images
    st.rerun()

# --- Define Main Columns ---
chat_col, image_col = st.columns([2, 1])

# --- Render Columns using imported functions ---
with chat_col:
    render_chat_column(st, llm_client, model_option, max_tokens, BASE_PROMPT)