mike23415 commited on
Commit
75cbe07
·
verified ·
1 Parent(s): 3bd81ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +267 -106
app.py CHANGED
@@ -1,55 +1,127 @@
1
  from flask import Flask, request, jsonify
2
  from flask_cors import CORS
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
4
  import torch
5
  import os
6
  import json
 
 
 
 
 
 
 
7
 
8
  app = Flask(__name__)
9
- CORS(app) # Enable CORS for all routes
10
 
11
- # Set Hugging Face cache to ephemeral storage
12
- os.environ["HF_HOME"] = "/data/.huggingface"
 
 
13
 
14
- # Load Qwen2.5-1.5B model and tokenizer
15
- model_name = "Qwen/Qwen2.5-1.5B-Instruct"
16
- tokenizer = AutoTokenizer.from_pretrained(model_name)
17
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
 
 
 
 
18
 
19
- # Move to GPU if available
20
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
- model.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- # Data file for preloaded and dynamic data
24
- data_file = "data/train_data.json"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- # Load or initialize dataset
27
- if os.path.exists(data_file):
28
- with open(data_file, 'r') as f:
29
- train_texts = json.load(f)
30
- else:
31
- train_texts = []
32
- os.makedirs(os.path.dirname(data_file), exist_ok=True)
33
- with open(data_file, 'w') as f:
34
- json.dump(train_texts, f)
35
- print(f"Loaded {len(train_texts)} examples from {data_file}")
36
 
37
- # Model save directory
38
- model_save_dir = "./results/model"
39
 
40
  @app.route('/')
41
  def home():
42
- """Root endpoint to show API status and usage"""
43
  return jsonify({
44
  'status': 'SEAL Framework API is running',
45
  'version': '1.0.0',
46
- 'model': model_name,
47
- 'device': str(device),
 
48
  'training_examples': len(train_texts),
49
  'endpoints': {
50
  '/': 'GET - API status and information',
51
  '/adapt': 'POST - Adaptive model training and response',
52
- '/health': 'GET - Health check'
 
53
  },
54
  'usage': {
55
  'adapt_endpoint': {
@@ -61,122 +133,211 @@ def home():
61
  }
62
  })
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  @app.route('/health')
65
  def health():
66
- """Health check endpoint"""
67
  try:
 
 
 
 
 
 
 
 
 
68
  # Simple model test
69
  test_input = "Health check"
70
- inputs = tokenizer(test_input, return_tensors="pt", truncation=True, max_length=32).to(device)
71
- with torch.no_grad():
72
- outputs = model.generate(**inputs, max_length=40, num_return_sequences=1, do_sample=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  return jsonify({
75
  'status': 'healthy',
76
  'model_loaded': True,
77
  'device': str(device),
78
- 'training_examples': len(train_texts)
 
79
  })
 
80
  except Exception as e:
 
81
  return jsonify({
82
  'status': 'unhealthy',
83
- 'error': str(e)
 
84
  }), 500
85
 
86
  @app.route('/adapt', methods=['POST'])
87
  def adapt_model():
 
 
 
88
  try:
 
 
 
 
 
 
 
 
89
  data = request.json
90
- user_input = data.get('text', '')
91
-
 
 
92
  if not user_input:
93
- return jsonify({'error': 'No input provided'}), 400
94
-
95
- # Generate self-edit
96
- prompt = f"Rephrase this: {user_input}"
97
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128).to(device)
98
- self_edit_output = model.generate(**inputs, max_length=150, num_return_sequences=1)
99
- self_edit = tokenizer.decode(self_edit_output[0], skip_special_tokens=True)
100
-
101
- # Add to training data and save to disk
102
- train_texts.append({"prompt": user_input, "completion": self_edit})
103
- with open(data_file, 'w') as f:
104
- json.dump(train_texts, f, indent=2)
105
-
106
- # Prepare dataset for fine-tuning
107
- encodings = tokenizer(
108
- [t["prompt"] + " " + t["completion"] for t in train_texts],
109
- truncation=True,
110
- padding=True,
111
- max_length=256,
112
- return_tensors="pt"
113
- )
114
- dataset = [
115
- {
116
- "input_ids": encodings["input_ids"][i],
117
- "attention_mask": encodings["attention_mask"][i],
118
- "labels": encodings["input_ids"][i]
119
- } for i in range(len(train_texts))
120
- ]
121
-
122
- # Fine-tune model
123
- training_args = TrainingArguments(
124
- output_dir=model_save_dir,
125
- num_train_epochs=1,
126
- per_device_train_batch_size=2,
127
- gradient_accumulation_steps=4,
128
- logging_steps=10,
129
- save_steps=10,
130
- save_total_limit=1, # Keep only latest checkpoint
131
- disable_tqdm=True,
132
- fp16=True if torch.cuda.is_available() else False
133
- )
134
- trainer = Trainer(
135
- model=model,
136
- args=training_args,
137
- train_dataset=dataset
138
- )
139
- trainer.train()
140
-
141
- # Save model weights
142
- trainer.save_model(model_save_dir)
143
- tokenizer.save_pretrained(model_save_dir)
144
-
145
- # Generate response
146
- response_inputs = tokenizer(user_input, return_tensors="pt", truncation=True, max_length=128).to(device)
147
- response_output = model.generate(**response_inputs, max_length=200, num_return_sequences=1)
148
- response = tokenizer.decode(response_output[0], skip_special_tokens=True)
149
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  return jsonify({
151
  'input': user_input,
152
  'self_edit': self_edit,
153
  'response': response,
154
  'training_examples': len(train_texts),
155
- 'status': 'Model adapted successfully'
 
156
  })
157
-
158
  except Exception as e:
159
- return jsonify({'error': str(e)}), 500
 
 
 
 
 
160
 
161
  @app.errorhandler(404)
162
  def not_found(error):
163
- """Custom 404 handler"""
164
  return jsonify({
165
  'error': 'Endpoint not found',
166
- 'available_endpoints': {
167
- '/': 'GET - API information',
168
- '/health': 'GET - Health check',
169
- '/adapt': 'POST - Adaptive model training'
170
- }
171
  }), 404
172
 
173
  @app.errorhandler(500)
174
  def internal_error(error):
175
- """Custom 500 handler"""
176
  return jsonify({
177
  'error': 'Internal server error',
178
- 'message': 'Please check the server logs for more details'
179
  }), 500
180
 
 
181
  if __name__ == '__main__':
182
- app.run(host='0.0.0.0', port=7860, debug=False)
 
 
 
 
 
 
 
1
  from flask import Flask, request, jsonify
2
  from flask_cors import CORS
 
3
  import torch
4
  import os
5
  import json
6
+ import logging
7
+ import gc
8
+ from contextlib import contextmanager
9
+
10
+ # Set up logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
 
14
  app = Flask(__name__)
15
+ CORS(app)
16
 
17
+ # Global variables for model and tokenizer
18
+ model = None
19
+ tokenizer = None
20
+ device = None
21
 
22
+ # Configuration
23
+ MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
24
+ DATA_FILE = "data/train_data.json"
25
+ MODEL_SAVE_DIR = "./results/model"
26
+
27
+ # Set environment variables
28
+ os.environ["HF_HOME"] = "/data/.huggingface"
29
+ os.environ["TRANSFORMERS_CACHE"] = "/data/.huggingface"
30
 
31
+ def initialize_model():
32
+ """Initialize model and tokenizer with error handling"""
33
+ global model, tokenizer, device
34
+
35
+ try:
36
+ logger.info("Initializing model and tokenizer...")
37
+
38
+ # Set device
39
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
+ logger.info(f"Using device: {device}")
41
+
42
+ # Import here to avoid import errors during startup
43
+ from transformers import AutoModelForCausalLM, AutoTokenizer
44
+
45
+ # Load tokenizer first (lighter)
46
+ logger.info("Loading tokenizer...")
47
+ tokenizer = AutoTokenizer.from_pretrained(
48
+ MODEL_NAME,
49
+ trust_remote_code=True,
50
+ cache_dir="/data/.huggingface"
51
+ )
52
+
53
+ # Add padding token if it doesn't exist
54
+ if tokenizer.pad_token is None:
55
+ tokenizer.pad_token = tokenizer.eos_token
56
+
57
+ logger.info("Loading model...")
58
+ # Load model with specific configuration for stability
59
+ model = AutoModelForCausalLM.from_pretrained(
60
+ MODEL_NAME,
61
+ torch_dtype=torch.float16 if device.type == "cuda" else torch.float32,
62
+ device_map="auto" if device.type == "cuda" else None,
63
+ trust_remote_code=True,
64
+ cache_dir="/data/.huggingface",
65
+ low_cpu_mem_usage=True
66
+ )
67
+
68
+ # Move to device if not using device_map
69
+ if device.type == "cpu":
70
+ model = model.to(device)
71
+
72
+ logger.info("Model initialization completed successfully")
73
+ return True
74
+
75
+ except Exception as e:
76
+ logger.error(f"Failed to initialize model: {str(e)}")
77
+ return False
78
 
79
+ def load_training_data():
80
+ """Load or initialize training data"""
81
+ try:
82
+ if os.path.exists(DATA_FILE):
83
+ with open(DATA_FILE, 'r') as f:
84
+ train_texts = json.load(f)
85
+ else:
86
+ train_texts = []
87
+ os.makedirs(os.path.dirname(DATA_FILE), exist_ok=True)
88
+ with open(DATA_FILE, 'w') as f:
89
+ json.dump(train_texts, f)
90
+
91
+ logger.info(f"Loaded {len(train_texts)} training examples")
92
+ return train_texts
93
+ except Exception as e:
94
+ logger.error(f"Error loading training data: {str(e)}")
95
+ return []
96
 
97
+ @contextmanager
98
+ def torch_no_grad():
99
+ """Context manager for torch.no_grad with error handling"""
100
+ try:
101
+ with torch.no_grad():
102
+ yield
103
+ except Exception as e:
104
+ logger.error(f"Error in torch context: {str(e)}")
105
+ raise
 
106
 
107
+ # Initialize data
108
+ train_texts = load_training_data()
109
 
110
  @app.route('/')
111
  def home():
112
+ """Root endpoint with system information"""
113
  return jsonify({
114
  'status': 'SEAL Framework API is running',
115
  'version': '1.0.0',
116
+ 'model': MODEL_NAME,
117
+ 'model_loaded': model is not None,
118
+ 'device': str(device) if device else 'Not initialized',
119
  'training_examples': len(train_texts),
120
  'endpoints': {
121
  '/': 'GET - API status and information',
122
  '/adapt': 'POST - Adaptive model training and response',
123
+ '/health': 'GET - Health check',
124
+ '/init': 'POST - Initialize model (if not already loaded)'
125
  },
126
  'usage': {
127
  'adapt_endpoint': {
 
133
  }
134
  })
135
 
136
+ @app.route('/init', methods=['POST'])
137
+ def init_model():
138
+ """Manual model initialization endpoint"""
139
+ global model, tokenizer
140
+
141
+ if model is not None:
142
+ return jsonify({'status': 'Model already initialized', 'success': True})
143
+
144
+ success = initialize_model()
145
+ if success:
146
+ return jsonify({'status': 'Model initialized successfully', 'success': True})
147
+ else:
148
+ return jsonify({'status': 'Model initialization failed', 'success': False}), 500
149
+
150
  @app.route('/health')
151
  def health():
152
+ """Comprehensive health check"""
153
  try:
154
+ # Check if model is loaded
155
+ if model is None or tokenizer is None:
156
+ return jsonify({
157
+ 'status': 'unhealthy',
158
+ 'error': 'Model not initialized',
159
+ 'model_loaded': False,
160
+ 'suggestion': 'Call /init endpoint to initialize model'
161
+ }), 500
162
+
163
  # Simple model test
164
  test_input = "Health check"
165
+ try:
166
+ with torch_no_grad():
167
+ inputs = tokenizer(
168
+ test_input,
169
+ return_tensors="pt",
170
+ truncation=True,
171
+ max_length=32,
172
+ padding=True
173
+ ).to(device)
174
+
175
+ outputs = model.generate(
176
+ **inputs,
177
+ max_length=40,
178
+ num_return_sequences=1,
179
+ do_sample=False,
180
+ pad_token_id=tokenizer.pad_token_id
181
+ )
182
+ except Exception as e:
183
+ raise Exception(f"Model inference failed: {str(e)}")
184
 
185
  return jsonify({
186
  'status': 'healthy',
187
  'model_loaded': True,
188
  'device': str(device),
189
+ 'training_examples': len(train_texts),
190
+ 'torch_version': torch.__version__
191
  })
192
+
193
  except Exception as e:
194
+ logger.error(f"Health check failed: {str(e)}")
195
  return jsonify({
196
  'status': 'unhealthy',
197
+ 'error': str(e),
198
+ 'model_loaded': model is not None
199
  }), 500
200
 
201
  @app.route('/adapt', methods=['POST'])
202
  def adapt_model():
203
+ """Simplified adaptive model endpoint"""
204
+ global train_texts
205
+
206
  try:
207
+ # Check if model is initialized
208
+ if model is None or tokenizer is None:
209
+ return jsonify({
210
+ 'error': 'Model not initialized. Call /init endpoint first.',
211
+ 'suggestion': 'POST to /init to initialize the model'
212
+ }), 500
213
+
214
+ # Get input
215
  data = request.json
216
+ if not data or 'text' not in data:
217
+ return jsonify({'error': 'No text provided in request body'}), 400
218
+
219
+ user_input = data['text'].strip()
220
  if not user_input:
221
+ return jsonify({'error': 'Empty text provided'}), 400
222
+
223
+ logger.info(f"Processing input: {user_input[:50]}...")
224
+
225
+ # Generate self-edit (simplified approach)
226
+ try:
227
+ with torch_no_grad():
228
+ prompt = f"Rephrase this text: {user_input}"
229
+ inputs = tokenizer(
230
+ prompt,
231
+ return_tensors="pt",
232
+ truncation=True,
233
+ max_length=128,
234
+ padding=True
235
+ ).to(device)
236
+
237
+ self_edit_output = model.generate(
238
+ **inputs,
239
+ max_length=200,
240
+ num_return_sequences=1,
241
+ do_sample=True,
242
+ temperature=0.7,
243
+ pad_token_id=tokenizer.pad_token_id
244
+ )
245
+
246
+ self_edit = tokenizer.decode(
247
+ self_edit_output[0],
248
+ skip_special_tokens=True
249
+ ).replace(prompt, "").strip()
250
+
251
+ except Exception as e:
252
+ logger.error(f"Self-edit generation failed: {str(e)}")
253
+ self_edit = f"Self-edit failed: {str(e)}"
254
+
255
+ # Generate response (simplified)
256
+ try:
257
+ with torch_no_grad():
258
+ response_inputs = tokenizer(
259
+ user_input,
260
+ return_tensors="pt",
261
+ truncation=True,
262
+ max_length=128,
263
+ padding=True
264
+ ).to(device)
265
+
266
+ response_output = model.generate(
267
+ **response_inputs,
268
+ max_length=256,
269
+ num_return_sequences=1,
270
+ do_sample=True,
271
+ temperature=0.8,
272
+ pad_token_id=tokenizer.pad_token_id
273
+ )
274
+
275
+ response = tokenizer.decode(
276
+ response_output[0],
277
+ skip_special_tokens=True
278
+ ).replace(user_input, "").strip()
279
+
280
+ except Exception as e:
281
+ logger.error(f"Response generation failed: {str(e)}")
282
+ response = f"Response generation failed: {str(e)}"
283
+
284
+ # Save training data (simplified - no actual fine-tuning for stability)
285
+ try:
286
+ train_texts.append({
287
+ "prompt": user_input,
288
+ "completion": self_edit,
289
+ "timestamp": str(torch.now() if hasattr(torch, 'now') else 'unknown')
290
+ })
291
+
292
+ # Save to file
293
+ with open(DATA_FILE, 'w') as f:
294
+ json.dump(train_texts, f, indent=2)
295
+
296
+ except Exception as e:
297
+ logger.error(f"Failed to save training data: {str(e)}")
298
+
299
+ # Clean up GPU memory
300
+ if device.type == "cuda":
301
+ torch.cuda.empty_cache()
302
+ gc.collect()
303
+
304
  return jsonify({
305
  'input': user_input,
306
  'self_edit': self_edit,
307
  'response': response,
308
  'training_examples': len(train_texts),
309
+ 'status': 'Processing completed successfully',
310
+ 'note': 'Fine-tuning disabled for stability - using generation only'
311
  })
312
+
313
  except Exception as e:
314
+ logger.error(f"Adapt endpoint error: {str(e)}")
315
+ return jsonify({
316
+ 'error': str(e),
317
+ 'type': type(e).__name__,
318
+ 'suggestion': 'Check logs for detailed error information'
319
+ }), 500
320
 
321
  @app.errorhandler(404)
322
  def not_found(error):
 
323
  return jsonify({
324
  'error': 'Endpoint not found',
325
+ 'available_endpoints': ['/health', '/adapt', '/init', '/']
 
 
 
 
326
  }), 404
327
 
328
  @app.errorhandler(500)
329
  def internal_error(error):
 
330
  return jsonify({
331
  'error': 'Internal server error',
332
+ 'message': 'Check server logs for details'
333
  }), 500
334
 
335
+ # Initialize model on startup (with fallback)
336
  if __name__ == '__main__':
337
+ logger.info("Starting SEAL Framework API...")
338
+ initialize_model()
339
+ app.run(host='0.0.0.0', port=7860, debug=False)
340
+ else:
341
+ # For production deployment
342
+ logger.info("SEAL Framework API starting in production mode...")
343
+ # Don't initialize model immediately in production to avoid startup timeouts