Spaces:
Running
Running
| # app.py | |
| import streamlit as st | |
| from cerebras.cloud.sdk import Cerebras | |
| import openai | |
| import os | |
| from dotenv import load_dotenv | |
| # --- Assuming config.py and utils.py exist --- | |
| import config | |
| import utils | |
| # --- BASE_PROMPT のインポート --- | |
| try: | |
| from prompt import BASE_PROMPT | |
| except ImportError: | |
| st.error( | |
| "Error: 'prompt.py' not found or 'BASE_PROMPT' is not defined within it.") | |
| st.stop() | |
| # --- Import column rendering functions --- | |
| from chat_column import render_chat_column | |
| # --- 環境変数読み込み --- | |
| load_dotenv() | |
| # --- Streamlit ページ設定 --- | |
| st.set_page_config(page_icon="🤖", layout="wide", | |
| page_title="Prompt & Image Generator") | |
| # --- UI 表示 --- | |
| utils.display_icon("🤖") | |
| st.title("Prompt & Image Generator") | |
| st.subheader("Generate text prompts (left) and edit/generate images (right)", | |
| divider="orange", anchor=False) | |
| # --- APIキーの処理 --- | |
| # (API Key logic remains the same) | |
| api_key_from_env = os.getenv("CEREBRAS_API_KEY") | |
| show_api_key_input = not bool(api_key_from_env) | |
| cerebras_api_key = None | |
| # --- サイドバーの設定 --- | |
| # (Sidebar logic remains the same) | |
| with st.sidebar: | |
| st.title("Settings") | |
| # Cerebras Key Input | |
| if show_api_key_input: | |
| st.markdown("### :red[Enter your Cerebras API Key below]") | |
| api_key_input = st.text_input( | |
| "Cerebras API Key:", type="password", key="cerebras_api_key_input_field") | |
| if api_key_input: | |
| cerebras_api_key = api_key_input | |
| else: | |
| cerebras_api_key = api_key_from_env | |
| st.success("✓ Cerebras API Key loaded from environment") | |
| # Model selection | |
| model_option = st.selectbox( | |
| "Choose a LLM model:", | |
| options=list(config.MODELS.keys()), | |
| format_func=lambda x: config.MODELS[x]["name"], | |
| key="model_select" | |
| ) | |
| # Max tokens slider | |
| max_tokens_range = config.MODELS[model_option]["tokens"] | |
| default_tokens = min(2048, max_tokens_range) | |
| max_tokens = st.slider( | |
| "Max Tokens (LLM):", | |
| min_value=512, | |
| max_value=max_tokens_range, | |
| value=default_tokens, | |
| step=512, | |
| help="Max tokens for the LLM's text prompt response." | |
| ) | |
| use_optillm = st.toggle( | |
| "Use Optillm (for Cerebras)", value=False) | |
| # --- メインアプリケーションロジック --- | |
| # Re-check Cerebras API key | |
| if not cerebras_api_key and show_api_key_input and 'cerebras_api_key_input_field' in st.session_state and st.session_state.cerebras_api_key_input_field: | |
| cerebras_api_key = st.session_state.cerebras_api_key_input_field | |
| if not cerebras_api_key: | |
| st.error("Cerebras API Key is required. Please enter it in the sidebar or set the CEREBRAS_API_KEY environment variable.", icon="🚨") | |
| st.stop() | |
| # APIクライアント初期化 | |
| # (Client initialization remains the same) | |
| llm_client = None | |
| image_client = None | |
| try: | |
| if use_optillm: | |
| if not hasattr(config, 'BASE_URL') or not config.BASE_URL: | |
| st.error("Optillm selected, but BASE_URL is not configured.", icon="🚨") | |
| st.stop() | |
| llm_client = openai.OpenAI( | |
| base_url=config.BASE_URL, api_key=cerebras_api_key) | |
| else: | |
| llm_client = Cerebras(api_key=cerebras_api_key) | |
| except Exception as e: | |
| st.error(f"Failed to initialize API client(s): {str(e)}", icon="🚨") | |
| st.stop() | |
| # --- Session State Initialization --- | |
| # Initialize state variables if they don't exist | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| if "current_image_prompt_text" not in st.session_state: | |
| st.session_state.current_image_prompt_text = "" | |
| # --- MODIFICATION START --- | |
| # Replace single image state with a list to store multiple images and their prompts | |
| if "generated_images_list" not in st.session_state: | |
| st.session_state.generated_images_list = [] # Initialize as empty list | |
| # Remove old state variable if it exists (optional cleanup) | |
| if "latest_generated_image" in st.session_state: | |
| del st.session_state["latest_generated_image"] | |
| # --- MODIFICATION END --- | |
| if "selected_model" not in st.session_state: | |
| st.session_state.selected_model = None | |
| # --- Clear history if model changes --- | |
| if st.session_state.selected_model != model_option: | |
| st.session_state.messages = [] | |
| st.session_state.current_image_prompt_text = "" | |
| # --- MODIFICATION START --- | |
| # Clear the list of generated images when model changes | |
| st.session_state.generated_images_list = [] | |
| # --- MODIFICATION END --- | |
| st.session_state.selected_model = model_option | |
| st.rerun() | |
| # --- Define Main Columns --- | |
| chat_col, image_col = st.columns([2, 1]) | |
| # --- Render Columns using imported functions --- | |
| with chat_col: | |
| render_chat_column(st, llm_client, model_option, max_tokens, BASE_PROMPT) | |