sonyps1928
update app
2ce4afd
raw
history blame
5.57 kB
import streamlit as st
import os
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
# Set page config
st.set_page_config(
page_title="GPT-2 Text Generator",
page_icon="πŸ€–",
layout="wide"
)
# Load environment variables
HF_TOKEN = os.getenv("HF_TOKEN")
API_KEY = os.getenv("API_KEY")
ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD")
@st.cache_resource
def load_model():
"""Load and cache the GPT-2 model"""
with st.spinner("Loading GPT-2 model..."):
try:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
return tokenizer, model
except Exception as e:
st.error(f"Error loading model: {e}")
return None, None
def generate_text(prompt, max_length, temperature, tokenizer, model):
"""Generate text using GPT-2"""
if not prompt:
return "Please enter a prompt"
if len(prompt) > 500:
return "Prompt too long (max 500 characters)"
try:
# Encode the prompt
inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=300, truncation=True)
# Generate text
with torch.no_grad():
outputs = model.generate(
inputs,
max_length=inputs.shape[1] + max_length,
temperature=temperature,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
no_repeat_ngram_size=2
)
# Decode the output
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
new_text = generated_text[len(prompt):].strip()
return new_text if new_text else "No text generated. Try a different prompt."
except Exception as e:
return f"Error generating text: {str(e)}"
def check_auth():
"""Handle authentication"""
if ADMIN_PASSWORD:
if "authenticated" not in st.session_state:
st.session_state.authenticated = False
if not st.session_state.authenticated:
st.title("πŸ”’ Authentication Required")
password = st.text_input("Enter admin password:", type="password")
if st.button("Login"):
if password == ADMIN_PASSWORD:
st.session_state.authenticated = True
st.rerun()
else:
st.error("Invalid password")
return False
return True
def main():
# Check authentication
if not check_auth():
return
# Load model
tokenizer, model = load_model()
if tokenizer is None or model is None:
st.error("Failed to load model. Please check the logs.")
return
# Main interface
st.title("πŸ€– GPT-2 Text Generator")
st.markdown("Generate text using GPT-2 language model")
# Security status
col1, col2, col3 = st.columns(3)
with col1:
if HF_TOKEN:
st.success("πŸ”‘ HF Token: Active")
else:
st.info("πŸ”‘ HF Token: Not set")
with col2:
if API_KEY:
st.success("πŸ”’ API Auth: Enabled")
else:
st.info("πŸ”’ API Auth: Disabled")
with col3:
if ADMIN_PASSWORD:
st.success("πŸ‘€ Admin Auth: Active")
else:
st.info("πŸ‘€ Admin Auth: Disabled")
# Input section
st.subheader("πŸ“ Input")
col1, col2 = st.columns([2, 1])
with col1:
prompt = st.text_area(
"Enter your prompt:",
placeholder="Type your text here...",
height=100
)
# API key input if needed
api_key = ""
if API_KEY:
api_key = st.text_input("API Key:", type="password")
with col2:
st.subheader("βš™οΈ Settings")
max_length = st.slider("Max Length", 20, 200, 100, 10)
temperature = st.slider("Temperature", 0.1, 1.5, 0.7, 0.1)
generate_btn = st.button("πŸš€ Generate Text", type="primary")
# API key validation
if API_KEY and generate_btn:
if not api_key or api_key != API_KEY:
st.error("πŸ”’ Invalid or missing API key")
return
# Generate text
if generate_btn and prompt:
with st.spinner("Generating text..."):
result = generate_text(prompt, max_length, temperature, tokenizer, model)
st.subheader("πŸ“„ Generated Text")
st.text_area("Output:", value=result, height=200)
# Copy button
st.code(result)
elif generate_btn:
st.warning("Please enter a prompt")
# Examples
st.subheader("πŸ’‘ Example Prompts")
examples = [
"Once upon a time in a distant galaxy,",
"The future of artificial intelligence is",
"In the heart of the ancient forest,",
"The detective walked into the room and noticed"
]
cols = st.columns(len(examples))
for i, example in enumerate(examples):
with cols[i]:
if st.button(f"Use Example {i+1}", key=f"ex_{i}"):
st.session_state.example_prompt = example
st.rerun()
# Use selected example
if hasattr(st.session_state, 'example_prompt'):
st.info(f"Example selected: {st.session_state.example_prompt}")
if __name__ == "__main__":
main()