File size: 3,682 Bytes
2c26f27
 
 
 
 
cd2e7d2
d371780
17f49aa
 
d46d521
17f49aa
6b77474
ed40313
2c26f27
 
 
86fdd9e
2c26f27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68dc151
650b85f
1d1e105
 
 
 
 
 
 
3fddafc
 
d9fde08
650b85f
 
 
 
 
 
 
 
 
 
 
a556863
047c267
650b85f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3fddafc
2c26f27
 
 
 
 
 
 
 
 
 
 
 
cd2e7d2
045f9c6
2c26f27
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import itertools
import gradio as gr
import requests
import os
from gradio.themes.utils import sizes
import json
import pandas as pd

import base64
import io
from PIL import Image
import numpy as np
    
def respond(message, history):

    if len(message.strip()) == 0:
        return "指示を入力してください"

    local_token = os.getenv('API_TOKEN')
    local_endpoint = os.getenv('API_ENDPOINT')

    if local_token is None or local_endpoint is None:
        return "ERROR missing env variables"

    # Add your API token to the headers
    headers = {
        'Content-Type': 'application/json',
        'Authorization': f'Bearer {local_token}'
    }

    #prompt = list(itertools.chain.from_iterable(history))
    #prompt.append(message)

    # プロンプトの作成
    prompt = pd.DataFrame(
        {"prompt": [message], "num_inference_steps": 25}
    )

    print(prompt)
    ds_dict = {"dataframe_split": prompt.to_dict(orient="split")}
    data_json = json.dumps(ds_dict, allow_nan=True)

    embed_image_markdown = ""

    try:
        # モデルサービングエンドポイントに問い合わせ
        response = requests.request(method="POST", headers=headers, url=local_endpoint, data=data_json)    
        response_data = response.json()
        #print(response_data["predictions"])

        # numpy arrayに変換
        im_array = np.array(response_data["predictions"], dtype=np.uint8)
        #print(im_array)
        # 画像に変換
        im = Image.fromarray(im_array, 'RGB')

    
        # debug
        #image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/e/ec/Mona_Lisa,_by_Leonardo_da_Vinci,_from_C2RMF_retouched.jpg/687px-Mona_Lisa,_by_Leonardo_da_Vinci,_from_C2RMF_retouched.jpg"
        #print("image_url:", image_url)
        #im = Image.open(io.BytesIO(requests.get(image_url).content))
        #numpydata = np.asarray(im)

        rawBytes = io.BytesIO()
        im.save(rawBytes, "PNG")
        rawBytes.seek(0)  # ファイルの先頭に移動
        # base64にエンコード
        image_encoded = base64.b64encode(rawBytes.read()).decode('ascii')
        #print(image_encoded)

        # マークダウンに埋め込み
        embed_image_markdown = f"![](data:image/png;base64,{image_encoded})"
        #print(embed_image_markdown)
    
    except Exception as error:
        response_data = f"ERROR status_code: {type(error).__name__}" 
        #+ str(response.status_code) + " response:" + response.text    

    return embed_image_markdown


theme = gr.themes.Soft(
    text_size=sizes.text_sm,radius_size=sizes.radius_sm, spacing_size=sizes.spacing_sm,
)


demo = gr.ChatInterface(
    respond,
    chatbot=gr.Chatbot(show_label=False, container=False, show_copy_button=True, bubble_full_width=True),
    textbox=gr.Textbox(placeholder="生成する画像を指示",
                       container=False, scale=7),
    title="Databricks画像生成デモ - モデルサービングエンドポイントによるパーソナライズ画像の生成",
    description="[Databricksにおける生成AIを用いたブランドに沿う画像の生成](https://qiita.com/taka_yayoi/items/8d3473847d9ccc8ca00c)<br>**ファインチューニングに用いた画像**<br>![](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/1168882/d4e3422b-107e-1bd4-ff84-28e0ea0ac156.png)",
    examples=[["A photo of an orange bcnchr chair"],
              ["A photo of an blue hsmnchr chair"],
              ["A photo of an red rckchr chair"],],
    cache_examples=False,
    theme=theme,
    retry_btn=None,
    undo_btn=None,
    clear_btn="Clear",
)

if __name__ == "__main__":
    demo.launch()