Kevin Hu
commited on
Commit
·
1e02591
1
Parent(s):
c50cfc1
Fix @ in model name issue. (#3821)
Browse files### What problem does this PR solve?
#3814
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
api/db/services/dialog_service.py
CHANGED
|
@@ -120,7 +120,7 @@ def message_fit_in(msg, max_length=4000):
|
|
| 120 |
|
| 121 |
|
| 122 |
def llm_id2llm_type(llm_id):
|
| 123 |
-
llm_id =
|
| 124 |
fnm = os.path.join(get_project_base_directory(), "conf")
|
| 125 |
llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r"))
|
| 126 |
for llm_factory in llm_factories["factory_llm_infos"]:
|
|
@@ -132,11 +132,7 @@ def llm_id2llm_type(llm_id):
|
|
| 132 |
def chat(dialog, messages, stream=True, **kwargs):
|
| 133 |
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
| 134 |
st = timer()
|
| 135 |
-
|
| 136 |
-
fid = None
|
| 137 |
-
llm_id = tmp[0]
|
| 138 |
-
if len(tmp)>1: fid = tmp[1]
|
| 139 |
-
|
| 140 |
llm = LLMService.query(llm_name=llm_id) if not fid else LLMService.query(llm_name=llm_id, fid=fid)
|
| 141 |
if not llm:
|
| 142 |
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not fid else \
|
|
|
|
| 120 |
|
| 121 |
|
| 122 |
def llm_id2llm_type(llm_id):
|
| 123 |
+
llm_id, _ = TenantLLMService.split_model_name_and_factory(llm_id)
|
| 124 |
fnm = os.path.join(get_project_base_directory(), "conf")
|
| 125 |
llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r"))
|
| 126 |
for llm_factory in llm_factories["factory_llm_infos"]:
|
|
|
|
| 132 |
def chat(dialog, messages, stream=True, **kwargs):
|
| 133 |
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
| 134 |
st = timer()
|
| 135 |
+
llm_id, fid = TenantLLMService.split_model_name_and_factory(dialog.llm_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
llm = LLMService.query(llm_name=llm_id) if not fid else LLMService.query(llm_name=llm_id, fid=fid)
|
| 137 |
if not llm:
|
| 138 |
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not fid else \
|
api/db/services/llm_service.py
CHANGED
|
@@ -13,8 +13,12 @@
|
|
| 13 |
# See the License for the specific language governing permissions and
|
| 14 |
# limitations under the License.
|
| 15 |
#
|
|
|
|
| 16 |
import logging
|
|
|
|
|
|
|
| 17 |
from api.db.services.user_service import TenantService
|
|
|
|
| 18 |
from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel, TTSModel
|
| 19 |
from api.db import LLMType
|
| 20 |
from api.db.db_models import DB
|
|
@@ -36,11 +40,11 @@ class TenantLLMService(CommonService):
|
|
| 36 |
@classmethod
|
| 37 |
@DB.connection_context()
|
| 38 |
def get_api_key(cls, tenant_id, model_name):
|
| 39 |
-
|
| 40 |
-
if
|
| 41 |
-
objs = cls.query(tenant_id=tenant_id, llm_name=
|
| 42 |
else:
|
| 43 |
-
objs = cls.query(tenant_id=tenant_id, llm_name=
|
| 44 |
if not objs:
|
| 45 |
return
|
| 46 |
return objs[0]
|
|
@@ -61,6 +65,23 @@ class TenantLLMService(CommonService):
|
|
| 61 |
|
| 62 |
return list(objs)
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
@classmethod
|
| 65 |
@DB.connection_context()
|
| 66 |
def model_instance(cls, tenant_id, llm_type,
|
|
@@ -85,9 +106,7 @@ class TenantLLMService(CommonService):
|
|
| 85 |
assert False, "LLM type error"
|
| 86 |
|
| 87 |
model_config = cls.get_api_key(tenant_id, mdlnm)
|
| 88 |
-
|
| 89 |
-
fid = None if len(tmp) < 2 else tmp[1]
|
| 90 |
-
mdlnm = tmp[0]
|
| 91 |
if model_config: model_config = model_config.to_dict()
|
| 92 |
if not model_config:
|
| 93 |
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
|
@@ -168,7 +187,7 @@ class TenantLLMService(CommonService):
|
|
| 168 |
else:
|
| 169 |
assert False, "LLM type error"
|
| 170 |
|
| 171 |
-
llm_name =
|
| 172 |
|
| 173 |
num = 0
|
| 174 |
try:
|
|
@@ -179,7 +198,7 @@ class TenantLLMService(CommonService):
|
|
| 179 |
.where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name)\
|
| 180 |
.execute()
|
| 181 |
else:
|
| 182 |
-
|
| 183 |
num = cls.model.create(tenant_id=tenant_id, llm_factory=llm_factory, llm_name=llm_name, used_tokens=used_tokens)
|
| 184 |
except Exception:
|
| 185 |
logging.exception("TenantLLMService.increase_usage got exception")
|
|
|
|
| 13 |
# See the License for the specific language governing permissions and
|
| 14 |
# limitations under the License.
|
| 15 |
#
|
| 16 |
+
import json
|
| 17 |
import logging
|
| 18 |
+
import os
|
| 19 |
+
|
| 20 |
from api.db.services.user_service import TenantService
|
| 21 |
+
from api.utils.file_utils import get_project_base_directory
|
| 22 |
from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel, TTSModel
|
| 23 |
from api.db import LLMType
|
| 24 |
from api.db.db_models import DB
|
|
|
|
| 40 |
@classmethod
|
| 41 |
@DB.connection_context()
|
| 42 |
def get_api_key(cls, tenant_id, model_name):
|
| 43 |
+
mdlnm, fid = TenantLLMService.split_model_name_and_factory(model_name)
|
| 44 |
+
if not fid:
|
| 45 |
+
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm)
|
| 46 |
else:
|
| 47 |
+
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
|
| 48 |
if not objs:
|
| 49 |
return
|
| 50 |
return objs[0]
|
|
|
|
| 65 |
|
| 66 |
return list(objs)
|
| 67 |
|
| 68 |
+
@staticmethod
|
| 69 |
+
def split_model_name_and_factory(model_name):
|
| 70 |
+
arr = model_name.split("@")
|
| 71 |
+
if len(arr) < 2:
|
| 72 |
+
return model_name, None
|
| 73 |
+
if len(arr) > 2:
|
| 74 |
+
return "@".join(arr[0:-1]), arr[-1]
|
| 75 |
+
try:
|
| 76 |
+
fact = json.load(open(os.path.join(get_project_base_directory(), "conf/llm_factories.json"), "r"))["factory_llm_infos"]
|
| 77 |
+
fact = set([f["name"] for f in fact])
|
| 78 |
+
if arr[-1] not in fact:
|
| 79 |
+
return model_name, None
|
| 80 |
+
return arr[0], arr[-1]
|
| 81 |
+
except Exception as e:
|
| 82 |
+
logging.exception(f"TenantLLMService.split_model_name_and_factory got exception: {e}")
|
| 83 |
+
return model_name, None
|
| 84 |
+
|
| 85 |
@classmethod
|
| 86 |
@DB.connection_context()
|
| 87 |
def model_instance(cls, tenant_id, llm_type,
|
|
|
|
| 106 |
assert False, "LLM type error"
|
| 107 |
|
| 108 |
model_config = cls.get_api_key(tenant_id, mdlnm)
|
| 109 |
+
mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
|
|
|
|
|
|
|
| 110 |
if model_config: model_config = model_config.to_dict()
|
| 111 |
if not model_config:
|
| 112 |
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
|
|
|
| 187 |
else:
|
| 188 |
assert False, "LLM type error"
|
| 189 |
|
| 190 |
+
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm)
|
| 191 |
|
| 192 |
num = 0
|
| 193 |
try:
|
|
|
|
| 198 |
.where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name)\
|
| 199 |
.execute()
|
| 200 |
else:
|
| 201 |
+
if not llm_factory: llm_factory = mdlnm
|
| 202 |
num = cls.model.create(tenant_id=tenant_id, llm_factory=llm_factory, llm_name=llm_name, used_tokens=used_tokens)
|
| 203 |
except Exception:
|
| 204 |
logging.exception("TenantLLMService.increase_usage got exception")
|