pors tricktreat commited on
Commit
abe1f97
Β·
0 Parent(s):

Duplicate from microsoft/HuggingGPT

Browse files

Co-authored-by: Yongliang Shen <[email protected]>

.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ logs/
2
+ logs2/
3
+ models
4
+ public/*
5
+ *.pyc
6
+ !public/examples
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: HuggingGPT
3
+ emoji: 😻
4
+ colorFrom: gray
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 3.24.1
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: microsoft/HuggingGPT
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
14
+ Link to paper https://arxiv.org/abs/2303.17580
app.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ import gradio as gr
3
+ import re
4
+ from diffusers.utils import load_image
5
+ import requests
6
+ from awesome_chat import chat_huggingface
7
+ import os
8
+
9
+ os.makedirs("public/images", exist_ok=True)
10
+ os.makedirs("public/audios", exist_ok=True)
11
+ os.makedirs("public/videos", exist_ok=True)
12
+
13
+ HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN")
14
+ OPENAI_KEY = os.environ.get("OPENAI_KEY")
15
+
16
+ class Client:
17
+ def __init__(self) -> None:
18
+ self.OPENAI_KEY = OPENAI_KEY
19
+ self.HUGGINGFACE_TOKEN = HUGGINGFACE_TOKEN
20
+ self.all_messages = []
21
+
22
+ def set_key(self, openai_key):
23
+ self.OPENAI_KEY = openai_key
24
+ return self.OPENAI_KEY
25
+
26
+ def set_token(self, huggingface_token):
27
+ self.HUGGINGFACE_TOKEN = huggingface_token
28
+ return self.HUGGINGFACE_TOKEN
29
+
30
+ def add_message(self, content, role):
31
+ message = {"role":role, "content":content}
32
+ self.all_messages.append(message)
33
+
34
+ def extract_medias(self, message):
35
+ # url_pattern = re.compile(r"(http(s?):|\/)?([\.\/_\w:-])*?")
36
+ urls = []
37
+ # for match in url_pattern.finditer(message):
38
+ # if match.group(0) not in urls:
39
+ # urls.append(match.group(0))
40
+
41
+ image_pattern = re.compile(r"(http(s?):|\/)?([\.\/_\w:-])*?\.(jpg|jpeg|tiff|gif|png)")
42
+ image_urls = []
43
+ for match in image_pattern.finditer(message):
44
+ if match.group(0) not in image_urls:
45
+ image_urls.append(match.group(0))
46
+
47
+ audio_pattern = re.compile(r"(http(s?):|\/)?([\.\/_\w:-])*?\.(flac|wav)")
48
+ audio_urls = []
49
+ for match in audio_pattern.finditer(message):
50
+ if match.group(0) not in audio_urls:
51
+ audio_urls.append(match.group(0))
52
+
53
+ video_pattern = re.compile(r"(http(s?):|\/)?([\.\/_\w:-])*?\.(mp4)")
54
+ video_urls = []
55
+ for match in video_pattern.finditer(message):
56
+ if match.group(0) not in video_urls:
57
+ video_urls.append(match.group(0))
58
+
59
+ return urls, image_urls, audio_urls, video_urls
60
+
61
+ def add_text(self, messages, message):
62
+ if not self.OPENAI_KEY or not self.OPENAI_KEY.startswith("sk-") or not self.HUGGINGFACE_TOKEN or not self.HUGGINGFACE_TOKEN.startswith("hf_"):
63
+ return messages, "Please set your OpenAI API key and Hugging Face token first!!!"
64
+ self.add_message(message, "user")
65
+ messages = messages + [(message, None)]
66
+ urls, image_urls, audio_urls, video_urls = self.extract_medias(message)
67
+
68
+ for image_url in image_urls:
69
+ if not image_url.startswith("http") and not image_url.startswith("public"):
70
+ image_url = "public/" + image_url
71
+ image = load_image(image_url)
72
+ name = f"public/images/{str(uuid.uuid4())[:4]}.jpg"
73
+ image.save(name)
74
+ messages = messages + [((f"{name}",), None)]
75
+ for audio_url in audio_urls and not audio_url.startswith("public"):
76
+ if not audio_url.startswith("http"):
77
+ audio_url = "public/" + audio_url
78
+ ext = audio_url.split(".")[-1]
79
+ name = f"public/audios/{str(uuid.uuid4()[:4])}.{ext}"
80
+ response = requests.get(audio_url)
81
+ with open(name, "wb") as f:
82
+ f.write(response.content)
83
+ messages = messages + [((f"{name}",), None)]
84
+ for video_url in video_urls and not video_url.startswith("public"):
85
+ if not video_url.startswith("http"):
86
+ video_url = "public/" + video_url
87
+ ext = video_url.split(".")[-1]
88
+ name = f"public/audios/{str(uuid.uuid4()[:4])}.{ext}"
89
+ response = requests.get(video_url)
90
+ with open(name, "wb") as f:
91
+ f.write(response.content)
92
+ messages = messages + [((f"{name}",), None)]
93
+ return messages, ""
94
+
95
+ def bot(self, messages):
96
+ if not self.OPENAI_KEY or not self.OPENAI_KEY.startswith("sk-") or not self.HUGGINGFACE_TOKEN or not self.HUGGINGFACE_TOKEN.startswith("hf_"):
97
+ return messages, {}
98
+ message, results = chat_huggingface(self.all_messages, self.OPENAI_KEY, self.HUGGINGFACE_TOKEN)
99
+ urls, image_urls, audio_urls, video_urls = self.extract_medias(message)
100
+ self.add_message(message, "assistant")
101
+ messages[-1][1] = message
102
+ for image_url in image_urls:
103
+ if not image_url.startswith("http"):
104
+ image_url = image_url.replace("public/", "")
105
+ messages = messages + [((None, (f"public/{image_url}",)))]
106
+ # else:
107
+ # messages = messages + [((None, (f"{image_url}",)))]
108
+ for audio_url in audio_urls:
109
+ if not audio_url.startswith("http"):
110
+ audio_url = audio_url.replace("public/", "")
111
+ messages = messages + [((None, (f"public/{audio_url}",)))]
112
+ # else:
113
+ # messages = messages + [((None, (f"{audio_url}",)))]
114
+ for video_url in video_urls:
115
+ if not video_url.startswith("http"):
116
+ video_url = video_url.replace("public/", "")
117
+ messages = messages + [((None, (f"public/{video_url}",)))]
118
+ # else:
119
+ # messages = messages + [((None, (f"{video_url}",)))]
120
+ # replace int key to string key
121
+ results = {str(k): v for k, v in results.items()}
122
+ return messages, results
123
+
124
+ css = ".json {height: 527px; overflow: scroll;} .json-holder {height: 527px; overflow: scroll;}"
125
+ with gr.Blocks(css=css) as demo:
126
+ state = gr.State(value={"client": Client()})
127
+ gr.Markdown("<h1><center>HuggingGPT</center></h1>")
128
+ gr.Markdown("<p align='center'><img src='https://i.ibb.co/qNH3Jym/logo.png' height='25' width='95'></p>")
129
+ gr.Markdown("<p align='center' style='font-size: 20px;'>A system to connect LLMs with ML community. See our <a href='https://github.com/microsoft/JARVIS'>Project</a> and <a href='http://arxiv.org/abs/2303.17580'>Paper</a>.</p>")
130
+ gr.HTML('''<center><a href="https://huggingface.co/spaces/microsoft/HuggingGPT?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate the Space and run securely with your OpenAI API Key and Hugging Face Token</center>''')
131
+ gr.HTML('''<center>Note: Only a few models are deployed in the local inference endpoint due to hardware limitations. In addition, online HuggingFace inference endpoints may sometimes not be available. Thus the capability of HuggingGPT is limited.</center>''')
132
+ if not OPENAI_KEY:
133
+ with gr.Row().style():
134
+ with gr.Column(scale=0.85):
135
+ openai_api_key = gr.Textbox(
136
+ show_label=False,
137
+ placeholder="Set your OpenAI API key here and press Enter",
138
+ lines=1,
139
+ type="password"
140
+ ).style(container=False)
141
+ with gr.Column(scale=0.15, min_width=0):
142
+ btn1 = gr.Button("Submit").style(full_height=True)
143
+
144
+ if not HUGGINGFACE_TOKEN:
145
+ with gr.Row().style():
146
+ with gr.Column(scale=0.85):
147
+ hugging_face_token = gr.Textbox(
148
+ show_label=False,
149
+ placeholder="Set your Hugging Face Token here and press Enter",
150
+ lines=1,
151
+ type="password"
152
+ ).style(container=False)
153
+ with gr.Column(scale=0.15, min_width=0):
154
+ btn3 = gr.Button("Submit").style(full_height=True)
155
+
156
+
157
+ with gr.Row().style():
158
+ with gr.Column(scale=0.6):
159
+ chatbot = gr.Chatbot([], elem_id="chatbot").style(height=500)
160
+ with gr.Column(scale=0.4):
161
+ results = gr.JSON(elem_classes="json")
162
+
163
+
164
+ with gr.Row().style():
165
+ with gr.Column(scale=0.85):
166
+ txt = gr.Textbox(
167
+ show_label=False,
168
+ placeholder="Enter text and press enter. The url must contain the media type. e.g, https://example.com/example.jpg",
169
+ lines=1,
170
+ ).style(container=False)
171
+ with gr.Column(scale=0.15, min_width=0):
172
+ btn2 = gr.Button("Send").style(full_height=True)
173
+
174
+ def set_key(state, openai_api_key):
175
+ return state["client"].set_key(openai_api_key)
176
+
177
+ def add_text(state, chatbot, txt):
178
+ return state["client"].add_text(chatbot, txt)
179
+
180
+ def set_token(state, hugging_face_token):
181
+ return state["client"].set_token(hugging_face_token)
182
+
183
+ def bot(state, chatbot):
184
+ return state["client"].bot(chatbot)
185
+
186
+ if not OPENAI_KEY:
187
+ openai_api_key.submit(set_key, [state, openai_api_key], [openai_api_key])
188
+ btn1.click(set_key, [state, openai_api_key], [openai_api_key])
189
+
190
+ if not HUGGINGFACE_TOKEN:
191
+ hugging_face_token.submit(set_token, [state, hugging_face_token], [hugging_face_token])
192
+ btn3.click(set_token, [state, hugging_face_token], [hugging_face_token])
193
+
194
+ txt.submit(add_text, [state, chatbot, txt], [chatbot, txt]).then(bot, [state, chatbot], [chatbot, results])
195
+ btn2.click(add_text, [state, chatbot, txt], [chatbot, txt]).then(bot, [state, chatbot], [chatbot, results])
196
+
197
+
198
+ gr.Examples(
199
+ examples=["Given a collection of image A: /examples/a.jpg, B: /examples/b.jpg, C: /examples/c.jpg, please tell me how many zebras in these picture?",
200
+ "Please generate a canny image based on /examples/f.jpg",
201
+ "show me a joke and an image of cat",
202
+ "what is in the examples/a.jpg",
203
+ "based on the /examples/a.jpg, please generate a video and audio",
204
+ "based on pose of /examples/d.jpg and content of /examples/e.jpg, please show me a new image",
205
+ ],
206
+ inputs=txt
207
+ )
208
+
209
+ demo.launch()
awesome_chat.py ADDED
@@ -0,0 +1,933 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import copy
3
+ import datetime
4
+ from io import BytesIO
5
+ import io
6
+ import os
7
+ import random
8
+ import time
9
+ import traceback
10
+ import uuid
11
+ import requests
12
+ import re
13
+ import json
14
+ import logging
15
+ import argparse
16
+ import yaml
17
+ from PIL import Image, ImageDraw
18
+ from diffusers.utils import load_image
19
+ from pydub import AudioSegment
20
+ import threading
21
+ from queue import Queue
22
+ from get_token_ids import get_token_ids_for_task_parsing, get_token_ids_for_choose_model, count_tokens, get_max_context_length
23
+ from huggingface_hub.inference_api import InferenceApi
24
+ from huggingface_hub.inference_api import ALL_TASKS
25
+ from models_server import models, status
26
+ from functools import partial
27
+ from huggingface_hub import Repository
28
+
29
+ parser = argparse.ArgumentParser()
30
+ parser.add_argument("--config", type=str, default="config.yaml.dev")
31
+ parser.add_argument("--mode", type=str, default="cli")
32
+ args = parser.parse_args()
33
+
34
+ if __name__ != "__main__":
35
+ args.config = "config.gradio.yaml"
36
+
37
+ config = yaml.load(open(args.config, "r"), Loader=yaml.FullLoader)
38
+
39
+ if not os.path.exists("logs"):
40
+ os.mkdir("logs")
41
+
42
+ now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
43
+
44
+ DATASET_REPO_URL = "https://huggingface.co/datasets/tricktreat/HuggingGPT_logs"
45
+ LOG_HF_TOKEN = os.environ.get("LOG_HF_TOKEN")
46
+ if LOG_HF_TOKEN:
47
+ repo = Repository(
48
+ local_dir="logs", clone_from=DATASET_REPO_URL, use_auth_token=LOG_HF_TOKEN
49
+ )
50
+
51
+ logger = logging.getLogger(__name__)
52
+ logger.setLevel(logging.CRITICAL)
53
+
54
+ handler = logging.StreamHandler()
55
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
56
+ handler.setFormatter(formatter)
57
+ if not config["debug"]:
58
+ handler.setLevel(logging.INFO)
59
+ logger.addHandler(handler)
60
+
61
+ log_file = config["log_file"]
62
+ if log_file:
63
+ log_file = log_file.replace("TIMESTAMP", now)
64
+ filehandler = logging.FileHandler(log_file)
65
+ filehandler.setLevel(logging.DEBUG)
66
+ filehandler.setFormatter(formatter)
67
+ logger.addHandler(filehandler)
68
+
69
+ LLM = config["model"]
70
+ use_completion = config["use_completion"]
71
+
72
+ # consistent: wrong msra model name
73
+ LLM_encoding = LLM
74
+ if LLM == "gpt-3.5-turbo":
75
+ LLM_encoding = "text-davinci-003"
76
+ task_parsing_highlight_ids = get_token_ids_for_task_parsing(LLM_encoding)
77
+ choose_model_highlight_ids = get_token_ids_for_choose_model(LLM_encoding)
78
+
79
+ # ENDPOINT MODEL NAME
80
+ # /v1/chat/completions gpt-4, gpt-4-0314, gpt-4-32k, gpt-4-32k-0314, gpt-3.5-turbo, gpt-3.5-turbo-0301
81
+ # /v1/completions text-davinci-003, text-davinci-002, text-curie-001, text-babbage-001, text-ada-001, davinci, curie, babbage, ada
82
+
83
+ if use_completion:
84
+ api_name = "completions"
85
+ else:
86
+ api_name = "chat/completions"
87
+
88
+ if not config["dev"]:
89
+ if not config["openai"]["key"].startswith("sk-") and not config["openai"]["key"]=="gradio":
90
+ raise ValueError("Incrorrect OpenAI key. Please check your config.yaml file.")
91
+ OPENAI_KEY = config["openai"]["key"]
92
+ endpoint = f"https://api.openai.com/v1/{api_name}"
93
+ if OPENAI_KEY.startswith("sk-"):
94
+ HEADER = {
95
+ "Authorization": f"Bearer {OPENAI_KEY}"
96
+ }
97
+ else:
98
+ HEADER = None
99
+ else:
100
+ endpoint = f"{config['local']['endpoint']}/v1/{api_name}"
101
+ HEADER = None
102
+
103
+ PROXY = None
104
+ if config["proxy"]:
105
+ PROXY = {
106
+ "https": config["proxy"],
107
+ }
108
+
109
+ inference_mode = config["inference_mode"]
110
+
111
+ parse_task_demos_or_presteps = open(config["demos_or_presteps"]["parse_task"], "r").read()
112
+ choose_model_demos_or_presteps = open(config["demos_or_presteps"]["choose_model"], "r").read()
113
+ response_results_demos_or_presteps = open(config["demos_or_presteps"]["response_results"], "r").read()
114
+
115
+ parse_task_prompt = config["prompt"]["parse_task"]
116
+ choose_model_prompt = config["prompt"]["choose_model"]
117
+ response_results_prompt = config["prompt"]["response_results"]
118
+
119
+ parse_task_tprompt = config["tprompt"]["parse_task"]
120
+ choose_model_tprompt = config["tprompt"]["choose_model"]
121
+ response_results_tprompt = config["tprompt"]["response_results"]
122
+
123
+ MODELS = [json.loads(line) for line in open("data/p0_models.jsonl", "r").readlines()]
124
+ MODELS_MAP = {}
125
+ for model in MODELS:
126
+ tag = model["task"]
127
+ if tag not in MODELS_MAP:
128
+ MODELS_MAP[tag] = []
129
+ MODELS_MAP[tag].append(model)
130
+ METADATAS = {}
131
+ for model in MODELS:
132
+ METADATAS[model["id"]] = model
133
+
134
+ def convert_chat_to_completion(data):
135
+ messages = data.pop('messages', [])
136
+ tprompt = ""
137
+ if messages[0]['role'] == "system":
138
+ tprompt = messages[0]['content']
139
+ messages = messages[1:]
140
+ final_prompt = ""
141
+ for message in messages:
142
+ if message['role'] == "user":
143
+ final_prompt += ("<im_start>"+ "user" + "\n" + message['content'] + "<im_end>\n")
144
+ elif message['role'] == "assistant":
145
+ final_prompt += ("<im_start>"+ "assistant" + "\n" + message['content'] + "<im_end>\n")
146
+ else:
147
+ final_prompt += ("<im_start>"+ "system" + "\n" + message['content'] + "<im_end>\n")
148
+ final_prompt = tprompt + final_prompt
149
+ final_prompt = final_prompt + "<im_start>assistant"
150
+ data["prompt"] = final_prompt
151
+ data['stop'] = data.get('stop', ["<im_end>"])
152
+ data['max_tokens'] = data.get('max_tokens', max(get_max_context_length(LLM) - count_tokens(LLM_encoding, final_prompt), 1))
153
+ return data
154
+
155
+ def send_request(data):
156
+ global HEADER
157
+ openaikey = data.pop("openaikey")
158
+ if use_completion:
159
+ data = convert_chat_to_completion(data)
160
+ if openaikey and openaikey.startswith("sk-"):
161
+ HEADER = {
162
+ "Authorization": f"Bearer {openaikey}"
163
+ }
164
+
165
+ response = requests.post(endpoint, json=data, headers=HEADER, proxies=PROXY)
166
+ logger.debug(response.text.strip())
167
+ if "choices" not in response.json():
168
+ return response.json()
169
+ if use_completion:
170
+ return response.json()["choices"][0]["text"].strip()
171
+ else:
172
+ return response.json()["choices"][0]["message"]["content"].strip()
173
+
174
+ def replace_slot(text, entries):
175
+ for key, value in entries.items():
176
+ if not isinstance(value, str):
177
+ value = str(value)
178
+ text = text.replace("{{" + key +"}}", value.replace('"', "'").replace('\n', ""))
179
+ return text
180
+
181
+ def find_json(s):
182
+ s = s.replace("\'", "\"")
183
+ start = s.find("{")
184
+ end = s.rfind("}")
185
+ res = s[start:end+1]
186
+ res = res.replace("\n", "")
187
+ return res
188
+
189
+ def field_extract(s, field):
190
+ try:
191
+ field_rep = re.compile(f'{field}.*?:.*?"(.*?)"', re.IGNORECASE)
192
+ extracted = field_rep.search(s).group(1).replace("\"", "\'")
193
+ except:
194
+ field_rep = re.compile(f'{field}:\ *"(.*?)"', re.IGNORECASE)
195
+ extracted = field_rep.search(s).group(1).replace("\"", "\'")
196
+ return extracted
197
+
198
+ def get_id_reason(choose_str):
199
+ reason = field_extract(choose_str, "reason")
200
+ id = field_extract(choose_str, "id")
201
+ choose = {"id": id, "reason": reason}
202
+ return id.strip(), reason.strip(), choose
203
+
204
+ def record_case(success, **args):
205
+ if not success:
206
+ return
207
+ f = open(f"logs/log_success_{now}.jsonl", "a")
208
+ log = args
209
+ f.write(json.dumps(log) + "\n")
210
+ f.close()
211
+ if LOG_HF_TOKEN:
212
+ commit_url = repo.push_to_hub(blocking=False)
213
+
214
+ def image_to_bytes(img_url):
215
+ img_byte = io.BytesIO()
216
+ type = img_url.split(".")[-1]
217
+ load_image(img_url).save(img_byte, format="png")
218
+ img_data = img_byte.getvalue()
219
+ return img_data
220
+
221
+ def resource_has_dep(command):
222
+ args = command["args"]
223
+ for _, v in args.items():
224
+ if "<GENERATED>" in v:
225
+ return True
226
+ return False
227
+
228
+ def fix_dep(tasks):
229
+ for task in tasks:
230
+ args = task["args"]
231
+ task["dep"] = []
232
+ for k, v in args.items():
233
+ if "<GENERATED>" in v:
234
+ dep_task_id = int(v.split("-")[1])
235
+ if dep_task_id not in task["dep"]:
236
+ task["dep"].append(dep_task_id)
237
+ if len(task["dep"]) == 0:
238
+ task["dep"] = [-1]
239
+ return tasks
240
+
241
+ def unfold(tasks):
242
+ flag_unfold_task = False
243
+ try:
244
+ for task in tasks:
245
+ for key, value in task["args"].items():
246
+ if "<GENERATED>" in value:
247
+ generated_items = value.split(",")
248
+ if len(generated_items) > 1:
249
+ flag_unfold_task = True
250
+ for item in generated_items:
251
+ new_task = copy.deepcopy(task)
252
+ dep_task_id = int(item.split("-")[1])
253
+ new_task["dep"] = [dep_task_id]
254
+ new_task["args"][key] = item
255
+ tasks.append(new_task)
256
+ tasks.remove(task)
257
+ except Exception as e:
258
+ print(e)
259
+ traceback.print_exc()
260
+ logger.debug("unfold task failed.")
261
+
262
+ if flag_unfold_task:
263
+ logger.debug(f"unfold tasks: {tasks}")
264
+
265
+ return tasks
266
+
267
+ def chitchat(messages, openaikey=None):
268
+ data = {
269
+ "model": LLM,
270
+ "messages": messages,
271
+ "openaikey": openaikey
272
+ }
273
+ return send_request(data)
274
+
275
+ def parse_task(context, input, openaikey=None):
276
+ demos_or_presteps = parse_task_demos_or_presteps
277
+ messages = json.loads(demos_or_presteps)
278
+ messages.insert(0, {"role": "system", "content": parse_task_tprompt})
279
+
280
+ # cut chat logs
281
+ start = 0
282
+ while start <= len(context):
283
+ history = context[start:]
284
+ prompt = replace_slot(parse_task_prompt, {
285
+ "input": input,
286
+ "context": history
287
+ })
288
+ messages.append({"role": "user", "content": prompt})
289
+ history_text = "<im_end>\nuser<im_start>".join([m["content"] for m in messages])
290
+ num = count_tokens(LLM_encoding, history_text)
291
+ if get_max_context_length(LLM) - num > 800:
292
+ break
293
+ messages.pop()
294
+ start += 2
295
+
296
+ logger.debug(messages)
297
+ data = {
298
+ "model": LLM,
299
+ "messages": messages,
300
+ "temperature": 0,
301
+ "logit_bias": {item: config["logit_bias"]["parse_task"] for item in task_parsing_highlight_ids},
302
+ "openaikey": openaikey
303
+ }
304
+ return send_request(data)
305
+
306
+ def choose_model(input, task, metas, openaikey = None):
307
+ prompt = replace_slot(choose_model_prompt, {
308
+ "input": input,
309
+ "task": task,
310
+ "metas": metas,
311
+ })
312
+ demos_or_presteps = replace_slot(choose_model_demos_or_presteps, {
313
+ "input": input,
314
+ "task": task,
315
+ "metas": metas
316
+ })
317
+ messages = json.loads(demos_or_presteps)
318
+ messages.insert(0, {"role": "system", "content": choose_model_tprompt})
319
+ messages.append({"role": "user", "content": prompt})
320
+ logger.debug(messages)
321
+ data = {
322
+ "model": LLM,
323
+ "messages": messages,
324
+ "temperature": 0,
325
+ "logit_bias": {item: config["logit_bias"]["choose_model"] for item in choose_model_highlight_ids}, # 5
326
+ "openaikey": openaikey
327
+ }
328
+ return send_request(data)
329
+
330
+
331
+ def response_results(input, results, openaikey=None):
332
+ results = [v for k, v in sorted(results.items(), key=lambda item: item[0])]
333
+ prompt = replace_slot(response_results_prompt, {
334
+ "input": input,
335
+ })
336
+ demos_or_presteps = replace_slot(response_results_demos_or_presteps, {
337
+ "input": input,
338
+ "processes": results
339
+ })
340
+ messages = json.loads(demos_or_presteps)
341
+ messages.insert(0, {"role": "system", "content": response_results_tprompt})
342
+ messages.append({"role": "user", "content": prompt})
343
+ logger.debug(messages)
344
+ data = {
345
+ "model": LLM,
346
+ "messages": messages,
347
+ "temperature": 0,
348
+ "openaikey": openaikey
349
+ }
350
+ return send_request(data)
351
+
352
+ def huggingface_model_inference(model_id, data, task, huggingfacetoken=None):
353
+ if huggingfacetoken is None:
354
+ HUGGINGFACE_HEADERS = {}
355
+ else:
356
+ HUGGINGFACE_HEADERS = {
357
+ "Authorization": f"Bearer {huggingfacetoken}",
358
+ }
359
+ task_url = f"https://api-inference.huggingface.co/models/{model_id}" # InferenceApi does not yet support some tasks
360
+ inference = InferenceApi(repo_id=model_id, token=huggingfacetoken)
361
+
362
+ # NLP tasks
363
+ if task == "question-answering":
364
+ inputs = {"question": data["text"], "context": (data["context"] if "context" in data else "" )}
365
+ result = inference(inputs)
366
+ if task == "sentence-similarity":
367
+ inputs = {"source_sentence": data["text1"], "target_sentence": data["text2"]}
368
+ result = inference(inputs)
369
+ if task in ["text-classification", "token-classification", "text2text-generation", "summarization", "translation", "conversational", "text-generation"]:
370
+ inputs = data["text"]
371
+ result = inference(inputs)
372
+
373
+ # CV tasks
374
+ if task == "visual-question-answering" or task == "document-question-answering":
375
+ img_url = data["image"]
376
+ text = data["text"]
377
+ img_data = image_to_bytes(img_url)
378
+ img_base64 = base64.b64encode(img_data).decode("utf-8")
379
+ json_data = {}
380
+ json_data["inputs"] = {}
381
+ json_data["inputs"]["question"] = text
382
+ json_data["inputs"]["image"] = img_base64
383
+ result = requests.post(task_url, headers=HUGGINGFACE_HEADERS, json=json_data).json()
384
+ # result = inference(inputs) # not support
385
+
386
+ if task == "image-to-image":
387
+ img_url = data["image"]
388
+ img_data = image_to_bytes(img_url)
389
+ # result = inference(data=img_data) # not support
390
+ HUGGINGFACE_HEADERS["Content-Length"] = str(len(img_data))
391
+ r = requests.post(task_url, headers=HUGGINGFACE_HEADERS, data=img_data)
392
+ result = r.json()
393
+ if "path" in result:
394
+ result["generated image"] = result.pop("path")
395
+
396
+ if task == "text-to-image":
397
+ inputs = data["text"]
398
+ img = inference(inputs)
399
+ name = str(uuid.uuid4())[:4]
400
+ img.save(f"public/images/{name}.png")
401
+ result = {}
402
+ result["generated image"] = f"/images/{name}.png"
403
+
404
+ if task == "image-segmentation":
405
+ img_url = data["image"]
406
+ img_data = image_to_bytes(img_url)
407
+ image = Image.open(BytesIO(img_data))
408
+ predicted = inference(data=img_data)
409
+ colors = []
410
+ for i in range(len(predicted)):
411
+ colors.append((random.randint(100, 255), random.randint(100, 255), random.randint(100, 255), 155))
412
+ for i, pred in enumerate(predicted):
413
+ label = pred["label"]
414
+ mask = pred.pop("mask").encode("utf-8")
415
+ mask = base64.b64decode(mask)
416
+ mask = Image.open(BytesIO(mask), mode='r')
417
+ mask = mask.convert('L')
418
+
419
+ layer = Image.new('RGBA', mask.size, colors[i])
420
+ image.paste(layer, (0, 0), mask)
421
+ name = str(uuid.uuid4())[:4]
422
+ image.save(f"public/images/{name}.jpg")
423
+ result = {}
424
+ result["generated image with segmentation mask"] = f"/images/{name}.jpg"
425
+ result["predicted"] = predicted
426
+
427
+ if task == "object-detection":
428
+ img_url = data["image"]
429
+ img_data = image_to_bytes(img_url)
430
+ predicted = inference(data=img_data)
431
+ image = Image.open(BytesIO(img_data))
432
+ draw = ImageDraw.Draw(image)
433
+ labels = list(item['label'] for item in predicted)
434
+ color_map = {}
435
+ for label in labels:
436
+ if label not in color_map:
437
+ color_map[label] = (random.randint(0, 255), random.randint(0, 100), random.randint(0, 255))
438
+ for label in predicted:
439
+ box = label["box"]
440
+ draw.rectangle(((box["xmin"], box["ymin"]), (box["xmax"], box["ymax"])), outline=color_map[label["label"]], width=2)
441
+ draw.text((box["xmin"]+5, box["ymin"]-15), label["label"], fill=color_map[label["label"]])
442
+ name = str(uuid.uuid4())[:4]
443
+ image.save(f"public/images/{name}.jpg")
444
+ result = {}
445
+ result["generated image with predicted box"] = f"/images/{name}.jpg"
446
+ result["predicted"] = predicted
447
+
448
+ if task in ["image-classification"]:
449
+ img_url = data["image"]
450
+ img_data = image_to_bytes(img_url)
451
+ result = inference(data=img_data)
452
+
453
+ if task == "image-to-text":
454
+ img_url = data["image"]
455
+ img_data = image_to_bytes(img_url)
456
+ HUGGINGFACE_HEADERS["Content-Length"] = str(len(img_data))
457
+ r = requests.post(task_url, headers=HUGGINGFACE_HEADERS, data=img_data)
458
+ result = {}
459
+ if "generated_text" in r.json()[0]:
460
+ result["generated text"] = r.json()[0].pop("generated_text")
461
+
462
+ # AUDIO tasks
463
+ if task == "text-to-speech":
464
+ inputs = data["text"]
465
+ response = inference(inputs, raw_response=True)
466
+ # response = requests.post(task_url, headers=HUGGINGFACE_HEADERS, json={"inputs": text})
467
+ name = str(uuid.uuid4())[:4]
468
+ with open(f"public/audios/{name}.flac", "wb") as f:
469
+ f.write(response.content)
470
+ result = {"generated audio": f"/audios/{name}.flac"}
471
+ if task in ["automatic-speech-recognition", "audio-to-audio", "audio-classification"]:
472
+ audio_url = data["audio"]
473
+ audio_data = requests.get(audio_url, timeout=10).content
474
+ response = inference(data=audio_data, raw_response=True)
475
+ result = response.json()
476
+ if task == "audio-to-audio":
477
+ content = None
478
+ type = None
479
+ for k, v in result[0].items():
480
+ if k == "blob":
481
+ content = base64.b64decode(v.encode("utf-8"))
482
+ if k == "content-type":
483
+ type = "audio/flac".split("/")[-1]
484
+ audio = AudioSegment.from_file(BytesIO(content))
485
+ name = str(uuid.uuid4())[:4]
486
+ audio.export(f"public/audios/{name}.{type}", format=type)
487
+ result = {"generated audio": f"/audios/{name}.{type}"}
488
+ return result
489
+
490
+ def local_model_inference(model_id, data, task):
491
+ inference = partial(models, model_id)
492
+ # contronlet
493
+ if model_id.startswith("lllyasviel/sd-controlnet-"):
494
+ img_url = data["image"]
495
+ text = data["text"]
496
+ results = inference({"img_url": img_url, "text": text})
497
+ if "path" in results:
498
+ results["generated image"] = results.pop("path")
499
+ return results
500
+ if model_id.endswith("-control"):
501
+ img_url = data["image"]
502
+ results = inference({"img_url": img_url})
503
+ if "path" in results:
504
+ results["generated image"] = results.pop("path")
505
+ return results
506
+
507
+ if task == "text-to-video":
508
+ results = inference(data)
509
+ if "path" in results:
510
+ results["generated video"] = results.pop("path")
511
+ return results
512
+
513
+ # NLP tasks
514
+ if task == "question-answering" or task == "sentence-similarity":
515
+ results = inference(json=data)
516
+ return results
517
+ if task in ["text-classification", "token-classification", "text2text-generation", "summarization", "translation", "conversational", "text-generation"]:
518
+ results = inference(json=data)
519
+ return results
520
+
521
+ # CV tasks
522
+ if task == "depth-estimation":
523
+ img_url = data["image"]
524
+ results = inference({"img_url": img_url})
525
+ if "path" in results:
526
+ results["generated depth image"] = results.pop("path")
527
+ return results
528
+ if task == "image-segmentation":
529
+ img_url = data["image"]
530
+ results = inference({"img_url": img_url})
531
+ results["generated image with segmentation mask"] = results.pop("path")
532
+ return results
533
+ if task == "image-to-image":
534
+ img_url = data["image"]
535
+ results = inference({"img_url": img_url})
536
+ if "path" in results:
537
+ results["generated image"] = results.pop("path")
538
+ return results
539
+ if task == "text-to-image":
540
+ results = inference(data)
541
+ if "path" in results:
542
+ results["generated image"] = results.pop("path")
543
+ return results
544
+ if task == "object-detection":
545
+ img_url = data["image"]
546
+ predicted = inference({"img_url": img_url})
547
+ if "error" in predicted:
548
+ return predicted
549
+ image = load_image(img_url)
550
+ draw = ImageDraw.Draw(image)
551
+ labels = list(item['label'] for item in predicted)
552
+ color_map = {}
553
+ for label in labels:
554
+ if label not in color_map:
555
+ color_map[label] = (random.randint(0, 255), random.randint(0, 100), random.randint(0, 255))
556
+ for label in predicted:
557
+ box = label["box"]
558
+ draw.rectangle(((box["xmin"], box["ymin"]), (box["xmax"], box["ymax"])), outline=color_map[label["label"]], width=2)
559
+ draw.text((box["xmin"]+5, box["ymin"]-15), label["label"], fill=color_map[label["label"]])
560
+ name = str(uuid.uuid4())[:4]
561
+ image.save(f"public/images/{name}.jpg")
562
+ results = {}
563
+ results["generated image with predicted box"] = f"/images/{name}.jpg"
564
+ results["predicted"] = predicted
565
+ return results
566
+ if task in ["image-classification", "image-to-text", "document-question-answering", "visual-question-answering"]:
567
+ img_url = data["image"]
568
+ text = None
569
+ if "text" in data:
570
+ text = data["text"]
571
+ results = inference({"img_url": img_url, "text": text})
572
+ return results
573
+ # AUDIO tasks
574
+ if task == "text-to-speech":
575
+ results = inference(data)
576
+ if "path" in results:
577
+ results["generated audio"] = results.pop("path")
578
+ return results
579
+ if task in ["automatic-speech-recognition", "audio-to-audio", "audio-classification"]:
580
+ audio_url = data["audio"]
581
+ results = inference({"audio_url": audio_url})
582
+ return results
583
+
584
+
585
+ def model_inference(model_id, data, hosted_on, task, huggingfacetoken=None):
586
+ if huggingfacetoken:
587
+ HUGGINGFACE_HEADERS = {
588
+ "Authorization": f"Bearer {huggingfacetoken}",
589
+ }
590
+ else:
591
+ HUGGINGFACE_HEADERS = None
592
+ if hosted_on == "unknown":
593
+ r = status(model_id)
594
+ logger.debug("Local Server Status: " + str(r))
595
+ if "loaded" in r and r["loaded"]:
596
+ hosted_on = "local"
597
+ else:
598
+ huggingfaceStatusUrl = f"https://api-inference.huggingface.co/status/{model_id}"
599
+ r = requests.get(huggingfaceStatusUrl, headers=HUGGINGFACE_HEADERS, proxies=PROXY)
600
+ logger.debug("Huggingface Status: " + str(r.json()))
601
+ if "loaded" in r and r["loaded"]:
602
+ hosted_on = "huggingface"
603
+ try:
604
+ if hosted_on == "local":
605
+ inference_result = local_model_inference(model_id, data, task)
606
+ elif hosted_on == "huggingface":
607
+ inference_result = huggingface_model_inference(model_id, data, task, huggingfacetoken)
608
+ except Exception as e:
609
+ print(e)
610
+ traceback.print_exc()
611
+ inference_result = {"error":{"message": str(e)}}
612
+ return inference_result
613
+
614
+
615
+ def get_model_status(model_id, url, headers, queue = None):
616
+ endpoint_type = "huggingface" if "huggingface" in url else "local"
617
+ if "huggingface" in url:
618
+ r = requests.get(url, headers=headers, proxies=PROXY)
619
+ else:
620
+ r = status(model_id)
621
+ if "loaded" in r and r["loaded"]:
622
+ if queue:
623
+ queue.put((model_id, True, endpoint_type))
624
+ return True
625
+ else:
626
+ if queue:
627
+ queue.put((model_id, False, None))
628
+ return False
629
+
630
+ def get_avaliable_models(candidates, topk=10, huggingfacetoken = None):
631
+ all_available_models = {"local": [], "huggingface": []}
632
+ threads = []
633
+ result_queue = Queue()
634
+ HUGGINGFACE_HEADERS = {
635
+ "Authorization": f"Bearer {huggingfacetoken}",
636
+ }
637
+ for candidate in candidates:
638
+ model_id = candidate["id"]
639
+
640
+ if inference_mode != "local":
641
+ huggingfaceStatusUrl = f"https://api-inference.huggingface.co/status/{model_id}"
642
+ thread = threading.Thread(target=get_model_status, args=(model_id, huggingfaceStatusUrl, HUGGINGFACE_HEADERS, result_queue))
643
+ threads.append(thread)
644
+ thread.start()
645
+
646
+ if inference_mode != "huggingface" and config["local_deployment"] != "minimal":
647
+ thread = threading.Thread(target=get_model_status, args=(model_id, "", {}, result_queue))
648
+ threads.append(thread)
649
+ thread.start()
650
+
651
+ result_count = len(threads)
652
+ while result_count:
653
+ model_id, status, endpoint_type = result_queue.get()
654
+ if status and model_id not in all_available_models:
655
+ all_available_models[endpoint_type].append(model_id)
656
+ if len(all_available_models["local"] + all_available_models["huggingface"]) >= topk:
657
+ break
658
+ result_count -= 1
659
+
660
+ for thread in threads:
661
+ thread.join()
662
+
663
+ return all_available_models
664
+
665
+ def collect_result(command, choose, inference_result):
666
+ result = {"task": command}
667
+ result["inference result"] = inference_result
668
+ result["choose model result"] = choose
669
+ logger.debug(f"inference result: {inference_result}")
670
+ return result
671
+
672
+
673
+ def run_task(input, command, results, openaikey = None, huggingfacetoken = None):
674
+ id = command["id"]
675
+ args = command["args"]
676
+ task = command["task"]
677
+ deps = command["dep"]
678
+ if deps[0] != -1:
679
+ dep_tasks = [results[dep] for dep in deps]
680
+ else:
681
+ dep_tasks = []
682
+
683
+ logger.debug(f"Run task: {id} - {task}")
684
+ logger.debug("Deps: " + json.dumps(dep_tasks))
685
+
686
+ if deps[0] != -1:
687
+ if "image" in args and "<GENERATED>-" in args["image"]:
688
+ resource_id = int(args["image"].split("-")[1])
689
+ if "generated image" in results[resource_id]["inference result"]:
690
+ args["image"] = results[resource_id]["inference result"]["generated image"]
691
+ if "audio" in args and "<GENERATED>-" in args["audio"]:
692
+ resource_id = int(args["audio"].split("-")[1])
693
+ if "generated audio" in results[resource_id]["inference result"]:
694
+ args["audio"] = results[resource_id]["inference result"]["generated audio"]
695
+ if "text" in args and "<GENERATED>-" in args["text"]:
696
+ resource_id = int(args["text"].split("-")[1])
697
+ if "generated text" in results[resource_id]["inference result"]:
698
+ args["text"] = results[resource_id]["inference result"]["generated text"]
699
+
700
+ text = image = audio = None
701
+ for dep_task in dep_tasks:
702
+ if "generated text" in dep_task["inference result"]:
703
+ text = dep_task["inference result"]["generated text"]
704
+ logger.debug("Detect the generated text of dependency task (from results):" + text)
705
+ elif "text" in dep_task["task"]["args"]:
706
+ text = dep_task["task"]["args"]["text"]
707
+ logger.debug("Detect the text of dependency task (from args): " + text)
708
+ if "generated image" in dep_task["inference result"]:
709
+ image = dep_task["inference result"]["generated image"]
710
+ logger.debug("Detect the generated image of dependency task (from results): " + image)
711
+ elif "image" in dep_task["task"]["args"]:
712
+ image = dep_task["task"]["args"]["image"]
713
+ logger.debug("Detect the image of dependency task (from args): " + image)
714
+ if "generated audio" in dep_task["inference result"]:
715
+ audio = dep_task["inference result"]["generated audio"]
716
+ logger.debug("Detect the generated audio of dependency task (from results): " + audio)
717
+ elif "audio" in dep_task["task"]["args"]:
718
+ audio = dep_task["task"]["args"]["audio"]
719
+ logger.debug("Detect the audio of dependency task (from args): " + audio)
720
+
721
+ if "image" in args and "<GENERATED>" in args["image"]:
722
+ if image:
723
+ args["image"] = image
724
+ if "audio" in args and "<GENERATED>" in args["audio"]:
725
+ if audio:
726
+ args["audio"] = audio
727
+ if "text" in args and "<GENERATED>" in args["text"]:
728
+ if text:
729
+ args["text"] = text
730
+
731
+ for resource in ["image", "audio"]:
732
+ if resource in args and not args[resource].startswith("public/") and len(args[resource]) > 0 and not args[resource].startswith("http"):
733
+ args[resource] = f"public/{args[resource]}"
734
+
735
+ if "-text-to-image" in command['task'] and "text" not in args:
736
+ logger.debug("control-text-to-image task, but text is empty, so we use control-generation instead.")
737
+ control = task.split("-")[0]
738
+
739
+ if control == "seg":
740
+ task = "image-segmentation"
741
+ command['task'] = task
742
+ elif control == "depth":
743
+ task = "depth-estimation"
744
+ command['task'] = task
745
+ else:
746
+ task = f"{control}-control"
747
+
748
+ command["args"] = args
749
+ logger.debug(f"parsed task: {command}")
750
+
751
+ if task.endswith("-text-to-image") or task.endswith("-control"):
752
+ if inference_mode != "huggingface":
753
+ if task.endswith("-text-to-image"):
754
+ control = task.split("-")[0]
755
+ best_model_id = f"lllyasviel/sd-controlnet-{control}"
756
+ else:
757
+ best_model_id = task
758
+ hosted_on = "local"
759
+ reason = "ControlNet is the best model for this task."
760
+ choose = {"id": best_model_id, "reason": reason}
761
+ logger.debug(f"chosen model: {choose}")
762
+ else:
763
+ logger.warning(f"Task {command['task']} is not available. ControlNet need to be deployed locally.")
764
+ record_case(success=False, **{"input": input, "task": command, "reason": f"Task {command['task']} is not available. ControlNet need to be deployed locally.", "op":"message"})
765
+ inference_result = {"error": f"service related to ControlNet is not available."}
766
+ results[id] = collect_result(command, "", inference_result)
767
+ return False
768
+ elif task in ["summarization", "translation", "conversational", "text-generation", "text2text-generation"]: # ChatGPT Can do
769
+ best_model_id = "ChatGPT"
770
+ reason = "ChatGPT performs well on some NLP tasks as well."
771
+ choose = {"id": best_model_id, "reason": reason}
772
+ messages = [{
773
+ "role": "user",
774
+ "content": f"[ {input} ] contains a task in JSON format {command}, 'task' indicates the task type and 'args' indicates the arguments required for the task. Don't explain the task to me, just help me do it and give me the result. The result must be in text form without any urls."
775
+ }]
776
+ response = chitchat(messages, openaikey)
777
+ results[id] = collect_result(command, choose, {"response": response})
778
+ return True
779
+ else:
780
+ if task not in MODELS_MAP:
781
+ logger.warning(f"no available models on {task} task.")
782
+ record_case(success=False, **{"input": input, "task": command, "reason": f"task not support: {command['task']}", "op":"message"})
783
+ inference_result = {"error": f"{command['task']} not found in available tasks."}
784
+ results[id] = collect_result(command, "", inference_result)
785
+ return False
786
+
787
+ candidates = MODELS_MAP[task][:20]
788
+ all_avaliable_models = get_avaliable_models(candidates, config["num_candidate_models"], huggingfacetoken)
789
+ all_avaliable_model_ids = all_avaliable_models["local"] + all_avaliable_models["huggingface"]
790
+ logger.debug(f"avaliable models on {command['task']}: {all_avaliable_models}")
791
+
792
+ if len(all_avaliable_model_ids) == 0:
793
+ logger.warning(f"no available models on {command['task']}")
794
+ record_case(success=False, **{"input": input, "task": command, "reason": f"no available models: {command['task']}", "op":"message"})
795
+ inference_result = {"error": f"no available models on {command['task']} task."}
796
+ results[id] = collect_result(command, "", inference_result)
797
+ return False
798
+
799
+ if len(all_avaliable_model_ids) == 1:
800
+ best_model_id = all_avaliable_model_ids[0]
801
+ hosted_on = "local" if best_model_id in all_avaliable_models["local"] else "huggingface"
802
+ reason = "Only one model available."
803
+ choose = {"id": best_model_id, "reason": reason}
804
+ logger.debug(f"chosen model: {choose}")
805
+ else:
806
+ cand_models_info = [
807
+ {
808
+ "id": model["id"],
809
+ "inference endpoint": all_avaliable_models.get(
810
+ "local" if model["id"] in all_avaliable_models["local"] else "huggingface"
811
+ ),
812
+ "likes": model.get("likes"),
813
+ "description": model.get("description", "")[:config["max_description_length"]],
814
+ "language": model.get("language"),
815
+ "tags": model.get("tags"),
816
+ }
817
+ for model in candidates
818
+ if model["id"] in all_avaliable_model_ids
819
+ ]
820
+
821
+ choose_str = choose_model(input, command, cand_models_info, openaikey)
822
+ logger.debug(f"chosen model: {choose_str}")
823
+ try:
824
+ choose = json.loads(choose_str)
825
+ reason = choose["reason"]
826
+ best_model_id = choose["id"]
827
+ hosted_on = "local" if best_model_id in all_avaliable_models["local"] else "huggingface"
828
+ except Exception as e:
829
+ logger.warning(f"the response [ {choose_str} ] is not a valid JSON, try to find the model id and reason in the response.")
830
+ choose_str = find_json(choose_str)
831
+ best_model_id, reason, choose = get_id_reason(choose_str)
832
+ hosted_on = "local" if best_model_id in all_avaliable_models["local"] else "huggingface"
833
+ inference_result = model_inference(best_model_id, args, hosted_on, command['task'], huggingfacetoken)
834
+
835
+ if "error" in inference_result:
836
+ logger.warning(f"Inference error: {inference_result['error']}")
837
+ record_case(success=False, **{"input": input, "task": command, "reason": f"inference error: {inference_result['error']}", "op":"message"})
838
+ results[id] = collect_result(command, choose, inference_result)
839
+ return False
840
+
841
+ results[id] = collect_result(command, choose, inference_result)
842
+ return True
843
+
844
+ def chat_huggingface(messages, openaikey = None, huggingfacetoken = None, return_planning = False, return_results = False):
845
+ start = time.time()
846
+ context = messages[:-1]
847
+ input = messages[-1]["content"]
848
+ logger.info("*"*80)
849
+ logger.info(f"input: {input}")
850
+
851
+ task_str = parse_task(context, input, openaikey)
852
+ logger.info(task_str)
853
+
854
+ if "error" in task_str:
855
+ return str(task_str), {}
856
+ else:
857
+ task_str = task_str.strip()
858
+
859
+ try:
860
+ tasks = json.loads(task_str)
861
+ except Exception as e:
862
+ logger.debug(e)
863
+ response = chitchat(messages, openaikey)
864
+ record_case(success=False, **{"input": input, "task": task_str, "reason": "task parsing fail", "op":"chitchat"})
865
+ return response, {}
866
+
867
+ if task_str == "[]": # using LLM response for empty task
868
+ record_case(success=False, **{"input": input, "task": [], "reason": "task parsing fail: empty", "op": "chitchat"})
869
+ response = chitchat(messages, openaikey)
870
+ return response, {}
871
+
872
+ if len(tasks)==1 and tasks[0]["task"] in ["summarization", "translation", "conversational", "text-generation", "text2text-generation"]:
873
+ record_case(success=True, **{"input": input, "task": tasks, "reason": "task parsing fail: empty", "op": "chitchat"})
874
+ response = chitchat(messages, openaikey)
875
+ best_model_id = "ChatGPT"
876
+ reason = "ChatGPT performs well on some NLP tasks as well."
877
+ choose = {"id": best_model_id, "reason": reason}
878
+ return response, collect_result(tasks[0], choose, {"response": response})
879
+
880
+
881
+ tasks = unfold(tasks)
882
+ tasks = fix_dep(tasks)
883
+ logger.debug(tasks)
884
+
885
+ if return_planning:
886
+ return tasks
887
+
888
+ results = {}
889
+ threads = []
890
+ tasks = tasks[:]
891
+ d = dict()
892
+ retry = 0
893
+ while True:
894
+ num_threads = len(threads)
895
+ for task in tasks:
896
+ dep = task["dep"]
897
+ # logger.debug(f"d.keys(): {d.keys()}, dep: {dep}")
898
+ for dep_id in dep:
899
+ if dep_id >= task["id"]:
900
+ task["dep"] = [-1]
901
+ dep = [-1]
902
+ break
903
+ if len(list(set(dep).intersection(d.keys()))) == len(dep) or dep[0] == -1:
904
+ tasks.remove(task)
905
+ thread = threading.Thread(target=run_task, args=(input, task, d, openaikey, huggingfacetoken))
906
+ thread.start()
907
+ threads.append(thread)
908
+ if num_threads == len(threads):
909
+ time.sleep(0.5)
910
+ retry += 1
911
+ if retry > 80:
912
+ logger.debug("User has waited too long, Loop break.")
913
+ break
914
+ if len(tasks) == 0:
915
+ break
916
+ for thread in threads:
917
+ thread.join()
918
+
919
+ results = d.copy()
920
+
921
+ logger.debug(results)
922
+ if return_results:
923
+ return results
924
+
925
+ response = response_results(input, results, openaikey).strip()
926
+
927
+ end = time.time()
928
+ during = end - start
929
+
930
+ answer = {"message": response}
931
+ record_case(success=True, **{"input": input, "task": task_str, "results": results, "response": response, "during": during, "op":"response"})
932
+ logger.info(f"response: {response}")
933
+ return response, results
config.gradio.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ openai:
2
+ key: gradio # "gradio" (set when request) or your_personal_key
3
+ huggingface:
4
+ token: # required: huggingface token @ https://huggingface.co/settings/tokens
5
+ dev: false
6
+ debug: true
7
+ log_file: logs/debug_TIMESTAMP.log
8
+ model: text-davinci-003 # text-davinci-003
9
+ use_completion: true
10
+ inference_mode: hybrid # local, huggingface or hybrid
11
+ local_deployment: standard # minimal, standard or full
12
+ num_candidate_models: 5
13
+ max_description_length: 100
14
+ proxy:
15
+ logit_bias:
16
+ parse_task: 0.5
17
+ choose_model: 5
18
+ tprompt:
19
+ parse_task: >-
20
+ #1 Task Planning Stage: The AI assistant can parse user input to several tasks: [{"task": task, "id": task_id, "dep": dependency_task_id, "args": {"text": text or <GENERATED>-dep_id, "image": image_url or <GENERATED>-dep_id, "audio": audio_url or <GENERATED>-dep_id}}]. The special tag "<GENERATED>-dep_id" refer to the one genereted text/image/audio in the dependency task (Please consider whether the dependency task generates resources of this type.) and "dep_id" must be in "dep" list. The "dep" field denotes the ids of the previous prerequisite tasks which generate a new resource that the current task relies on. The "args" field must in ["text", "image", "audio"], nothing else. The task MUST be selected from the following options: "token-classification", "text2text-generation", "summarization", "translation", "question-answering", "conversational", "text-generation", "sentence-similarity", "tabular-classification", "object-detection", "image-classification", "image-to-image", "image-to-text", "text-to-image", "text-to-video", "visual-question-answering", "document-question-answering", "image-segmentation", "depth-estimation", "text-to-speech", "automatic-speech-recognition", "audio-to-audio", "audio-classification", "canny-control", "hed-control", "mlsd-control", "normal-control", "openpose-control", "canny-text-to-image", "depth-text-to-image", "hed-text-to-image", "mlsd-text-to-image", "normal-text-to-image", "openpose-text-to-image", "seg-text-to-image". There may be multiple tasks of the same type. Think step by step about all the tasks needed to resolve the user's request. Parse out as few tasks as possible while ensuring that the user request can be resolved. Pay attention to the dependencies and order among tasks. If the user input can't be parsed, you need to reply empty JSON [].
21
+ choose_model: >-
22
+ #2 Model Selection Stage: Given the user request and the parsed tasks, the AI assistant helps the user to select a suitable model from a list of models to process the user request. The assistant should focus more on the description of the model and find the model that has the most potential to solve requests and tasks. Also, prefer models with local inference endpoints for speed and stability.
23
+ response_results: >-
24
+ #4 Response Generation Stage: With the task execution logs, the AI assistant needs to describe the process and inference results.
25
+ demos_or_presteps:
26
+ parse_task: demos/demo_parse_task.json
27
+ choose_model: demos/demo_choose_model.json
28
+ response_results: demos/demo_response_results.json
29
+ prompt:
30
+ parse_task: The chat log [ {{context}} ] may contain the resources I mentioned. Now I input { {{input}} }. Pay attention to the input and output types of tasks and the dependencies between tasks.
31
+ choose_model: >-
32
+ Please choose the most suitable model from {{metas}} for the task {{task}}. The output must be in a strict JSON format: {"id": "id", "reason": "your detail reasons for the choice"}.
33
+ response_results: >-
34
+ Yes. Please first think carefully and directly answer my request based on the inference results. Some of the inferences may not always turn out to be correct and require you to make careful consideration in making decisions. Then please detail your workflow including the used models and inference results for my request in your friendly tone. Please filter out information that is not relevant to my request. Tell me the complete path or urls of files in inference results. If there is nothing in the results, please tell me you can't make it. }
data/p0_models.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
demos/demo_choose_model.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "role": "user",
4
+ "content": "{{input}}"
5
+ },
6
+ {
7
+ "role": "assistant",
8
+ "content": "{{task}}"
9
+ }
10
+ ]
demos/demo_parse_task.json ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "role": "user",
4
+ "content": "Give you some pictures e1.jpg, e2.png, e3.jpg, help me count the number of sheep?"
5
+ },
6
+ {
7
+ "role": "assistant",
8
+ "content": "[{\"task\": \"image-to-text\", \"id\": 0, \"dep\": [-1], \"args\": {\"image\": \"e1.jpg\" }}, {\"task\": \"object-detection\", \"id\": 1, \"dep\": [-1], \"args\": {\"image\": \"e1.jpg\" }}, {\"task\": \"visual-question-answering\", \"id\": 2, \"dep\": [1], \"args\": {\"image\": \"<GENERATED>-1\", \"text\": \"How many sheep in the picture\"}} }}, {\"task\": \"image-to-text\", \"id\": 3, \"dep\": [-1], \"args\": {\"image\": \"e2.png\" }}, {\"task\": \"object-detection\", \"id\": 4, \"dep\": [-1], \"args\": {\"image\": \"e2.png\" }}, {\"task\": \"visual-question-answering\", \"id\": 5, \"dep\": [4], \"args\": {\"image\": \"<GENERATED>-4\", \"text\": \"How many sheep in the picture\"}} }}, {\"task\": \"image-to-text\", \"id\": 6, \"dep\": [-1], \"args\": {\"image\": \"e3.jpg\" }}, {\"task\": \"object-detection\", \"id\": 7, \"dep\": [-1], \"args\": {\"image\": \"e3.jpg\" }}, {\"task\": \"visual-question-answering\", \"id\": 8, \"dep\": [7], \"args\": {\"image\": \"<GENERATED>-7\", \"text\": \"How many sheep in the picture\"}}]"
9
+ },
10
+
11
+ {
12
+ "role":"user",
13
+ "content":"Look at /e.jpg, can you tell me how many objects in the picture? Give me a picture and video similar to this one."
14
+ },
15
+ {
16
+ "role":"assistant",
17
+ "content":"[{\"task\": \"image-to-text\", \"id\": 0, \"dep\": [-1], \"args\": {\"image\": \"/e.jpg\" }}, {\"task\": \"object-detection\", \"id\": 1, \"dep\": [-1], \"args\": {\"image\": \"/e.jpg\" }}, {\"task\": \"visual-question-answering\", \"id\": 2, \"dep\": [1], \"args\": {\"image\": \"<GENERATED>-1\", \"text\": \"how many objects in the picture?\" }}, {\"task\": \"text-to-image\", \"id\": 3, \"dep\": [0], \"args\": {\"text\": \"<GENERATED-0>\" }}, {\"task\": \"image-to-image\", \"id\": 4, \"dep\": [-1], \"args\": {\"image\": \"/e.jpg\" }}, {\"task\": \"text-to-video\", \"id\": 5, \"dep\": [0], \"args\": {\"text\": \"<GENERATED-0>\" }}]"
18
+ },
19
+
20
+ {
21
+ "role":"user",
22
+ "content":"given a document /images/e.jpeg, answer me what is the student amount? And describe the image with your voice"
23
+ },
24
+ {
25
+ "role":"assistant",
26
+ "content":"{\"task\": \"document-question-answering\", \"id\": 0, \"dep\": [-1], \"args\": {\"image\": \"/images/e.jpeg\", \"text\": \"what is the student amount?\" }}, {\"task\": \"visual-question-answering\", \"id\": 1, \"dep\": [-1], \"args\": {\"image\": \"/images/e.jpeg\", \"text\": \"what is the student amount?\" }}, {\"task\": \"image-to-text\", \"id\": 2, \"dep\": [-1], \"args\": {\"image\": \"/images/e.jpg\" }}, {\"task\": \"text-to-speech\", \"id\": 3, \"dep\": [2], \"args\": {\"text\": \"<GENERATED>-2\" }}]"
27
+ },
28
+
29
+ {
30
+ "role": "user",
31
+ "content": "Given an image /example.jpg, first generate a hed image, then based on the hed image generate a new image where a girl is reading a book"
32
+ },
33
+ {
34
+ "role": "assistant",
35
+ "content": "[{\"task\": \"openpose-control\", \"id\": 0, \"dep\": [-1], \"args\": {\"image\": \"/example.jpg\" }}, {\"task\": \"openpose-text-to-image\", \"id\": 1, \"dep\": [0], \"args\": {\"text\": \"a girl is reading a book\", \"image\": \"<GENERATED>-0\" }}]"
36
+ },
37
+
38
+ {
39
+ "role": "user",
40
+ "content": "please show me a video and an image of (based on the text) 'a boy is running' and dub it"
41
+ },
42
+ {
43
+ "role": "assistant",
44
+ "content": "[{\"task\": \"text-to-video\", \"id\": 0, \"dep\": [-1], \"args\": {\"text\": \"a boy is running\" }}, {\"task\": \"text-to-speech\", \"id\": 1, \"dep\": [-1], \"args\": {\"text\": \"a boy is running\" }}, {\"task\": \"text-to-image\", \"id\": 2, \"dep\": [-1], \"args\": {\"text\": \"a boy is running\" }}]"
45
+ },
46
+
47
+
48
+ {
49
+ "role": "user",
50
+ "content": "please show me a joke and an image of cat"
51
+ },
52
+ {
53
+ "role": "assistant",
54
+ "content": "[{\"task\": \"conversational\", \"id\": 0, \"dep\": [-1], \"args\": {\"text\": \"please show me a joke of cat\" }}, {\"task\": \"text-to-image\", \"id\": 1, \"dep\": [-1], \"args\": {\"text\": \"a photo of cat\" }}]"
55
+ },
56
+
57
+ {
58
+ "role": "user",
59
+ "content": "give me a picture about a cut dog, then describe the image to me and tell a story about it"
60
+ },
61
+ {
62
+ "role": "assistant",
63
+ "content": "[{\"task\": \"text-to-image\", \"id\": 0, \"dep\": [-1], \"args\": {\"text\": \"a picture of a cut dog\" }}, {\"task\": \"image-to-text\", \"id\": 1, \"dep\": [0], \"args\": {\"image\": \"<GENERATED>-0\" }}, {\"task\": \"text-generation\", \"id\": 2, \"dep\": [1], \"args\": {\"text\": \"<GENERATED>-1\" }}, {\"task\": \"text-to-speech\", \"id\": 3, \"dep\": [2], \"args\": {\"text\": \"<GENERATED>-2\" }}]"
64
+ },
65
+
66
+ {
67
+ "role": "user",
68
+ "content": "give you a picture /example.jpg, what's in it and tell me a joke about it"
69
+ },
70
+ {
71
+ "role": "assistant",
72
+ "content": "[{\"task\": \"image-to-text\", \"id\": 0, \"dep\": [-1], \"args\": {\"image\": \"/example.jpg\" }}, {\"task\": \"object-detection\", \"id\": 1, \"dep\": [-1], \"args\": {\"image\": \"/example.jpg\" }}, {\"task\": \"conversational\", \"id\": 2, \"dep\": [0], \"args\": {\"text\": \"<GENERATED>-0\" }}, {\"task\": \"text-to-speech\", \"id\": 3, \"dep\": [2], \"args\": {\"text\": \"<GENERATED>-1\" }}]"
73
+ }
74
+ ]
demos/demo_response_results.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "role": "user",
4
+ "content": "{{input}}"
5
+ },
6
+ {
7
+ "role": "assistant",
8
+ "content": "Before give you a response, I want to introduce my workflow for your request, which is shown in the following JSON data: {{processes}}. Do you have any demands regarding my response?"
9
+ }
10
+ ]
get_token_ids.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tiktoken
2
+
3
+ encodings = {
4
+ "gpt-3.5-turbo": tiktoken.get_encoding("cl100k_base"),
5
+ "gpt-3.5-turbo-0301": tiktoken.get_encoding("cl100k_base"),
6
+ "text-davinci-003": tiktoken.get_encoding("p50k_base"),
7
+ "text-davinci-002": tiktoken.get_encoding("p50k_base"),
8
+ "text-davinci-001": tiktoken.get_encoding("r50k_base"),
9
+ "text-curie-001": tiktoken.get_encoding("r50k_base"),
10
+ "text-babbage-001": tiktoken.get_encoding("r50k_base"),
11
+ "text-ada-001": tiktoken.get_encoding("r50k_base"),
12
+ "davinci": tiktoken.get_encoding("r50k_base"),
13
+ "curie": tiktoken.get_encoding("r50k_base"),
14
+ "babbage": tiktoken.get_encoding("r50k_base"),
15
+ "ada": tiktoken.get_encoding("r50k_base"),
16
+ }
17
+
18
+ max_length = {
19
+ "gpt-3.5-turbo": 4096,
20
+ "gpt-3.5-turbo-0301": 4096,
21
+ "text-davinci-003": 4096,
22
+ "text-davinci-002": 4096,
23
+ "text-davinci-001": 2049,
24
+ "text-curie-001": 2049,
25
+ "text-babbage-001": 2049,
26
+ "text-ada-001": 2049,
27
+ "davinci": 2049,
28
+ "curie": 2049,
29
+ "babbage": 2049,
30
+ "ada": 2049
31
+ }
32
+
33
+ def count_tokens(model_name, text):
34
+ return len(encodings[model_name].encode(text))
35
+
36
+ def get_max_context_length(model_name):
37
+ return max_length[model_name]
38
+
39
+ def get_token_ids_for_task_parsing(model_name):
40
+ text = '''{"task": "text-classification", "token-classification", "text2text-generation", "summarization", "translation", "question-answering", "conversational", "text-generation", "sentence-similarity", "tabular-classification", "object-detection", "image-classification", "image-to-image", "image-to-text", "text-to-image", "visual-question-answering", "document-question-answering", "image-segmentation", "text-to-speech", "text-to-video", "automatic-speech-recognition", "audio-to-audio", "audio-classification", "canny-control", "hed-control", "mlsd-control", "normal-control", "openpose-control", "canny-text-to-image", "depth-text-to-image", "hed-text-to-image", "mlsd-text-to-image", "normal-text-to-image", "openpose-text-to-image", "seg-text-to-image", "args", "text", "path", "dep", "id", "<GENERATED>-"}'''
41
+ res = encodings[model_name].encode(text)
42
+ res = list(set(res))
43
+ return res
44
+
45
+ def get_token_ids_for_choose_model(model_name):
46
+ text = '''{"id": "reason"}'''
47
+ res = encodings[model_name].encode(text)
48
+ res = list(set(res))
49
+ return res
models_server.py ADDED
@@ -0,0 +1,618 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import random
4
+ import uuid
5
+ import numpy as np
6
+ from transformers import pipeline
7
+ from diffusers import DiffusionPipeline, StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
8
+ from diffusers.utils import load_image
9
+ from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
10
+ from diffusers.utils import export_to_video
11
+ from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5ForSpeechToSpeech
12
+ from transformers import BlipProcessor, BlipForConditionalGeneration
13
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
14
+ from datasets import load_dataset
15
+ from PIL import Image
16
+ import io
17
+ from torchvision import transforms
18
+ import torch
19
+ import torchaudio
20
+ from speechbrain.pretrained import WaveformEnhancement
21
+ import joblib
22
+ from huggingface_hub import hf_hub_url, cached_download
23
+ from transformers import AutoImageProcessor, TimesformerForVideoClassification
24
+ from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation, AutoFeatureExtractor
25
+ from controlnet_aux import OpenposeDetector, MLSDdetector, HEDdetector, CannyDetector, MidasDetector
26
+ from controlnet_aux.open_pose.body import Body
27
+ from controlnet_aux.mlsd.models.mbv2_mlsd_large import MobileV2_MLSD_Large
28
+ from controlnet_aux.hed import Network
29
+ from transformers import DPTForDepthEstimation, DPTFeatureExtractor
30
+ import warnings
31
+ import time
32
+ from espnet2.bin.tts_inference import Text2Speech
33
+ import soundfile as sf
34
+ from asteroid.models import BaseModel
35
+ import traceback
36
+ import os
37
+ import yaml
38
+
39
+ warnings.filterwarnings("ignore")
40
+
41
+ parser = argparse.ArgumentParser()
42
+ parser.add_argument("--config", type=str, default="config.yaml")
43
+ args = parser.parse_args()
44
+
45
+ if __name__ != "__main__":
46
+ args.config = "config.gradio.yaml"
47
+
48
+ logger = logging.getLogger(__name__)
49
+ logger.setLevel(logging.INFO)
50
+ handler = logging.StreamHandler()
51
+ handler.setLevel(logging.INFO)
52
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
53
+ handler.setFormatter(formatter)
54
+ logger.addHandler(handler)
55
+
56
+ config = yaml.load(open(args.config, "r"), Loader=yaml.FullLoader)
57
+
58
+ local_deployment = config["local_deployment"]
59
+ if config["inference_mode"] == "huggingface":
60
+ local_deployment = "none"
61
+
62
+ PROXY = None
63
+ if config["proxy"]:
64
+ PROXY = {
65
+ "https": config["proxy"],
66
+ }
67
+
68
+ start = time.time()
69
+
70
+ # local_models = "models/"
71
+ local_models = ""
72
+
73
+
74
+ def load_pipes(local_deployment):
75
+ other_pipes = {}
76
+ standard_pipes = {}
77
+ controlnet_sd_pipes = {}
78
+ if local_deployment in ["full"]:
79
+ other_pipes = {
80
+
81
+ # "Salesforce/blip-image-captioning-large": {
82
+ # "model": BlipForConditionalGeneration.from_pretrained(f"Salesforce/blip-image-captioning-large"),
83
+ # "processor": BlipProcessor.from_pretrained(f"Salesforce/blip-image-captioning-large"),
84
+ # "device": "cuda:0"
85
+ # },
86
+ "damo-vilab/text-to-video-ms-1.7b": {
87
+ "model": DiffusionPipeline.from_pretrained(f"{local_models}damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16"),
88
+ "device": "cuda:0"
89
+ },
90
+ # "facebook/maskformer-swin-large-ade": {
91
+ # "model": MaskFormerForInstanceSegmentation.from_pretrained(f"facebook/maskformer-swin-large-ade"),
92
+ # "feature_extractor" : AutoFeatureExtractor.from_pretrained("facebook/maskformer-swin-large-ade"),
93
+ # "device": "cuda:0"
94
+ # },
95
+ # "microsoft/trocr-base-printed": {
96
+ # "processor": TrOCRProcessor.from_pretrained(f"microsoft/trocr-base-printed"),
97
+ # "model": VisionEncoderDecoderModel.from_pretrained(f"microsoft/trocr-base-printed"),
98
+ # "device": "cuda:0"
99
+ # },
100
+ # "microsoft/trocr-base-handwritten": {
101
+ # "processor": TrOCRProcessor.from_pretrained(f"microsoft/trocr-base-handwritten"),
102
+ # "model": VisionEncoderDecoderModel.from_pretrained(f"microsoft/trocr-base-handwritten"),
103
+ # "device": "cuda:0"
104
+ # },
105
+ "JorisCos/DCCRNet_Libri1Mix_enhsingle_16k": {
106
+ "model": BaseModel.from_pretrained("JorisCos/DCCRNet_Libri1Mix_enhsingle_16k"),
107
+ "device": "cuda:0"
108
+ },
109
+
110
+ # "CompVis/stable-diffusion-v1-4": {
111
+ # "model": DiffusionPipeline.from_pretrained(f"CompVis/stable-diffusion-v1-4"),
112
+ # "device": "cuda:0"
113
+ # },
114
+ # "stabilityai/stable-diffusion-2-1": {
115
+ # "model": DiffusionPipeline.from_pretrained(f"stabilityai/stable-diffusion-2-1"),
116
+ # "device": "cuda:0"
117
+ # },
118
+
119
+ # "microsoft/speecht5_tts":{
120
+ # "processor": SpeechT5Processor.from_pretrained(f"microsoft/speecht5_tts"),
121
+ # "model": SpeechT5ForTextToSpeech.from_pretrained(f"microsoft/speecht5_tts"),
122
+ # "vocoder": SpeechT5HifiGan.from_pretrained(f"microsoft/speecht5_hifigan"),
123
+ # "embeddings_dataset": load_dataset(f"Matthijs/cmu-arctic-xvectors", split="validation"),
124
+ # "device": "cuda:0"
125
+ # },
126
+ # "speechbrain/mtl-mimic-voicebank": {
127
+ # "model": WaveformEnhancement.from_hparams(source="speechbrain/mtl-mimic-voicebank", savedir="models/mtl-mimic-voicebank"),
128
+ # "device": "cuda:0"
129
+ # },
130
+ "microsoft/speecht5_vc":{
131
+ "processor": SpeechT5Processor.from_pretrained(f"{local_models}microsoft/speecht5_vc"),
132
+ "model": SpeechT5ForSpeechToSpeech.from_pretrained(f"{local_models}microsoft/speecht5_vc"),
133
+ "vocoder": SpeechT5HifiGan.from_pretrained(f"{local_models}microsoft/speecht5_hifigan"),
134
+ "embeddings_dataset": load_dataset(f"{local_models}Matthijs/cmu-arctic-xvectors", split="validation"),
135
+ "device": "cuda:0"
136
+ },
137
+ # "julien-c/wine-quality": {
138
+ # "model": joblib.load(cached_download(hf_hub_url("julien-c/wine-quality", "sklearn_model.joblib")))
139
+ # },
140
+ # "facebook/timesformer-base-finetuned-k400": {
141
+ # "processor": AutoImageProcessor.from_pretrained(f"facebook/timesformer-base-finetuned-k400"),
142
+ # "model": TimesformerForVideoClassification.from_pretrained(f"facebook/timesformer-base-finetuned-k400"),
143
+ # "device": "cuda:0"
144
+ # },
145
+ "facebook/maskformer-swin-base-coco": {
146
+ "feature_extractor": MaskFormerFeatureExtractor.from_pretrained(f"{local_models}facebook/maskformer-swin-base-coco"),
147
+ "model": MaskFormerForInstanceSegmentation.from_pretrained(f"{local_models}facebook/maskformer-swin-base-coco"),
148
+ "device": "cuda:0"
149
+ },
150
+ "Intel/dpt-hybrid-midas": {
151
+ "model": DPTForDepthEstimation.from_pretrained(f"{local_models}Intel/dpt-hybrid-midas", low_cpu_mem_usage=True),
152
+ "feature_extractor": DPTFeatureExtractor.from_pretrained(f"{local_models}Intel/dpt-hybrid-midas"),
153
+ "device": "cuda:0"
154
+ }
155
+ }
156
+
157
+ if local_deployment in ["full", "standard"]:
158
+ standard_pipes = {
159
+ # "nlpconnect/vit-gpt2-image-captioning":{
160
+ # "model": VisionEncoderDecoderModel.from_pretrained(f"{local_models}nlpconnect/vit-gpt2-image-captioning"),
161
+ # "feature_extractor": ViTImageProcessor.from_pretrained(f"{local_models}nlpconnect/vit-gpt2-image-captioning"),
162
+ # "tokenizer": AutoTokenizer.from_pretrained(f"{local_models}nlpconnect/vit-gpt2-image-captioning"),
163
+ # "device": "cuda:0"
164
+ # },
165
+ "espnet/kan-bayashi_ljspeech_vits": {
166
+ "model": Text2Speech.from_pretrained("espnet/kan-bayashi_ljspeech_vits"),
167
+ "device": "cuda:0"
168
+ },
169
+ # "lambdalabs/sd-image-variations-diffusers": {
170
+ # "model": DiffusionPipeline.from_pretrained(f"{local_models}lambdalabs/sd-image-variations-diffusers"), #torch_dtype=torch.float16
171
+ # "device": "cuda:0"
172
+ # },
173
+ "runwayml/stable-diffusion-v1-5": {
174
+ "model": DiffusionPipeline.from_pretrained(f"{local_models}runwayml/stable-diffusion-v1-5"),
175
+ "device": "cuda:0"
176
+ },
177
+ # "superb/wav2vec2-base-superb-ks": {
178
+ # "model": pipeline(task="audio-classification", model=f"superb/wav2vec2-base-superb-ks"),
179
+ # "device": "cuda:0"
180
+ # },
181
+ "openai/whisper-base": {
182
+ "model": pipeline(task="automatic-speech-recognition", model=f"{local_models}openai/whisper-base"),
183
+ "device": "cuda:0"
184
+ },
185
+ # "microsoft/speecht5_asr": {
186
+ # "model": pipeline(task="automatic-speech-recognition", model=f"{local_models}microsoft/speecht5_asr"),
187
+ # "device": "cuda:0"
188
+ # },
189
+ "Intel/dpt-large": {
190
+ "model": pipeline(task="depth-estimation", model=f"{local_models}Intel/dpt-large"),
191
+ "device": "cuda:0"
192
+ },
193
+ # "microsoft/beit-base-patch16-224-pt22k-ft22k": {
194
+ # "model": pipeline(task="image-classification", model=f"microsoft/beit-base-patch16-224-pt22k-ft22k"),
195
+ # "device": "cuda:0"
196
+ # },
197
+ "facebook/detr-resnet-50-panoptic": {
198
+ "model": pipeline(task="image-segmentation", model=f"{local_models}facebook/detr-resnet-50-panoptic"),
199
+ "device": "cuda:0"
200
+ },
201
+ "facebook/detr-resnet-101": {
202
+ "model": pipeline(task="object-detection", model=f"{local_models}facebook/detr-resnet-101"),
203
+ "device": "cuda:0"
204
+ },
205
+ # "openai/clip-vit-large-patch14": {
206
+ # "model": pipeline(task="zero-shot-image-classification", model=f"openai/clip-vit-large-patch14"),
207
+ # "device": "cuda:0"
208
+ # },
209
+ # "google/owlvit-base-patch32": {
210
+ # "model": pipeline(task="zero-shot-object-detection", model=f"{local_models}google/owlvit-base-patch32"),
211
+ # "device": "cuda:0"
212
+ # },
213
+ # "microsoft/DialoGPT-medium": {
214
+ # "model": pipeline(task="conversational", model=f"microsoft/DialoGPT-medium"),
215
+ # "device": "cuda:0"
216
+ # },
217
+ # "bert-base-uncased": {
218
+ # "model": pipeline(task="fill-mask", model=f"bert-base-uncased"),
219
+ # "device": "cuda:0"
220
+ # },
221
+ # "deepset/roberta-base-squad2": {
222
+ # "model": pipeline(task = "question-answering", model=f"deepset/roberta-base-squad2"),
223
+ # "device": "cuda:0"
224
+ # },
225
+ # "facebook/bart-large-cnn": {
226
+ # "model": pipeline(task="summarization", model=f"facebook/bart-large-cnn"),
227
+ # "device": "cuda:0"
228
+ # },
229
+ # "google/tapas-base-finetuned-wtq": {
230
+ # "model": pipeline(task="table-question-answering", model=f"google/tapas-base-finetuned-wtq"),
231
+ # "device": "cuda:0"
232
+ # },
233
+ # "distilbert-base-uncased-finetuned-sst-2-english": {
234
+ # "model": pipeline(task="text-classification", model=f"distilbert-base-uncased-finetuned-sst-2-english"),
235
+ # "device": "cuda:0"
236
+ # },
237
+ # "gpt2": {
238
+ # "model": pipeline(task="text-generation", model="gpt2"),
239
+ # "device": "cuda:0"
240
+ # },
241
+ # "mrm8488/t5-base-finetuned-question-generation-ap": {
242
+ # "model": pipeline(task="text2text-generation", model=f"mrm8488/t5-base-finetuned-question-generation-ap"),
243
+ # "device": "cuda:0"
244
+ # },
245
+ # "Jean-Baptiste/camembert-ner": {
246
+ # "model": pipeline(task="token-classification", model=f"Jean-Baptiste/camembert-ner", aggregation_strategy="simple"),
247
+ # "device": "cuda:0"
248
+ # },
249
+ # "t5-base": {
250
+ # "model": pipeline(task="translation", model=f"t5-base"),
251
+ # "device": "cuda:0"
252
+ # },
253
+ # "impira/layoutlm-document-qa": {
254
+ # "model": pipeline(task="document-question-answering", model=f"{local_models}impira/layoutlm-document-qa"),
255
+ # "device": "cuda:0"
256
+ # },
257
+ "ydshieh/vit-gpt2-coco-en": {
258
+ "model": pipeline(task="image-to-text", model=f"{local_models}ydshieh/vit-gpt2-coco-en"),
259
+ "device": "cuda:0"
260
+ },
261
+ "dandelin/vilt-b32-finetuned-vqa": {
262
+ "model": pipeline(task="visual-question-answering", model=f"{local_models}dandelin/vilt-b32-finetuned-vqa"),
263
+ "device": "cuda:0"
264
+ }
265
+ }
266
+
267
+ if local_deployment in ["full", "standard", "minimal"]:
268
+
269
+ controlnet = ControlNetModel.from_pretrained(f"{local_models}lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
270
+ controlnetpipe = StableDiffusionControlNetPipeline.from_pretrained(
271
+ f"{local_models}runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
272
+ )
273
+
274
+
275
+ hed_network = HEDdetector.from_pretrained('lllyasviel/ControlNet')
276
+
277
+ controlnet_sd_pipes = {
278
+ "openpose-control": {
279
+ "model": OpenposeDetector.from_pretrained('lllyasviel/ControlNet')
280
+ },
281
+ "mlsd-control": {
282
+ "model": MLSDdetector.from_pretrained('lllyasviel/ControlNet')
283
+ },
284
+ "hed-control": {
285
+ "model": hed_network
286
+ },
287
+ "scribble-control": {
288
+ "model": hed_network
289
+ },
290
+ "midas-control": {
291
+ "model": MidasDetector.from_pretrained('lllyasviel/ControlNet')
292
+ },
293
+ "canny-control": {
294
+ "model": CannyDetector()
295
+ },
296
+ "lllyasviel/sd-controlnet-canny":{
297
+ "control": controlnet,
298
+ "model": controlnetpipe,
299
+ "device": "cuda:0"
300
+ },
301
+ "lllyasviel/sd-controlnet-depth":{
302
+ "control": ControlNetModel.from_pretrained(f"{local_models}lllyasviel/sd-controlnet-depth", torch_dtype=torch.float16),
303
+ "model": controlnetpipe,
304
+ "device": "cuda:0"
305
+ },
306
+ "lllyasviel/sd-controlnet-hed":{
307
+ "control": ControlNetModel.from_pretrained(f"{local_models}lllyasviel/sd-controlnet-hed", torch_dtype=torch.float16),
308
+ "model": controlnetpipe,
309
+ "device": "cuda:0"
310
+ },
311
+ "lllyasviel/sd-controlnet-mlsd":{
312
+ "control": ControlNetModel.from_pretrained(f"{local_models}lllyasviel/sd-controlnet-mlsd", torch_dtype=torch.float16),
313
+ "model": controlnetpipe,
314
+ "device": "cuda:0"
315
+ },
316
+ "lllyasviel/sd-controlnet-openpose":{
317
+ "control": ControlNetModel.from_pretrained(f"{local_models}lllyasviel/sd-controlnet-openpose", torch_dtype=torch.float16),
318
+ "model": controlnetpipe,
319
+ "device": "cuda:0"
320
+ },
321
+ "lllyasviel/sd-controlnet-scribble":{
322
+ "control": ControlNetModel.from_pretrained(f"{local_models}lllyasviel/sd-controlnet-scribble", torch_dtype=torch.float16),
323
+ "model": controlnetpipe,
324
+ "device": "cuda:0"
325
+ },
326
+ "lllyasviel/sd-controlnet-seg":{
327
+ "control": ControlNetModel.from_pretrained(f"{local_models}lllyasviel/sd-controlnet-seg", torch_dtype=torch.float16),
328
+ "model": controlnetpipe,
329
+ "device": "cuda:0"
330
+ }
331
+ }
332
+ pipes = {**standard_pipes, **other_pipes, **controlnet_sd_pipes}
333
+ return pipes
334
+
335
+ pipes = load_pipes(local_deployment)
336
+
337
+ end = time.time()
338
+ during = end - start
339
+
340
+ print(f"[ ready ] {during}s")
341
+
342
+ def running():
343
+ return {"running": True}
344
+
345
+ def status(model_id):
346
+ disabled_models = ["microsoft/trocr-base-printed", "microsoft/trocr-base-handwritten"]
347
+ if model_id in pipes.keys() and model_id not in disabled_models:
348
+ print(f"[ check {model_id} ] success")
349
+ return {"loaded": True}
350
+ else:
351
+ print(f"[ check {model_id} ] failed")
352
+ return {"loaded": False}
353
+
354
+ def models(model_id, data):
355
+ while "using" in pipes[model_id] and pipes[model_id]["using"]:
356
+ print(f"[ inference {model_id} ] waiting")
357
+ time.sleep(0.1)
358
+ pipes[model_id]["using"] = True
359
+ print(f"[ inference {model_id} ] start")
360
+
361
+ start = time.time()
362
+
363
+ pipe = pipes[model_id]["model"]
364
+
365
+ if "device" in pipes[model_id]:
366
+ try:
367
+ pipe.to(pipes[model_id]["device"])
368
+ except:
369
+ pipe.device = torch.device(pipes[model_id]["device"])
370
+ pipe.model.to(pipes[model_id]["device"])
371
+
372
+ result = None
373
+ try:
374
+ # text to video
375
+ if model_id == "damo-vilab/text-to-video-ms-1.7b":
376
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
377
+ # pipe.enable_model_cpu_offload()
378
+ prompt = data["text"]
379
+ video_frames = pipe(prompt, num_inference_steps=50, num_frames=40).frames
380
+ file_name = str(uuid.uuid4())[:4]
381
+ video_path = export_to_video(video_frames, f"public/videos/{file_name}.mp4")
382
+
383
+ new_file_name = str(uuid.uuid4())[:4]
384
+ os.system(f"ffmpeg -i {video_path} -vcodec libx264 public/videos/{new_file_name}.mp4")
385
+
386
+ if os.path.exists(f"public/videos/{new_file_name}.mp4"):
387
+ result = {"path": f"/videos/{new_file_name}.mp4"}
388
+ else:
389
+ result = {"path": f"/videos/{file_name}.mp4"}
390
+
391
+ # controlnet
392
+ if model_id.startswith("lllyasviel/sd-controlnet-"):
393
+ pipe.controlnet.to('cpu')
394
+ pipe.controlnet = pipes[model_id]["control"].to(pipes[model_id]["device"])
395
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
396
+ control_image = load_image(data["img_url"])
397
+ # generator = torch.manual_seed(66)
398
+ out_image: Image = pipe(data["text"], num_inference_steps=20, image=control_image).images[0]
399
+ file_name = str(uuid.uuid4())[:4]
400
+ out_image.save(f"public/images/{file_name}.png")
401
+ result = {"path": f"/images/{file_name}.png"}
402
+
403
+ if model_id.endswith("-control"):
404
+ image = load_image(data["img_url"])
405
+ if "scribble" in model_id:
406
+ control = pipe(image, scribble = True)
407
+ elif "canny" in model_id:
408
+ control = pipe(image, low_threshold=100, high_threshold=200)
409
+ else:
410
+ control = pipe(image)
411
+ file_name = str(uuid.uuid4())[:4]
412
+ control.save(f"public/images/{file_name}.png")
413
+ result = {"path": f"/images/{file_name}.png"}
414
+
415
+ # image to image
416
+ if model_id == "lambdalabs/sd-image-variations-diffusers":
417
+ im = load_image(data["img_url"])
418
+ file_name = str(uuid.uuid4())[:4]
419
+ with open(f"public/images/{file_name}.png", "wb") as f:
420
+ f.write(data)
421
+ tform = transforms.Compose([
422
+ transforms.ToTensor(),
423
+ transforms.Resize(
424
+ (224, 224),
425
+ interpolation=transforms.InterpolationMode.BICUBIC,
426
+ antialias=False,
427
+ ),
428
+ transforms.Normalize(
429
+ [0.48145466, 0.4578275, 0.40821073],
430
+ [0.26862954, 0.26130258, 0.27577711]),
431
+ ])
432
+ inp = tform(im).to(pipes[model_id]["device"]).unsqueeze(0)
433
+ out = pipe(inp, guidance_scale=3)
434
+ out["images"][0].save(f"public/images/{file_name}.jpg")
435
+ result = {"path": f"/images/{file_name}.jpg"}
436
+
437
+ # image to text
438
+ if model_id == "Salesforce/blip-image-captioning-large":
439
+ raw_image = load_image(data["img_url"]).convert('RGB')
440
+ text = data["text"]
441
+ inputs = pipes[model_id]["processor"](raw_image, return_tensors="pt").to(pipes[model_id]["device"])
442
+ out = pipe.generate(**inputs)
443
+ caption = pipes[model_id]["processor"].decode(out[0], skip_special_tokens=True)
444
+ result = {"generated text": caption}
445
+ if model_id == "ydshieh/vit-gpt2-coco-en":
446
+ img_url = data["img_url"]
447
+ generated_text = pipe(img_url)[0]['generated_text']
448
+ result = {"generated text": generated_text}
449
+ if model_id == "nlpconnect/vit-gpt2-image-captioning":
450
+ image = load_image(data["img_url"]).convert("RGB")
451
+ pixel_values = pipes[model_id]["feature_extractor"](images=image, return_tensors="pt").pixel_values
452
+ pixel_values = pixel_values.to(pipes[model_id]["device"])
453
+ generated_ids = pipe.generate(pixel_values, **{"max_length": 200, "num_beams": 1})
454
+ generated_text = pipes[model_id]["tokenizer"].batch_decode(generated_ids, skip_special_tokens=True)[0]
455
+ result = {"generated text": generated_text}
456
+ # image to text: OCR
457
+ if model_id == "microsoft/trocr-base-printed" or model_id == "microsoft/trocr-base-handwritten":
458
+ image = load_image(data["img_url"]).convert("RGB")
459
+ pixel_values = pipes[model_id]["processor"](image, return_tensors="pt").pixel_values
460
+ pixel_values = pixel_values.to(pipes[model_id]["device"])
461
+ generated_ids = pipe.generate(pixel_values)
462
+ generated_text = pipes[model_id]["processor"].batch_decode(generated_ids, skip_special_tokens=True)[0]
463
+ result = {"generated text": generated_text}
464
+
465
+ # text to image
466
+ if model_id == "runwayml/stable-diffusion-v1-5":
467
+ file_name = str(uuid.uuid4())[:4]
468
+ text = data["text"]
469
+ out = pipe(prompt=text)
470
+ out["images"][0].save(f"public/images/{file_name}.jpg")
471
+ result = {"path": f"/images/{file_name}.jpg"}
472
+
473
+ # object detection
474
+ if model_id == "google/owlvit-base-patch32" or model_id == "facebook/detr-resnet-101":
475
+ img_url = data["img_url"]
476
+ open_types = ["cat", "couch", "person", "car", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird"]
477
+ result = pipe(img_url, candidate_labels=open_types)
478
+
479
+ # VQA
480
+ if model_id == "dandelin/vilt-b32-finetuned-vqa":
481
+ question = data["text"]
482
+ img_url = data["img_url"]
483
+ result = pipe(question=question, image=img_url)
484
+
485
+ #DQA
486
+ if model_id == "impira/layoutlm-document-qa":
487
+ question = data["text"]
488
+ img_url = data["img_url"]
489
+ result = pipe(img_url, question)
490
+
491
+ # depth-estimation
492
+ if model_id == "Intel/dpt-large":
493
+ output = pipe(data["img_url"])
494
+ image = output['depth']
495
+ name = str(uuid.uuid4())[:4]
496
+ image.save(f"public/images/{name}.jpg")
497
+ result = {"path": f"/images/{name}.jpg"}
498
+
499
+ if model_id == "Intel/dpt-hybrid-midas" and model_id == "Intel/dpt-large":
500
+ image = load_image(data["img_url"])
501
+ inputs = pipes[model_id]["feature_extractor"](images=image, return_tensors="pt")
502
+ with torch.no_grad():
503
+ outputs = pipe(**inputs)
504
+ predicted_depth = outputs.predicted_depth
505
+ prediction = torch.nn.functional.interpolate(
506
+ predicted_depth.unsqueeze(1),
507
+ size=image.size[::-1],
508
+ mode="bicubic",
509
+ align_corners=False,
510
+ )
511
+ output = prediction.squeeze().cpu().numpy()
512
+ formatted = (output * 255 / np.max(output)).astype("uint8")
513
+ image = Image.fromarray(formatted)
514
+ name = str(uuid.uuid4())[:4]
515
+ image.save(f"public/images/{name}.jpg")
516
+ result = {"path": f"/images/{name}.jpg"}
517
+
518
+ # TTS
519
+ if model_id == "espnet/kan-bayashi_ljspeech_vits":
520
+ text = data["text"]
521
+ wav = pipe(text)["wav"]
522
+ name = str(uuid.uuid4())[:4]
523
+ sf.write(f"public/audios/{name}.wav", wav.cpu().numpy(), pipe.fs, "PCM_16")
524
+ result = {"path": f"/audios/{name}.wav"}
525
+
526
+ if model_id == "microsoft/speecht5_tts":
527
+ text = data["text"]
528
+ inputs = pipes[model_id]["processor"](text=text, return_tensors="pt")
529
+ embeddings_dataset = pipes[model_id]["embeddings_dataset"]
530
+ speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0).to(pipes[model_id]["device"])
531
+ pipes[model_id]["vocoder"].to(pipes[model_id]["device"])
532
+ speech = pipe.generate_speech(inputs["input_ids"].to(pipes[model_id]["device"]), speaker_embeddings, vocoder=pipes[model_id]["vocoder"])
533
+ name = str(uuid.uuid4())[:4]
534
+ sf.write(f"public/audios/{name}.wav", speech.cpu().numpy(), samplerate=16000)
535
+ result = {"path": f"/audios/{name}.wav"}
536
+
537
+ # ASR
538
+ if model_id == "openai/whisper-base" or model_id == "microsoft/speecht5_asr":
539
+ audio_url = data["audio_url"]
540
+ result = { "text": pipe(audio_url)["text"]}
541
+
542
+ # audio to audio
543
+ if model_id == "JorisCos/DCCRNet_Libri1Mix_enhsingle_16k":
544
+ audio_url = data["audio_url"]
545
+ wav, sr = torchaudio.load(audio_url)
546
+ with torch.no_grad():
547
+ result_wav = pipe(wav.to(pipes[model_id]["device"]))
548
+ name = str(uuid.uuid4())[:4]
549
+ sf.write(f"public/audios/{name}.wav", result_wav.cpu().squeeze().numpy(), sr)
550
+ result = {"path": f"/audios/{name}.wav"}
551
+
552
+ if model_id == "microsoft/speecht5_vc":
553
+ audio_url = data["audio_url"]
554
+ wav, sr = torchaudio.load(audio_url)
555
+ inputs = pipes[model_id]["processor"](audio=wav, sampling_rate=sr, return_tensors="pt")
556
+ embeddings_dataset = pipes[model_id]["embeddings_dataset"]
557
+ speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
558
+ pipes[model_id]["vocoder"].to(pipes[model_id]["device"])
559
+ speech = pipe.generate_speech(inputs["input_ids"].to(pipes[model_id]["device"]), speaker_embeddings, vocoder=pipes[model_id]["vocoder"])
560
+ name = str(uuid.uuid4())[:4]
561
+ sf.write(f"public/audios/{name}.wav", speech.cpu().numpy(), samplerate=16000)
562
+ result = {"path": f"/audios/{name}.wav"}
563
+
564
+ # segmentation
565
+ if model_id == "facebook/detr-resnet-50-panoptic":
566
+ result = []
567
+ segments = pipe(data["img_url"])
568
+ image = load_image(data["img_url"])
569
+
570
+ colors = []
571
+ for i in range(len(segments)):
572
+ colors.append((random.randint(100, 255), random.randint(100, 255), random.randint(100, 255), 50))
573
+
574
+ for segment in segments:
575
+ mask = segment["mask"]
576
+ mask = mask.convert('L')
577
+ layer = Image.new('RGBA', mask.size, colors[i])
578
+ image.paste(layer, (0, 0), mask)
579
+ name = str(uuid.uuid4())[:4]
580
+ image.save(f"public/images/{name}.jpg")
581
+ result = {"path": f"/images/{name}.jpg"}
582
+
583
+ if model_id == "facebook/maskformer-swin-base-coco" or model_id == "facebook/maskformer-swin-large-ade":
584
+ image = load_image(data["img_url"])
585
+ inputs = pipes[model_id]["feature_extractor"](images=image, return_tensors="pt").to(pipes[model_id]["device"])
586
+ outputs = pipe(**inputs)
587
+ result = pipes[model_id]["feature_extractor"].post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
588
+ predicted_panoptic_map = result["segmentation"].cpu().numpy()
589
+ predicted_panoptic_map = Image.fromarray(predicted_panoptic_map.astype(np.uint8))
590
+ name = str(uuid.uuid4())[:4]
591
+ predicted_panoptic_map.save(f"public/images/{name}.jpg")
592
+ result = {"path": f"/images/{name}.jpg"}
593
+
594
+ except Exception as e:
595
+ print(e)
596
+ traceback.print_exc()
597
+ result = {"error": {"message": "Error when running the model inference."}}
598
+
599
+ if "device" in pipes[model_id]:
600
+ try:
601
+ pipe.to("cpu")
602
+ torch.cuda.empty_cache()
603
+ except:
604
+ pipe.device = torch.device("cpu")
605
+ pipe.model.to("cpu")
606
+ torch.cuda.empty_cache()
607
+
608
+ pipes[model_id]["using"] = False
609
+
610
+ if result is None:
611
+ result = {"error": {"message": "model not found"}}
612
+
613
+ end = time.time()
614
+ during = end - start
615
+ print(f"[ complete {model_id} ] {during}s")
616
+ print(f"[ result {model_id} ] {result}")
617
+
618
+ return result
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ tesseract-ocr
public/examples/a.jpg ADDED
public/examples/b.jpg ADDED
public/examples/c.jpg ADDED
public/examples/d.jpg ADDED
public/examples/e.jpg ADDED
public/examples/f.jpg ADDED
public/examples/g.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/huggingface/diffusers.git@8c530fc2f6a76a2aefb6b285dce6df1675092ac6#egg=diffusers
2
+ git+https://github.com/huggingface/transformers@c612628045822f909020f7eb6784c79700813eda#egg=transformers
3
+ git+https://github.com/patrickvonplaten/controlnet_aux@78efc716868a7f5669c288233d65b471f542ce40#egg=controlnet_aux
4
+ tiktoken==0.3.3
5
+ pydub==0.25.1
6
+ espnet==202301
7
+ espnet_model_zoo==0.1.7
8
+ flask==2.2.3
9
+ flask_cors==3.0.10
10
+ waitress==2.1.2
11
+ datasets==2.11.0
12
+ asteroid==0.6.0
13
+ speechbrain==0.5.14
14
+ timm==0.6.13
15
+ typeguard==2.13.3
16
+ accelerate==0.18.0
17
+ pytesseract==0.3.10
18
+ basicsr==1.4.2