File size: 4,432 Bytes
4930d77
 
c4c046d
 
 
9a44a33
c4c046d
4930d77
 
 
c4c046d
 
 
 
 
 
 
 
 
9a44a33
c4c046d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a44a33
9ad727f
 
 
9a44a33
 
 
 
 
 
 
c4c046d
 
 
 
 
 
 
9ad727f
 
 
c4c046d
 
9ad727f
c4c046d
 
 
 
 
9a44a33
c4c046d
 
 
 
 
9ad727f
c4c046d
 
 
9ad727f
c4c046d
46ca356
9ad727f
 
9a44a33
9ad727f
c4c046d
46ca356
9ad727f
 
 
c4c046d
46ca356
9a44a33
 
 
 
46ca356
9a44a33
 
 
46ca356
9a44a33
 
9ad727f
 
83b518f
9ad727f
c4c046d
 
9ad727f
c4c046d
9ad727f
c4c046d
9ad727f
 
c4c046d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
local sched, err = require("cgemma").scheduler(config().scheduler)
if not sched then
  ngx.log(ngx.ERR, "cgemma error: ", err)
end

local function gemma_inst()
  if not worker_gemma_inst then
    local gemma_cfg = config().gemma
    gemma_cfg.scheduler = sched
    local gemma, err = require("cgemma").new(gemma_cfg)
    if not gemma then
      ngx.log(ngx.ERR, "cgemma error: ", err)
      ngx.exit(ngx.HTTP_INTERNAL_SERVER_ERROR)
    end
    worker_gemma_inst = gemma
  end
  return worker_gemma_inst
end

local function embed_image(image)
  if not image then
    return nil, "image missing"
  end
  local img_buf = ngx.decode_base64(image)
  if not img_buf then
    return nil, "invalid image format"
  end
  local img = require("vips").Image.new_from_buffer(img_buf)
  if not img then
    return nil, "invalid image format"
  end
  img = img:resize(config().vlm.resize_to / img:width(), {vscale = config().vlm.resize_to / img:height(), kernel = "linear"})
  local ppm = require("vips").Target.new_to_memory()
  img:write_to_target(ppm, ".ppm")
  return gemma_inst():embed_image(ppm:vobject():get("blob"))
end

local function send_resp(ws, data)
  return ws:send_text(require("cjson.safe").encode(data))
end

local function prefill_fn(_, pos, prompt_size)
  if pos >= prompt_size then
    return false
  end
  return true
end

local prompts = {
  en = [[<start_of_turn>user
You are a talented poet who is very good at writing poems in the style of %s. Your task is to create a poem based on the content of this picture. Please only reply with the poem you created, and do not reply with any description, explanation, or other content.]],
  zh = [[<start_of_turn>user
你是一名才华横溢的诗人,非常擅长创作%s风格的诗歌。你的任务是根据这张图片中的内容创作一首诗歌,请只回复你创作的诗歌,不要回复任何描述、解释或其它内容。]]
}

function gemma_loop(ws)
  local session = assert(gemma_inst():session(config().session))
  assert(send_resp(ws, {op = "init", vlm = config().vlm}))
  local function stream_fn(token, pos, prompt_size)
    if pos >= prompt_size then
      if not send_resp(ws, {op = "stream", token = token}) then
        return false
      end
    end
    return true
  end
  local image, retry_state
  while session:ready() do
    local data, tp, err = ws:recv_frame()
    while err == "again" do
      local frag, ct
      frag, ct, err = ws:recv_frame()
      assert(ct == "continuation", err)
      data = data..frag
    end
    if tp == "text" then
      local msg = assert(require("cjson.safe").decode(data))
      if msg.op == "create" then
        assert(msg.lang and msg.style, "create error: both `lang` and `style` MUST be set")
        local prompt = assert(prompts[msg.lang], string.format("create error: invalid `lang` %s", tostring(msg.lang)))
        assert(send_resp(ws, {op = "status", id = 0}))
        image = assert(embed_image(msg.image))
        assert(send_resp(ws, {op = "status", id = 1}))
        session:reset()
        assert(session(image, string.format(prompt, msg.style), prefill_fn))
        retry_state = assert(session:dumps())
        assert(send_resp(ws, {op = "status", id = 2}))
        assert(session("<end_of_turn>\n<start_of_turn>model\n", stream_fn))
      elseif msg.op == "retry" then
        if not msg.lang and not msg.style then
          assert(retry_state, "retry error: there is no existing request")
          assert(session:loads(retry_state))
        else
          assert(image, "retry error: there is no existing image")
          assert(msg.lang and msg.style, "retry error: both `lang` and `style` MUST be set")
          local prompt = assert(prompts[msg.lang], string.format("retry error: invalid `lang` %s", tostring(msg.lang)))
          assert(send_resp(ws, {op = "status", id = 1}))
          session:reset()
          assert(session(image, string.format(prompt, msg.style), prefill_fn))
          retry_state = assert(session:dumps())
        end
        assert(send_resp(ws, {op = "status", id = 2}))
        assert(session("<end_of_turn>\n<start_of_turn>model\n", stream_fn))
      elseif msg.op ~= "keepalive" then
        error(string.format("unknown protocol: %s", tostring(msg.op)))
      end
    elseif tp == "ping" then
      assert(wb:send_pong())
    elseif tp == "close" then
      break
    elseif tp ~= "pong" then
      assert(not err, err)
      break
    end
  end
end