KevinHuSh
commited on
Commit
·
db713d9
1
Parent(s):
30d6885
Add 2 embeding models from OpenAI (#812)
Browse files### What problem does this PR solve?
#810
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
- api/db/init_data.py +30 -0
- api/db/services/llm_service.py +10 -0
api/db/init_data.py
CHANGED
|
@@ -16,6 +16,7 @@
|
|
| 16 |
import os
|
| 17 |
import time
|
| 18 |
import uuid
|
|
|
|
| 19 |
|
| 20 |
from api.db import LLMType, UserTenantRole
|
| 21 |
from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM, TenantLLM
|
|
@@ -166,6 +167,18 @@ def init_llm_factory():
|
|
| 166 |
"tags": "TEXT EMBEDDING,8K",
|
| 167 |
"max_tokens": 8191,
|
| 168 |
"model_type": LLMType.EMBEDDING.value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
}, {
|
| 170 |
"fid": factory_infos[0]["name"],
|
| 171 |
"llm_name": "whisper-1",
|
|
@@ -376,6 +389,23 @@ def init_llm_factory():
|
|
| 376 |
LLMFactoriesService.filter_delete([LLMFactoriesService.model.name == "QAnything"])
|
| 377 |
LLMService.filter_delete([LLMService.model.fid == "QAnything"])
|
| 378 |
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
"""
|
| 380 |
drop table llm;
|
| 381 |
drop table llm_factories;
|
|
|
|
| 16 |
import os
|
| 17 |
import time
|
| 18 |
import uuid
|
| 19 |
+
from copy import deepcopy
|
| 20 |
|
| 21 |
from api.db import LLMType, UserTenantRole
|
| 22 |
from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM, TenantLLM
|
|
|
|
| 167 |
"tags": "TEXT EMBEDDING,8K",
|
| 168 |
"max_tokens": 8191,
|
| 169 |
"model_type": LLMType.EMBEDDING.value
|
| 170 |
+
}, {
|
| 171 |
+
"fid": factory_infos[0]["name"],
|
| 172 |
+
"llm_name": "text-embedding-3-small",
|
| 173 |
+
"tags": "TEXT EMBEDDING,8K",
|
| 174 |
+
"max_tokens": 8191,
|
| 175 |
+
"model_type": LLMType.EMBEDDING.value
|
| 176 |
+
}, {
|
| 177 |
+
"fid": factory_infos[0]["name"],
|
| 178 |
+
"llm_name": "text-embedding-3-large",
|
| 179 |
+
"tags": "TEXT EMBEDDING,8K",
|
| 180 |
+
"max_tokens": 8191,
|
| 181 |
+
"model_type": LLMType.EMBEDDING.value
|
| 182 |
}, {
|
| 183 |
"fid": factory_infos[0]["name"],
|
| 184 |
"llm_name": "whisper-1",
|
|
|
|
| 389 |
LLMFactoriesService.filter_delete([LLMFactoriesService.model.name == "QAnything"])
|
| 390 |
LLMService.filter_delete([LLMService.model.fid == "QAnything"])
|
| 391 |
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"})
|
| 392 |
+
## insert openai two embedding models to the current openai user.
|
| 393 |
+
print("Start to insert 2 OpenAI embedding models...")
|
| 394 |
+
tenant_ids = set([row.tenant_id for row in TenantLLMService.get_openai_models()])
|
| 395 |
+
for tid in tenant_ids:
|
| 396 |
+
for row in TenantLLMService.get_openai_models(llm_factory="OpenAI", tenant_id=tid):
|
| 397 |
+
row = row.to_dict()
|
| 398 |
+
row["model_type"] = LLMType.EMBEDDING.value
|
| 399 |
+
row["llm_name"] = "text-embedding-3-small"
|
| 400 |
+
row["used_tokens"] = 0
|
| 401 |
+
try:
|
| 402 |
+
TenantLLMService.save(**row)
|
| 403 |
+
row = deepcopy(row)
|
| 404 |
+
row["llm_name"] = "text-embedding-3-large"
|
| 405 |
+
TenantLLMService.save(**row)
|
| 406 |
+
except Exception as e:
|
| 407 |
+
pass
|
| 408 |
+
break
|
| 409 |
"""
|
| 410 |
drop table llm;
|
| 411 |
drop table llm_factories;
|
api/db/services/llm_service.py
CHANGED
|
@@ -135,6 +135,16 @@ class TenantLLMService(CommonService):
|
|
| 135 |
.execute()
|
| 136 |
return num
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
class LLMBundle(object):
|
| 140 |
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese"):
|
|
|
|
| 135 |
.execute()
|
| 136 |
return num
|
| 137 |
|
| 138 |
+
@classmethod
|
| 139 |
+
@DB.connection_context()
|
| 140 |
+
def get_openai_models(cls):
|
| 141 |
+
objs = cls.model.select().where(
|
| 142 |
+
(cls.model.llm_factory == "OpenAI"),
|
| 143 |
+
~(cls.model.llm_name == "text-embedding-3-small"),
|
| 144 |
+
~(cls.model.llm_name == "text-embedding-3-large")
|
| 145 |
+
).dicts()
|
| 146 |
+
return list(objs)
|
| 147 |
+
|
| 148 |
|
| 149 |
class LLMBundle(object):
|
| 150 |
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese"):
|