|
import base64 |
|
import gradio as gr |
|
import json |
|
import mimetypes |
|
import os |
|
import requests |
|
import time |
|
|
|
|
|
MODEL_VERSION = os.environ['MODEL_VERSION'] |
|
API_URL = os.environ['API_URL'] |
|
API_KEY = os.environ['API_KEY'] |
|
SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT') |
|
MULTIMODAL_FLAG = os.environ.get('MULTIMODAL') |
|
MODEL_CONTROL_DEFAULTS = json.loads(os.environ['MODEL_CONTROL_DEFAULTS']) |
|
NAME_MAP = { |
|
'system': os.environ.get('SYSTEM_NAME'), |
|
'user': os.environ.get('USER_NAME'), |
|
} |
|
|
|
|
|
def respond( |
|
message, |
|
history, |
|
max_tokens, |
|
temperature, |
|
top_p, |
|
): |
|
messages = [] |
|
if SYSTEM_PROMPT is not None: |
|
messages.append({ |
|
'role': 'system', |
|
'content': SYSTEM_PROMPT, |
|
}) |
|
for val in history: |
|
messages.append({ |
|
'role': val['role'], |
|
'content': convert_content(val['content']), |
|
}) |
|
messages.append({ |
|
'role': 'user', |
|
'content': convert_content(message), |
|
}) |
|
for message in messages: |
|
add_name_for_message(message) |
|
|
|
data = { |
|
'model': MODEL_VERSION, |
|
'messages': messages, |
|
'stream': True, |
|
'max_tokens': max_tokens, |
|
'temperature': temperature, |
|
'top_p': top_p, |
|
} |
|
r = requests.post( |
|
API_URL, |
|
headers={ |
|
'Content-Type': 'application/json', |
|
'Authorization': 'Bearer {}'.format(API_KEY), |
|
}, |
|
data=json.dumps(data), |
|
stream=True, |
|
) |
|
reply = '' |
|
for row in r.iter_lines(): |
|
if row.startswith(b'data:'): |
|
data = json.loads(row[5:]) |
|
if 'choices' not in data: |
|
raise gr.Error('request failed') |
|
choice = data['choices'][0] |
|
if 'delta' in choice: |
|
reply += choice['delta']['content'] |
|
yield reply |
|
elif 'message' in choice: |
|
yield choice['message']['content'] |
|
|
|
|
|
def add_name_for_message(message): |
|
name = NAME_MAP.get(message['role']) |
|
if name is not None: |
|
message['name'] = name |
|
|
|
|
|
def convert_content(content): |
|
if isinstance(content, str): |
|
return content |
|
if isinstance(content, tuple): |
|
return [{ |
|
'type': 'image_url', |
|
'image_url': { |
|
'url': encode_base64(content[0]), |
|
}, |
|
}] |
|
content_list = [] |
|
for key, val in content.items(): |
|
if key == 'text': |
|
content_list.append({ |
|
'type': 'text', |
|
'text': val, |
|
}) |
|
elif key == 'files': |
|
for f in val: |
|
content_list.append({ |
|
'type': 'image_url', |
|
'image_url': { |
|
'url': encode_base64(f), |
|
}, |
|
}) |
|
return content_list |
|
|
|
|
|
def encode_base64(path): |
|
guess_type = mimetypes.guess_type(path)[0] |
|
if not guess_type.startswith('image/'): |
|
raise gr.Error('not an image ({}): {}'.format(guess_type, path)) |
|
with open(path, 'rb') as handle: |
|
data = handle.read() |
|
return 'data:{};base64,{}'.format( |
|
guess_type, |
|
base64.b64encode(data).decode(), |
|
) |
|
|
|
|
|
demo = gr.ChatInterface( |
|
respond, |
|
multimodal=MULTIMODAL_FLAG == 'ON', |
|
type='messages', |
|
additional_inputs=[ |
|
gr.Slider(minimum=1, maximum=1000000, value=MODEL_CONTROL_DEFAULTS['tokens_to_generate'], step=1, label='Tokens to generate'), |
|
gr.Slider(minimum=0.1, maximum=1.0, value=MODEL_CONTROL_DEFAULTS['temperature'], step=0.05, label='Temperature'), |
|
gr.Slider(minimum=0.1, maximum=1.0, value=MODEL_CONTROL_DEFAULTS['top_p'], step=0.05, label='Top-p (nucleus sampling)'), |
|
], |
|
) |
|
|
|
|
|
if __name__ == '__main__': |
|
demo.queue(default_concurrency_limit=50).launch() |
|
|