from flask import Flask, request, jsonify, render_template_string from sentence_transformers import SentenceTransformer, util import logging import sys import signal # 初始化 Flask 应用 app = Flask(__name__) # 配置日志,级别设为 INFO logging.basicConfig(level=logging.INFO) app.logger = logging.getLogger("CodeSearchAPI") # 预定义代码片段 CODE_SNIPPETS = [ "console.log('Hello, World!')", "const sum = (a, b) => a + b", "const randomNum = Math.random()", "const isEven = num => num % 2 === 0", "const strLength = str => str.length", "const currentDate = new Date().toLocaleDateString()", "const fs = require('fs'); const fileExists = path => fs.existsSync(path)", "const readFile = path => fs.readFileSync(path, 'utf8')", "const writeFile = (path, content) => fs.writeFileSync(path, content)", "const currentTime = new Date().toLocaleTimeString()", "const toUpperCase = str => str.toUpperCase()", "const toLowerCase = str => str.toLowerCase()", "const reverseStr = str => str.split('').reverse().join('')", "const countElements = list => list.length", "const maxInList = list => Math.max(...list)", "const minInList = list => Math.min(...list)", "const sortList = list => list.sort()", "const mergeLists = (list1, list2) => list1.concat(list2)", "const removeElement = (list, element) => list.filter(e => e !== element)", "const isListEmpty = list => list.length === 0", "const countChar = (str, char) => str.split(char).length - 1", "const containsSubstring = (str, substring) => str.includes(substring)", "const numToString = num => num.toString()", "const strToNum = str => Number(str)", "const isNumeric = str => !isNaN(str)", "const getIndex = (list, element) => list.indexOf(element)", "const clearList = list => list.length = 0", "const reverseList = list => list.reverse()", "const removeDuplicates = list => [...new Set(list)]", "const isInList = (list, value) => list.includes(value)", "const createDict = () => ({})", "const addToDict = (dict, key, value) => dict[key] = value", "const removeKey = (dict, key) => delete dict[key]", "const getDictKeys = dict => Object.keys(dict)", "const getDictValues = dict => Object.values(dict)", "const mergeDicts = (dict1, dict2) => ({ ...dict1, ...dict2 })", "const isDictEmpty = dict => Object.keys(dict).length === 0", "const getDictValue = (dict, key) => dict[key]", "const keyExists = (dict, key) => key in dict", "const clearDict = dict => Object.keys(dict).forEach(key => delete dict[key])", "const countFileLines = path => fs.readFileSync(path, 'utf8').split('\n').length", "const writeListToFile = (path, list) => fs.writeFileSync(path, list.join('\n'))", "const readListFromFile = path => fs.readFileSync(path, 'utf8').split('\n')", "const countFileWords = path => fs.readFileSync(path, 'utf8').split(/\s+/).length", "const isLeapYear = year => (year % 4 === 0 && year % 100 !== 0) || year % 400 === 0", "const formatTime = (date, format) => date.toLocaleTimeString('en-US', format)", "const daysBetween = (date1, date2) => Math.abs(date1 - date2) / (1000 * 60 * 60 * 24)", "const currentDir = process.cwd()", "const listFiles = path => fs.readdirSync(path)", "const createDir = path => fs.mkdirSync(path)", "const removeDir = path => fs.rmdirSync(path)", "const isFile = path => fs.statSync(path).isFile()", "const isDirectory = path => fs.statSync(path).isDirectory()", "const getFileSize = path => fs.statSync(path).size", "const renameFile = (oldPath, newPath) => fs.renameSync(oldPath, newPath)", "const copyFile = (src, dest) => fs.copyFileSync(src, dest)", "const moveFile = (src, dest) => fs.renameSync(src, dest)", "const deleteFile = path => fs.unlinkSync(path)", "const getEnvVar = key => process.env[key]", "const setEnvVar = (key, value) => process.env[key] = value", "const openLink = url => require('open')(url)", "const sendGetRequest = async url => await (await fetch(url)).text()", "const parseJSON = json => JSON.parse(json)", "const writeJSON = (path, data) => fs.writeFileSync(path, JSON.stringify(data))", "const readJSON = path => JSON.parse(fs.readFileSync(path, 'utf8'))", "const listToString = list => list.join('')", "const stringToList = str => str.split('')", "const joinWithComma = list => list.join(',')", "const joinWithNewline = list => list.join('\n')", "const splitBySpace = str => str.split(' ')", "const splitByChar = (str, char) => str.split(char)", "const splitToChars = str => str.split('')", "const replaceInStr = (str, old, newStr) => str.replace(old, newStr)", "const removeSpaces = str => str.replace(/\s+/g, '')", "const removePunctuation = str => str.replace(/[^\w\s]/g, '')", "const isStrEmpty = str => str.length === 0", "const isPalindrome = str => str === str.split('').reverse().join('')", "const writeCSV = (path, data) => fs.writeFileSync(path, data.map(row => row.join(',')).join('\n'))", "const readCSV = path => fs.readFileSync(path, 'utf8').split('\n').map(row => row.split(','))", "const countCSVLines = path => fs.readFileSync(path, 'utf8').split('\n').length", "const shuffleList = list => list.sort(() => Math.random() - 0.5)", "const randomElement = list => list[Math.floor(Math.random() * list.length)]", "const randomElements = (list, count) => list.sort(() => Math.random() - 0.5).slice(0, count)", "const rollDice = () => Math.floor(Math.random() * 6) + 1", "const flipCoin = () => Math.random() < 0.5 ? 'Heads' : 'Tails'", "const randomPassword = length => Array.from({ length }, () => Math.random().toString(36).charAt(2)).join('')", "const randomColor = () => `#${Math.floor(Math.random() * 16777215).toString(16)}`", "const uniqueID = () => Math.random().toString(36).substring(2) + Date.now().toString(36)", """class MyClass { constructor() {} }""", "const myInstance = new MyClass()", """class MyClass { myMethod() {} }""", """class MyClass { constructor() { this.myProp = 'value' } }""", """class ChildClass extends MyClass { constructor() { super() } }""", """class ChildClass extends MyClass { myMethod() { super.myMethod() } }""", "const instance = new MyClass(); instance.myMethod()", """class MyClass { static myStaticMethod() {} }""", "const typeOf = obj => typeof obj", "const getProp = (obj, prop) => obj[prop]", "const setProp = (obj, prop, value) => obj[prop] = value", "const deleteProp = (obj, prop) => delete obj[prop]", "try{foo();}catch(e){}", "throw new Error('CustomError')", """try{foo();}catch(e){const info=e.message;}""", "console.error(err)", "const timer={start(){this.s=Date.now()},stop(){return Date.now()-this.s}}", "const runtime=(s)=>Date.now()-s", """const progress=(i,n)=>process.stdout.write(Math.floor(i/n*100)+'%\r')""", "const delay=(ms)=>new Promise(r=>setTimeout(r,ms))", "const f=(x)=>x*2", "const m=arr.map(x=>x*2)", "const f2=arr.filter(x=>x>0)", "const r=arr.reduce((a,x)=>a+x,0)", "const a=\[1,2,3].map(x=>x)", "const o={a:1,b:2};const d={k\:v for(\[k,v] of Object.entries(o))}", "const s=new Set(\[1,2,3]);const p=new Set(x for(x of s))", "const inter=new Set(\[...a].filter(x=>b.has(x)))", "const uni=new Set(\[...a,...b])", "const diff=new Set(\[...a].filter(x=>!b.has(x)))", "const noNone=list.filter(x=>x!=null)", """try{fs.openSync(path)}catch{}""", "typeof x==='string'", "const b=!!str", "if(cond)doSomething()", "while(cond){}", "for(const x of arr){}", "for(const k in obj){}", "for(const c of str){}", "for(...){if(cond)break}", "for(...){if(cond)continue}", "function fn(){}", "function fn(a=1){}", "function fn(){return \[1,2]}", "function fn(...a){}", "function fn(kwargs){const{a,b}=kwargs}", """function timed(fn){return(...a)=>{const s=Date.now();const r=fn(...a);console.log(Date.now()-s);return r}}""", """const deco=fn=>(...a)=>fn(...a)""", """const memo=fn=>{const c={};return x=>c\[x]||=(fn(x))}""", "function*gen(){yield 1;yield 2}", "const g=gen();", "const it={i:0,next(){return this.i<2?{value\:this.i++,done\:false}:{done\:true}}}", "for(const x of it){}", "for(const \[i,x] of arr.entries()){}", "const z=arr1.map((v,i)=>\[v,arr2\[i]])", "const dict=Object.fromEntries(arr1.map((v,i)=>\[v,arr2\[i]]))", "JSON.stringify(arr1)===JSON.stringify(arr2)", "JSON.stringify(obj1)===JSON.stringify(obj2)", "JSON.stringify(new Set(a))===JSON.stringify(new Set(b))", "const uniq=\[...new Set(arr)]", "set.clear()", "set.size===0", "set.add(x)", "set.delete(x)", "set.has(x)", "set.size", "const hasInt=(\[...a].some(x=>b.has(x)))", "arr1.every(x=>arr2.includes(x))", "str.includes(sub)", "str\[0]", "str\[str.length-1]", """const isText=path=>\['.txt','.md'].includes(require('path').extname(path))""", """const isImage=path=>\['.png','.jpg','.jpeg','.gif'].includes(require('path').extname(path))""", "Math.round(n)", "Math.ceil(n)", "Math.floor(n)", "n.toFixed(2)", """const randStr=(l)=>\[...Array(l)].map(()=>Math.random().toString(36).charAt(2)).join('')""", "const exists=require('fs').existsSync(path)", """const walk=(d)=>require('fs').readdirSync(d).flatMap(f=>{const p=require('path').join(d,f);return require('fs').statSync(p).isDirectory()?walk(p)\:p})""", """const ext=require('path').extname(fp)""", """const name=require('path').basename(fp)""", """const full=require('path').resolve(fp)""", "process.version", "process.platform", "require('os').cpus().length", "require('os').totalmem()", """const d=require('os').diskUsageSync?require('os').diskUsageSync('/')\:null""", "require('os').networkInterfaces()", """require('dns').resolve('[www.google.com',e=>console.log(!e)](http://www.google.com',e=>console.log%28!e%29))""", """require('https').get(url,res=>res.pipe(require('fs').createWriteStream(dest)))""", """const upload=async f=>Promise.resolve('ok')""", """require('https').request({method:'POST',host,u\:path},()=>{}).end(data)""", """require('https').get(url+'?'+new URLSearchParams(params),res=>{})""", """const req=()=>fetch(url,{headers})""", """const jsdom=require('jsdom');const d=new jsdom.JSDOM(html)""", """const title=jsdom.JSDOM(html).window\.document.querySelector('title').textContent""", """const links=\[...d.window\.document.querySelectorAll('a')].map(a=>a.href)""", """Promise.all(links.map(u=>fetch(u).then(r=>r.blob()).then(b=>require('fs').writeFileSync(require('path').basename(u),Buffer.from(b)))))""", """const freq=html.split(/\W+/).reduce((c,w)=>{c\[w]=(c\[w]||0)+1;return c},{})""", """const login=()=>fetch(url,{method:'POST',body\:creds})""", """const text=html.replace(/<\[^>]+>/g,'')""", """const emails=html.match(/\[\w\.-]+@\[\w\.-]+/g)""", """const phones=html.match(/\\+?\d\[\d -]{7,}\d/g)""", """const nums=html.match(/\d+/g)""", """const newHtml=html.replace(/foo/g,'bar')""", """const ok=/^\d{3}\$/.test(str)""", """const noTags=html.replace(/<\[^>]\*>/g,'')""", """const enc=html.replace(/./g,c=>'\'+c.charCodeAt(0)+';')""", """const dec=enc.replace(/\(\d+);/g,(m,n)=>String.fromCharCode(n))""", """const {app,BrowserWindow}=require('electron');app.on('ready',()=>new BrowserWindow().loadURL('about\:blank'))""", "const button = document.createElement('button'); button.textContent = 'Click Me'; document.body.appendChild(button)", "button.addEventListener('click', () => alert('Button Clicked!'))", "const input = document.createElement('input'); input.type = 'text'; document.body.appendChild(input)", "const inputValue = input.value", "document.title = 'New Title'", "window.resizeTo(800, 600)", "window.moveTo((window.screen.width - window.outerWidth) / 2, (window.screen.height - window.outerHeight) / 2)", "const menuBar = document.createElement('menu'); document.body.appendChild(menuBar)", "const select = document.createElement('select'); document.body.appendChild(select)", "const radio = document.createElement('input'); radio.type = 'radio'; document.body.appendChild(radio)", "const checkbox = document.createElement('input'); checkbox.type = 'checkbox'; document.body.appendChild(checkbox)", "const img = document.createElement('img'); img.src = 'image.png'; document.body.appendChild(img)", "const audio = new Audio('audio.mp3'); audio.play()", "const video = document.createElement('video'); video.src = 'video.mp4'; document.body.appendChild(video); video.play()", "const currentTime = audio.currentTime", "navigator.mediaDevices.getDisplayMedia().then(stream => {})", "navigator.mediaDevices.getUserMedia({ video: true }).then(stream => {})", "document.addEventListener('mousemove', (event) => { const x = event.clientX; const y = event.clientY })", "document.execCommand('insertText', false, 'Hello World')", "document.elementFromPoint(100, 100).click()", "const timestamp = Date.now()", "const date = new Date(timestamp)", "const timestampFromDate = date.getTime()", "const dayOfWeek = new Date().getDay()", "const daysInMonth = new Date(new Date().getFullYear(), new Date().getMonth() + 1, 0).getDate()", "const firstDayOfYear = new Date(new Date().getFullYear(), 0, 1)", "const lastDayOfYear = new Date(new Date().getFullYear(), 11, 31)", "const firstDayOfMonth = new Date(new Date().getFullYear(), new Date().getMonth(), 1)", "const lastDayOfMonth = new Date(new Date().getFullYear(), new Date().getMonth() + 1, 0)", "const isWeekday = new Date().getDay() !== 0 && new Date().getDay() !== 6", "const isWeekend = new Date().getDay() === 0 || new Date().getDay() === 6", "const currentHour = new Date().getHours()", "const currentMinute = new Date().getMinutes()", "const currentSecond = new Date().getSeconds()", "setTimeout(() => {}, 1000)", "const millisecondsTimestamp = Date.now()", "const formattedTime = new Date().toLocaleTimeString()", "const parsedTime = Date.parse('2023-10-01T00:00:00Z')", "const worker = new Worker('worker.js')", "worker.postMessage('pause')", "new Worker('worker.js').postMessage('start')", "const threadName = self.name", "worker.terminate()", "const lock = new Mutex(); lock.acquire()", "const process = new Worker('process.js')", "const pid = process.pid", "const isAlive = process.terminated === false", "new Worker('process.js').postMessage('start')", "const queue = new MessageChannel()", "const pipe = new MessageChannel()", "const cpuUsage = performance.now()", "const output = new Response('ls -la').text()", "const statusCode = new Response('ls -la').status", "const isSuccess = new Response('ls -la').ok", "const scriptPath = import.meta.url", "const args = process.argv", "const parser = new ArgumentParser(); parser.parse_args()", "parser.print_help()", "Object.keys(require.cache).forEach(module => console.log(module))", "const { exec } = require('child_process'); exec('pip install package')", "const { exec } = require('child_process'); exec('pip uninstall package')", "const packageVersion = require('package').version", "const { exec } = require('child_process'); exec('python -m venv venv')", "const { exec } = require('child_process'); exec('pip list')", "const { exec } = require('child_process'); exec('pip install --upgrade package')", "const db = require('sqlite3').Database('db.sqlite')", "db.all('SELECT * FROM table', (err, rows) => {})", "db.run('INSERT INTO table (column) VALUES (?)', ['value'])", "db.run('DELETE FROM table WHERE id = ?', [1])", "db.run('UPDATE table SET column = ? WHERE id = ?', ['new_value', 1])", "db.all('SELECT * FROM table', (err, rows) => {})", "db.run('SELECT * FROM table WHERE column = ?', ['value'], (err, row) => {})", "db.close()", "db.run('CREATE TABLE table (column TEXT)')", "db.run('DROP TABLE table')", "db.get('SELECT name FROM sqlite_master WHERE type = \"table\" AND name = ?', ['table'], (err, row) => {})", "db.all('SELECT name FROM sqlite_master WHERE type = \"table\"', (err, rows) => {})", "const { Model } = require('sequelize'); Model.create({ column: 'value' })", "Model.findAll({ where: { column: 'value' } })", "Model.destroy({ where: { id: 1 } })", "Model.update({ column: 'new_value' }, { where: { id: 1 } })", "class Table extends Model {}", "class ChildTable extends ParentTable {}", "Model.init({ id: { type: DataTypes.INTEGER, primaryKey: true } }, { sequelize })", "Model.init({ column: { type: DataTypes.STRING, unique: true } }, { sequelize })", "Model.init({ column: { type: DataTypes.STRING, defaultValue: 'default' } }, { sequelize })", "const csv = require('csv-parser'); fs.createReadStream('data.csv').pipe(csv())", "const xlsx = require('xlsx'); xlsx.writeFile(data, 'data.xlsx')", "const json = JSON.stringify(data)", "const workbook = xlsx.readFile('data.xlsx')", "const mergedWorkbook = xlsx.utils.book_append_sheet(workbook1, workbook2)", "xlsx.utils.book_append_sheet(workbook, worksheet, 'New Sheet')", "const style = workbook.Sheets['Sheet1']['A1'].s", "const color = workbook.Sheets['Sheet1']['A1'].s.fill.fgColor", "const font = workbook.Sheets['Sheet1']['A1'].s.font", "const cellValue = workbook.Sheets['Sheet1']['A1'].v", "workbook.Sheets['Sheet1']['A1'].v = 'New Value'", "const { width, height } = require('image-size')('image.png')", "const sharp = require('sharp'); sharp('image.png').resize(200, 200)" ] # 全局服务状态 service_ready = False # 优雅关闭处理 def handle_shutdown(signum, frame): app.logger.info("收到终止信号,开始关闭...") sys.exit(0) signal.signal(signal.SIGTERM, handle_shutdown) signal.signal(signal.SIGINT, handle_shutdown) # 初始化模型和预计算编码 try: app.logger.info("开始加载模型...") model = SentenceTransformer( "flax-sentence-embeddings/st-codesearch-distilroberta-base", cache_folder="/model-cache" ) # 预计算代码片段的编码(强制使用 CPU) code_emb = model.encode(CODE_SNIPPETS, convert_to_tensor=True, device="cpu") service_ready = True app.logger.info("服务初始化完成") except Exception as e: app.logger.error("初始化失败: %s", str(e)) raise # Hugging Face 健康检查端点,必须响应根路径 @app.route('/') def hf_health_check(): # 如果请求接受 HTML,则返回一个简单的 HTML 页面(包含测试链接) if request.accept_mimetypes.accept_html: html = """
服务状态:{{ status }}
你可以在地址栏输入 /search?query=你的查询 来测试接口
""" status = "ready" if service_ready else "initializing" return render_template_string(html, status=status) # 否则返回 JSON 格式的健康检查 if service_ready: return jsonify({"status": "ready"}), 200 else: return jsonify({"status": "initializing"}), 503 # 搜索 API 端点,同时支持 GET 和 POST 请求 @app.route('/search', methods=['GET', 'POST']) def handle_search(): if not service_ready: app.logger.info("服务未就绪") return jsonify({"error": "服务正在初始化"}), 503 try: # 根据请求方法提取查询内容 if request.method == 'GET': query = request.args.get('query', '').strip() else: data = request.get_json() or {} query = data.get('query', '').strip() if not query: app.logger.info("收到空的查询请求") return jsonify({"error": "查询不能为空"}), 400 # 记录接收到的查询 app.logger.info("收到查询请求: %s", query) # 对查询进行编码,并进行语义搜索 query_emb = model.encode(query, convert_to_tensor=True, device="cpu") hits = util.semantic_search(query_emb, code_emb, top_k=1)[0] best = hits[0] result = { "code": CODE_SNIPPETS[best['corpus_id']], "score": round(float(best['score']), 4) } # 记录返回结果 app.logger.info("返回结果: %s", result) return jsonify(result) except Exception as e: app.logger.error("请求处理失败: %s", str(e)) return jsonify({"error": "服务器内部错误"}), 500 if __name__ == "__main__": # 本地测试用,Hugging Face Spaces 通常通过 gunicorn 启动 app.run(host='0.0.0.0', port=7860)