File size: 5,569 Bytes
2ce4afd
ad32177
cef31a4
ad32177
 
2ce4afd
 
 
 
 
 
b4a4c25
 
ad32177
b4a4c25
ad32177
 
2ce4afd
 
 
 
 
 
 
 
 
 
 
 
8511f5e
2ce4afd
 
b4a4c25
 
40bbb95
b4a4c25
 
ad32177
40bbb95
b4a4c25
 
cef31a4
b4a4c25
cef31a4
 
 
b4a4c25
 
cef31a4
 
b4a4c25
8511f5e
cef31a4
 
b4a4c25
 
 
 
 
40bbb95
 
2ce4afd
cef31a4
2ce4afd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8511f5e
2ce4afd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cef31a4
2ce4afd
 
 
 
cef31a4
2ce4afd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4a4c25
ad32177
2ce4afd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import streamlit as st
import os
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch

# Set page config
st.set_page_config(
    page_title="GPT-2 Text Generator",
    page_icon="πŸ€–",
    layout="wide"
)

# Load environment variables
HF_TOKEN = os.getenv("HF_TOKEN")
API_KEY = os.getenv("API_KEY") 
ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD")

@st.cache_resource
def load_model():
    """Load and cache the GPT-2 model"""
    with st.spinner("Loading GPT-2 model..."):
        try:
            tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
            model = GPT2LMHeadModel.from_pretrained("gpt2")
            tokenizer.pad_token = tokenizer.eos_token
            return tokenizer, model
        except Exception as e:
            st.error(f"Error loading model: {e}")
            return None, None

def generate_text(prompt, max_length, temperature, tokenizer, model):
    """Generate text using GPT-2"""
    if not prompt:
        return "Please enter a prompt"
    
    if len(prompt) > 500:
        return "Prompt too long (max 500 characters)"
    
    try:
        # Encode the prompt
        inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=300, truncation=True)
        
        # Generate text
        with torch.no_grad():
            outputs = model.generate(
                inputs,
                max_length=inputs.shape[1] + max_length,
                temperature=temperature,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id,
                eos_token_id=tokenizer.eos_token_id,
                no_repeat_ngram_size=2
            )
        
        # Decode the output
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        new_text = generated_text[len(prompt):].strip()
        
        return new_text if new_text else "No text generated. Try a different prompt."
        
    except Exception as e:
        return f"Error generating text: {str(e)}"

def check_auth():
    """Handle authentication"""
    if ADMIN_PASSWORD:
        if "authenticated" not in st.session_state:
            st.session_state.authenticated = False
        
        if not st.session_state.authenticated:
            st.title("πŸ”’ Authentication Required")
            password = st.text_input("Enter admin password:", type="password")
            if st.button("Login"):
                if password == ADMIN_PASSWORD:
                    st.session_state.authenticated = True
                    st.rerun()
                else:
                    st.error("Invalid password")
            return False
    return True

def main():
    # Check authentication
    if not check_auth():
        return
    
    # Load model
    tokenizer, model = load_model()
    if tokenizer is None or model is None:
        st.error("Failed to load model. Please check the logs.")
        return
    
    # Main interface
    st.title("πŸ€– GPT-2 Text Generator")
    st.markdown("Generate text using GPT-2 language model")
    
    # Security status
    col1, col2, col3 = st.columns(3)
    with col1:
        if HF_TOKEN:
            st.success("πŸ”‘ HF Token: Active")
        else:
            st.info("πŸ”‘ HF Token: Not set")
    
    with col2:
        if API_KEY:
            st.success("πŸ”’ API Auth: Enabled")
        else:
            st.info("πŸ”’ API Auth: Disabled")
    
    with col3:
        if ADMIN_PASSWORD:
            st.success("πŸ‘€ Admin Auth: Active")
        else:
            st.info("πŸ‘€ Admin Auth: Disabled")
    
    # Input section
    st.subheader("πŸ“ Input")
    
    col1, col2 = st.columns([2, 1])
    
    with col1:
        prompt = st.text_area(
            "Enter your prompt:",
            placeholder="Type your text here...",
            height=100
        )
        
        # API key input if needed
        api_key = ""
        if API_KEY:
            api_key = st.text_input("API Key:", type="password")
    
    with col2:
        st.subheader("βš™οΈ Settings")
        max_length = st.slider("Max Length", 20, 200, 100, 10)
        temperature = st.slider("Temperature", 0.1, 1.5, 0.7, 0.1)
        
        generate_btn = st.button("πŸš€ Generate Text", type="primary")
    
    # API key validation
    if API_KEY and generate_btn:
        if not api_key or api_key != API_KEY:
            st.error("πŸ”’ Invalid or missing API key")
            return
    
    # Generate text
    if generate_btn and prompt:
        with st.spinner("Generating text..."):
            result = generate_text(prompt, max_length, temperature, tokenizer, model)
        
        st.subheader("πŸ“„ Generated Text")
        st.text_area("Output:", value=result, height=200)
        
        # Copy button
        st.code(result)
    
    elif generate_btn:
        st.warning("Please enter a prompt")
    
    # Examples
    st.subheader("πŸ’‘ Example Prompts")
    examples = [
        "Once upon a time in a distant galaxy,",
        "The future of artificial intelligence is",
        "In the heart of the ancient forest,",
        "The detective walked into the room and noticed"
    ]
    
    cols = st.columns(len(examples))
    for i, example in enumerate(examples):
        with cols[i]:
            if st.button(f"Use Example {i+1}", key=f"ex_{i}"):
                st.session_state.example_prompt = example
                st.rerun()
    
    # Use selected example
    if hasattr(st.session_state, 'example_prompt'):
        st.info(f"Example selected: {st.session_state.example_prompt}")

if __name__ == "__main__":
    main()