File size: 3,291 Bytes
ac8da7f
1c67fcb
ac8da7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c67fcb
 
ac8da7f
 
 
 
 
 
 
 
 
 
 
 
 
1c67fcb
 
 
ac8da7f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import itertools
import gradio as gr
import requests
import os
from gradio.themes.utils import sizes
import json
import pandas as pd

import base64
import io
from PIL import Image
import numpy as np
    
def respond(message, history):

    if len(message.strip()) == 0:
        return "質問を入力してください"

    local_token = os.getenv('API_TOKEN')
    local_endpoint = os.getenv('API_ENDPOINT')

    if local_token is None or local_endpoint is None:
        return "ERROR missing env variables"

    # Add your API token to the headers
    headers = {
        'Content-Type': 'application/json',
        'Authorization': f'Bearer {local_token}'
    }

    #prompt = list(itertools.chain.from_iterable(history))
    #prompt.append(message)

    # プロンプトの作成
    prompt = pd.DataFrame(
        {"prompt": [message], "num_inference_steps": 25}
    )

    print(prompt)
    ds_dict = {"dataframe_split": prompt.to_dict(orient="split")}
    data_json = json.dumps(ds_dict, allow_nan=True)

    embed_image_markdown = ""

    try:
        # モデルサービングエンドポイントに問い合わせ
        response = requests.request(method="POST", headers=headers, url=local_endpoint, data=data_json)    
        response_data = response.json()
        #print(response_data["predictions"])

        # numpy arrayに変換
        im_array = np.array(response_data["predictions"], dtype=np.uint8)
        #print(im_array)
        # 画像に変換
        im = Image.fromarray(im_array, 'RGB')

    
        # debug
        #image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/e/ec/Mona_Lisa,_by_Leonardo_da_Vinci,_from_C2RMF_retouched.jpg/687px-Mona_Lisa,_by_Leonardo_da_Vinci,_from_C2RMF_retouched.jpg"
        #print("image_url:", image_url)
        #im = Image.open(io.BytesIO(requests.get(image_url).content))
        #numpydata = np.asarray(im)

        rawBytes = io.BytesIO()
        im.save(rawBytes, "PNG")
        rawBytes.seek(0)  # ファイルの先頭に移動
        # base64にエンコード
        image_encoded = base64.b64encode(rawBytes.read()).decode('ascii')
        #print(image_encoded)

        # マークダウンに埋め込み
        embed_image_markdown = f"![](data:image/png;base64,{image_encoded})"
        #print(embed_image_markdown)
    
    except Exception as error:
        response_data = f"ERROR status_code: {type(error).__name__}" 
        #+ str(response.status_code) + " response:" + response.text    

    return embed_image_markdown


theme = gr.themes.Soft(
    text_size=sizes.text_sm,radius_size=sizes.radius_sm, spacing_size=sizes.spacing_sm,
)


demo = gr.ChatInterface(
    respond,
    chatbot=gr.Chatbot(show_label=False, container=False, show_copy_button=True, bubble_full_width=True),
    textbox=gr.Textbox(placeholder="質問を入力してください",
                       container=False, scale=7),
    title="Databricks QAチャットボット",
    description="TBD",
    examples=[["Databricksクラスターとは?"],
              ["Unity Catalogの有効化方法"],
              ["リネージの保持期間"],],
    cache_examples=False,
    theme=theme,
    retry_btn=None,
    undo_btn=None,
    clear_btn="Clear",
)

if __name__ == "__main__":
    demo.launch()