File size: 5,569 Bytes
2ce4afd ad32177 cef31a4 ad32177 2ce4afd b4a4c25 ad32177 b4a4c25 ad32177 2ce4afd 8511f5e 2ce4afd b4a4c25 40bbb95 b4a4c25 ad32177 40bbb95 b4a4c25 cef31a4 b4a4c25 cef31a4 b4a4c25 cef31a4 b4a4c25 8511f5e cef31a4 b4a4c25 40bbb95 2ce4afd cef31a4 2ce4afd 8511f5e 2ce4afd cef31a4 2ce4afd cef31a4 2ce4afd b4a4c25 ad32177 2ce4afd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
import streamlit as st
import os
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
# Set page config
st.set_page_config(
page_title="GPT-2 Text Generator",
page_icon="π€",
layout="wide"
)
# Load environment variables
HF_TOKEN = os.getenv("HF_TOKEN")
API_KEY = os.getenv("API_KEY")
ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD")
@st.cache_resource
def load_model():
"""Load and cache the GPT-2 model"""
with st.spinner("Loading GPT-2 model..."):
try:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
return tokenizer, model
except Exception as e:
st.error(f"Error loading model: {e}")
return None, None
def generate_text(prompt, max_length, temperature, tokenizer, model):
"""Generate text using GPT-2"""
if not prompt:
return "Please enter a prompt"
if len(prompt) > 500:
return "Prompt too long (max 500 characters)"
try:
# Encode the prompt
inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=300, truncation=True)
# Generate text
with torch.no_grad():
outputs = model.generate(
inputs,
max_length=inputs.shape[1] + max_length,
temperature=temperature,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
no_repeat_ngram_size=2
)
# Decode the output
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
new_text = generated_text[len(prompt):].strip()
return new_text if new_text else "No text generated. Try a different prompt."
except Exception as e:
return f"Error generating text: {str(e)}"
def check_auth():
"""Handle authentication"""
if ADMIN_PASSWORD:
if "authenticated" not in st.session_state:
st.session_state.authenticated = False
if not st.session_state.authenticated:
st.title("π Authentication Required")
password = st.text_input("Enter admin password:", type="password")
if st.button("Login"):
if password == ADMIN_PASSWORD:
st.session_state.authenticated = True
st.rerun()
else:
st.error("Invalid password")
return False
return True
def main():
# Check authentication
if not check_auth():
return
# Load model
tokenizer, model = load_model()
if tokenizer is None or model is None:
st.error("Failed to load model. Please check the logs.")
return
# Main interface
st.title("π€ GPT-2 Text Generator")
st.markdown("Generate text using GPT-2 language model")
# Security status
col1, col2, col3 = st.columns(3)
with col1:
if HF_TOKEN:
st.success("π HF Token: Active")
else:
st.info("π HF Token: Not set")
with col2:
if API_KEY:
st.success("π API Auth: Enabled")
else:
st.info("π API Auth: Disabled")
with col3:
if ADMIN_PASSWORD:
st.success("π€ Admin Auth: Active")
else:
st.info("π€ Admin Auth: Disabled")
# Input section
st.subheader("π Input")
col1, col2 = st.columns([2, 1])
with col1:
prompt = st.text_area(
"Enter your prompt:",
placeholder="Type your text here...",
height=100
)
# API key input if needed
api_key = ""
if API_KEY:
api_key = st.text_input("API Key:", type="password")
with col2:
st.subheader("βοΈ Settings")
max_length = st.slider("Max Length", 20, 200, 100, 10)
temperature = st.slider("Temperature", 0.1, 1.5, 0.7, 0.1)
generate_btn = st.button("π Generate Text", type="primary")
# API key validation
if API_KEY and generate_btn:
if not api_key or api_key != API_KEY:
st.error("π Invalid or missing API key")
return
# Generate text
if generate_btn and prompt:
with st.spinner("Generating text..."):
result = generate_text(prompt, max_length, temperature, tokenizer, model)
st.subheader("π Generated Text")
st.text_area("Output:", value=result, height=200)
# Copy button
st.code(result)
elif generate_btn:
st.warning("Please enter a prompt")
# Examples
st.subheader("π‘ Example Prompts")
examples = [
"Once upon a time in a distant galaxy,",
"The future of artificial intelligence is",
"In the heart of the ancient forest,",
"The detective walked into the room and noticed"
]
cols = st.columns(len(examples))
for i, example in enumerate(examples):
with cols[i]:
if st.button(f"Use Example {i+1}", key=f"ex_{i}"):
st.session_state.example_prompt = example
st.rerun()
# Use selected example
if hasattr(st.session_state, 'example_prompt'):
st.info(f"Example selected: {st.session_state.example_prompt}")
if __name__ == "__main__":
main() |