Update app.py
Browse files
app.py
CHANGED
@@ -1,55 +1,127 @@
|
|
1 |
from flask import Flask, request, jsonify
|
2 |
from flask_cors import CORS
|
3 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
|
4 |
import torch
|
5 |
import os
|
6 |
import json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
app = Flask(__name__)
|
9 |
-
CORS(app)
|
10 |
|
11 |
-
#
|
12 |
-
|
|
|
|
|
13 |
|
14 |
-
#
|
15 |
-
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
print(f"Loaded {len(train_texts)} examples from {data_file}")
|
36 |
|
37 |
-
#
|
38 |
-
|
39 |
|
40 |
@app.route('/')
|
41 |
def home():
|
42 |
-
"""Root endpoint
|
43 |
return jsonify({
|
44 |
'status': 'SEAL Framework API is running',
|
45 |
'version': '1.0.0',
|
46 |
-
'model':
|
47 |
-
'
|
|
|
48 |
'training_examples': len(train_texts),
|
49 |
'endpoints': {
|
50 |
'/': 'GET - API status and information',
|
51 |
'/adapt': 'POST - Adaptive model training and response',
|
52 |
-
'/health': 'GET - Health check'
|
|
|
53 |
},
|
54 |
'usage': {
|
55 |
'adapt_endpoint': {
|
@@ -61,122 +133,211 @@ def home():
|
|
61 |
}
|
62 |
})
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
@app.route('/health')
|
65 |
def health():
|
66 |
-
"""
|
67 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
# Simple model test
|
69 |
test_input = "Health check"
|
70 |
-
|
71 |
-
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
return jsonify({
|
75 |
'status': 'healthy',
|
76 |
'model_loaded': True,
|
77 |
'device': str(device),
|
78 |
-
'training_examples': len(train_texts)
|
|
|
79 |
})
|
|
|
80 |
except Exception as e:
|
|
|
81 |
return jsonify({
|
82 |
'status': 'unhealthy',
|
83 |
-
'error': str(e)
|
|
|
84 |
}), 500
|
85 |
|
86 |
@app.route('/adapt', methods=['POST'])
|
87 |
def adapt_model():
|
|
|
|
|
|
|
88 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
data = request.json
|
90 |
-
|
91 |
-
|
|
|
|
|
92 |
if not user_input:
|
93 |
-
return jsonify({'error': '
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
return jsonify({
|
151 |
'input': user_input,
|
152 |
'self_edit': self_edit,
|
153 |
'response': response,
|
154 |
'training_examples': len(train_texts),
|
155 |
-
'status': '
|
|
|
156 |
})
|
157 |
-
|
158 |
except Exception as e:
|
159 |
-
|
|
|
|
|
|
|
|
|
|
|
160 |
|
161 |
@app.errorhandler(404)
|
162 |
def not_found(error):
|
163 |
-
"""Custom 404 handler"""
|
164 |
return jsonify({
|
165 |
'error': 'Endpoint not found',
|
166 |
-
'available_endpoints':
|
167 |
-
'/': 'GET - API information',
|
168 |
-
'/health': 'GET - Health check',
|
169 |
-
'/adapt': 'POST - Adaptive model training'
|
170 |
-
}
|
171 |
}), 404
|
172 |
|
173 |
@app.errorhandler(500)
|
174 |
def internal_error(error):
|
175 |
-
"""Custom 500 handler"""
|
176 |
return jsonify({
|
177 |
'error': 'Internal server error',
|
178 |
-
'message': '
|
179 |
}), 500
|
180 |
|
|
|
181 |
if __name__ == '__main__':
|
182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from flask import Flask, request, jsonify
|
2 |
from flask_cors import CORS
|
|
|
3 |
import torch
|
4 |
import os
|
5 |
import json
|
6 |
+
import logging
|
7 |
+
import gc
|
8 |
+
from contextlib import contextmanager
|
9 |
+
|
10 |
+
# Set up logging
|
11 |
+
logging.basicConfig(level=logging.INFO)
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
|
14 |
app = Flask(__name__)
|
15 |
+
CORS(app)
|
16 |
|
17 |
+
# Global variables for model and tokenizer
|
18 |
+
model = None
|
19 |
+
tokenizer = None
|
20 |
+
device = None
|
21 |
|
22 |
+
# Configuration
|
23 |
+
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
24 |
+
DATA_FILE = "data/train_data.json"
|
25 |
+
MODEL_SAVE_DIR = "./results/model"
|
26 |
+
|
27 |
+
# Set environment variables
|
28 |
+
os.environ["HF_HOME"] = "/data/.huggingface"
|
29 |
+
os.environ["TRANSFORMERS_CACHE"] = "/data/.huggingface"
|
30 |
|
31 |
+
def initialize_model():
|
32 |
+
"""Initialize model and tokenizer with error handling"""
|
33 |
+
global model, tokenizer, device
|
34 |
+
|
35 |
+
try:
|
36 |
+
logger.info("Initializing model and tokenizer...")
|
37 |
+
|
38 |
+
# Set device
|
39 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
40 |
+
logger.info(f"Using device: {device}")
|
41 |
+
|
42 |
+
# Import here to avoid import errors during startup
|
43 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
44 |
+
|
45 |
+
# Load tokenizer first (lighter)
|
46 |
+
logger.info("Loading tokenizer...")
|
47 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
48 |
+
MODEL_NAME,
|
49 |
+
trust_remote_code=True,
|
50 |
+
cache_dir="/data/.huggingface"
|
51 |
+
)
|
52 |
+
|
53 |
+
# Add padding token if it doesn't exist
|
54 |
+
if tokenizer.pad_token is None:
|
55 |
+
tokenizer.pad_token = tokenizer.eos_token
|
56 |
+
|
57 |
+
logger.info("Loading model...")
|
58 |
+
# Load model with specific configuration for stability
|
59 |
+
model = AutoModelForCausalLM.from_pretrained(
|
60 |
+
MODEL_NAME,
|
61 |
+
torch_dtype=torch.float16 if device.type == "cuda" else torch.float32,
|
62 |
+
device_map="auto" if device.type == "cuda" else None,
|
63 |
+
trust_remote_code=True,
|
64 |
+
cache_dir="/data/.huggingface",
|
65 |
+
low_cpu_mem_usage=True
|
66 |
+
)
|
67 |
+
|
68 |
+
# Move to device if not using device_map
|
69 |
+
if device.type == "cpu":
|
70 |
+
model = model.to(device)
|
71 |
+
|
72 |
+
logger.info("Model initialization completed successfully")
|
73 |
+
return True
|
74 |
+
|
75 |
+
except Exception as e:
|
76 |
+
logger.error(f"Failed to initialize model: {str(e)}")
|
77 |
+
return False
|
78 |
|
79 |
+
def load_training_data():
|
80 |
+
"""Load or initialize training data"""
|
81 |
+
try:
|
82 |
+
if os.path.exists(DATA_FILE):
|
83 |
+
with open(DATA_FILE, 'r') as f:
|
84 |
+
train_texts = json.load(f)
|
85 |
+
else:
|
86 |
+
train_texts = []
|
87 |
+
os.makedirs(os.path.dirname(DATA_FILE), exist_ok=True)
|
88 |
+
with open(DATA_FILE, 'w') as f:
|
89 |
+
json.dump(train_texts, f)
|
90 |
+
|
91 |
+
logger.info(f"Loaded {len(train_texts)} training examples")
|
92 |
+
return train_texts
|
93 |
+
except Exception as e:
|
94 |
+
logger.error(f"Error loading training data: {str(e)}")
|
95 |
+
return []
|
96 |
|
97 |
+
@contextmanager
|
98 |
+
def torch_no_grad():
|
99 |
+
"""Context manager for torch.no_grad with error handling"""
|
100 |
+
try:
|
101 |
+
with torch.no_grad():
|
102 |
+
yield
|
103 |
+
except Exception as e:
|
104 |
+
logger.error(f"Error in torch context: {str(e)}")
|
105 |
+
raise
|
|
|
106 |
|
107 |
+
# Initialize data
|
108 |
+
train_texts = load_training_data()
|
109 |
|
110 |
@app.route('/')
|
111 |
def home():
|
112 |
+
"""Root endpoint with system information"""
|
113 |
return jsonify({
|
114 |
'status': 'SEAL Framework API is running',
|
115 |
'version': '1.0.0',
|
116 |
+
'model': MODEL_NAME,
|
117 |
+
'model_loaded': model is not None,
|
118 |
+
'device': str(device) if device else 'Not initialized',
|
119 |
'training_examples': len(train_texts),
|
120 |
'endpoints': {
|
121 |
'/': 'GET - API status and information',
|
122 |
'/adapt': 'POST - Adaptive model training and response',
|
123 |
+
'/health': 'GET - Health check',
|
124 |
+
'/init': 'POST - Initialize model (if not already loaded)'
|
125 |
},
|
126 |
'usage': {
|
127 |
'adapt_endpoint': {
|
|
|
133 |
}
|
134 |
})
|
135 |
|
136 |
+
@app.route('/init', methods=['POST'])
|
137 |
+
def init_model():
|
138 |
+
"""Manual model initialization endpoint"""
|
139 |
+
global model, tokenizer
|
140 |
+
|
141 |
+
if model is not None:
|
142 |
+
return jsonify({'status': 'Model already initialized', 'success': True})
|
143 |
+
|
144 |
+
success = initialize_model()
|
145 |
+
if success:
|
146 |
+
return jsonify({'status': 'Model initialized successfully', 'success': True})
|
147 |
+
else:
|
148 |
+
return jsonify({'status': 'Model initialization failed', 'success': False}), 500
|
149 |
+
|
150 |
@app.route('/health')
|
151 |
def health():
|
152 |
+
"""Comprehensive health check"""
|
153 |
try:
|
154 |
+
# Check if model is loaded
|
155 |
+
if model is None or tokenizer is None:
|
156 |
+
return jsonify({
|
157 |
+
'status': 'unhealthy',
|
158 |
+
'error': 'Model not initialized',
|
159 |
+
'model_loaded': False,
|
160 |
+
'suggestion': 'Call /init endpoint to initialize model'
|
161 |
+
}), 500
|
162 |
+
|
163 |
# Simple model test
|
164 |
test_input = "Health check"
|
165 |
+
try:
|
166 |
+
with torch_no_grad():
|
167 |
+
inputs = tokenizer(
|
168 |
+
test_input,
|
169 |
+
return_tensors="pt",
|
170 |
+
truncation=True,
|
171 |
+
max_length=32,
|
172 |
+
padding=True
|
173 |
+
).to(device)
|
174 |
+
|
175 |
+
outputs = model.generate(
|
176 |
+
**inputs,
|
177 |
+
max_length=40,
|
178 |
+
num_return_sequences=1,
|
179 |
+
do_sample=False,
|
180 |
+
pad_token_id=tokenizer.pad_token_id
|
181 |
+
)
|
182 |
+
except Exception as e:
|
183 |
+
raise Exception(f"Model inference failed: {str(e)}")
|
184 |
|
185 |
return jsonify({
|
186 |
'status': 'healthy',
|
187 |
'model_loaded': True,
|
188 |
'device': str(device),
|
189 |
+
'training_examples': len(train_texts),
|
190 |
+
'torch_version': torch.__version__
|
191 |
})
|
192 |
+
|
193 |
except Exception as e:
|
194 |
+
logger.error(f"Health check failed: {str(e)}")
|
195 |
return jsonify({
|
196 |
'status': 'unhealthy',
|
197 |
+
'error': str(e),
|
198 |
+
'model_loaded': model is not None
|
199 |
}), 500
|
200 |
|
201 |
@app.route('/adapt', methods=['POST'])
|
202 |
def adapt_model():
|
203 |
+
"""Simplified adaptive model endpoint"""
|
204 |
+
global train_texts
|
205 |
+
|
206 |
try:
|
207 |
+
# Check if model is initialized
|
208 |
+
if model is None or tokenizer is None:
|
209 |
+
return jsonify({
|
210 |
+
'error': 'Model not initialized. Call /init endpoint first.',
|
211 |
+
'suggestion': 'POST to /init to initialize the model'
|
212 |
+
}), 500
|
213 |
+
|
214 |
+
# Get input
|
215 |
data = request.json
|
216 |
+
if not data or 'text' not in data:
|
217 |
+
return jsonify({'error': 'No text provided in request body'}), 400
|
218 |
+
|
219 |
+
user_input = data['text'].strip()
|
220 |
if not user_input:
|
221 |
+
return jsonify({'error': 'Empty text provided'}), 400
|
222 |
+
|
223 |
+
logger.info(f"Processing input: {user_input[:50]}...")
|
224 |
+
|
225 |
+
# Generate self-edit (simplified approach)
|
226 |
+
try:
|
227 |
+
with torch_no_grad():
|
228 |
+
prompt = f"Rephrase this text: {user_input}"
|
229 |
+
inputs = tokenizer(
|
230 |
+
prompt,
|
231 |
+
return_tensors="pt",
|
232 |
+
truncation=True,
|
233 |
+
max_length=128,
|
234 |
+
padding=True
|
235 |
+
).to(device)
|
236 |
+
|
237 |
+
self_edit_output = model.generate(
|
238 |
+
**inputs,
|
239 |
+
max_length=200,
|
240 |
+
num_return_sequences=1,
|
241 |
+
do_sample=True,
|
242 |
+
temperature=0.7,
|
243 |
+
pad_token_id=tokenizer.pad_token_id
|
244 |
+
)
|
245 |
+
|
246 |
+
self_edit = tokenizer.decode(
|
247 |
+
self_edit_output[0],
|
248 |
+
skip_special_tokens=True
|
249 |
+
).replace(prompt, "").strip()
|
250 |
+
|
251 |
+
except Exception as e:
|
252 |
+
logger.error(f"Self-edit generation failed: {str(e)}")
|
253 |
+
self_edit = f"Self-edit failed: {str(e)}"
|
254 |
+
|
255 |
+
# Generate response (simplified)
|
256 |
+
try:
|
257 |
+
with torch_no_grad():
|
258 |
+
response_inputs = tokenizer(
|
259 |
+
user_input,
|
260 |
+
return_tensors="pt",
|
261 |
+
truncation=True,
|
262 |
+
max_length=128,
|
263 |
+
padding=True
|
264 |
+
).to(device)
|
265 |
+
|
266 |
+
response_output = model.generate(
|
267 |
+
**response_inputs,
|
268 |
+
max_length=256,
|
269 |
+
num_return_sequences=1,
|
270 |
+
do_sample=True,
|
271 |
+
temperature=0.8,
|
272 |
+
pad_token_id=tokenizer.pad_token_id
|
273 |
+
)
|
274 |
+
|
275 |
+
response = tokenizer.decode(
|
276 |
+
response_output[0],
|
277 |
+
skip_special_tokens=True
|
278 |
+
).replace(user_input, "").strip()
|
279 |
+
|
280 |
+
except Exception as e:
|
281 |
+
logger.error(f"Response generation failed: {str(e)}")
|
282 |
+
response = f"Response generation failed: {str(e)}"
|
283 |
+
|
284 |
+
# Save training data (simplified - no actual fine-tuning for stability)
|
285 |
+
try:
|
286 |
+
train_texts.append({
|
287 |
+
"prompt": user_input,
|
288 |
+
"completion": self_edit,
|
289 |
+
"timestamp": str(torch.now() if hasattr(torch, 'now') else 'unknown')
|
290 |
+
})
|
291 |
+
|
292 |
+
# Save to file
|
293 |
+
with open(DATA_FILE, 'w') as f:
|
294 |
+
json.dump(train_texts, f, indent=2)
|
295 |
+
|
296 |
+
except Exception as e:
|
297 |
+
logger.error(f"Failed to save training data: {str(e)}")
|
298 |
+
|
299 |
+
# Clean up GPU memory
|
300 |
+
if device.type == "cuda":
|
301 |
+
torch.cuda.empty_cache()
|
302 |
+
gc.collect()
|
303 |
+
|
304 |
return jsonify({
|
305 |
'input': user_input,
|
306 |
'self_edit': self_edit,
|
307 |
'response': response,
|
308 |
'training_examples': len(train_texts),
|
309 |
+
'status': 'Processing completed successfully',
|
310 |
+
'note': 'Fine-tuning disabled for stability - using generation only'
|
311 |
})
|
312 |
+
|
313 |
except Exception as e:
|
314 |
+
logger.error(f"Adapt endpoint error: {str(e)}")
|
315 |
+
return jsonify({
|
316 |
+
'error': str(e),
|
317 |
+
'type': type(e).__name__,
|
318 |
+
'suggestion': 'Check logs for detailed error information'
|
319 |
+
}), 500
|
320 |
|
321 |
@app.errorhandler(404)
|
322 |
def not_found(error):
|
|
|
323 |
return jsonify({
|
324 |
'error': 'Endpoint not found',
|
325 |
+
'available_endpoints': ['/health', '/adapt', '/init', '/']
|
|
|
|
|
|
|
|
|
326 |
}), 404
|
327 |
|
328 |
@app.errorhandler(500)
|
329 |
def internal_error(error):
|
|
|
330 |
return jsonify({
|
331 |
'error': 'Internal server error',
|
332 |
+
'message': 'Check server logs for details'
|
333 |
}), 500
|
334 |
|
335 |
+
# Initialize model on startup (with fallback)
|
336 |
if __name__ == '__main__':
|
337 |
+
logger.info("Starting SEAL Framework API...")
|
338 |
+
initialize_model()
|
339 |
+
app.run(host='0.0.0.0', port=7860, debug=False)
|
340 |
+
else:
|
341 |
+
# For production deployment
|
342 |
+
logger.info("SEAL Framework API starting in production mode...")
|
343 |
+
# Don't initialize model immediately in production to avoid startup timeouts
|