from flask import Flask, request, jsonify from sentence_transformers import SentenceTransformer, util import logging import sys import signal # 初始化Flask应用 app = Flask(__name__) # 配置日志,日志级别设为INFO logging.basicConfig(level=logging.INFO) app.logger = logging.getLogger("CodeSearchAPI") # 预定义代码片段 CODE_SNIPPETS = [ "def sort_list(x): return sorted(x)", """def count_above_threshold(elements, threshold=0): return sum(1 for e in elements if e > threshold)""", """def find_min_max(elements): return min(elements), max(elements)""" ] # 全局服务状态 service_ready = False # 优雅关闭处理 def handle_shutdown(signum, frame): app.logger.info("收到终止信号,开始关闭...") sys.exit(0) signal.signal(signal.SIGTERM, handle_shutdown) signal.signal(signal.SIGINT, handle_shutdown) # 初始化模型和编码 try: # Hugging Face Spaces专用缓存路径 model = SentenceTransformer( "flax-sentence-embeddings/st-codesearch-distilroberta-base", cache_folder="/model-cache" ) # 预计算编码(强制使用CPU) code_emb = model.encode(CODE_SNIPPETS, convert_to_tensor=True, device="cpu") service_ready = True app.logger.info("服务初始化完成") except Exception as e: app.logger.error(f"初始化失败: {str(e)}") raise # Hugging Face健康检查端点 @app.route('/') def hf_health_check(): """必须响应根路径的健康检查""" if service_ready: return jsonify({"status": "ready"}), 200 else: return jsonify({"status": "initializing"}), 503 # 支持GET和POST请求的搜索API端点 @app.route('/search', methods=['GET', 'POST']) def handle_search(): if not service_ready: return jsonify({"error": "服务正在初始化"}), 503 try: # 区分GET和POST请求,GET从URL参数中获取query,POST从JSON体中获取 if request.method == 'GET': query = request.args.get('query', '').strip() else: data = request.get_json() or {} query = data.get('query', '').strip() if not query: app.logger.info("收到空的查询请求") return jsonify({"error": "查询不能为空"}), 400 # 记录接收到的查询 app.logger.info("收到查询请求: %s", query) # 对查询进行编码,并搜索最匹配的代码片段 query_emb = model.encode(query, convert_to_tensor=True, device="cpu") hits = util.semantic_search(query_emb, code_emb, top_k=1)[0] best = hits[0] result = { "code": CODE_SNIPPETS[best['corpus_id']], "score": round(float(best['score']), 4) } # 记录返回结果 app.logger.info("返回结果: %s", result) return jsonify(result) except Exception as e: app.logger.error("请求处理失败: %s", str(e)) return jsonify({"error": "服务器内部错误"}), 500 if __name__ == "__main__": # Hugging Face Spaces会通过gunicorn启动,此处仅为本地测试保留 app.run(host='0.0.0.0', port=7860)