from flask import Flask, request, jsonify, render_template_string from sentence_transformers import SentenceTransformer, util import logging import sys import signal # 初始化 Flask 应用 app = Flask(__name__) # 配置日志,级别设为 INFO logging.basicConfig(level=logging.INFO) app.logger = logging.getLogger("CodeSearchAPI") # 预定义代码片段 CODE_SNIPPETS = [ "print('Hello, World!')", "def add(a, b): return a + b", "import random; def generate_random(): return random.randint(1, 100)", "def is_even(n): return n % 2 == 0", "def string_length(s): return len(s)", "from datetime import date; def get_current_date(): return date.today()", "import os; def file_exists(path): return os.path.exists(path)", "def read_file(path): return open(path, 'r').read()", "def write_file(path, content): open(path, 'w').write(content)", "from datetime import datetime; def get_current_time(): return datetime.now()", "def to_upper(s): return s.upper()", "def to_lower(s): return s.lower()", "def reverse_string(s): return s[::-1]", "def list_length(lst): return len(lst)", "def list_max(lst): return max(lst)", "def list_min(lst): return min(lst)", "def sort_list(lst): return sorted(lst)", "def merge_lists(lst1, lst2): return lst1 + lst2", "def remove_element(lst, element): lst.remove(element)", "def is_list_empty(lst): return len(lst) == 0", "def count_char(s, char): return s.count(char)", "def contains_substring(s, sub): return sub in s", "def int_to_str(n): return str(n)", "def str_to_int(s): return int(s)", "def is_numeric(s): return s.isdigit()", "def get_index(lst, element): return lst.index(element)", "def clear_list(lst): lst.clear()", "def reverse_list(lst): lst.reverse()", "def remove_duplicates(lst): return list(set(lst))", "def is_in_list(lst, value): return value in lst", "def create_dict(): return {}", "def add_to_dict(d, key, value): d[key] = value", "def delete_key(d, key): del d[key]", "def get_keys(d): return list(d.keys())", "def get_values(d): return list(d.values())", "def merge_dicts(d1, d2): return {**d1, **d2}", "def is_dict_empty(d): return len(d) == 0", "def get_value(d, key): return d[key]", "def key_exists(d, key): return key in d", "def clear_dict(d): d.clear()", "def count_lines(path): return len(open(path).readlines())", "def write_list_to_file(path, lst): open(path, 'w').write('\\n'.join(map(str, lst)))", "def read_list_from_file(path): return open(path, 'r').read().splitlines()", "def count_words(path): return len(open(path, 'r').read().split())", "def is_leap_year(year): return year % 4 == 0 and (year % 100 != 0 or year % 400 == 0)", "from datetime import datetime; def format_time(dt): return dt.strftime('%Y-%m-%d %H:%M:%S')", "from datetime import date; def days_between(d1, d2): return (d2 - d1).days", "import os; def get_current_dir(): return os.getcwd()", "import os; def list_files(path): return os.listdir(path)", "import os; def create_dir(path): os.mkdir(path)" "import os; def remove_dir(path): os.rmdir(path)", "import os; def is_file(path): return os.path.isfile(path)", "import os; def is_dir(path): return os.path.isdir(path)", "import os; def get_file_size(path): return os.path.getsize(path)", "import os; def rename_file(src, dst): os.rename(src, dst)", "import shutil; def copy_file(src, dst): shutil.copy(src, dst)", "import shutil; def move_file(src, dst): shutil.move(src, dst)", "import os; def delete_file(path): os.remove(path)", "import os; def get_env_var(key): return os.getenv(key)", "import os; def set_env_var(key, value): os.environ[key] = value", "import webbrowser; def open_url(url): webbrowser.open(url)", "import requests; def send_get_request(url): return requests.get(url).text", "import json; def parse_json(data): return json.loads(data)", "import json; def write_json(data, path): open(path, 'w').write(json.dumps(data))", "import json; def read_json(path): return json.loads(open(path, 'r').read())", "def list_to_string(lst): return ''.join(lst)", "def string_to_list(s): return list(s)", "def join_with_comma(lst): return ','.join(lst)", "def join_with_newline(lst): return '\\n'.join(lst)", "def split_by_space(s): return s.split()", "def split_by_char(s, char): return s.split(char)", "def split_to_chars(s): return list(s)", "def replace_string(s, old, new): return s.replace(old, new)", "def remove_spaces(s): return s.replace(' ', '')", "import string; def remove_punctuation(s): return s.translate(str.maketrans('', '', string.punctuation))", "def is_string_empty(s): return len(s) == 0", "def is_palindrome(s): return s == s[::-1]", "import csv; def write_csv(data, path): open(path, 'w', newline='').write('\\n'.join([','.join(map(str, row)) for row in data]))", "import csv; def read_csv(path): return [row for row in csv.reader(open(path, 'r'))]", "def count_csv_lines(path): return len(open(path).readlines())", "import random; def shuffle_list(lst): random.shuffle(lst)", "import random; def random_choice(lst): return random.choice(lst)", "import random; def random_sample(lst, k): return random.sample(lst, k)", "import random; def roll_dice(): return random.randint(1, 6)", "import random; def flip_coin(): return random.choice(['Heads', 'Tails'])", "import random; import string; def generate_password(length=8): return ''.join(random.choices(string.ascii_letters + string.digits, k=length))", "import random; def generate_color(): return '#%06x' % random.randint(0, 0xFFFFFF)", "import uuid; def generate_uuid(): return str(uuid.uuid4())", "class MyClass: pass", "def create_instance(): return MyClass()", "class MyClass: def my_method(self): pass", "class MyClass: def __init__(self): self.my_attr = None", "class ChildClass(MyClass): pass", "class ChildClass(MyClass): def my_method(self): pass", "class MyClass: @classmethod def my_class_method(cls): pass", "class MyClass: @staticmethod def my_static_method(): pass", "def check_type(obj): return type(obj)", "def get_attr(obj, attr): return getattr(obj, attr)", "def set_attr(obj, attr, value): setattr(obj, attr, value)", "def del_attr(obj, attr): delattr(obj, attr)", """try: x = 1 / 0 except ZeroDivisionError: pass""", """class CustomError(Exception): pass def raise_custom_error(): raise CustomError('Error occurred')""", """try: x = 1 / 0 except Exception as e: return str(e)""", """import logging; logging.basicConfig(filename='error.log', level=logging.ERROR); logging.error('Error occurred')""", """import time; def timer(): start = time.time(); return lambda: time.time() - start""", """import time; def run_time(): start = time.time(); return lambda: time.time() - start""", """import sys; def print_progress(progress): sys.stdout.write(f'\\rProgress: {progress}%'); sys.stdout.flush()""", """import time; def delay(seconds): time.sleep(seconds)""", "lambda x: x * 2", "map(lambda x: x * 2, [1, 2, 3])", "filter(lambda x: x > 2, [1, 2, 3])", "from functools import reduce; reduce(lambda x, y: x + y, [1, 2, 3])", "[x * 2 for x in [1, 2, 3]]", "{x: x * 2 for x in [1, 2, 3]}", "{x for x in [1, 2, 3]}", "set1 & set2", "set1 | set2", "set1 - set2", "[x for x in lst if x is not None]", """try: with open('file.txt', 'r') as f: pass except IOError: pass""", "type(var)", "bool(s)", "if condition: pass", "while condition: pass", "for item in lst: pass", "for key, value in d.items(): pass", "for char in s: pass", "for item in lst: if condition: break", "for item in lst: if condition: continue", "def my_func(): pass", "def my_func(param=1): pass", "def my_func(): return 1, 2", "def my_func(*args): pass", "def my_func(**kwargs): pass", """import time; def timer(func): def wrapper(*args, **kwargs): start = time.time() result = func(*args, **kwargs) print(f'Time: {time.time() - start}'); return result return wrapper""", """def decorator(func): def wrapper(*args, **kwargs): return func(*args, **kwargs) return wrapper""", """from functools import lru_cache; @lru_cache(maxsize=None) def my_func(): pass""", "def my_generator(): yield 1", "gen = my_generator(); next(gen)", "class MyIterator: def __iter__(self): return self; def __next__(self): pass", "it = iter([1, 2, 3]); next(it)", "for i, val in enumerate(lst): pass", "list(zip(lst1, lst2))", "dict(zip(keys, values))", "lst1 == lst2", "dict1 == dict2", "set1 == set2", "set(lst)", "set.clear()", "len(set) == 0", "set.add(item)", "set.remove(item)", "item in set", "len(set)", "set1 & set2", "set(lst1).issubset(lst2)", "sub in s", "s[0]", "s[-1]", "import mimetypes; mimetypes.guess_type(path)[0] == 'text/plain'", "import mimetypes; mimetypes.guess_type(path)[0].startswith('image/')", "round(num)", "import math; math.ceil(num)", "import math; math.floor(num)", "f'{num:.2f}'", "import random; import string; ''.join(random.choices(string.ascii_letters + string.digits, k=8))", "import os; os.path.exists(path)", "import os; for root, dirs, files in os.walk(path): pass", "import os; os.path.splitext(path)[1]", "import os; os.path.basename(path)", "import os; os.path.abspath(path)", "import platform; platform.python_version()", "import platform; platform.system()", "import multiprocessing; multiprocessing.cpu_count()", "import psutil; psutil.virtual_memory().total", "import psutil; psutil.disk_usage('/')", "import socket; socket.gethostbyname(socket.gethostname())", "import requests; try: requests.get('http://www.google.com'); return True; except: return False", "import requests; def download_file(url, path): with open(path, 'wb') as f: f.write(requests.get(url).content)", "def upload_file(path): with open(path, 'rb') as f: requests.post('http://example.com/upload', files={'file': f})", "import requests; requests.post(url, data={'key': 'value'})", "import requests; requests.get(url, params={'key': 'value'})", "import requests; requests.get(url, headers={'key': 'value'})", "from bs4 import BeautifulSoup; BeautifulSoup(html, 'html.parser')", "from bs4 import BeautifulSoup; soup.title.text", "from bs4 import BeautifulSoup; [a['href'] for a in soup.find_all('a')]", "from bs4 import BeautifulSoup; import requests; for img in soup.find_all('img'): requests.get(img['src']).content", "from collections import Counter; Counter(text.split())", "import requests; session = requests.Session(); session.post(login_url, data={'username': 'user', 'password': 'pass'})", "from bs4 import BeautifulSoup; soup.get_text()", "import re; re.findall(r'[\\w.-]+@[\\w.-]+', text)", "import re; re.findall(r'\\+?\\d[\\d -]{8,12}\\d', text)", "import re; re.findall(r'\\d+', text)", "import re; re.sub(pattern, repl, text)", "import re; re.match(pattern, text)", "from bs4 import BeautifulSoup; soup.get_text()", "import html; html.escape(text)", "import html; html.unescape(text)", "import tkinter as tk; root = tk.Tk(); root.mainloop()", "import tkinter as tk; def add_button(window, text): return tk.Button(window, text=text)", """def bind_click(button, func): button.config(command=func)""", """import tkinter.messagebox; def show_alert(message): tkinter.messagebox.showinfo('Info', message)""", """def get_entry_text(entry): return entry.get()""", "def set_title(window, title): window.title(title)", "def set_size(window, width, height): window.geometry(f'{width}x{height}')", """def center_window(window): window.update_idletasks() width = window.winfo_width() height = window.winfo_height() x = (window.winfo_screenwidth() // 2) - (width // 2) y = (window.winfo_screenheight() // 2) - (height // 2) window.geometry(f'{width}x{height}+{x}+{y}')""", """def create_menu(window): return tk.Menu(window)""", "def create_combobox(window): return ttk.Combobox(window)", "def create_radiobutton(window, text): return tk.Radiobutton(window, text=text)", "def create_checkbutton(window, text): return tk.Checkbutton(window, text=text)", """from PIL import ImageTk, Image; def show_image(window, path): img = Image.open(path) photo = ImageTk.PhotoImage(img) label = tk.Label(window, image=photo) label.image = photo return label""", "import pygame; def play_audio(path): pygame.mixer.init(); pygame.mixer.music.load(path); pygame.mixer.music.play()", "import cv2; def play_video(path): cap = cv2.VideoCapture(path); while cap.isOpened(): ret, frame = cap.read()", "def get_playback_time(): return pygame.mixer.music.get_pos()", "import pyautogui; def screenshot(): return pyautogui.screenshot()", "import pyautogui; import time; def record_screen(duration): return [pyautogui.screenshot() for _ in range(duration)]", "def get_mouse_pos(): return pyautogui.position()", "import pyautogui; def type_text(text): pyautogui.write(text)", "import pyautogui; def click_mouse(x, y): pyautogui.click(x, y)", "import time; def get_timestamp(): return int(time.time())", "import datetime; def timestamp_to_date(ts): return datetime.datetime.fromtimestamp(ts)", "import time; def date_to_timestamp(dt): return int(time.mktime(dt.timetuple()))", "def get_weekday(): return datetime.datetime.now().strftime('%A')", "import calendar; def get_month_days(): return calendar.monthrange(datetime.datetime.now().year, datetime.datetime.now().month)[1]", "def first_day_of_year(): return datetime.date(datetime.datetime.now().year, 1, 1)", "def last_day_of_year(): return datetime.date(datetime.datetime.now().year, 12, 31)", "def first_day_of_month(month): return datetime.date(datetime.datetime.now().year, month, 1)", "import calendar; def last_day_of_month(month): return datetime.date(datetime.datetime.now().year, month, calendar.monthrange(datetime.datetime.now().year, month)[1])", "def is_weekday(): return datetime.datetime.now().weekday() < 5", "def is_weekend(): return datetime.datetime.now().weekday() >= 5", "def current_hour(): return datetime.datetime.now().hour", "def current_minute(): return datetime.datetime.now().minute", "def current_second(): return datetime.datetime.now().second", "import time; def delay_1s(): time.sleep(1)", "import time; def millis_timestamp(): return int(time.time() * 1000)", "def format_time(dt, fmt='%Y-%m-%d %H:%M:%S'): return dt.strftime(fmt)", "from dateutil.parser import parse; def parse_time(s): return parse(s)", "import threading; def create_thread(target): return threading.Thread(target=target)", "import time; def thread_pause(seconds): time.sleep(seconds)", "def run_threads(*threads): [t.start() for t in threads]", "import threading; def current_thread_name(): return threading.current_thread().name", "def set_daemon(thread): thread.daemon = True", "import threading; lock = threading.Lock()", "import multiprocessing; def create_process(target): return multiprocessing.Process(target=target)", "import os; def get_pid(): return os.getpid()", "import psutil; def is_process_alive(pid): return psutil.pid_exists(pid)", "def run_processes(*procs): [p.start() for p in procs]", "from queue import Queue; q = Queue()", "from multiprocessing import Pipe; parent_conn, child_conn = Pipe()", "import os; def limit_cpu_usage(percent): os.system(f'cpulimit -p {os.getpid()} -l {percent}')", "import subprocess; def run_command(cmd): subprocess.run(cmd, shell=True)", "import subprocess; def get_command_output(cmd): return subprocess.check_output(cmd, shell=True).decode()", "def get_exit_code(cmd): return subprocess.call(cmd, shell=True)", "def is_success(code): return code == 0", "import os; def script_path(): return os.path.realpath(__file__)", "import sys; def get_cli_args(): return sys.argv[1:]", "import argparse; parser = argparse.ArgumentParser()", "parser.print_help()", "help('modules')", "import pip; def install_pkg(pkg): pip.main(['install', pkg])", "import pip; def uninstall_pkg(pkg): pip.main(['uninstall', pkg])", "import pkg_resources; def get_pkg_version(pkg): return pkg_resources.get_distribution(pkg).version", "import venv; def create_venv(path): venv.create(path)", "import pip; def list_pkgs(): return pip.get_installed_distributions()", "import pip; def upgrade_pkg(pkg): pip.main(['install', '--upgrade', pkg])", "import sqlite3; conn = sqlite3.connect(':memory:')", "def execute_query(conn, query): return conn.execute(query)", """def insert_record(conn, table, data): conn.execute(f'INSERT INTO {table} VALUES ({",".join("?"*len(data))})', data)""", "def delete_record(conn, table, condition): conn.execute(f'DELETE FROM {table} WHERE {condition}')", "def update_record(conn, table, set_clause, condition): conn.execute(f'UPDATE {table} SET {set_clause} WHERE {condition}')", "def fetch_all(conn, query): return conn.execute(query).fetchall()", "def safe_query(conn, query, params): return conn.execute(query, params)", "def close_db(conn): conn.close()", "def create_table(conn, name, columns): conn.execute(f'CREATE TABLE {name} ({columns})')", "def drop_table(conn, name): conn.execute(f'DROP TABLE {name}')", "def table_exists(conn, name): return conn.execute(f\"SELECT name FROM sqlite_master WHERE type='table' AND name='{name}'\").fetchone()", "def list_tables(conn): return conn.execute(\"SELECT name FROM sqlite_master WHERE type='table'\").fetchall()", """from sqlalchemy import Column, Integer, String class User(Base): __tablename__ = 'users' id = Column(Integer, primary_key=True) name = Column(String)""", "session.add(User(name='John'))", "session.query(User).filter_by(name='John')", "session.query(User).filter_by(name='John').delete()", "session.query(User).filter_by(name='John').update({'name': 'Bob'})", "Base = declarative_base()", "class Admin(User): pass", "id = Column(Integer, primary_key=True)", "name = Column(String, unique=True)", "name = Column(String, default='Unknown')", "import csv; def export_csv(data, path): open(path, 'w').write('\\n'.join([','.join(map(str, row)) for row in data]))", "import pandas as pd; pd.DataFrame(data).to_excel(path)", "import json; json.dump(data, open(path, 'w'))", "pd.read_excel(path).values.tolist()", "pd.concat([pd.read_excel(f) for f in files])", "with pd.ExcelWriter(path, mode='a') as writer: df.to_excel(writer, sheet_name='New')", "from openpyxl.styles import copy; copy.copy(style)", "from openpyxl.styles import PatternFill; cell.fill = PatternFill(start_color='FFFF00', fill_type='solid')", "from openpyxl.styles import Font; cell.font = Font(bold=True)", "sheet['A1'].value", "sheet['A1'] = value", "from PIL import Image; Image.open(path).size", "from PIL import Image; Image.open(path).resize((w, h))" ] # 全局服务状态 service_ready = False # 优雅关闭处理 def handle_shutdown(signum, frame): app.logger.info("收到终止信号,开始关闭...") sys.exit(0) signal.signal(signal.SIGTERM, handle_shutdown) signal.signal(signal.SIGINT, handle_shutdown) # 初始化模型和预计算编码 try: app.logger.info("开始加载模型...") model = SentenceTransformer( "flax-sentence-embeddings/st-codesearch-distilroberta-base", cache_folder="/model-cache" ) # 预计算代码片段的编码(强制使用 CPU) code_emb = model.encode(CODE_SNIPPETS, convert_to_tensor=True, device="cpu") service_ready = True app.logger.info("服务初始化完成") except Exception as e: app.logger.error("初始化失败: %s", str(e)) raise # Hugging Face 健康检查端点,必须响应根路径 @app.route('/') def hf_health_check(): # 如果请求接受 HTML,则返回一个简单的 HTML 页面(包含测试链接) if request.accept_mimetypes.accept_html: html = """
服务状态:{{ status }}
你可以在地址栏输入 /search?query=你的查询 来测试接口
""" status = "ready" if service_ready else "initializing" return render_template_string(html, status=status) # 否则返回 JSON 格式的健康检查 if service_ready: return jsonify({"status": "ready"}), 200 else: return jsonify({"status": "initializing"}), 503 # 搜索 API 端点,同时支持 GET 和 POST 请求 @app.route('/search', methods=['GET', 'POST']) def handle_search(): if not service_ready: app.logger.info("服务未就绪") return jsonify({"error": "服务正在初始化"}), 503 try: # 根据请求方法提取查询内容 if request.method == 'GET': query = request.args.get('query', '').strip() else: data = request.get_json() or {} query = data.get('query', '').strip() if not query: app.logger.info("收到空的查询请求") return jsonify({"error": "查询不能为空"}), 400 # 记录接收到的查询 app.logger.info("收到查询请求: %s", query) # 对查询进行编码,并进行语义搜索 query_emb = model.encode(query, convert_to_tensor=True, device="cpu") hits = util.semantic_search(query_emb, code_emb, top_k=1)[0] best = hits[0] result = { "code": CODE_SNIPPETS[best['corpus_id']], "score": round(float(best['score']), 4) } # 记录返回结果 app.logger.info("返回结果: %s", result) return jsonify(result) except Exception as e: app.logger.error("请求处理失败: %s", str(e)) return jsonify({"error": "服务器内部错误"}), 500 if __name__ == "__main__": # 本地测试用,Hugging Face Spaces 通常通过 gunicorn 启动 app.run(host='0.0.0.0', port=7860)