Seal / app.py
mike23415's picture
Update app.py
75cbe07 verified
from flask import Flask, request, jsonify
from flask_cors import CORS
import torch
import os
import json
import logging
import gc
from contextlib import contextmanager
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = Flask(__name__)
CORS(app)
# Global variables for model and tokenizer
model = None
tokenizer = None
device = None
# Configuration
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
DATA_FILE = "data/train_data.json"
MODEL_SAVE_DIR = "./results/model"
# Set environment variables
os.environ["HF_HOME"] = "/data/.huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/data/.huggingface"
def initialize_model():
"""Initialize model and tokenizer with error handling"""
global model, tokenizer, device
try:
logger.info("Initializing model and tokenizer...")
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
# Import here to avoid import errors during startup
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load tokenizer first (lighter)
logger.info("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
cache_dir="/data/.huggingface"
)
# Add padding token if it doesn't exist
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Loading model...")
# Load model with specific configuration for stability
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16 if device.type == "cuda" else torch.float32,
device_map="auto" if device.type == "cuda" else None,
trust_remote_code=True,
cache_dir="/data/.huggingface",
low_cpu_mem_usage=True
)
# Move to device if not using device_map
if device.type == "cpu":
model = model.to(device)
logger.info("Model initialization completed successfully")
return True
except Exception as e:
logger.error(f"Failed to initialize model: {str(e)}")
return False
def load_training_data():
"""Load or initialize training data"""
try:
if os.path.exists(DATA_FILE):
with open(DATA_FILE, 'r') as f:
train_texts = json.load(f)
else:
train_texts = []
os.makedirs(os.path.dirname(DATA_FILE), exist_ok=True)
with open(DATA_FILE, 'w') as f:
json.dump(train_texts, f)
logger.info(f"Loaded {len(train_texts)} training examples")
return train_texts
except Exception as e:
logger.error(f"Error loading training data: {str(e)}")
return []
@contextmanager
def torch_no_grad():
"""Context manager for torch.no_grad with error handling"""
try:
with torch.no_grad():
yield
except Exception as e:
logger.error(f"Error in torch context: {str(e)}")
raise
# Initialize data
train_texts = load_training_data()
@app.route('/')
def home():
"""Root endpoint with system information"""
return jsonify({
'status': 'SEAL Framework API is running',
'version': '1.0.0',
'model': MODEL_NAME,
'model_loaded': model is not None,
'device': str(device) if device else 'Not initialized',
'training_examples': len(train_texts),
'endpoints': {
'/': 'GET - API status and information',
'/adapt': 'POST - Adaptive model training and response',
'/health': 'GET - Health check',
'/init': 'POST - Initialize model (if not already loaded)'
},
'usage': {
'adapt_endpoint': {
'method': 'POST',
'content_type': 'application/json',
'body': {'text': 'Your input text here'},
'example': 'curl -X POST -H "Content-Type: application/json" -d \'{"text":"Hello world"}\' /adapt'
}
}
})
@app.route('/init', methods=['POST'])
def init_model():
"""Manual model initialization endpoint"""
global model, tokenizer
if model is not None:
return jsonify({'status': 'Model already initialized', 'success': True})
success = initialize_model()
if success:
return jsonify({'status': 'Model initialized successfully', 'success': True})
else:
return jsonify({'status': 'Model initialization failed', 'success': False}), 500
@app.route('/health')
def health():
"""Comprehensive health check"""
try:
# Check if model is loaded
if model is None or tokenizer is None:
return jsonify({
'status': 'unhealthy',
'error': 'Model not initialized',
'model_loaded': False,
'suggestion': 'Call /init endpoint to initialize model'
}), 500
# Simple model test
test_input = "Health check"
try:
with torch_no_grad():
inputs = tokenizer(
test_input,
return_tensors="pt",
truncation=True,
max_length=32,
padding=True
).to(device)
outputs = model.generate(
**inputs,
max_length=40,
num_return_sequences=1,
do_sample=False,
pad_token_id=tokenizer.pad_token_id
)
except Exception as e:
raise Exception(f"Model inference failed: {str(e)}")
return jsonify({
'status': 'healthy',
'model_loaded': True,
'device': str(device),
'training_examples': len(train_texts),
'torch_version': torch.__version__
})
except Exception as e:
logger.error(f"Health check failed: {str(e)}")
return jsonify({
'status': 'unhealthy',
'error': str(e),
'model_loaded': model is not None
}), 500
@app.route('/adapt', methods=['POST'])
def adapt_model():
"""Simplified adaptive model endpoint"""
global train_texts
try:
# Check if model is initialized
if model is None or tokenizer is None:
return jsonify({
'error': 'Model not initialized. Call /init endpoint first.',
'suggestion': 'POST to /init to initialize the model'
}), 500
# Get input
data = request.json
if not data or 'text' not in data:
return jsonify({'error': 'No text provided in request body'}), 400
user_input = data['text'].strip()
if not user_input:
return jsonify({'error': 'Empty text provided'}), 400
logger.info(f"Processing input: {user_input[:50]}...")
# Generate self-edit (simplified approach)
try:
with torch_no_grad():
prompt = f"Rephrase this text: {user_input}"
inputs = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=128,
padding=True
).to(device)
self_edit_output = model.generate(
**inputs,
max_length=200,
num_return_sequences=1,
do_sample=True,
temperature=0.7,
pad_token_id=tokenizer.pad_token_id
)
self_edit = tokenizer.decode(
self_edit_output[0],
skip_special_tokens=True
).replace(prompt, "").strip()
except Exception as e:
logger.error(f"Self-edit generation failed: {str(e)}")
self_edit = f"Self-edit failed: {str(e)}"
# Generate response (simplified)
try:
with torch_no_grad():
response_inputs = tokenizer(
user_input,
return_tensors="pt",
truncation=True,
max_length=128,
padding=True
).to(device)
response_output = model.generate(
**response_inputs,
max_length=256,
num_return_sequences=1,
do_sample=True,
temperature=0.8,
pad_token_id=tokenizer.pad_token_id
)
response = tokenizer.decode(
response_output[0],
skip_special_tokens=True
).replace(user_input, "").strip()
except Exception as e:
logger.error(f"Response generation failed: {str(e)}")
response = f"Response generation failed: {str(e)}"
# Save training data (simplified - no actual fine-tuning for stability)
try:
train_texts.append({
"prompt": user_input,
"completion": self_edit,
"timestamp": str(torch.now() if hasattr(torch, 'now') else 'unknown')
})
# Save to file
with open(DATA_FILE, 'w') as f:
json.dump(train_texts, f, indent=2)
except Exception as e:
logger.error(f"Failed to save training data: {str(e)}")
# Clean up GPU memory
if device.type == "cuda":
torch.cuda.empty_cache()
gc.collect()
return jsonify({
'input': user_input,
'self_edit': self_edit,
'response': response,
'training_examples': len(train_texts),
'status': 'Processing completed successfully',
'note': 'Fine-tuning disabled for stability - using generation only'
})
except Exception as e:
logger.error(f"Adapt endpoint error: {str(e)}")
return jsonify({
'error': str(e),
'type': type(e).__name__,
'suggestion': 'Check logs for detailed error information'
}), 500
@app.errorhandler(404)
def not_found(error):
return jsonify({
'error': 'Endpoint not found',
'available_endpoints': ['/health', '/adapt', '/init', '/']
}), 404
@app.errorhandler(500)
def internal_error(error):
return jsonify({
'error': 'Internal server error',
'message': 'Check server logs for details'
}), 500
# Initialize model on startup (with fallback)
if __name__ == '__main__':
logger.info("Starting SEAL Framework API...")
initialize_model()
app.run(host='0.0.0.0', port=7860, debug=False)
else:
# For production deployment
logger.info("SEAL Framework API starting in production mode...")
# Don't initialize model immediately in production to avoid startup timeouts