|
import streamlit as st |
|
import os |
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer |
|
import torch |
|
|
|
|
|
st.set_page_config( |
|
page_title="GPT-2 Text Generator", |
|
page_icon="π€", |
|
layout="wide" |
|
) |
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
API_KEY = os.getenv("API_KEY") |
|
ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD") |
|
|
|
@st.cache_resource |
|
def load_model(): |
|
"""Load and cache the GPT-2 model""" |
|
with st.spinner("Loading GPT-2 model..."): |
|
try: |
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
|
model = GPT2LMHeadModel.from_pretrained("gpt2") |
|
tokenizer.pad_token = tokenizer.eos_token |
|
return tokenizer, model |
|
except Exception as e: |
|
st.error(f"Error loading model: {e}") |
|
return None, None |
|
|
|
def generate_text(prompt, max_length, temperature, tokenizer, model): |
|
"""Generate text using GPT-2""" |
|
if not prompt: |
|
return "Please enter a prompt" |
|
|
|
if len(prompt) > 500: |
|
return "Prompt too long (max 500 characters)" |
|
|
|
try: |
|
|
|
inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=300, truncation=True) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
inputs, |
|
max_length=inputs.shape[1] + max_length, |
|
temperature=temperature, |
|
do_sample=True, |
|
pad_token_id=tokenizer.eos_token_id, |
|
eos_token_id=tokenizer.eos_token_id, |
|
no_repeat_ngram_size=2 |
|
) |
|
|
|
|
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
new_text = generated_text[len(prompt):].strip() |
|
|
|
return new_text if new_text else "No text generated. Try a different prompt." |
|
|
|
except Exception as e: |
|
return f"Error generating text: {str(e)}" |
|
|
|
def check_auth(): |
|
"""Handle authentication""" |
|
if ADMIN_PASSWORD: |
|
if "authenticated" not in st.session_state: |
|
st.session_state.authenticated = False |
|
|
|
if not st.session_state.authenticated: |
|
st.title("π Authentication Required") |
|
password = st.text_input("Enter admin password:", type="password") |
|
if st.button("Login"): |
|
if password == ADMIN_PASSWORD: |
|
st.session_state.authenticated = True |
|
st.rerun() |
|
else: |
|
st.error("Invalid password") |
|
return False |
|
return True |
|
|
|
def main(): |
|
|
|
if not check_auth(): |
|
return |
|
|
|
|
|
tokenizer, model = load_model() |
|
if tokenizer is None or model is None: |
|
st.error("Failed to load model. Please check the logs.") |
|
return |
|
|
|
|
|
st.title("π€ GPT-2 Text Generator") |
|
st.markdown("Generate text using GPT-2 language model") |
|
|
|
|
|
col1, col2, col3 = st.columns(3) |
|
with col1: |
|
if HF_TOKEN: |
|
st.success("π HF Token: Active") |
|
else: |
|
st.info("π HF Token: Not set") |
|
|
|
with col2: |
|
if API_KEY: |
|
st.success("π API Auth: Enabled") |
|
else: |
|
st.info("π API Auth: Disabled") |
|
|
|
with col3: |
|
if ADMIN_PASSWORD: |
|
st.success("π€ Admin Auth: Active") |
|
else: |
|
st.info("π€ Admin Auth: Disabled") |
|
|
|
|
|
st.subheader("π Input") |
|
|
|
col1, col2 = st.columns([2, 1]) |
|
|
|
with col1: |
|
prompt = st.text_area( |
|
"Enter your prompt:", |
|
placeholder="Type your text here...", |
|
height=100 |
|
) |
|
|
|
|
|
api_key = "" |
|
if API_KEY: |
|
api_key = st.text_input("API Key:", type="password") |
|
|
|
with col2: |
|
st.subheader("βοΈ Settings") |
|
max_length = st.slider("Max Length", 20, 200, 100, 10) |
|
temperature = st.slider("Temperature", 0.1, 1.5, 0.7, 0.1) |
|
|
|
generate_btn = st.button("π Generate Text", type="primary") |
|
|
|
|
|
if API_KEY and generate_btn: |
|
if not api_key or api_key != API_KEY: |
|
st.error("π Invalid or missing API key") |
|
return |
|
|
|
|
|
if generate_btn and prompt: |
|
with st.spinner("Generating text..."): |
|
result = generate_text(prompt, max_length, temperature, tokenizer, model) |
|
|
|
st.subheader("π Generated Text") |
|
st.text_area("Output:", value=result, height=200) |
|
|
|
|
|
st.code(result) |
|
|
|
elif generate_btn: |
|
st.warning("Please enter a prompt") |
|
|
|
|
|
st.subheader("π‘ Example Prompts") |
|
examples = [ |
|
"Once upon a time in a distant galaxy,", |
|
"The future of artificial intelligence is", |
|
"In the heart of the ancient forest,", |
|
"The detective walked into the room and noticed" |
|
] |
|
|
|
cols = st.columns(len(examples)) |
|
for i, example in enumerate(examples): |
|
with cols[i]: |
|
if st.button(f"Use Example {i+1}", key=f"ex_{i}"): |
|
st.session_state.example_prompt = example |
|
st.rerun() |
|
|
|
|
|
if hasattr(st.session_state, 'example_prompt'): |
|
st.info(f"Example selected: {st.session_state.example_prompt}") |
|
|
|
if __name__ == "__main__": |
|
main() |