|
from flask import Flask, request, jsonify |
|
from flask_cors import CORS |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments |
|
import torch |
|
import os |
|
import json |
|
|
|
app = Flask(__name__) |
|
CORS(app) |
|
|
|
|
|
os.environ["HF_HOME"] = "/data/.huggingface" |
|
|
|
|
|
model_name = "Qwen/Qwen2.5-1.5B-Instruct" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
|
|
|
|
data_file = "data/train_data.json" |
|
|
|
|
|
if os.path.exists(data_file): |
|
with open(data_file, 'r') as f: |
|
train_texts = json.load(f) |
|
else: |
|
train_texts = [] |
|
os.makedirs(os.path.dirname(data_file), exist_ok=True) |
|
with open(data_file, 'w') as f: |
|
json.dump(train_texts, f) |
|
print(f"Loaded {len(train_texts)} examples from {data_file}") |
|
|
|
|
|
model_save_dir = "./results/model" |
|
|
|
@app.route('/') |
|
def home(): |
|
"""Root endpoint to show API status and usage""" |
|
return jsonify({ |
|
'status': 'SEAL Framework API is running', |
|
'version': '1.0.0', |
|
'model': model_name, |
|
'device': str(device), |
|
'training_examples': len(train_texts), |
|
'endpoints': { |
|
'/': 'GET - API status and information', |
|
'/adapt': 'POST - Adaptive model training and response', |
|
'/health': 'GET - Health check' |
|
}, |
|
'usage': { |
|
'adapt_endpoint': { |
|
'method': 'POST', |
|
'content_type': 'application/json', |
|
'body': {'text': 'Your input text here'}, |
|
'example': 'curl -X POST -H "Content-Type: application/json" -d \'{"text":"Hello world"}\' /adapt' |
|
} |
|
} |
|
}) |
|
|
|
@app.route('/health') |
|
def health(): |
|
"""Health check endpoint""" |
|
try: |
|
|
|
test_input = "Health check" |
|
inputs = tokenizer(test_input, return_tensors="pt", truncation=True, max_length=32).to(device) |
|
with torch.no_grad(): |
|
outputs = model.generate(**inputs, max_length=40, num_return_sequences=1, do_sample=False) |
|
|
|
return jsonify({ |
|
'status': 'healthy', |
|
'model_loaded': True, |
|
'device': str(device), |
|
'training_examples': len(train_texts) |
|
}) |
|
except Exception as e: |
|
return jsonify({ |
|
'status': 'unhealthy', |
|
'error': str(e) |
|
}), 500 |
|
|
|
@app.route('/adapt', methods=['POST']) |
|
def adapt_model(): |
|
try: |
|
data = request.json |
|
user_input = data.get('text', '') |
|
|
|
if not user_input: |
|
return jsonify({'error': 'No input provided'}), 400 |
|
|
|
|
|
prompt = f"Rephrase this: {user_input}" |
|
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128).to(device) |
|
self_edit_output = model.generate(**inputs, max_length=150, num_return_sequences=1) |
|
self_edit = tokenizer.decode(self_edit_output[0], skip_special_tokens=True) |
|
|
|
|
|
train_texts.append({"prompt": user_input, "completion": self_edit}) |
|
with open(data_file, 'w') as f: |
|
json.dump(train_texts, f, indent=2) |
|
|
|
|
|
encodings = tokenizer( |
|
[t["prompt"] + " " + t["completion"] for t in train_texts], |
|
truncation=True, |
|
padding=True, |
|
max_length=256, |
|
return_tensors="pt" |
|
) |
|
dataset = [ |
|
{ |
|
"input_ids": encodings["input_ids"][i], |
|
"attention_mask": encodings["attention_mask"][i], |
|
"labels": encodings["input_ids"][i] |
|
} for i in range(len(train_texts)) |
|
] |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir=model_save_dir, |
|
num_train_epochs=1, |
|
per_device_train_batch_size=2, |
|
gradient_accumulation_steps=4, |
|
logging_steps=10, |
|
save_steps=10, |
|
save_total_limit=1, |
|
disable_tqdm=True, |
|
fp16=True if torch.cuda.is_available() else False |
|
) |
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=dataset |
|
) |
|
trainer.train() |
|
|
|
|
|
trainer.save_model(model_save_dir) |
|
tokenizer.save_pretrained(model_save_dir) |
|
|
|
|
|
response_inputs = tokenizer(user_input, return_tensors="pt", truncation=True, max_length=128).to(device) |
|
response_output = model.generate(**response_inputs, max_length=200, num_return_sequences=1) |
|
response = tokenizer.decode(response_output[0], skip_special_tokens=True) |
|
|
|
return jsonify({ |
|
'input': user_input, |
|
'self_edit': self_edit, |
|
'response': response, |
|
'training_examples': len(train_texts), |
|
'status': 'Model adapted successfully' |
|
}) |
|
|
|
except Exception as e: |
|
return jsonify({'error': str(e)}), 500 |
|
|
|
@app.errorhandler(404) |
|
def not_found(error): |
|
"""Custom 404 handler""" |
|
return jsonify({ |
|
'error': 'Endpoint not found', |
|
'available_endpoints': { |
|
'/': 'GET - API information', |
|
'/health': 'GET - Health check', |
|
'/adapt': 'POST - Adaptive model training' |
|
} |
|
}), 404 |
|
|
|
@app.errorhandler(500) |
|
def internal_error(error): |
|
"""Custom 500 handler""" |
|
return jsonify({ |
|
'error': 'Internal server error', |
|
'message': 'Please check the server logs for more details' |
|
}), 500 |
|
|
|
if __name__ == '__main__': |
|
app.run(host='0.0.0.0', port=7860, debug=False) |