|
from flask import Flask, request, jsonify |
|
from flask_cors import CORS |
|
import torch |
|
import os |
|
import json |
|
import logging |
|
import gc |
|
from contextlib import contextmanager |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
app = Flask(__name__) |
|
CORS(app) |
|
|
|
|
|
model = None |
|
tokenizer = None |
|
device = None |
|
|
|
|
|
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" |
|
DATA_FILE = "data/train_data.json" |
|
MODEL_SAVE_DIR = "./results/model" |
|
|
|
|
|
os.environ["HF_HOME"] = "/data/.huggingface" |
|
os.environ["TRANSFORMERS_CACHE"] = "/data/.huggingface" |
|
|
|
def initialize_model(): |
|
"""Initialize model and tokenizer with error handling""" |
|
global model, tokenizer, device |
|
|
|
try: |
|
logger.info("Initializing model and tokenizer...") |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
logger.info(f"Using device: {device}") |
|
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
logger.info("Loading tokenizer...") |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
MODEL_NAME, |
|
trust_remote_code=True, |
|
cache_dir="/data/.huggingface" |
|
) |
|
|
|
|
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
logger.info("Loading model...") |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_NAME, |
|
torch_dtype=torch.float16 if device.type == "cuda" else torch.float32, |
|
device_map="auto" if device.type == "cuda" else None, |
|
trust_remote_code=True, |
|
cache_dir="/data/.huggingface", |
|
low_cpu_mem_usage=True |
|
) |
|
|
|
|
|
if device.type == "cpu": |
|
model = model.to(device) |
|
|
|
logger.info("Model initialization completed successfully") |
|
return True |
|
|
|
except Exception as e: |
|
logger.error(f"Failed to initialize model: {str(e)}") |
|
return False |
|
|
|
def load_training_data(): |
|
"""Load or initialize training data""" |
|
try: |
|
if os.path.exists(DATA_FILE): |
|
with open(DATA_FILE, 'r') as f: |
|
train_texts = json.load(f) |
|
else: |
|
train_texts = [] |
|
os.makedirs(os.path.dirname(DATA_FILE), exist_ok=True) |
|
with open(DATA_FILE, 'w') as f: |
|
json.dump(train_texts, f) |
|
|
|
logger.info(f"Loaded {len(train_texts)} training examples") |
|
return train_texts |
|
except Exception as e: |
|
logger.error(f"Error loading training data: {str(e)}") |
|
return [] |
|
|
|
@contextmanager |
|
def torch_no_grad(): |
|
"""Context manager for torch.no_grad with error handling""" |
|
try: |
|
with torch.no_grad(): |
|
yield |
|
except Exception as e: |
|
logger.error(f"Error in torch context: {str(e)}") |
|
raise |
|
|
|
|
|
train_texts = load_training_data() |
|
|
|
@app.route('/') |
|
def home(): |
|
"""Root endpoint with system information""" |
|
return jsonify({ |
|
'status': 'SEAL Framework API is running', |
|
'version': '1.0.0', |
|
'model': MODEL_NAME, |
|
'model_loaded': model is not None, |
|
'device': str(device) if device else 'Not initialized', |
|
'training_examples': len(train_texts), |
|
'endpoints': { |
|
'/': 'GET - API status and information', |
|
'/adapt': 'POST - Adaptive model training and response', |
|
'/health': 'GET - Health check', |
|
'/init': 'POST - Initialize model (if not already loaded)' |
|
}, |
|
'usage': { |
|
'adapt_endpoint': { |
|
'method': 'POST', |
|
'content_type': 'application/json', |
|
'body': {'text': 'Your input text here'}, |
|
'example': 'curl -X POST -H "Content-Type: application/json" -d \'{"text":"Hello world"}\' /adapt' |
|
} |
|
} |
|
}) |
|
|
|
@app.route('/init', methods=['POST']) |
|
def init_model(): |
|
"""Manual model initialization endpoint""" |
|
global model, tokenizer |
|
|
|
if model is not None: |
|
return jsonify({'status': 'Model already initialized', 'success': True}) |
|
|
|
success = initialize_model() |
|
if success: |
|
return jsonify({'status': 'Model initialized successfully', 'success': True}) |
|
else: |
|
return jsonify({'status': 'Model initialization failed', 'success': False}), 500 |
|
|
|
@app.route('/health') |
|
def health(): |
|
"""Comprehensive health check""" |
|
try: |
|
|
|
if model is None or tokenizer is None: |
|
return jsonify({ |
|
'status': 'unhealthy', |
|
'error': 'Model not initialized', |
|
'model_loaded': False, |
|
'suggestion': 'Call /init endpoint to initialize model' |
|
}), 500 |
|
|
|
|
|
test_input = "Health check" |
|
try: |
|
with torch_no_grad(): |
|
inputs = tokenizer( |
|
test_input, |
|
return_tensors="pt", |
|
truncation=True, |
|
max_length=32, |
|
padding=True |
|
).to(device) |
|
|
|
outputs = model.generate( |
|
**inputs, |
|
max_length=40, |
|
num_return_sequences=1, |
|
do_sample=False, |
|
pad_token_id=tokenizer.pad_token_id |
|
) |
|
except Exception as e: |
|
raise Exception(f"Model inference failed: {str(e)}") |
|
|
|
return jsonify({ |
|
'status': 'healthy', |
|
'model_loaded': True, |
|
'device': str(device), |
|
'training_examples': len(train_texts), |
|
'torch_version': torch.__version__ |
|
}) |
|
|
|
except Exception as e: |
|
logger.error(f"Health check failed: {str(e)}") |
|
return jsonify({ |
|
'status': 'unhealthy', |
|
'error': str(e), |
|
'model_loaded': model is not None |
|
}), 500 |
|
|
|
@app.route('/adapt', methods=['POST']) |
|
def adapt_model(): |
|
"""Simplified adaptive model endpoint""" |
|
global train_texts |
|
|
|
try: |
|
|
|
if model is None or tokenizer is None: |
|
return jsonify({ |
|
'error': 'Model not initialized. Call /init endpoint first.', |
|
'suggestion': 'POST to /init to initialize the model' |
|
}), 500 |
|
|
|
|
|
data = request.json |
|
if not data or 'text' not in data: |
|
return jsonify({'error': 'No text provided in request body'}), 400 |
|
|
|
user_input = data['text'].strip() |
|
if not user_input: |
|
return jsonify({'error': 'Empty text provided'}), 400 |
|
|
|
logger.info(f"Processing input: {user_input[:50]}...") |
|
|
|
|
|
try: |
|
with torch_no_grad(): |
|
prompt = f"Rephrase this text: {user_input}" |
|
inputs = tokenizer( |
|
prompt, |
|
return_tensors="pt", |
|
truncation=True, |
|
max_length=128, |
|
padding=True |
|
).to(device) |
|
|
|
self_edit_output = model.generate( |
|
**inputs, |
|
max_length=200, |
|
num_return_sequences=1, |
|
do_sample=True, |
|
temperature=0.7, |
|
pad_token_id=tokenizer.pad_token_id |
|
) |
|
|
|
self_edit = tokenizer.decode( |
|
self_edit_output[0], |
|
skip_special_tokens=True |
|
).replace(prompt, "").strip() |
|
|
|
except Exception as e: |
|
logger.error(f"Self-edit generation failed: {str(e)}") |
|
self_edit = f"Self-edit failed: {str(e)}" |
|
|
|
|
|
try: |
|
with torch_no_grad(): |
|
response_inputs = tokenizer( |
|
user_input, |
|
return_tensors="pt", |
|
truncation=True, |
|
max_length=128, |
|
padding=True |
|
).to(device) |
|
|
|
response_output = model.generate( |
|
**response_inputs, |
|
max_length=256, |
|
num_return_sequences=1, |
|
do_sample=True, |
|
temperature=0.8, |
|
pad_token_id=tokenizer.pad_token_id |
|
) |
|
|
|
response = tokenizer.decode( |
|
response_output[0], |
|
skip_special_tokens=True |
|
).replace(user_input, "").strip() |
|
|
|
except Exception as e: |
|
logger.error(f"Response generation failed: {str(e)}") |
|
response = f"Response generation failed: {str(e)}" |
|
|
|
|
|
try: |
|
train_texts.append({ |
|
"prompt": user_input, |
|
"completion": self_edit, |
|
"timestamp": str(torch.now() if hasattr(torch, 'now') else 'unknown') |
|
}) |
|
|
|
|
|
with open(DATA_FILE, 'w') as f: |
|
json.dump(train_texts, f, indent=2) |
|
|
|
except Exception as e: |
|
logger.error(f"Failed to save training data: {str(e)}") |
|
|
|
|
|
if device.type == "cuda": |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
return jsonify({ |
|
'input': user_input, |
|
'self_edit': self_edit, |
|
'response': response, |
|
'training_examples': len(train_texts), |
|
'status': 'Processing completed successfully', |
|
'note': 'Fine-tuning disabled for stability - using generation only' |
|
}) |
|
|
|
except Exception as e: |
|
logger.error(f"Adapt endpoint error: {str(e)}") |
|
return jsonify({ |
|
'error': str(e), |
|
'type': type(e).__name__, |
|
'suggestion': 'Check logs for detailed error information' |
|
}), 500 |
|
|
|
@app.errorhandler(404) |
|
def not_found(error): |
|
return jsonify({ |
|
'error': 'Endpoint not found', |
|
'available_endpoints': ['/health', '/adapt', '/init', '/'] |
|
}), 404 |
|
|
|
@app.errorhandler(500) |
|
def internal_error(error): |
|
return jsonify({ |
|
'error': 'Internal server error', |
|
'message': 'Check server logs for details' |
|
}), 500 |
|
|
|
|
|
if __name__ == '__main__': |
|
logger.info("Starting SEAL Framework API...") |
|
initialize_model() |
|
app.run(host='0.0.0.0', port=7860, debug=False) |
|
else: |
|
|
|
logger.info("SEAL Framework API starting in production mode...") |
|
|