sonyps1928
commited on
Commit
·
107fb80
1
Parent(s):
36fde64
update app22
Browse files
app.py
CHANGED
@@ -1,16 +1,21 @@
|
|
1 |
-
import
|
|
|
|
|
2 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
3 |
import torch
|
|
|
4 |
|
|
|
|
|
|
|
5 |
|
6 |
-
# Load model and tokenizer
|
7 |
-
|
8 |
-
|
|
|
9 |
model = GPT2LMHeadModel.from_pretrained(model_name)
|
10 |
-
|
11 |
-
|
12 |
-
# Set pad token
|
13 |
tokenizer.pad_token = tokenizer.eos_token
|
|
|
14 |
|
15 |
|
16 |
def generate_text(prompt, max_length=100, temperature=0.7, top_p=0.9, top_k=50):
|
@@ -23,7 +28,7 @@ def generate_text(prompt, max_length=100, temperature=0.7, top_p=0.9, top_k=50):
|
|
23 |
with torch.no_grad():
|
24 |
outputs = model.generate(
|
25 |
inputs,
|
26 |
-
max_length=min(max_length + len(inputs[0]), 512),
|
27 |
temperature=temperature,
|
28 |
top_p=top_p,
|
29 |
top_k=top_k,
|
@@ -39,83 +44,224 @@ def generate_text(prompt, max_length=100, temperature=0.7, top_p=0.9, top_k=50):
|
|
39 |
return generated_text[len(prompt):].strip()
|
40 |
|
41 |
except Exception as e:
|
42 |
-
|
|
|
43 |
|
44 |
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
step=0.1,
|
80 |
-
label="Top-p"
|
81 |
-
)
|
82 |
-
top_k = gr.Slider(
|
83 |
-
minimum=1,
|
84 |
-
maximum=100,
|
85 |
-
value=50,
|
86 |
-
step=1,
|
87 |
-
label="Top-k"
|
88 |
-
)
|
89 |
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
["The future of artificial intelligence is"],
|
104 |
-
["In the heart of the ancient forest,"],
|
105 |
-
["The detective walked into the room and noticed"],
|
106 |
-
],
|
107 |
-
inputs=prompt_input
|
108 |
-
)
|
109 |
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
api_name="/predict" # Explicit API endpoint for external calls
|
116 |
-
)
|
117 |
|
118 |
|
119 |
-
# Launch the app
|
120 |
if __name__ == "__main__":
|
121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from http.server import HTTPServer, BaseHTTPRequestHandler
|
3 |
+
from urllib.parse import urlparse, parse_qs
|
4 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
5 |
import torch
|
6 |
+
import logging
|
7 |
|
8 |
+
# Set up logging
|
9 |
+
logging.basicConfig(level=logging.INFO)
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
|
12 |
+
# Load model and tokenizer globally
|
13 |
+
logger.info("Loading GPT-2 model and tokenizer...")
|
14 |
+
model_name = "gpt2"
|
15 |
+
tokenizer = GPT2LMHeadModel.from_pretrained(model_name)
|
16 |
model = GPT2LMHeadModel.from_pretrained(model_name)
|
|
|
|
|
|
|
17 |
tokenizer.pad_token = tokenizer.eos_token
|
18 |
+
logger.info("Model loaded successfully!")
|
19 |
|
20 |
|
21 |
def generate_text(prompt, max_length=100, temperature=0.7, top_p=0.9, top_k=50):
|
|
|
28 |
with torch.no_grad():
|
29 |
outputs = model.generate(
|
30 |
inputs,
|
31 |
+
max_length=min(max_length + len(inputs[0]), 512),
|
32 |
temperature=temperature,
|
33 |
top_p=top_p,
|
34 |
top_k=top_k,
|
|
|
44 |
return generated_text[len(prompt):].strip()
|
45 |
|
46 |
except Exception as e:
|
47 |
+
logger.error(f"Error generating text: {str(e)}")
|
48 |
+
return f"Error: {str(e)}"
|
49 |
|
50 |
|
51 |
+
class GPT2Handler(BaseHTTPRequestHandler):
|
52 |
+
def _set_headers(self, content_type='application/json'):
|
53 |
+
self.send_response(200)
|
54 |
+
self.send_header('Content-type', content_type)
|
55 |
+
self.send_header('Access-Control-Allow-Origin', '*')
|
56 |
+
self.send_header('Access-Control-Allow-Methods', 'GET, POST, OPTIONS')
|
57 |
+
self.send_header('Access-Control-Allow-Headers', 'Content-Type')
|
58 |
+
self.end_headers()
|
59 |
+
|
60 |
+
def _send_error(self, code, message):
|
61 |
+
self.send_response(code)
|
62 |
+
self.send_header('Content-type', 'application/json')
|
63 |
+
self.end_headers()
|
64 |
+
response = {'error': message}
|
65 |
+
self.wfile.write(json.dumps(response).encode())
|
66 |
+
|
67 |
+
def do_OPTIONS(self):
|
68 |
+
self._set_headers()
|
69 |
+
|
70 |
+
def do_GET(self):
|
71 |
+
parsed_path = urlparse(self.path)
|
72 |
+
|
73 |
+
if parsed_path.path == '/':
|
74 |
+
# Serve a simple HTML interface
|
75 |
+
self._set_headers('text/html')
|
76 |
+
html = '''
|
77 |
+
<!DOCTYPE html>
|
78 |
+
<html>
|
79 |
+
<head>
|
80 |
+
<title>GPT-2 Text Generator</title>
|
81 |
+
<style>
|
82 |
+
body { font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; }
|
83 |
+
.container { margin: 20px 0; }
|
84 |
+
textarea, input, button { margin: 5px 0; padding: 8px; }
|
85 |
+
textarea { width: 100%; height: 100px; }
|
86 |
+
button { background: #007bff; color: white; border: none; padding: 10px 20px; cursor: pointer; }
|
87 |
+
button:hover { background: #0056b3; }
|
88 |
+
.output { background: #f8f9fa; padding: 15px; border-radius: 5px; min-height: 100px; }
|
89 |
+
.controls { display: grid; grid-template-columns: 1fr 1fr; gap: 10px; }
|
90 |
+
label { font-weight: bold; }
|
91 |
+
</style>
|
92 |
+
</head>
|
93 |
+
<body>
|
94 |
+
<h1>GPT-2 Text Generator</h1>
|
95 |
+
<p>Enter a prompt and generate text using GPT-2</p>
|
96 |
+
|
97 |
+
<div class="container">
|
98 |
+
<label for="prompt">Prompt:</label>
|
99 |
+
<textarea id="prompt" placeholder="Enter your text prompt here...">Once upon a time in a distant galaxy,</textarea>
|
100 |
+
</div>
|
101 |
+
|
102 |
+
<div class="controls">
|
103 |
+
<div>
|
104 |
+
<label for="maxLength">Max Length: <span id="maxLengthValue">100</span></label>
|
105 |
+
<input type="range" id="maxLength" min="10" max="200" value="100" step="10">
|
106 |
+
</div>
|
107 |
+
<div>
|
108 |
+
<label for="temperature">Temperature: <span id="temperatureValue">0.7</span></label>
|
109 |
+
<input type="range" id="temperature" min="0.1" max="2.0" value="0.7" step="0.1">
|
110 |
+
</div>
|
111 |
+
<div>
|
112 |
+
<label for="topP">Top-p: <span id="topPValue">0.9</span></label>
|
113 |
+
<input type="range" id="topP" min="0.1" max="1.0" value="0.9" step="0.1">
|
114 |
+
</div>
|
115 |
+
<div>
|
116 |
+
<label for="topK">Top-k: <span id="topKValue">50</span></label>
|
117 |
+
<input type="range" id="topK" min="1" max="100" value="50" step="1">
|
118 |
+
</div>
|
119 |
+
</div>
|
120 |
+
|
121 |
+
<div class="container">
|
122 |
+
<button onclick="generateText()" id="generateBtn">Generate Text</button>
|
123 |
+
</div>
|
124 |
+
|
125 |
+
<div class="container">
|
126 |
+
<label>Generated Text:</label>
|
127 |
+
<div id="output" class="output">Generated text will appear here...</div>
|
128 |
+
</div>
|
129 |
+
|
130 |
+
<script>
|
131 |
+
// Update slider value displays
|
132 |
+
document.getElementById('maxLength').oninput = function() {
|
133 |
+
document.getElementById('maxLengthValue').textContent = this.value;
|
134 |
+
}
|
135 |
+
document.getElementById('temperature').oninput = function() {
|
136 |
+
document.getElementById('temperatureValue').textContent = this.value;
|
137 |
+
}
|
138 |
+
document.getElementById('topP').oninput = function() {
|
139 |
+
document.getElementById('topPValue').textContent = this.value;
|
140 |
+
}
|
141 |
+
document.getElementById('topK').oninput = function() {
|
142 |
+
document.getElementById('topKValue').textContent = this.value;
|
143 |
+
}
|
144 |
+
|
145 |
+
async function generateText() {
|
146 |
+
const btn = document.getElementById('generateBtn');
|
147 |
+
const output = document.getElementById('output');
|
148 |
+
|
149 |
+
btn.disabled = true;
|
150 |
+
btn.textContent = 'Generating...';
|
151 |
+
output.textContent = 'Generating text...';
|
152 |
+
|
153 |
+
const data = {
|
154 |
+
prompt: document.getElementById('prompt').value,
|
155 |
+
max_length: parseInt(document.getElementById('maxLength').value),
|
156 |
+
temperature: parseFloat(document.getElementById('temperature').value),
|
157 |
+
top_p: parseFloat(document.getElementById('topP').value),
|
158 |
+
top_k: parseInt(document.getElementById('topK').value)
|
159 |
+
};
|
160 |
+
|
161 |
+
try {
|
162 |
+
const response = await fetch('/generate', {
|
163 |
+
method: 'POST',
|
164 |
+
headers: {'Content-Type': 'application/json'},
|
165 |
+
body: JSON.stringify(data)
|
166 |
+
});
|
167 |
+
|
168 |
+
const result = await response.json();
|
169 |
+
|
170 |
+
if (result.error) {
|
171 |
+
output.textContent = 'Error: ' + result.error;
|
172 |
+
} else {
|
173 |
+
output.textContent = result.generated_text;
|
174 |
+
}
|
175 |
+
} catch (error) {
|
176 |
+
output.textContent = 'Error: ' + error.message;
|
177 |
+
}
|
178 |
+
|
179 |
+
btn.disabled = false;
|
180 |
+
btn.textContent = 'Generate Text';
|
181 |
+
}
|
182 |
+
</script>
|
183 |
+
</body>
|
184 |
+
</html>
|
185 |
+
'''
|
186 |
+
self.wfile.write(html.encode())
|
187 |
|
188 |
+
elif parsed_path.path == '/health':
|
189 |
+
# Health check endpoint
|
190 |
+
self._set_headers()
|
191 |
+
response = {'status': 'healthy', 'model': model_name}
|
192 |
+
self.wfile.write(json.dumps(response).encode())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
|
194 |
+
else:
|
195 |
+
self._send_error(404, 'Not found')
|
196 |
+
|
197 |
+
def do_POST(self):
|
198 |
+
if self.path == '/generate':
|
199 |
+
try:
|
200 |
+
# Get request body
|
201 |
+
content_length = int(self.headers['Content-Length'])
|
202 |
+
post_data = self.rfile.read(content_length)
|
203 |
+
data = json.loads(post_data.decode())
|
204 |
+
|
205 |
+
# Extract parameters
|
206 |
+
prompt = data.get('prompt', '')
|
207 |
+
max_length = data.get('max_length', 100)
|
208 |
+
temperature = data.get('temperature', 0.7)
|
209 |
+
top_p = data.get('top_p', 0.9)
|
210 |
+
top_k = data.get('top_k', 50)
|
211 |
+
|
212 |
+
if not prompt:
|
213 |
+
self._send_error(400, 'Prompt is required')
|
214 |
+
return
|
215 |
+
|
216 |
+
# Generate text
|
217 |
+
logger.info(f"Generating text for prompt: {prompt[:50]}...")
|
218 |
+
generated_text = generate_text(prompt, max_length, temperature, top_p, top_k)
|
219 |
+
|
220 |
+
# Send response
|
221 |
+
self._set_headers()
|
222 |
+
response = {'generated_text': generated_text}
|
223 |
+
self.wfile.write(json.dumps(response).encode())
|
224 |
+
|
225 |
+
except json.JSONDecodeError:
|
226 |
+
self._send_error(400, 'Invalid JSON')
|
227 |
+
except Exception as e:
|
228 |
+
logger.error(f"Error in POST /generate: {str(e)}")
|
229 |
+
self._send_error(500, str(e))
|
230 |
+
else:
|
231 |
+
self._send_error(404, 'Not found')
|
232 |
+
|
233 |
+
def log_message(self, format, *args):
|
234 |
+
# Override to use our logger
|
235 |
+
logger.info(f"{self.address_string()} - {format % args}")
|
236 |
+
|
237 |
+
|
238 |
+
def run_server(host='localhost', port=8000):
|
239 |
+
"""Start the HTTP server"""
|
240 |
+
server_address = (host, port)
|
241 |
+
httpd = HTTPServer(server_address, GPT2Handler)
|
242 |
|
243 |
+
logger.info(f"Starting GPT-2 server on http://{host}:{port}")
|
244 |
+
logger.info(f"Web interface: http://{host}:{port}")
|
245 |
+
logger.info(f"API endpoint: http://{host}:{port}/generate")
|
246 |
+
logger.info(f"Health check: http://{host}:{port}/health")
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
|
248 |
+
try:
|
249 |
+
httpd.serve_forever()
|
250 |
+
except KeyboardInterrupt:
|
251 |
+
logger.info("Shutting down server...")
|
252 |
+
httpd.shutdown()
|
|
|
|
|
253 |
|
254 |
|
|
|
255 |
if __name__ == "__main__":
|
256 |
+
import sys
|
257 |
+
|
258 |
+
# Parse command line arguments
|
259 |
+
host = 'localhost'
|
260 |
+
port = 8000
|
261 |
+
|
262 |
+
if len(sys.argv) > 1:
|
263 |
+
port = int(sys.argv[1])
|
264 |
+
if len(sys.argv) > 2:
|
265 |
+
host = sys.argv[2]
|
266 |
+
|
267 |
+
run_server(host, port)
|
app1.txt
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
# Load model and tokenizer (using smaller GPT-2 for free tier)
|
7 |
+
model_name = "gpt2" # You can also use "gpt2-medium" if it fits in memory
|
8 |
+
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
|
9 |
+
model = GPT2LMHeadModel.from_pretrained(model_name)
|
10 |
+
|
11 |
+
|
12 |
+
# Set pad token
|
13 |
+
tokenizer.pad_token = tokenizer.eos_token
|
14 |
+
|
15 |
+
|
16 |
+
def generate_text(prompt, max_length=100, temperature=0.7, top_p=0.9, top_k=50):
|
17 |
+
"""Generate text using GPT-2"""
|
18 |
+
try:
|
19 |
+
# Encode input
|
20 |
+
inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True)
|
21 |
+
|
22 |
+
# Generate
|
23 |
+
with torch.no_grad():
|
24 |
+
outputs = model.generate(
|
25 |
+
inputs,
|
26 |
+
max_length=min(max_length + len(inputs[0]), 512), # Limit total length
|
27 |
+
temperature=temperature,
|
28 |
+
top_p=top_p,
|
29 |
+
top_k=top_k,
|
30 |
+
do_sample=True,
|
31 |
+
pad_token_id=tokenizer.eos_token_id,
|
32 |
+
num_return_sequences=1
|
33 |
+
)
|
34 |
+
|
35 |
+
# Decode output
|
36 |
+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
37 |
+
|
38 |
+
# Return only the new generated part
|
39 |
+
return generated_text[len(prompt):].strip()
|
40 |
+
|
41 |
+
except Exception as e:
|
42 |
+
return f"Error generating text: {str(e)}"
|
43 |
+
|
44 |
+
|
45 |
+
# Create Gradio interface
|
46 |
+
with gr.Blocks(title="GPT-2 Text Generator") as demo:
|
47 |
+
gr.Markdown("# GPT-2 Text Generation Server")
|
48 |
+
gr.Markdown("Enter a prompt and generate text using GPT-2. Free tier optimized!")
|
49 |
+
|
50 |
+
with gr.Row():
|
51 |
+
with gr.Column():
|
52 |
+
prompt_input = gr.Textbox(
|
53 |
+
label="Prompt",
|
54 |
+
placeholder="Enter your text prompt here...",
|
55 |
+
lines=3
|
56 |
+
)
|
57 |
+
|
58 |
+
with gr.Row():
|
59 |
+
max_length = gr.Slider(
|
60 |
+
minimum=10,
|
61 |
+
maximum=200,
|
62 |
+
value=100,
|
63 |
+
step=10,
|
64 |
+
label="Max Length"
|
65 |
+
)
|
66 |
+
temperature = gr.Slider(
|
67 |
+
minimum=0.1,
|
68 |
+
maximum=2.0,
|
69 |
+
value=0.7,
|
70 |
+
step=0.1,
|
71 |
+
label="Temperature"
|
72 |
+
)
|
73 |
+
|
74 |
+
with gr.Row():
|
75 |
+
top_p = gr.Slider(
|
76 |
+
minimum=0.1,
|
77 |
+
maximum=1.0,
|
78 |
+
value=0.9,
|
79 |
+
step=0.1,
|
80 |
+
label="Top-p"
|
81 |
+
)
|
82 |
+
top_k = gr.Slider(
|
83 |
+
minimum=1,
|
84 |
+
maximum=100,
|
85 |
+
value=50,
|
86 |
+
step=1,
|
87 |
+
label="Top-k"
|
88 |
+
)
|
89 |
+
|
90 |
+
generate_btn = gr.Button("Generate Text", variant="primary")
|
91 |
+
|
92 |
+
with gr.Column():
|
93 |
+
output_text = gr.Textbox(
|
94 |
+
label="Generated Text",
|
95 |
+
lines=10,
|
96 |
+
placeholder="Generated text will appear here..."
|
97 |
+
)
|
98 |
+
|
99 |
+
# Examples
|
100 |
+
gr.Examples(
|
101 |
+
examples=[
|
102 |
+
["Once upon a time in a distant galaxy,"],
|
103 |
+
["The future of artificial intelligence is"],
|
104 |
+
["In the heart of the ancient forest,"],
|
105 |
+
["The detective walked into the room and noticed"],
|
106 |
+
],
|
107 |
+
inputs=prompt_input
|
108 |
+
)
|
109 |
+
|
110 |
+
# Connect the function with explicit API endpoint name
|
111 |
+
generate_btn.click(
|
112 |
+
fn=generate_text,
|
113 |
+
inputs=[prompt_input, max_length, temperature, top_p, top_k],
|
114 |
+
outputs=output_text,
|
115 |
+
api_name="/predict" # Explicit API endpoint for external calls
|
116 |
+
)
|
117 |
+
|
118 |
+
|
119 |
+
# Launch the app
|
120 |
+
if __name__ == "__main__":
|
121 |
+
demo.launch()
|