codesearchBase / app.py
Forrest99's picture
Update app.py
9284d6e verified
raw
history blame
16.5 kB
from flask import Flask, request, jsonify, render_template_string
from sentence_transformers import SentenceTransformer, util
import logging
import sys
import signal
# 初始化 Flask 应用
app = Flask(__name__)
# 配置日志,级别设为 INFO
logging.basicConfig(level=logging.INFO)
app.logger = logging.getLogger("CodeSearchAPI")
# 预定义代码片段
CODE_SNIPPETS = [
"puts 'Hello, World!'",
"def sum(a, b); a + b; end",
"rand",
"def even?(num); num.even?; end",
"str.length",
"Date.today",
"File.exist?('file.txt')",
"File.read('file.txt')",
"File.write('file.txt', 'content')",
"Time.now",
"str.upcase",
"str.downcase",
"str.reverse",
"list.size",
"list.max",
"list.min",
"list.sort",
"list1 + list2",
"list.delete(element)",
"list.empty?",
"str.count(char)",
"str.include?(substring)",
"num.to_s",
"str.to_i",
"str.match?(/^\d+$/)",
"list.index(element)",
"list.clear",
"list.reverse",
"list.uniq",
"list.include?(value)",
"{}",
"hash[key] = value",
"hash.delete(key)",
"hash.keys",
"hash.values",
"hash1.merge(hash2)",
"hash.empty?",
"hash[key]",
"hash.key?(key)",
"hash.clear",
"File.readlines('file.txt').size",
"File.write('file.txt', list.join('\\n'))",
"File.read('file.txt').split('\\n')",
"File.read('file.txt').split.size",
"def leap_year?(year); (year % 400 == 0) || (year % 100 != 0 && year % 4 == 0); end",
"Time.now.strftime('%Y-%m-%d %H:%M:%S')",
"(Date.today - Date.new(2023, 1, 1)).to_i",
"Dir.pwd",
"Dir.entries('.')",
"Dir.mkdir('new_dir')",
"Dir.rmdir('new_dir')",
"File.file?('path')",
"File.directory?('path')",
"File.size('file.txt')",
"File.rename('old.txt', 'new.txt')",
"FileUtils.cp('source.txt', 'destination.txt')",
"FileUtils.mv('source.txt', 'destination.txt')",
"File.delete('file.txt')",
"ENV['VAR_NAME']",
"ENV['VAR_NAME'] = 'value'",
"system('open https://example.com')",
"require 'net/http'; Net::HTTP.get(URI('https://example.com'))",
"require 'json'; JSON.parse(json_string)",
"require 'json'; File.write('file.json', JSON.dump(data))",
"require 'json'; JSON.parse(File.read('file.json'))",
"list.join",
"str.split(',')",
"list.join(',')",
"list.join('\\n')",
"str.split",
"str.split(delimiter)",
"str.chars",
"str.gsub(old, new)",
"str.gsub(' ', '')",
"str.gsub(/[^a-zA-Z0-9]/, '')",
"str.empty?",
"str == str.reverse",
"require 'csv'; CSV.open('file.csv', 'w') { |csv| csv << ['data'] }",
"require 'csv'; CSV.read('file.csv')",
"require 'csv'; CSV.read('file.csv').size",
"list.shuffle",
"list.sample",
"list.sample(n)",
"rand(6) + 1",
"rand(2) == 0 ? 'Heads' : 'Tails'",
"SecureRandom.alphanumeric(8)",
"format('#%06x', rand(0xffffff))",
"SecureRandom.uuid",
"class MyClass; end",
"MyClass.new",
"class MyClass; def my_method; end; end",
"class MyClass; attr_accessor :my_attr; end",
"class ChildClass < ParentClass; end",
"class ChildClass < ParentClass; def my_method; super; end; end",
"class MyClass; def self.class_method; end; end",
"class MyClass; def self.static_method; end; end",
"obj.is_a?(Class)",
"obj.instance_variable_get(:@attr)",
"obj.instance_variable_set(:@attr, value)",
"obj.instance_variable_defined?(:@attr)",
"begin; risky_operation; rescue => e; puts e; end",
"""class CustomError < StandardError
end
raise CustomError, 'error occurred'""",
"begin; raise 'oops'; rescue => e; e.message; end",
"""require 'logger'
logger = Logger.new('error.log')
logger.error('error occurred')""",
"start_time = Time.now",
"Time.now - start_time",
"20.times { |i| print "\r[#{'='(i+1)}#{' '(19-i)}]"; sleep(0.1) }",
"sleep(1)",
"square = ->(x) { xx }",
"squares = [1,2,3].map { |n| nn }",
"evens = [1,2,3,4].select { |n| n.even? }",
"sum = [1,2,3].reduce(0) { |acc,n| acc+n }",
"doubles = [1,2,3,4,5].map { |n| n2 }",
"hash = [1,2,3].map { |n| [n, n2] }.to_h",
"require 'set'; s = Set.new([1,2,3].map { |n| n*2 })",
"intersection = a & b",
"union = a | b",
"diff = a - b",
"filtered = list.compact",
"begin; File.open('file.txt'); rescue; false; end",
"x.is_a?(String)",
"bool = ['true','1'].include?(str.downcase)",
"puts 'yes' if x > 0",
"i=0; while i<5; i+=1; end",
"for item in [1,2,3]; puts item; end",
"""h = {a:1, b:2}
for k, v in h
puts "#{k}:#{v}"
end""",
"""for c in 'hello'.chars
puts c
end""",
"""for i in 1..5
break if i==3
puts i
end""",
"""for i in 1..5
next if i==3
puts i
end""",
"def foo; end",
"def foo(a=1); a; end",
"def foo; [1,2]; end",
"def foo(*args); args; end",
"def foo(a:, b:); a+b; end",
"""def foo
end
start = Time.now
foo
puts Time.now - start""",
"""def decorate(f)
->(args) { puts 'before'; result = f.call(args); puts 'after'; result }
end""",
"""def fib(n, memo={})
return memo[n] if memo[n]
memo[n] = n<2 ? n : fib(n-1, memo) + fib(n-2, memo)
end""",
"gen = Enumerator.new { |y| i=0; loop { y << i; i+=1 } }",
"def foo; yield 1; end",
"gen.next",
"itr = [1,2,3].each",
"""itr = [1,2,3].each
loop do
puts itr.next
end""",
"[1,2].each_with_index { |v, i| puts i, v }",
"zipped = [1,2].zip(['a','b'])",
"h = [1,2].zip(['a','b']).to_h",
"[1,2] == [1,2]",
"{a:1, b:2} == {b:2, a:1}",
"require 'set'; Set.new([1,2]) == Set.new([2,1])",
"unique = [1,2,1].uniq",
"s.clear",
"s.empty?",
"s.add(1)",
"s.delete(1)",
"s.include?(1)",
"s.size",
"!(a & b).empty?",
"[1,2].all? { |e| [1,2,3].include?(e) }",
"'hi'.include?('h')",
"str[0]",
"str[-1]",
"File.extname(path) == '.txt'",
"['.png','.jpg','.jpeg','.gif'].include?(File.extname(path))",
"x.round",
"x.ceil",
"x.floor",
"sprintf('%.2f', x)",
"require 'securerandom'; SecureRandom.alphanumeric(8)",
"File.exist?('path')",
"Dir['**/'].each { |f| puts f }",
"File.extname('path.txt')",
"File.basename(path)",
"File.expand_path(path)",
"RUBY_VERSION",
"RUBY_PLATFORM",
"require 'etc'; Etc.nprocessors",
"mem = grep MemTotal /proc/meminfo",
"df = df -h /",
"""require 'socket'
ip = Socket.ip_address_list.detect(&:ipv4_private).ip_address""",
"system('ping -c1 8.8.8.8 > /dev/null 2>&1')",
"""require 'open-uri'
File.open('file', 'wb') { |f| f.write open(url).read }""",
"""def upload(file)
puts 'Uploading'
end""",
"""require 'net/http'
uri = URI(url)
Net::HTTP.post_form(uri, key: 'value')""",
"""uri = URI(url)
uri.query = URI.encode_www_form(params)
Net::HTTP.get(uri)""",
"""require 'net/http'
uri = URI(url)
req = Net::HTTP::Get.new(uri)
req['User-Agent'] = 'Custom'
res = Net::HTTP.start(uri.hostname, uri.port) { |http| http.request(req) }""",
"require 'nokogiri'; doc = Nokogiri::HTML(html)",
"doc.at('title').text",
"links = doc.css('a').map { |a| a['href'] }",
"""doc.css('img').each do |img|
open(img['src']).each do |chunk|
File.open(File.basename(img['src']), 'ab') { |f| f.write chunk }
end
end""",
"""freq = Hash.new(0)
text.split.each { |w| freq[w] += 1 }""",
"""require 'net/http'
res = Net::HTTP.post_form(URI(login_url), username: 'u', password: 'p')""",
"Nokogiri::HTML(html).text",
"emails = text.scan(/\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b/)",
"phones = text.scan(/\b\d{3}-\d{3}-\d{4}\b/)",
"nums = text.scan(/\d+/)",
"new = text.gsub(/foo/, 'bar')",
"!!(text =~ /pattern/)",
"clean = text.gsub(/<[^>]>/, '')",
"CGI.escapeHTML(text)",
"CGI.unescapeHTML(text)",
"""require 'tk'
root = TkRoot.new { title 'App' }
Tk.mainloop""",
"require 'tk'",
"""root = TkRoot.new
button = TkButton.new(root) {text 'Click Me'; command { Tk.messageBox(message: 'Button Clicked!') }}
button.pack""",
"""Tk.messageBox(message: 'Hello, World!')""",
"""entry = TkEntry.new(root).pack
entry.get""",
"""root.title = 'My Window'""",
"""root.geometry('400x300')""",
"""root.geometry('+%d+%d' % [(root.winfo_screenwidth() - root.winfo_reqwidth()) / 2, (root.winfo_screenheight() - root.winfo_reqheight()) / 2])""",
"""menu = TkMenu.new(root)
root['menu'] = menu
menu.add('command', 'label' => 'File')""",
"""combobox = Tk::Tile::Combobox.new(root).pack""",
"""radio = TkRadioButton.new(root) {text 'Option 1'}.pack""",
"""check = TkCheckButton.new(root) {text 'Check Me'}.pack""",
"""image = TkPhotoImage.new(file: 'image.png')
label = TkLabel.new(root) {image image}.pack""",
"""`afplay audio.mp3`""",
"""`ffplay video.mp4`""",
"""`ffmpeg -i video.mp4 -f null - 2>&1 | grep 'time=' | awk '{print $2}'`""",
"""`screencapture screen.png`""",
"""`ffmpeg -f avfoundation -i "1" -t 10 screen.mp4`""",
"""`cliclick p:.`""",
"""`cliclick kd:cmd kp:space ku:cmd`""",
"""`cliclick c:.`""",
"""Time.now.to_i""",
"""Time.at(timestamp).strftime('%Y-%m-%d')""",
"""Time.parse(date).to_i""",
"""Time.now.strftime('%A')""",
"""Time.days_in_month(Time.now.month, Time.now.year)""",
"""Time.new(Time.now.year, 1, 1)""",
"""Time.new(Time.now.year, 12, 31)""",
"""Time.new(year, month, 1)""",
"""Time.new(year, month, -1)""",
"""Time.now.wday.between?(1, 5)""",
"""Time.now.wday.between?(6, 7)""",
"""Time.now.hour""",
"""Time.now.min""",
"""Time.now.sec""",
"""sleep(1)""",
"""(Time.now.to_f * 1000).to_i""",
"""Time.now.strftime('%Y-%m-%d %H:%M:%S')""",
"""Time.parse(time_str)""",
"""Thread.new { puts 'Hello from thread' }""",
"""sleep(1)""",
"""threads = []
3.times { threads << Thread.new { puts 'Hello from thread' } }
threads.each(&:join)""",
"""Thread.current.name""",
"""thread = Thread.new { puts 'Hello from thread' }
thread.abort_on_exception = true""",
"""mutex = Mutex.new
mutex.synchronize { puts 'Hello from synchronized thread' }""",
"""pid = Process.spawn('sleep 5')""",
"""Process.pid""",
"""Process.kill(0, pid) rescue false""",
"""pids = []
3.times { pids << Process.spawn('sleep 5') }
pids.each { |pid| Process.wait(pid) }""",
"""queue = Queue.new
queue.push('Hello')
queue.pop""",
"""reader, writer = IO.pipe
writer.puts 'Hello'
reader.gets""",
"""Process.setrlimit(:CPU, 50)""",
"""`ls`""",
"""`ls`.chomp""",
"""$?.exitstatus""",
"""$?.success?""",
"""File.expand_path(__FILE__)""",
"""ARGV""",
"""require 'optparse'
OptionParser.new { |opts| opts.on('-h', '--help', 'Show help') { puts opts } }.parse!""",
"""OptionParser.new { |opts| opts.on('-h', '--help', 'Show help') { puts opts } }.parse!""",
"""Gem.loaded_specs.keys""",
"""`gem install package_name`""",
"""`gem uninstall package_name`""",
"""Gem.loaded_specs['package_name'].version.to_s""",
"""`bundle exec ruby script.rb`""",
"""Gem::Specification.map(&:name)""",
"""`gem update package_name`""",
"""require 'sqlite3'
db = SQLite3::Database.new('test.db')""",
"""db.execute('SELECT * FROM table')""",
"""db.execute('INSERT INTO table (column) VALUES (?)', 'value')""",
"""db.execute('DELETE FROM table WHERE id = ?', 1)""",
"""db.execute('UPDATE table SET column = ? WHERE id = ?', 'new_value', 1)""",
"""db.execute('SELECT * FROM table').each { |row| puts row }""",
"""db.execute('SELECT * FROM table WHERE column = ?', 'value')""",
"""db.close""",
"""db.execute('CREATE TABLE table (id INTEGER PRIMARY KEY, column TEXT)')""",
"""db.execute('DROP TABLE table')""",
"""db.table_info('table').any?""",
"""db.execute('SELECT name FROM sqlite_master WHERE type = "table"')""",
"""class Model < ActiveRecord::Base
end
Model.create(column: 'value')""",
"""Model.find_by(column: 'value')""",
"""Model.find_by(column: 'value').destroy""",
"""Model.find_by(column: 'value').update(column: 'new_value')""",
"""class Model < ActiveRecord::Base
end""",
"""class ChildModel < ParentModel
end""",
"""class Model < ActiveRecord::Base
self.primary_key = 'id'
end""",
"""class Model < ActiveRecord::Base
validates_uniqueness_of :column
end""",
"""class Model < ActiveRecord::Base
attribute :column, default: 'value'
end""",
"""require 'csv'
CSV.open('data.csv', 'w') { |csv| csv << ['column1', 'column2'] }""",
"""require 'spreadsheet'
book = Spreadsheet::Workbook.new
sheet = book.create_worksheet
sheet[0, 0] = 'Hello'
book.write('data.xls')""",
"""require 'json'
File.write('data.json', {key: 'value'}.to_json)""",
"""require 'spreadsheet'
book = Spreadsheet.open('data.xls')
sheet = book.worksheet(0)
sheet.each { |row| puts row }""",
"""require 'spreadsheet'
book1 = Spreadsheet.open('file1.xls')
book2 = Spreadsheet.open('file2.xls')
book1.worksheets.each { |sheet| book2.add_worksheet(sheet) }
book2.write('merged.xls')""",
"""require 'spreadsheet'
book = Spreadsheet::Workbook.new
book.create_worksheet(name: 'New Sheet')
book.write('data.xls')""",
"""require 'spreadsheet'
book = Spreadsheet.open('data.xls')
sheet = book.worksheet(0)
new_sheet = book.create_worksheet
new_sheet.format_with(sheet)
book.write('data.xls')""",
"""require 'spreadsheet'
book = Spreadsheet.open('data.xls')
sheet = book.worksheet(0)
sheet.row(0).set_format(0, Spreadsheet::Format.new(color: :red))
book.write('data.xls')""",
"""require 'spreadsheet'
book = Spreadsheet.open('data.xls')
sheet = book.worksheet(0)
sheet.row(0).set_format(0, Spreadsheet::Format.new(weight: :bold))
book.write('data.xls')""",
"""require 'spreadsheet'
book = Spreadsheet.open('data.xls')
sheet = book.worksheet(0)
sheet[0, 0]""",
"""require 'spreadsheet'
book = Spreadsheet::Workbook.new
sheet = book.create_worksheet
sheet[0, 0] = 'Hello'
book.write('data.xls')""",
"""require 'rmagick'
image = Magick::Image.read('image.png').first
[image.columns, image.rows]""",
"""require 'rmagick'
image = Magick::Image.read('image.png').first
image.resize!(100, 100)"""
]
# 全局服务状态
service_ready = False
# 优雅关闭处理
def handle_shutdown(signum, frame):
app.logger.info("收到终止信号,开始关闭...")
sys.exit(0)
signal.signal(signal.SIGTERM, handle_shutdown)
signal.signal(signal.SIGINT, handle_shutdown)
# 初始化模型和预计算编码
try:
app.logger.info("开始加载模型...")
model = SentenceTransformer(
"flax-sentence-embeddings/st-codesearch-distilroberta-base",
cache_folder="/model-cache"
)
# 预计算代码片段的编码(强制使用 CPU)
code_emb = model.encode(CODE_SNIPPETS, convert_to_tensor=True, device="cpu")
service_ready = True
app.logger.info("服务初始化完成")
except Exception as e:
app.logger.error("初始化失败: %s", str(e))
raise
# Hugging Face 健康检查端点,必须响应根路径
@app.route('/')
def hf_health_check():
# 如果请求接受 HTML,则返回一个简单的 HTML 页面(包含测试链接)
if request.accept_mimetypes.accept_html:
html = """
<h2>CodeSearch API</h2>
<p>服务状态:{{ status }}</p>
<p>你可以在地址栏输入 /search?query=你的查询 来测试接口</p>
"""
status = "ready" if service_ready else "initializing"
return render_template_string(html, status=status)
# 否则返回 JSON 格式的健康检查
if service_ready:
return jsonify({"status": "ready"}), 200
else:
return jsonify({"status": "initializing"}), 503
# 搜索 API 端点,同时支持 GET 和 POST 请求
@app.route('/search', methods=['GET', 'POST'])
def handle_search():
if not service_ready:
app.logger.info("服务未就绪")
return jsonify({"error": "服务正在初始化"}), 503
try:
# 根据请求方法提取查询内容
if request.method == 'GET':
query = request.args.get('query', '').strip()
else:
data = request.get_json() or {}
query = data.get('query', '').strip()
if not query:
app.logger.info("收到空的查询请求")
return jsonify({"error": "查询不能为空"}), 400
# 记录接收到的查询
app.logger.info("收到查询请求: %s", query)
# 对查询进行编码,并进行语义搜索
query_emb = model.encode(query, convert_to_tensor=True, device="cpu")
hits = util.semantic_search(query_emb, code_emb, top_k=1)[0]
best = hits[0]
result = {
"code": CODE_SNIPPETS[best['corpus_id']],
"score": round(float(best['score']), 4)
}
# 记录返回结果
app.logger.info("返回结果: %s", result)
return jsonify(result)
except Exception as e:
app.logger.error("请求处理失败: %s", str(e))
return jsonify({"error": "服务器内部错误"}), 500
if __name__ == "__main__":
# 本地测试用,Hugging Face Spaces 通常通过 gunicorn 启动
app.run(host='0.0.0.0', port=7860)