File size: 3,767 Bytes
107fb80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import gradio as gr
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch


# Load model and tokenizer (using smaller GPT-2 for free tier)
model_name = "gpt2"  # You can also use "gpt2-medium" if it fits in memory
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)


# Set pad token
tokenizer.pad_token = tokenizer.eos_token


def generate_text(prompt, max_length=100, temperature=0.7, top_p=0.9, top_k=50):
    """Generate text using GPT-2"""
    try:
        # Encode input
        inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True)
        
        # Generate
        with torch.no_grad():
            outputs = model.generate(
                inputs,
                max_length=min(max_length + len(inputs[0]), 512),  # Limit total length
                temperature=temperature,
                top_p=top_p,
                top_k=top_k,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id,
                num_return_sequences=1
            )
        
        # Decode output
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Return only the new generated part
        return generated_text[len(prompt):].strip()
    
    except Exception as e:
        return f"Error generating text: {str(e)}"


# Create Gradio interface
with gr.Blocks(title="GPT-2 Text Generator") as demo:
    gr.Markdown("# GPT-2 Text Generation Server")
    gr.Markdown("Enter a prompt and generate text using GPT-2. Free tier optimized!")
    
    with gr.Row():
        with gr.Column():
            prompt_input = gr.Textbox(
                label="Prompt",
                placeholder="Enter your text prompt here...",
                lines=3
            )
            
            with gr.Row():
                max_length = gr.Slider(
                    minimum=10,
                    maximum=200,
                    value=100,
                    step=10,
                    label="Max Length"
                )
                temperature = gr.Slider(
                    minimum=0.1,
                    maximum=2.0,
                    value=0.7,
                    step=0.1,
                    label="Temperature"
                )
            
            with gr.Row():
                top_p = gr.Slider(
                    minimum=0.1,
                    maximum=1.0,
                    value=0.9,
                    step=0.1,
                    label="Top-p"
                )
                top_k = gr.Slider(
                    minimum=1,
                    maximum=100,
                    value=50,
                    step=1,
                    label="Top-k"
                )
            
            generate_btn = gr.Button("Generate Text", variant="primary")
        
        with gr.Column():
            output_text = gr.Textbox(
                label="Generated Text",
                lines=10,
                placeholder="Generated text will appear here..."
            )
    
    # Examples
    gr.Examples(
        examples=[
            ["Once upon a time in a distant galaxy,"],
            ["The future of artificial intelligence is"],
            ["In the heart of the ancient forest,"],
            ["The detective walked into the room and noticed"],
        ],
        inputs=prompt_input
    )
    
    # Connect the function with explicit API endpoint name
    generate_btn.click(
        fn=generate_text,
        inputs=[prompt_input, max_length, temperature, top_p, top_k],
        outputs=output_text,
        api_name="/predict"  # Explicit API endpoint for external calls
    )


# Launch the app
if __name__ == "__main__":
    demo.launch()