import streamlit as st import os from transformers import GPT2LMHeadModel, GPT2Tokenizer import torch # ---------------------------- # Page config # ---------------------------- st.set_page_config( page_title="GPT-2 Text Generator", page_icon="🤖", layout="wide" ) # ---------------------------- # Load environment variables # ---------------------------- HF_TOKEN = os.getenv("HF_TOKEN") API_KEY = os.getenv("API_KEY") ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD") # ---------------------------- # Model loading # ---------------------------- @st.cache_resource def load_model(): """Load and cache the GPT-2 model""" with st.spinner("Loading GPT-2 model..."): try: tokenizer = GPT2Tokenizer.from_pretrained("gpt2") model = GPT2LMHeadModel.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.eos_token return tokenizer, model except Exception as e: st.error(f"Error loading model: {e}") return None, None # ---------------------------- # Text generation # ---------------------------- def generate_text(prompt, max_length, temperature, tokenizer, model): """Generate text using GPT-2""" if not prompt: return "Please enter a prompt" if len(prompt) > 500: return "Prompt too long (max 500 characters)" try: inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=300, truncation=True) with torch.no_grad(): outputs = model.generate( inputs, max_length=inputs.shape[1] + max_length, temperature=temperature, do_sample=True, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, no_repeat_ngram_size=2 ) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) new_text = generated_text[len(prompt):].strip() return new_text if new_text else "No text generated. Try a different prompt." except Exception as e: return f"Error generating text: {str(e)}" # ---------------------------- # Authentication # ---------------------------- def check_auth(): """Handle authentication""" if ADMIN_PASSWORD: if "authenticated" not in st.session_state: st.session_state.authenticated = False if not st.session_state.authenticated: st.title("🔒 Authentication Required") password = st.text_input("Enter admin password:", type="password") if st.button("Login"): if password == ADMIN_PASSWORD: st.session_state.authenticated = True st.experimental_rerun() else: st.error("Invalid password") return False return True # ---------------------------- # Main UI # ---------------------------- def main(): if not check_auth(): return tokenizer, model = load_model() if tokenizer is None or model is None: st.error("Failed to load model. Please check the logs.") return st.title("🤖 GPT-2 Text Generator") st.markdown("Generate text using GPT-2 language model") # Security status col1, col2, col3 = st.columns(3) with col1: st.success("🔑 HF Token: Active" if HF_TOKEN else "🔑 HF Token: Not set") with col2: st.success("🔒 API Auth: Enabled" if API_KEY else "🔒 API Auth: Disabled") with col3: st.success("👤 Admin Auth: Active" if ADMIN_PASSWORD else "👤 Admin Auth: Disabled") # Input section st.subheader("📝 Input") col1, col2 = st.columns([2, 1]) with col1: prompt = st.text_area( "Enter your prompt:", placeholder="Type your text here...", height=100 ) api_key = "" if API_KEY: api_key = st.text_input("API Key:", type="password") with col2: st.subheader("⚙️ Settings") max_length = st.slider("Max Length", 20, 200, 100, 10) temperature = st.slider("Temperature", 0.1, 1.5, 0.7, 0.1) generate_btn = st.button("🚀 Generate Text", type="primary") # API key validation if API_KEY and generate_btn: if not api_key or api_key != API_KEY: st.error("🔒 Invalid or missing API key") return # Generate text if generate_btn and prompt: with st.spinner("Generating text..."): result = generate_text(prompt, max_length, temperature, tokenizer, model) st.subheader("📄 Generated Text") st.text_area("Output:", value=result, height=200) st.code(result) elif generate_btn: st.warning("Please enter a prompt") # Example prompts st.subheader("💡 Example Prompts") examples = [ "Once upon a time in a distant galaxy,", "The future of artificial intelligence is", "In the heart of the ancient forest,", "The detective walked into the room and noticed" ] cols = st.columns(len(examples)) for i, example in enumerate(examples): with cols[i]: if st.button(f"Use Example {i+1}", key=f"ex_{i}"): st.session_state.example_prompt = example st.experimental_rerun() if hasattr(st.session_state, 'example_prompt'): st.info(f"Example selected: {st.session_state.example_prompt}") if __name__ == "__main__": main()