Spaces:
Sleeping
Sleeping
local sched, err = require("cgemma").scheduler(config().scheduler) | |
if not sched then | |
ngx.log(ngx.ERR, "cgemma error: ", err) | |
end | |
local function gemma_inst() | |
if not worker_gemma_inst then | |
local gemma_cfg = config().gemma | |
gemma_cfg.scheduler = sched | |
local gemma, err = require("cgemma").new(gemma_cfg) | |
if not gemma then | |
ngx.log(ngx.ERR, "cgemma error: ", err) | |
ngx.exit(ngx.HTTP_INTERNAL_SERVER_ERROR) | |
end | |
worker_gemma_inst = gemma | |
end | |
return worker_gemma_inst | |
end | |
local function embed_image(image) | |
if not image then | |
return nil, "image missing" | |
end | |
local img_buf = ngx.decode_base64(image) | |
if not img_buf then | |
return nil, "invalid image format" | |
end | |
local img = require("vips").Image.new_from_buffer(img_buf) | |
if not img then | |
return nil, "invalid image format" | |
end | |
img = img:resize(config().vlm.resize_to / img:width(), {vscale = config().vlm.resize_to / img:height(), kernel = "linear"}) | |
local ppm = require("vips").Target.new_to_memory() | |
img:write_to_target(ppm, ".ppm") | |
return gemma_inst():embed_image(ppm:vobject():get("blob")) | |
end | |
local function send_resp(ws, data) | |
return ws:send_text(require("cjson.safe").encode(data)) | |
end | |
local function prefill_fn(_, pos, prompt_size) | |
if pos >= prompt_size then | |
return false | |
end | |
return true | |
end | |
local prompts = { | |
en = [[<start_of_turn>user | |
You are a talented poet who is very good at writing poems in the style of %s. Your task is to create a poem based on the content of this picture. Please only reply with the poem you created, and do not reply with any description, explanation, or other content.]], | |
zh = [[<start_of_turn>user | |
你是一名才华横溢的诗人,非常擅长创作%s风格的诗歌。你的任务是根据这张图片中的内容创作一首诗歌,请只回复你创作的诗歌,不要回复任何描述、解释或其它内容。]] | |
} | |
function gemma_loop(ws) | |
local session = assert(gemma_inst():session(config().session)) | |
assert(send_resp(ws, {op = "init", vlm = config().vlm})) | |
local function stream_fn(token, pos, prompt_size) | |
if pos >= prompt_size then | |
if not send_resp(ws, {op = "stream", token = token}) then | |
return false | |
end | |
end | |
return true | |
end | |
local image, retry_state | |
while session:ready() do | |
local data, tp, err = ws:recv_frame() | |
while err == "again" do | |
local frag, ct | |
frag, ct, err = ws:recv_frame() | |
assert(ct == "continuation", err) | |
data = data..frag | |
end | |
if tp == "text" then | |
local msg = assert(require("cjson.safe").decode(data)) | |
if msg.op == "create" then | |
assert(msg.lang and msg.style, "create error: both `lang` and `style` MUST be set") | |
local prompt = assert(prompts[msg.lang], string.format("create error: invalid `lang` %s", tostring(msg.lang))) | |
assert(send_resp(ws, {op = "status", id = 0})) | |
image = assert(embed_image(msg.image)) | |
assert(send_resp(ws, {op = "status", id = 1})) | |
session:reset() | |
assert(session(image, string.format(prompt, msg.style), prefill_fn)) | |
retry_state = assert(session:dumps()) | |
assert(send_resp(ws, {op = "status", id = 2})) | |
assert(session("<end_of_turn>\n<start_of_turn>model\n", stream_fn)) | |
elseif msg.op == "retry" then | |
if not msg.lang and not msg.style then | |
assert(retry_state, "retry error: there is no existing request") | |
assert(session:loads(retry_state)) | |
else | |
assert(image, "retry error: there is no existing image") | |
assert(msg.lang and msg.style, "retry error: both `lang` and `style` MUST be set") | |
local prompt = assert(prompts[msg.lang], string.format("retry error: invalid `lang` %s", tostring(msg.lang))) | |
assert(send_resp(ws, {op = "status", id = 1})) | |
session:reset() | |
assert(session(image, string.format(prompt, msg.style), prefill_fn)) | |
retry_state = assert(session:dumps()) | |
end | |
assert(send_resp(ws, {op = "status", id = 2})) | |
assert(session("<end_of_turn>\n<start_of_turn>model\n", stream_fn)) | |
elseif msg.op ~= "keepalive" then | |
error(string.format("unknown protocol: %s", tostring(msg.op))) | |
end | |
elseif tp == "ping" then | |
assert(wb:send_pong()) | |
elseif tp == "close" then | |
break | |
elseif tp ~= "pong" then | |
assert(not err, err) | |
break | |
end | |
end | |
end | |