from flask import Flask, request, jsonify, render_template_string from sentence_transformers import SentenceTransformer, util import logging import sys import signal # 初始化 Flask 应用 app = Flask(__name__) # 配置日志,级别设为 INFO logging.basicConfig(level=logging.INFO) app.logger = logging.getLogger("CodeSearchAPI") # 预定义代码片段 CODE_SNIPPETS = [ "puts 'Hello, World!'", "def sum(a, b); a + b; end", "rand", "def even?(num); num.even?; end", "str.length", "Date.today", "File.exist?('file.txt')", "File.read('file.txt')", "File.write('file.txt', 'content')", "Time.now", "str.upcase", "str.downcase", "str.reverse", "list.size", "list.max", "list.min", "list.sort", "list1 + list2", "list.delete(element)", "list.empty?", "str.count(char)", "str.include?(substring)", "num.to_s", "str.to_i", "str.match?(/^\d+$/)", "list.index(element)", "list.clear", "list.reverse", "list.uniq", "list.include?(value)", "{}", "hash[key] = value", "hash.delete(key)", "hash.keys", "hash.values", "hash1.merge(hash2)", "hash.empty?", "hash[key]", "hash.key?(key)", "hash.clear", "File.readlines('file.txt').size", "File.write('file.txt', list.join('\\n'))", "File.read('file.txt').split('\\n')", "File.read('file.txt').split.size", "def leap_year?(year); (year % 400 == 0) || (year % 100 != 0 && year % 4 == 0); end", "Time.now.strftime('%Y-%m-%d %H:%M:%S')", "(Date.today - Date.new(2023, 1, 1)).to_i", "Dir.pwd", "Dir.entries('.')", "Dir.mkdir('new_dir')", "Dir.rmdir('new_dir')", "File.file?('path')", "File.directory?('path')", "File.size('file.txt')", "File.rename('old.txt', 'new.txt')", "FileUtils.cp('source.txt', 'destination.txt')", "FileUtils.mv('source.txt', 'destination.txt')", "File.delete('file.txt')", "ENV['VAR_NAME']", "ENV['VAR_NAME'] = 'value'", "system('open https://example.com')", "require 'net/http'; Net::HTTP.get(URI('https://example.com'))", "require 'json'; JSON.parse(json_string)", "require 'json'; File.write('file.json', JSON.dump(data))", "require 'json'; JSON.parse(File.read('file.json'))", "list.join", "str.split(',')", "list.join(',')", "list.join('\\n')", "str.split", "str.split(delimiter)", "str.chars", "str.gsub(old, new)", "str.gsub(' ', '')", "str.gsub(/[^a-zA-Z0-9]/, '')", "str.empty?", "str == str.reverse", "require 'csv'; CSV.open('file.csv', 'w') { |csv| csv << ['data'] }", "require 'csv'; CSV.read('file.csv')", "require 'csv'; CSV.read('file.csv').size", "list.shuffle", "list.sample", "list.sample(n)", "rand(6) + 1", "rand(2) == 0 ? 'Heads' : 'Tails'", "SecureRandom.alphanumeric(8)", "format('#%06x', rand(0xffffff))", "SecureRandom.uuid", "class MyClass; end", "MyClass.new", "class MyClass; def my_method; end; end", "class MyClass; attr_accessor :my_attr; end", "class ChildClass < ParentClass; end", "class ChildClass < ParentClass; def my_method; super; end; end", "class MyClass; def self.class_method; end; end", "class MyClass; def self.static_method; end; end", "obj.is_a?(Class)", "obj.instance_variable_get(:@attr)", "obj.instance_variable_set(:@attr, value)", "obj.instance_variable_defined?(:@attr)", "begin; risky_operation; rescue => e; puts e; end", """class CustomError < StandardError end raise CustomError, 'error occurred'""", "begin; raise 'oops'; rescue => e; e.message; end", """require 'logger' logger = Logger.new('error.log') logger.error('error occurred')""", "start_time = Time.now", "Time.now - start_time", "20.times { |i| print "\r[#{'='(i+1)}#{' '(19-i)}]"; sleep(0.1) }", "sleep(1)", "square = ->(x) { xx }", "squares = [1,2,3].map { |n| nn }", "evens = [1,2,3,4].select { |n| n.even? }", "sum = [1,2,3].reduce(0) { |acc,n| acc+n }", "doubles = [1,2,3,4,5].map { |n| n2 }", "hash = [1,2,3].map { |n| [n, n2] }.to_h", "require 'set'; s = Set.new([1,2,3].map { |n| n*2 })", "intersection = a & b", "union = a | b", "diff = a - b", "filtered = list.compact", "begin; File.open('file.txt'); rescue; false; end", "x.is_a?(String)", "bool = ['true','1'].include?(str.downcase)", "puts 'yes' if x > 0", "i=0; while i<5; i+=1; end", "for item in [1,2,3]; puts item; end", """h = {a:1, b:2} for k, v in h puts "#{k}:#{v}" end""", """for c in 'hello'.chars puts c end""", """for i in 1..5 break if i==3 puts i end""", """for i in 1..5 next if i==3 puts i end""", "def foo; end", "def foo(a=1); a; end", "def foo; [1,2]; end", "def foo(*args); args; end", "def foo(a:, b:); a+b; end", """def foo end start = Time.now foo puts Time.now - start""", """def decorate(f) ->(args) { puts 'before'; result = f.call(args); puts 'after'; result } end""", """def fib(n, memo={}) return memo[n] if memo[n] memo[n] = n<2 ? n : fib(n-1, memo) + fib(n-2, memo) end""", "gen = Enumerator.new { |y| i=0; loop { y << i; i+=1 } }", "def foo; yield 1; end", "gen.next", "itr = [1,2,3].each", """itr = [1,2,3].each loop do puts itr.next end""", "[1,2].each_with_index { |v, i| puts i, v }", "zipped = [1,2].zip(['a','b'])", "h = [1,2].zip(['a','b']).to_h", "[1,2] == [1,2]", "{a:1, b:2} == {b:2, a:1}", "require 'set'; Set.new([1,2]) == Set.new([2,1])", "unique = [1,2,1].uniq", "s.clear", "s.empty?", "s.add(1)", "s.delete(1)", "s.include?(1)", "s.size", "!(a & b).empty?", "[1,2].all? { |e| [1,2,3].include?(e) }", "'hi'.include?('h')", "str[0]", "str[-1]", "File.extname(path) == '.txt'", "['.png','.jpg','.jpeg','.gif'].include?(File.extname(path))", "x.round", "x.ceil", "x.floor", "sprintf('%.2f', x)", "require 'securerandom'; SecureRandom.alphanumeric(8)", "File.exist?('path')", "Dir['**/'].each { |f| puts f }", "File.extname('path.txt')", "File.basename(path)", "File.expand_path(path)", "RUBY_VERSION", "RUBY_PLATFORM", "require 'etc'; Etc.nprocessors", "mem = grep MemTotal /proc/meminfo", "df = df -h /", """require 'socket' ip = Socket.ip_address_list.detect(&:ipv4_private).ip_address""", "system('ping -c1 8.8.8.8 > /dev/null 2>&1')", """require 'open-uri' File.open('file', 'wb') { |f| f.write open(url).read }""", """def upload(file) puts 'Uploading' end""", """require 'net/http' uri = URI(url) Net::HTTP.post_form(uri, key: 'value')""", """uri = URI(url) uri.query = URI.encode_www_form(params) Net::HTTP.get(uri)""", """require 'net/http' uri = URI(url) req = Net::HTTP::Get.new(uri) req['User-Agent'] = 'Custom' res = Net::HTTP.start(uri.hostname, uri.port) { |http| http.request(req) }""", "require 'nokogiri'; doc = Nokogiri::HTML(html)", "doc.at('title').text", "links = doc.css('a').map { |a| a['href'] }", """doc.css('img').each do |img| open(img['src']).each do |chunk| File.open(File.basename(img['src']), 'ab') { |f| f.write chunk } end end""", """freq = Hash.new(0) text.split.each { |w| freq[w] += 1 }""", """require 'net/http' res = Net::HTTP.post_form(URI(login_url), username: 'u', password: 'p')""", "Nokogiri::HTML(html).text", "emails = text.scan(/\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b/)", "phones = text.scan(/\b\d{3}-\d{3}-\d{4}\b/)", "nums = text.scan(/\d+/)", "new = text.gsub(/foo/, 'bar')", "!!(text =~ /pattern/)", "clean = text.gsub(/<[^>]>/, '')", "CGI.escapeHTML(text)", "CGI.unescapeHTML(text)", """require 'tk' root = TkRoot.new { title 'App' } Tk.mainloop""", "require 'tk'", """root = TkRoot.new button = TkButton.new(root) {text 'Click Me'; command { Tk.messageBox(message: 'Button Clicked!') }} button.pack""", """Tk.messageBox(message: 'Hello, World!')""", """entry = TkEntry.new(root).pack entry.get""", """root.title = 'My Window'""", """root.geometry('400x300')""", """root.geometry('+%d+%d' % [(root.winfo_screenwidth() - root.winfo_reqwidth()) / 2, (root.winfo_screenheight() - root.winfo_reqheight()) / 2])""", """menu = TkMenu.new(root) root['menu'] = menu menu.add('command', 'label' => 'File')""", """combobox = Tk::Tile::Combobox.new(root).pack""", """radio = TkRadioButton.new(root) {text 'Option 1'}.pack""", """check = TkCheckButton.new(root) {text 'Check Me'}.pack""", """image = TkPhotoImage.new(file: 'image.png') label = TkLabel.new(root) {image image}.pack""", """`afplay audio.mp3`""", """`ffplay video.mp4`""", """`ffmpeg -i video.mp4 -f null - 2>&1 | grep 'time=' | awk '{print $2}'`""", """`screencapture screen.png`""", """`ffmpeg -f avfoundation -i "1" -t 10 screen.mp4`""", """`cliclick p:.`""", """`cliclick kd:cmd kp:space ku:cmd`""", """`cliclick c:.`""", """Time.now.to_i""", """Time.at(timestamp).strftime('%Y-%m-%d')""", """Time.parse(date).to_i""", """Time.now.strftime('%A')""", """Time.days_in_month(Time.now.month, Time.now.year)""", """Time.new(Time.now.year, 1, 1)""", """Time.new(Time.now.year, 12, 31)""", """Time.new(year, month, 1)""", """Time.new(year, month, -1)""", """Time.now.wday.between?(1, 5)""", """Time.now.wday.between?(6, 7)""", """Time.now.hour""", """Time.now.min""", """Time.now.sec""", """sleep(1)""", """(Time.now.to_f * 1000).to_i""", """Time.now.strftime('%Y-%m-%d %H:%M:%S')""", """Time.parse(time_str)""", """Thread.new { puts 'Hello from thread' }""", """sleep(1)""", """threads = [] 3.times { threads << Thread.new { puts 'Hello from thread' } } threads.each(&:join)""", """Thread.current.name""", """thread = Thread.new { puts 'Hello from thread' } thread.abort_on_exception = true""", """mutex = Mutex.new mutex.synchronize { puts 'Hello from synchronized thread' }""", """pid = Process.spawn('sleep 5')""", """Process.pid""", """Process.kill(0, pid) rescue false""", """pids = [] 3.times { pids << Process.spawn('sleep 5') } pids.each { |pid| Process.wait(pid) }""", """queue = Queue.new queue.push('Hello') queue.pop""", """reader, writer = IO.pipe writer.puts 'Hello' reader.gets""", """Process.setrlimit(:CPU, 50)""", """`ls`""", """`ls`.chomp""", """$?.exitstatus""", """$?.success?""", """File.expand_path(__FILE__)""", """ARGV""", """require 'optparse' OptionParser.new { |opts| opts.on('-h', '--help', 'Show help') { puts opts } }.parse!""", """OptionParser.new { |opts| opts.on('-h', '--help', 'Show help') { puts opts } }.parse!""", """Gem.loaded_specs.keys""", """`gem install package_name`""", """`gem uninstall package_name`""", """Gem.loaded_specs['package_name'].version.to_s""", """`bundle exec ruby script.rb`""", """Gem::Specification.map(&:name)""", """`gem update package_name`""", """require 'sqlite3' db = SQLite3::Database.new('test.db')""", """db.execute('SELECT * FROM table')""", """db.execute('INSERT INTO table (column) VALUES (?)', 'value')""", """db.execute('DELETE FROM table WHERE id = ?', 1)""", """db.execute('UPDATE table SET column = ? WHERE id = ?', 'new_value', 1)""", """db.execute('SELECT * FROM table').each { |row| puts row }""", """db.execute('SELECT * FROM table WHERE column = ?', 'value')""", """db.close""", """db.execute('CREATE TABLE table (id INTEGER PRIMARY KEY, column TEXT)')""", """db.execute('DROP TABLE table')""", """db.table_info('table').any?""", """db.execute('SELECT name FROM sqlite_master WHERE type = "table"')""", """class Model < ActiveRecord::Base end Model.create(column: 'value')""", """Model.find_by(column: 'value')""", """Model.find_by(column: 'value').destroy""", """Model.find_by(column: 'value').update(column: 'new_value')""", """class Model < ActiveRecord::Base end""", """class ChildModel < ParentModel end""", """class Model < ActiveRecord::Base self.primary_key = 'id' end""", """class Model < ActiveRecord::Base validates_uniqueness_of :column end""", """class Model < ActiveRecord::Base attribute :column, default: 'value' end""", """require 'csv' CSV.open('data.csv', 'w') { |csv| csv << ['column1', 'column2'] }""", """require 'spreadsheet' book = Spreadsheet::Workbook.new sheet = book.create_worksheet sheet[0, 0] = 'Hello' book.write('data.xls')""", """require 'json' File.write('data.json', {key: 'value'}.to_json)""", """require 'spreadsheet' book = Spreadsheet.open('data.xls') sheet = book.worksheet(0) sheet.each { |row| puts row }""", """require 'spreadsheet' book1 = Spreadsheet.open('file1.xls') book2 = Spreadsheet.open('file2.xls') book1.worksheets.each { |sheet| book2.add_worksheet(sheet) } book2.write('merged.xls')""", """require 'spreadsheet' book = Spreadsheet::Workbook.new book.create_worksheet(name: 'New Sheet') book.write('data.xls')""", """require 'spreadsheet' book = Spreadsheet.open('data.xls') sheet = book.worksheet(0) new_sheet = book.create_worksheet new_sheet.format_with(sheet) book.write('data.xls')""", """require 'spreadsheet' book = Spreadsheet.open('data.xls') sheet = book.worksheet(0) sheet.row(0).set_format(0, Spreadsheet::Format.new(color: :red)) book.write('data.xls')""", """require 'spreadsheet' book = Spreadsheet.open('data.xls') sheet = book.worksheet(0) sheet.row(0).set_format(0, Spreadsheet::Format.new(weight: :bold)) book.write('data.xls')""", """require 'spreadsheet' book = Spreadsheet.open('data.xls') sheet = book.worksheet(0) sheet[0, 0]""", """require 'spreadsheet' book = Spreadsheet::Workbook.new sheet = book.create_worksheet sheet[0, 0] = 'Hello' book.write('data.xls')""", """require 'rmagick' image = Magick::Image.read('image.png').first [image.columns, image.rows]""", """require 'rmagick' image = Magick::Image.read('image.png').first image.resize!(100, 100)""" ] # 全局服务状态 service_ready = False # 优雅关闭处理 def handle_shutdown(signum, frame): app.logger.info("收到终止信号,开始关闭...") sys.exit(0) signal.signal(signal.SIGTERM, handle_shutdown) signal.signal(signal.SIGINT, handle_shutdown) # 初始化模型和预计算编码 try: app.logger.info("开始加载模型...") model = SentenceTransformer( "flax-sentence-embeddings/st-codesearch-distilroberta-base", cache_folder="/model-cache" ) # 预计算代码片段的编码(强制使用 CPU) code_emb = model.encode(CODE_SNIPPETS, convert_to_tensor=True, device="cpu") service_ready = True app.logger.info("服务初始化完成") except Exception as e: app.logger.error("初始化失败: %s", str(e)) raise # Hugging Face 健康检查端点,必须响应根路径 @app.route('/') def hf_health_check(): # 如果请求接受 HTML,则返回一个简单的 HTML 页面(包含测试链接) if request.accept_mimetypes.accept_html: html = """
服务状态:{{ status }}
你可以在地址栏输入 /search?query=你的查询 来测试接口
""" status = "ready" if service_ready else "initializing" return render_template_string(html, status=status) # 否则返回 JSON 格式的健康检查 if service_ready: return jsonify({"status": "ready"}), 200 else: return jsonify({"status": "initializing"}), 503 # 搜索 API 端点,同时支持 GET 和 POST 请求 @app.route('/search', methods=['GET', 'POST']) def handle_search(): if not service_ready: app.logger.info("服务未就绪") return jsonify({"error": "服务正在初始化"}), 503 try: # 根据请求方法提取查询内容 if request.method == 'GET': query = request.args.get('query', '').strip() else: data = request.get_json() or {} query = data.get('query', '').strip() if not query: app.logger.info("收到空的查询请求") return jsonify({"error": "查询不能为空"}), 400 # 记录接收到的查询 app.logger.info("收到查询请求: %s", query) # 对查询进行编码,并进行语义搜索 query_emb = model.encode(query, convert_to_tensor=True, device="cpu") hits = util.semantic_search(query_emb, code_emb, top_k=1)[0] best = hits[0] result = { "code": CODE_SNIPPETS[best['corpus_id']], "score": round(float(best['score']), 4) } # 记录返回结果 app.logger.info("返回结果: %s", result) return jsonify(result) except Exception as e: app.logger.error("请求处理失败: %s", str(e)) return jsonify({"error": "服务器内部错误"}), 500 if __name__ == "__main__": # 本地测试用,Hugging Face Spaces 通常通过 gunicorn 启动 app.run(host='0.0.0.0', port=7860)