codesearchBase / app.py
Forrest99's picture
Update app.py
edbccc3 verified
raw
history blame
20.3 kB
from flask import Flask, request, jsonify, render_template_string
from sentence_transformers import SentenceTransformer, util
import logging
import sys
import signal
# 初始化 Flask 应用
app = Flask(__name__)
# 配置日志,级别设为 INFO
logging.basicConfig(level=logging.INFO)
app.logger = logging.getLogger("CodeSearchAPI")
# 预定义代码片段
CODE_SNIPPETS = [
"console.log('Hello, World!')",
"const sum = (a, b) => a + b",
"const randomNum = Math.random()",
"const isEven = num => num % 2 === 0",
"const strLength = str => str.length",
"const currentDate = new Date().toLocaleDateString()",
"const fs = require('fs'); const fileExists = path => fs.existsSync(path)",
"const readFile = path => fs.readFileSync(path, 'utf8')",
"const writeFile = (path, content) => fs.writeFileSync(path, content)",
"const currentTime = new Date().toLocaleTimeString()",
"const toUpperCase = str => str.toUpperCase()",
"const toLowerCase = str => str.toLowerCase()",
"const reverseStr = str => str.split('').reverse().join('')",
"const countElements = list => list.length",
"const maxInList = list => Math.max(...list)",
"const minInList = list => Math.min(...list)",
"const sortList = list => list.sort()",
"const mergeLists = (list1, list2) => list1.concat(list2)",
"const removeElement = (list, element) => list.filter(e => e !== element)",
"const isListEmpty = list => list.length === 0",
"const countChar = (str, char) => str.split(char).length - 1",
"const containsSubstring = (str, substring) => str.includes(substring)",
"const numToString = num => num.toString()",
"const strToNum = str => Number(str)",
"const isNumeric = str => !isNaN(str)",
"const getIndex = (list, element) => list.indexOf(element)",
"const clearList = list => list.length = 0",
"const reverseList = list => list.reverse()",
"const removeDuplicates = list => [...new Set(list)]",
"const isInList = (list, value) => list.includes(value)",
"const createDict = () => ({})",
"const addToDict = (dict, key, value) => dict[key] = value",
"const removeKey = (dict, key) => delete dict[key]",
"const getDictKeys = dict => Object.keys(dict)",
"const getDictValues = dict => Object.values(dict)",
"const mergeDicts = (dict1, dict2) => ({ ...dict1, ...dict2 })",
"const isDictEmpty = dict => Object.keys(dict).length === 0",
"const getDictValue = (dict, key) => dict[key]",
"const keyExists = (dict, key) => key in dict",
"const clearDict = dict => Object.keys(dict).forEach(key => delete dict[key])",
"const countFileLines = path => fs.readFileSync(path, 'utf8').split('\n').length",
"const writeListToFile = (path, list) => fs.writeFileSync(path, list.join('\n'))",
"const readListFromFile = path => fs.readFileSync(path, 'utf8').split('\n')",
"const countFileWords = path => fs.readFileSync(path, 'utf8').split(/\s+/).length",
"const isLeapYear = year => (year % 4 === 0 && year % 100 !== 0) || year % 400 === 0",
"const formatTime = (date, format) => date.toLocaleTimeString('en-US', format)",
"const daysBetween = (date1, date2) => Math.abs(date1 - date2) / (1000 * 60 * 60 * 24)",
"const currentDir = process.cwd()",
"const listFiles = path => fs.readdirSync(path)",
"const createDir = path => fs.mkdirSync(path)",
"const removeDir = path => fs.rmdirSync(path)",
"const isFile = path => fs.statSync(path).isFile()",
"const isDirectory = path => fs.statSync(path).isDirectory()",
"const getFileSize = path => fs.statSync(path).size",
"const renameFile = (oldPath, newPath) => fs.renameSync(oldPath, newPath)",
"const copyFile = (src, dest) => fs.copyFileSync(src, dest)",
"const moveFile = (src, dest) => fs.renameSync(src, dest)",
"const deleteFile = path => fs.unlinkSync(path)",
"const getEnvVar = key => process.env[key]",
"const setEnvVar = (key, value) => process.env[key] = value",
"const openLink = url => require('open')(url)",
"const sendGetRequest = async url => await (await fetch(url)).text()",
"const parseJSON = json => JSON.parse(json)",
"const writeJSON = (path, data) => fs.writeFileSync(path, JSON.stringify(data))",
"const readJSON = path => JSON.parse(fs.readFileSync(path, 'utf8'))",
"const listToString = list => list.join('')",
"const stringToList = str => str.split('')",
"const joinWithComma = list => list.join(',')",
"const joinWithNewline = list => list.join('\n')",
"const splitBySpace = str => str.split(' ')",
"const splitByChar = (str, char) => str.split(char)",
"const splitToChars = str => str.split('')",
"const replaceInStr = (str, old, newStr) => str.replace(old, newStr)",
"const removeSpaces = str => str.replace(/\s+/g, '')",
"const removePunctuation = str => str.replace(/[^\w\s]/g, '')",
"const isStrEmpty = str => str.length === 0",
"const isPalindrome = str => str === str.split('').reverse().join('')",
"const writeCSV = (path, data) => fs.writeFileSync(path, data.map(row => row.join(',')).join('\n'))",
"const readCSV = path => fs.readFileSync(path, 'utf8').split('\n').map(row => row.split(','))",
"const countCSVLines = path => fs.readFileSync(path, 'utf8').split('\n').length",
"const shuffleList = list => list.sort(() => Math.random() - 0.5)",
"const randomElement = list => list[Math.floor(Math.random() * list.length)]",
"const randomElements = (list, count) => list.sort(() => Math.random() - 0.5).slice(0, count)",
"const rollDice = () => Math.floor(Math.random() * 6) + 1",
"const flipCoin = () => Math.random() < 0.5 ? 'Heads' : 'Tails'",
"const randomPassword = length => Array.from({ length }, () => Math.random().toString(36).charAt(2)).join('')",
"const randomColor = () => `#${Math.floor(Math.random() * 16777215).toString(16)}`",
"const uniqueID = () => Math.random().toString(36).substring(2) + Date.now().toString(36)",
"""class MyClass {
constructor() {}
}""",
"const myInstance = new MyClass()",
"""class MyClass {
myMethod() {}
}""",
"""class MyClass {
constructor() {
this.myProp = 'value'
}
}""",
"""class ChildClass extends MyClass {
constructor() {
super()
}
}""",
"""class ChildClass extends MyClass {
myMethod() {
super.myMethod()
}
}""",
"const instance = new MyClass(); instance.myMethod()",
"""class MyClass {
static myStaticMethod() {}
}""",
"const typeOf = obj => typeof obj",
"const getProp = (obj, prop) => obj[prop]",
"const setProp = (obj, prop, value) => obj[prop] = value",
"const deleteProp = (obj, prop) => delete obj[prop]",
"try{foo();}catch(e){}",
"throw new Error('CustomError')",
"""try{foo();}catch(e){const info=e.message;}""",
"console.error(err)",
"const timer={start(){this.s=Date.now()},stop(){return Date.now()-this.s}}",
"const runtime=(s)=>Date.now()-s",
"""const progress=(i,n)=>process.stdout.write(Math.floor(i/n*100)+'%\r')""",
"const delay=(ms)=>new Promise(r=>setTimeout(r,ms))",
"const f=(x)=>x*2",
"const m=arr.map(x=>x*2)",
"const f2=arr.filter(x=>x>0)",
"const r=arr.reduce((a,x)=>a+x,0)",
"const a=\[1,2,3].map(x=>x)",
"const o={a:1,b:2};const d={k\:v for(\[k,v] of Object.entries(o))}",
"const s=new Set(\[1,2,3]);const p=new Set(x for(x of s))",
"const inter=new Set(\[...a].filter(x=>b.has(x)))",
"const uni=new Set(\[...a,...b])",
"const diff=new Set(\[...a].filter(x=>!b.has(x)))",
"const noNone=list.filter(x=>x!=null)",
"""try{fs.openSync(path)}catch{}""",
"typeof x==='string'",
"const b=!!str",
"if(cond)doSomething()",
"while(cond){}",
"for(const x of arr){}",
"for(const k in obj){}",
"for(const c of str){}",
"for(...){if(cond)break}",
"for(...){if(cond)continue}",
"function fn(){}",
"function fn(a=1){}",
"function fn(){return \[1,2]}",
"function fn(...a){}",
"function fn(kwargs){const{a,b}=kwargs}",
"""function timed(fn){return(...a)=>{const s=Date.now();const r=fn(...a);console.log(Date.now()-s);return r}}""",
"""const deco=fn=>(...a)=>fn(...a)""",
"""const memo=fn=>{const c={};return x=>c\[x]||=(fn(x))}""",
"function*gen(){yield 1;yield 2}",
"const g=gen();",
"const it={i:0,next(){return this.i<2?{value\:this.i++,done\:false}:{done\:true}}}",
"for(const x of it){}",
"for(const \[i,x] of arr.entries()){}",
"const z=arr1.map((v,i)=>\[v,arr2\[i]])",
"const dict=Object.fromEntries(arr1.map((v,i)=>\[v,arr2\[i]]))",
"JSON.stringify(arr1)===JSON.stringify(arr2)",
"JSON.stringify(obj1)===JSON.stringify(obj2)",
"JSON.stringify(new Set(a))===JSON.stringify(new Set(b))",
"const uniq=\[...new Set(arr)]",
"set.clear()",
"set.size===0",
"set.add(x)",
"set.delete(x)",
"set.has(x)",
"set.size",
"const hasInt=(\[...a].some(x=>b.has(x)))",
"arr1.every(x=>arr2.includes(x))",
"str.includes(sub)",
"str\[0]",
"str\[str.length-1]",
"""const isText=path=>\['.txt','.md'].includes(require('path').extname(path))""",
"""const isImage=path=>\['.png','.jpg','.jpeg','.gif'].includes(require('path').extname(path))""",
"Math.round(n)",
"Math.ceil(n)",
"Math.floor(n)",
"n.toFixed(2)",
"""const randStr=(l)=>\[...Array(l)].map(()=>Math.random().toString(36).charAt(2)).join('')""",
"const exists=require('fs').existsSync(path)",
"""const walk=(d)=>require('fs').readdirSync(d).flatMap(f=>{const p=require('path').join(d,f);return require('fs').statSync(p).isDirectory()?walk(p)\:p})""",
"""const ext=require('path').extname(fp)""",
"""const name=require('path').basename(fp)""",
"""const full=require('path').resolve(fp)""",
"process.version",
"process.platform",
"require('os').cpus().length",
"require('os').totalmem()",
"""const d=require('os').diskUsageSync?require('os').diskUsageSync('/')\:null""",
"require('os').networkInterfaces()",
"""require('dns').resolve('[www.google.com',e=>console.log(!e)](http://www.google.com',e=>console.log%28!e%29))""",
"""require('https').get(url,res=>res.pipe(require('fs').createWriteStream(dest)))""",
"""const upload=async f=>Promise.resolve('ok')""",
"""require('https').request({method:'POST',host,u\:path},()=>{}).end(data)""",
"""require('https').get(url+'?'+new URLSearchParams(params),res=>{})""",
"""const req=()=>fetch(url,{headers})""",
"""const jsdom=require('jsdom');const d=new jsdom.JSDOM(html)""",
"""const title=jsdom.JSDOM(html).window\.document.querySelector('title').textContent""",
"""const links=\[...d.window\.document.querySelectorAll('a')].map(a=>a.href)""",
"""Promise.all(links.map(u=>fetch(u).then(r=>r.blob()).then(b=>require('fs').writeFileSync(require('path').basename(u),Buffer.from(b)))))""",
"""const freq=html.split(/\W+/).reduce((c,w)=>{c\[w]=(c\[w]||0)+1;return c},{})""",
"""const login=()=>fetch(url,{method:'POST',body\:creds})""",
"""const text=html.replace(/<\[^>]+>/g,'')""",
"""const emails=html.match(/\[\w\.-]+@\[\w\.-]+/g)""",
"""const phones=html.match(/\\+?\d\[\d -]{7,}\d/g)""",
"""const nums=html.match(/\d+/g)""",
"""const newHtml=html.replace(/foo/g,'bar')""",
"""const ok=/^\d{3}\$/.test(str)""",
"""const noTags=html.replace(/<\[^>]\*>/g,'')""",
"""const enc=html.replace(/./g,c=>'\&#'+c.charCodeAt(0)+';')""",
"""const dec=enc.replace(/\&#(\d+);/g,(m,n)=>String.fromCharCode(n))""",
"""const {app,BrowserWindow}=require('electron');app.on('ready',()=>new BrowserWindow().loadURL('about\:blank'))""",
"const button = document.createElement('button'); button.textContent = 'Click Me'; document.body.appendChild(button)",
"button.addEventListener('click', () => alert('Button Clicked!'))",
"const input = document.createElement('input'); input.type = 'text'; document.body.appendChild(input)",
"const inputValue = input.value",
"document.title = 'New Title'",
"window.resizeTo(800, 600)",
"window.moveTo((window.screen.width - window.outerWidth) / 2, (window.screen.height - window.outerHeight) / 2)",
"const menuBar = document.createElement('menu'); document.body.appendChild(menuBar)",
"const select = document.createElement('select'); document.body.appendChild(select)",
"const radio = document.createElement('input'); radio.type = 'radio'; document.body.appendChild(radio)",
"const checkbox = document.createElement('input'); checkbox.type = 'checkbox'; document.body.appendChild(checkbox)",
"const img = document.createElement('img'); img.src = 'image.png'; document.body.appendChild(img)",
"const audio = new Audio('audio.mp3'); audio.play()",
"const video = document.createElement('video'); video.src = 'video.mp4'; document.body.appendChild(video); video.play()",
"const currentTime = audio.currentTime",
"navigator.mediaDevices.getDisplayMedia().then(stream => {})",
"navigator.mediaDevices.getUserMedia({ video: true }).then(stream => {})",
"document.addEventListener('mousemove', (event) => { const x = event.clientX; const y = event.clientY })",
"document.execCommand('insertText', false, 'Hello World')",
"document.elementFromPoint(100, 100).click()",
"const timestamp = Date.now()",
"const date = new Date(timestamp)",
"const timestampFromDate = date.getTime()",
"const dayOfWeek = new Date().getDay()",
"const daysInMonth = new Date(new Date().getFullYear(), new Date().getMonth() + 1, 0).getDate()",
"const firstDayOfYear = new Date(new Date().getFullYear(), 0, 1)",
"const lastDayOfYear = new Date(new Date().getFullYear(), 11, 31)",
"const firstDayOfMonth = new Date(new Date().getFullYear(), new Date().getMonth(), 1)",
"const lastDayOfMonth = new Date(new Date().getFullYear(), new Date().getMonth() + 1, 0)",
"const isWeekday = new Date().getDay() !== 0 && new Date().getDay() !== 6",
"const isWeekend = new Date().getDay() === 0 || new Date().getDay() === 6",
"const currentHour = new Date().getHours()",
"const currentMinute = new Date().getMinutes()",
"const currentSecond = new Date().getSeconds()",
"setTimeout(() => {}, 1000)",
"const millisecondsTimestamp = Date.now()",
"const formattedTime = new Date().toLocaleTimeString()",
"const parsedTime = Date.parse('2023-10-01T00:00:00Z')",
"const worker = new Worker('worker.js')",
"worker.postMessage('pause')",
"new Worker('worker.js').postMessage('start')",
"const threadName = self.name",
"worker.terminate()",
"const lock = new Mutex(); lock.acquire()",
"const process = new Worker('process.js')",
"const pid = process.pid",
"const isAlive = process.terminated === false",
"new Worker('process.js').postMessage('start')",
"const queue = new MessageChannel()",
"const pipe = new MessageChannel()",
"const cpuUsage = performance.now()",
"const output = new Response('ls -la').text()",
"const statusCode = new Response('ls -la').status",
"const isSuccess = new Response('ls -la').ok",
"const scriptPath = import.meta.url",
"const args = process.argv",
"const parser = new ArgumentParser(); parser.parse_args()",
"parser.print_help()",
"Object.keys(require.cache).forEach(module => console.log(module))",
"const { exec } = require('child_process'); exec('pip install package')",
"const { exec } = require('child_process'); exec('pip uninstall package')",
"const packageVersion = require('package').version",
"const { exec } = require('child_process'); exec('python -m venv venv')",
"const { exec } = require('child_process'); exec('pip list')",
"const { exec } = require('child_process'); exec('pip install --upgrade package')",
"const db = require('sqlite3').Database('db.sqlite')",
"db.all('SELECT * FROM table', (err, rows) => {})",
"db.run('INSERT INTO table (column) VALUES (?)', ['value'])",
"db.run('DELETE FROM table WHERE id = ?', [1])",
"db.run('UPDATE table SET column = ? WHERE id = ?', ['new_value', 1])",
"db.all('SELECT * FROM table', (err, rows) => {})",
"db.run('SELECT * FROM table WHERE column = ?', ['value'], (err, row) => {})",
"db.close()",
"db.run('CREATE TABLE table (column TEXT)')",
"db.run('DROP TABLE table')",
"db.get('SELECT name FROM sqlite_master WHERE type = \"table\" AND name = ?', ['table'], (err, row) => {})",
"db.all('SELECT name FROM sqlite_master WHERE type = \"table\"', (err, rows) => {})",
"const { Model } = require('sequelize'); Model.create({ column: 'value' })",
"Model.findAll({ where: { column: 'value' } })",
"Model.destroy({ where: { id: 1 } })",
"Model.update({ column: 'new_value' }, { where: { id: 1 } })",
"class Table extends Model {}",
"class ChildTable extends ParentTable {}",
"Model.init({ id: { type: DataTypes.INTEGER, primaryKey: true } }, { sequelize })",
"Model.init({ column: { type: DataTypes.STRING, unique: true } }, { sequelize })",
"Model.init({ column: { type: DataTypes.STRING, defaultValue: 'default' } }, { sequelize })",
"const csv = require('csv-parser'); fs.createReadStream('data.csv').pipe(csv())",
"const xlsx = require('xlsx'); xlsx.writeFile(data, 'data.xlsx')",
"const json = JSON.stringify(data)",
"const workbook = xlsx.readFile('data.xlsx')",
"const mergedWorkbook = xlsx.utils.book_append_sheet(workbook1, workbook2)",
"xlsx.utils.book_append_sheet(workbook, worksheet, 'New Sheet')",
"const style = workbook.Sheets['Sheet1']['A1'].s",
"const color = workbook.Sheets['Sheet1']['A1'].s.fill.fgColor",
"const font = workbook.Sheets['Sheet1']['A1'].s.font",
"const cellValue = workbook.Sheets['Sheet1']['A1'].v",
"workbook.Sheets['Sheet1']['A1'].v = 'New Value'",
"const { width, height } = require('image-size')('image.png')",
"const sharp = require('sharp'); sharp('image.png').resize(200, 200)"
]
# 全局服务状态
service_ready = False
# 优雅关闭处理
def handle_shutdown(signum, frame):
app.logger.info("收到终止信号,开始关闭...")
sys.exit(0)
signal.signal(signal.SIGTERM, handle_shutdown)
signal.signal(signal.SIGINT, handle_shutdown)
# 初始化模型和预计算编码
try:
app.logger.info("开始加载模型...")
model = SentenceTransformer(
"flax-sentence-embeddings/st-codesearch-distilroberta-base",
cache_folder="/model-cache"
)
# 预计算代码片段的编码(强制使用 CPU)
code_emb = model.encode(CODE_SNIPPETS, convert_to_tensor=True, device="cpu")
service_ready = True
app.logger.info("服务初始化完成")
except Exception as e:
app.logger.error("初始化失败: %s", str(e))
raise
# Hugging Face 健康检查端点,必须响应根路径
@app.route('/')
def hf_health_check():
# 如果请求接受 HTML,则返回一个简单的 HTML 页面(包含测试链接)
if request.accept_mimetypes.accept_html:
html = """
<h2>CodeSearch API</h2>
<p>服务状态:{{ status }}</p>
<p>你可以在地址栏输入 /search?query=你的查询 来测试接口</p>
"""
status = "ready" if service_ready else "initializing"
return render_template_string(html, status=status)
# 否则返回 JSON 格式的健康检查
if service_ready:
return jsonify({"status": "ready"}), 200
else:
return jsonify({"status": "initializing"}), 503
# 搜索 API 端点,同时支持 GET 和 POST 请求
@app.route('/search', methods=['GET', 'POST'])
def handle_search():
if not service_ready:
app.logger.info("服务未就绪")
return jsonify({"error": "服务正在初始化"}), 503
try:
# 根据请求方法提取查询内容
if request.method == 'GET':
query = request.args.get('query', '').strip()
else:
data = request.get_json() or {}
query = data.get('query', '').strip()
if not query:
app.logger.info("收到空的查询请求")
return jsonify({"error": "查询不能为空"}), 400
# 记录接收到的查询
app.logger.info("收到查询请求: %s", query)
# 对查询进行编码,并进行语义搜索
query_emb = model.encode(query, convert_to_tensor=True, device="cpu")
hits = util.semantic_search(query_emb, code_emb, top_k=1)[0]
best = hits[0]
result = {
"code": CODE_SNIPPETS[best['corpus_id']],
"score": round(float(best['score']), 4)
}
# 记录返回结果
app.logger.info("返回结果: %s", result)
return jsonify(result)
except Exception as e:
app.logger.error("请求处理失败: %s", str(e))
return jsonify({"error": "服务器内部错误"}), 500
if __name__ == "__main__":
# 本地测试用,Hugging Face Spaces 通常通过 gunicorn 启动
app.run(host='0.0.0.0', port=7860)