File size: 3,767 Bytes
760431c
e27c591
 
1b3fa51
36fde64
760431c
 
7fe97c0
e27c591
760431c
 
 
e27c591
1b3fa51
36fde64
e27c591
 
1b3fa51
e27c591
 
1b3fa51
e27c591
 
 
 
760431c
e27c591
 
 
 
 
 
1b3fa51
 
e27c591
 
1b3fa51
e27c591
 
 
1b3fa51
760431c
107fb80
7fe97c0
760431c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7fe97c0
760431c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adb694f
36fde64
760431c
ad32177
760431c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import gradio as gr
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch


# Load model and tokenizer (using smaller GPT-2 for free tier)
model_name = "gpt2"  # You can also use "gpt2-medium" if it fits in memory
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)


# Set pad token
tokenizer.pad_token = tokenizer.eos_token


def generate_text(prompt, max_length=100, temperature=0.7, top_p=0.9, top_k=50):
    """Generate text using GPT-2"""
    try:
        # Encode input
        inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True)
        
        # Generate
        with torch.no_grad():
            outputs = model.generate(
                inputs,
                max_length=min(max_length + len(inputs[0]), 512),  # Limit total length
                temperature=temperature,
                top_p=top_p,
                top_k=top_k,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id,
                num_return_sequences=1
            )
        
        # Decode output
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Return only the new generated part
        return generated_text[len(prompt):].strip()
    
    except Exception as e:
        return f"Error generating text: {str(e)}"


# Create Gradio interface
with gr.Blocks(title="GPT-2 Text Generator") as demo:
    gr.Markdown("# GPT-2 Text Generation Server")
    gr.Markdown("Enter a prompt and generate text using GPT-2. Free tier optimized!")
    
    with gr.Row():
        with gr.Column():
            prompt_input = gr.Textbox(
                label="Prompt",
                placeholder="Enter your text prompt here...",
                lines=3
            )
            
            with gr.Row():
                max_length = gr.Slider(
                    minimum=10,
                    maximum=200,
                    value=100,
                    step=10,
                    label="Max Length"
                )
                temperature = gr.Slider(
                    minimum=0.1,
                    maximum=2.0,
                    value=0.7,
                    step=0.1,
                    label="Temperature"
                )
            
            with gr.Row():
                top_p = gr.Slider(
                    minimum=0.1,
                    maximum=1.0,
                    value=0.9,
                    step=0.1,
                    label="Top-p"
                )
                top_k = gr.Slider(
                    minimum=1,
                    maximum=100,
                    value=50,
                    step=1,
                    label="Top-k"
                )
            
            generate_btn = gr.Button("Generate Text", variant="primary")
        
        with gr.Column():
            output_text = gr.Textbox(
                label="Generated Text",
                lines=10,
                placeholder="Generated text will appear here..."
            )
    
    # Examples
    gr.Examples(
        examples=[
            ["Once upon a time in a distant galaxy,"],
            ["The future of artificial intelligence is"],
            ["In the heart of the ancient forest,"],
            ["The detective walked into the room and noticed"],
        ],
        inputs=prompt_input
    )
    
    # Connect the function with explicit API endpoint name
    generate_btn.click(
        fn=generate_text,
        inputs=[prompt_input, max_length, temperature, top_p, top_k],
        outputs=output_text,
        api_name="/predict"  # Explicit API endpoint for external calls
    )


# Launch the app
if __name__ == "__main__":
    demo.launch()