|
import gradio as gr |
|
import os |
|
import time |
|
from collections import defaultdict |
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer |
|
import torch |
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
API_KEY = os.getenv("API_KEY") |
|
ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD") |
|
|
|
print("π Security Status:") |
|
print(f" HF_TOKEN: {'β
Set' if HF_TOKEN else 'β Not set'}") |
|
print(f" API_KEY: {'β
Set' if API_KEY else 'β Not set'}") |
|
print(f" ADMIN_PASSWORD: {'β
Set' if ADMIN_PASSWORD else 'β Not set'}") |
|
|
|
|
|
request_counts = defaultdict(list) |
|
|
|
|
|
model_name = "gpt2" |
|
print("π¦ Loading model...") |
|
|
|
try: |
|
if HF_TOKEN: |
|
tokenizer = GPT2Tokenizer.from_pretrained(model_name, token=HF_TOKEN) |
|
model = GPT2LMHeadModel.from_pretrained(model_name, token=HF_TOKEN) |
|
print("β
Model loaded with HF token") |
|
else: |
|
tokenizer = GPT2Tokenizer.from_pretrained(model_name) |
|
model = GPT2LMHeadModel.from_pretrained(model_name) |
|
print("β
Model loaded without token") |
|
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
print("β
Model ready!") |
|
|
|
except Exception as e: |
|
print(f"β Model loading failed: {e}") |
|
raise |
|
|
|
def check_api_key(provided_key): |
|
"""Simple API key validation with rate limiting""" |
|
if not API_KEY: |
|
return True, "Public access" |
|
|
|
if not provided_key or provided_key != API_KEY: |
|
return False, "Invalid or missing API key" |
|
|
|
|
|
now = time.time() |
|
hour_ago = now - 3600 |
|
|
|
|
|
request_counts[provided_key] = [ |
|
t for t in request_counts[provided_key] if t > hour_ago |
|
] |
|
|
|
if len(request_counts[provided_key]) >= 100: |
|
return False, "Rate limit exceeded (100/hour)" |
|
|
|
request_counts[provided_key].append(now) |
|
return True, f"Authenticated ({len(request_counts[provided_key])}/100)" |
|
|
|
def generate_text(prompt, max_length, temperature, top_p, top_k, api_key): |
|
"""Generate text with GPT-2""" |
|
|
|
|
|
valid, msg = check_api_key(api_key) |
|
if not valid: |
|
return f"π Error: {msg}" |
|
|
|
|
|
if not prompt.strip(): |
|
return "β Please enter a prompt" |
|
|
|
if len(prompt) > 1000: |
|
return "β Prompt too long (max 1000 chars)" |
|
|
|
try: |
|
print(f"π {msg}") |
|
print(f"π Generating: {prompt[:50]}...") |
|
|
|
|
|
inputs = tokenizer.encode( |
|
prompt, |
|
return_tensors="pt", |
|
max_length=400, |
|
truncation=True |
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
inputs, |
|
max_length=min(inputs.shape[1] + max_length, 500), |
|
temperature=max(0.1, min(2.0, temperature)), |
|
top_p=max(0.1, min(1.0, top_p)), |
|
top_k=max(1, min(100, top_k)), |
|
do_sample=True, |
|
pad_token_id=tokenizer.eos_token_id, |
|
num_return_sequences=1, |
|
no_repeat_ngram_size=2 |
|
) |
|
|
|
|
|
generated = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
result = generated[len(prompt):].strip() |
|
|
|
print(f"β
Generated {len(result)} characters") |
|
return result if result else "β No text generated" |
|
|
|
except Exception as e: |
|
error = f"β Generation failed: {str(e)}" |
|
print(error) |
|
return error |
|
|
|
|
|
demo = gr.Blocks(title="GPT-2 Text Generator") |
|
|
|
with demo: |
|
|
|
gr.Markdown("# π€ GPT-2 Text Generator") |
|
|
|
|
|
if API_KEY: |
|
gr.Markdown("π **API Authentication Required**") |
|
else: |
|
gr.Markdown("π **Public Access Mode**") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
prompt = gr.Textbox( |
|
label="Prompt", |
|
placeholder="Enter your text prompt...", |
|
lines=3 |
|
) |
|
|
|
|
|
if API_KEY: |
|
api_key = gr.Textbox( |
|
label="API Key", |
|
type="password", |
|
placeholder="Enter API key..." |
|
) |
|
else: |
|
api_key = gr.Textbox(value="", visible=False) |
|
|
|
|
|
max_length = gr.Slider( |
|
10, 200, 100, |
|
label="Max Length" |
|
) |
|
temperature = gr.Slider( |
|
0.1, 2.0, 0.7, |
|
label="Temperature" |
|
) |
|
top_p = gr.Slider( |
|
0.1, 1.0, 0.9, |
|
label="Top-p" |
|
) |
|
top_k = gr.Slider( |
|
1, 100, 50, |
|
label="Top-k" |
|
) |
|
|
|
|
|
generate_btn = gr.Button("Generate", variant="primary") |
|
|
|
with gr.Column(): |
|
|
|
output = gr.Textbox( |
|
label="Generated Text", |
|
lines=10, |
|
placeholder="Generated text will appear here..." |
|
) |
|
|
|
|
|
gr.Examples([ |
|
["Once upon a time"], |
|
["The future of AI is"], |
|
["In a world where technology"], |
|
], inputs=prompt) |
|
|
|
|
|
generate_btn.click( |
|
generate_text, |
|
inputs=[prompt, max_length, temperature, top_p, top_k, api_key], |
|
outputs=output |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
auth = ("admin", ADMIN_PASSWORD) if ADMIN_PASSWORD else None |
|
|
|
if auth: |
|
print("π Admin auth enabled") |
|
|
|
print("π Starting server...") |
|
|
|
|
|
demo.launch(auth=auth) |
|
|
|
print("β
Server running!") |