refactor some llm api using openai api format (#1692)
Browse files### What problem does this PR solve?
refactor some llm api using openai api format
### Type of change
- [x] Refactoring
---------
Co-authored-by: Zhedong Cen <[email protected]>
- rag/llm/chat_model.py +28 -150
- rag/llm/cv_model.py +15 -67
- rag/llm/embedding_model.py +15 -23
rag/llm/chat_model.py
CHANGED
|
@@ -24,6 +24,7 @@ from volcengine.maas.v2 import MaasService
|
|
| 24 |
from rag.nlp import is_english
|
| 25 |
from rag.utils import num_tokens_from_string
|
| 26 |
from groq import Groq
|
|
|
|
| 27 |
import json
|
| 28 |
import requests
|
| 29 |
|
|
@@ -60,9 +61,16 @@ class Base(ABC):
|
|
| 60 |
stream=True,
|
| 61 |
**gen_conf)
|
| 62 |
for resp in response:
|
| 63 |
-
if not resp.choices
|
| 64 |
ans += resp.choices[0].delta.content
|
| 65 |
-
total_tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
if resp.choices[0].finish_reason == "length":
|
| 67 |
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
| 68 |
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
|
@@ -85,8 +93,13 @@ class MoonshotChat(Base):
|
|
| 85 |
if not base_url: base_url="https://api.moonshot.cn/v1"
|
| 86 |
super().__init__(key, model_name, base_url)
|
| 87 |
|
|
|
|
| 88 |
class XinferenceChat(Base):
|
| 89 |
def __init__(self, key=None, model_name="", base_url=""):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
key = "xxx"
|
| 91 |
super().__init__(key, model_name, base_url)
|
| 92 |
|
|
@@ -349,79 +362,13 @@ class OllamaChat(Base):
|
|
| 349 |
|
| 350 |
class LocalAIChat(Base):
|
| 351 |
def __init__(self, key, model_name, base_url):
|
| 352 |
-
if base_url
|
| 353 |
-
|
| 354 |
-
|
|
|
|
|
|
|
| 355 |
self.model_name = model_name.split("___")[0]
|
| 356 |
|
| 357 |
-
def chat(self, system, history, gen_conf):
|
| 358 |
-
if system:
|
| 359 |
-
history.insert(0, {"role": "system", "content": system})
|
| 360 |
-
for k in list(gen_conf.keys()):
|
| 361 |
-
if k not in ["temperature", "top_p", "max_tokens"]:
|
| 362 |
-
del gen_conf[k]
|
| 363 |
-
headers = {
|
| 364 |
-
"Content-Type": "application/json",
|
| 365 |
-
}
|
| 366 |
-
payload = json.dumps(
|
| 367 |
-
{"model": self.model_name, "messages": history, **gen_conf}
|
| 368 |
-
)
|
| 369 |
-
try:
|
| 370 |
-
response = requests.request(
|
| 371 |
-
"POST", url=self.base_url, headers=headers, data=payload
|
| 372 |
-
)
|
| 373 |
-
response = response.json()
|
| 374 |
-
ans = response["choices"][0]["message"]["content"].strip()
|
| 375 |
-
if response["choices"][0]["finish_reason"] == "length":
|
| 376 |
-
ans += (
|
| 377 |
-
"...\nFor the content length reason, it stopped, continue?"
|
| 378 |
-
if is_english([ans])
|
| 379 |
-
else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
| 380 |
-
)
|
| 381 |
-
return ans, response["usage"]["total_tokens"]
|
| 382 |
-
except Exception as e:
|
| 383 |
-
return "**ERROR**: " + str(e), 0
|
| 384 |
-
|
| 385 |
-
def chat_streamly(self, system, history, gen_conf):
|
| 386 |
-
if system:
|
| 387 |
-
history.insert(0, {"role": "system", "content": system})
|
| 388 |
-
ans = ""
|
| 389 |
-
total_tokens = 0
|
| 390 |
-
try:
|
| 391 |
-
headers = {
|
| 392 |
-
"Content-Type": "application/json",
|
| 393 |
-
}
|
| 394 |
-
payload = json.dumps(
|
| 395 |
-
{
|
| 396 |
-
"model": self.model_name,
|
| 397 |
-
"messages": history,
|
| 398 |
-
"stream": True,
|
| 399 |
-
**gen_conf,
|
| 400 |
-
}
|
| 401 |
-
)
|
| 402 |
-
response = requests.request(
|
| 403 |
-
"POST",
|
| 404 |
-
url=self.base_url,
|
| 405 |
-
headers=headers,
|
| 406 |
-
data=payload,
|
| 407 |
-
)
|
| 408 |
-
for resp in response.content.decode("utf-8").split("\n\n"):
|
| 409 |
-
if "choices" not in resp:
|
| 410 |
-
continue
|
| 411 |
-
resp = json.loads(resp[6:])
|
| 412 |
-
if "delta" in resp["choices"][0]:
|
| 413 |
-
text = resp["choices"][0]["delta"]["content"]
|
| 414 |
-
else:
|
| 415 |
-
continue
|
| 416 |
-
ans += text
|
| 417 |
-
total_tokens += 1
|
| 418 |
-
yield ans
|
| 419 |
-
|
| 420 |
-
except Exception as e:
|
| 421 |
-
yield ans + "\n**ERROR**: " + str(e)
|
| 422 |
-
|
| 423 |
-
yield total_tokens
|
| 424 |
-
|
| 425 |
|
| 426 |
class LocalLLM(Base):
|
| 427 |
class RPCProxy:
|
|
@@ -892,9 +839,10 @@ class GroqChat:
|
|
| 892 |
## openrouter
|
| 893 |
class OpenRouterChat(Base):
|
| 894 |
def __init__(self, key, model_name, base_url="https://openrouter.ai/api/v1"):
|
| 895 |
-
|
| 896 |
-
|
| 897 |
-
|
|
|
|
| 898 |
|
| 899 |
class StepFunChat(Base):
|
| 900 |
def __init__(self, key, model_name, base_url="https://api.stepfun.com/v1"):
|
|
@@ -904,87 +852,17 @@ class StepFunChat(Base):
|
|
| 904 |
|
| 905 |
|
| 906 |
class NvidiaChat(Base):
|
| 907 |
-
def __init__(
|
| 908 |
-
self,
|
| 909 |
-
key,
|
| 910 |
-
model_name,
|
| 911 |
-
base_url="https://integrate.api.nvidia.com/v1/chat/completions",
|
| 912 |
-
):
|
| 913 |
if not base_url:
|
| 914 |
-
base_url = "https://integrate.api.nvidia.com/v1
|
| 915 |
-
|
| 916 |
-
self.model_name = model_name
|
| 917 |
-
self.api_key = key
|
| 918 |
-
self.headers = {
|
| 919 |
-
"accept": "application/json",
|
| 920 |
-
"Authorization": f"Bearer {self.api_key}",
|
| 921 |
-
"Content-Type": "application/json",
|
| 922 |
-
}
|
| 923 |
-
|
| 924 |
-
def chat(self, system, history, gen_conf):
|
| 925 |
-
if system:
|
| 926 |
-
history.insert(0, {"role": "system", "content": system})
|
| 927 |
-
for k in list(gen_conf.keys()):
|
| 928 |
-
if k not in ["temperature", "top_p", "max_tokens"]:
|
| 929 |
-
del gen_conf[k]
|
| 930 |
-
payload = {"model": self.model_name, "messages": history, **gen_conf}
|
| 931 |
-
try:
|
| 932 |
-
response = requests.post(
|
| 933 |
-
url=self.base_url, headers=self.headers, json=payload
|
| 934 |
-
)
|
| 935 |
-
response = response.json()
|
| 936 |
-
ans = response["choices"][0]["message"]["content"].strip()
|
| 937 |
-
return ans, response["usage"]["total_tokens"]
|
| 938 |
-
except Exception as e:
|
| 939 |
-
return "**ERROR**: " + str(e), 0
|
| 940 |
-
|
| 941 |
-
def chat_streamly(self, system, history, gen_conf):
|
| 942 |
-
if system:
|
| 943 |
-
history.insert(0, {"role": "system", "content": system})
|
| 944 |
-
for k in list(gen_conf.keys()):
|
| 945 |
-
if k not in ["temperature", "top_p", "max_tokens"]:
|
| 946 |
-
del gen_conf[k]
|
| 947 |
-
ans = ""
|
| 948 |
-
total_tokens = 0
|
| 949 |
-
payload = {
|
| 950 |
-
"model": self.model_name,
|
| 951 |
-
"messages": history,
|
| 952 |
-
"stream": True,
|
| 953 |
-
**gen_conf,
|
| 954 |
-
}
|
| 955 |
-
|
| 956 |
-
try:
|
| 957 |
-
response = requests.post(
|
| 958 |
-
url=self.base_url,
|
| 959 |
-
headers=self.headers,
|
| 960 |
-
json=payload,
|
| 961 |
-
)
|
| 962 |
-
for resp in response.text.split("\n\n"):
|
| 963 |
-
if "choices" not in resp:
|
| 964 |
-
continue
|
| 965 |
-
resp = json.loads(resp[6:])
|
| 966 |
-
if "content" in resp["choices"][0]["delta"]:
|
| 967 |
-
text = resp["choices"][0]["delta"]["content"]
|
| 968 |
-
else:
|
| 969 |
-
continue
|
| 970 |
-
ans += text
|
| 971 |
-
if "usage" in resp:
|
| 972 |
-
total_tokens = resp["usage"]["total_tokens"]
|
| 973 |
-
yield ans
|
| 974 |
-
|
| 975 |
-
except Exception as e:
|
| 976 |
-
yield ans + "\n**ERROR**: " + str(e)
|
| 977 |
-
|
| 978 |
-
yield total_tokens
|
| 979 |
|
| 980 |
|
| 981 |
class LmStudioChat(Base):
|
| 982 |
def __init__(self, key, model_name, base_url):
|
| 983 |
-
from os.path import join
|
| 984 |
-
|
| 985 |
if not base_url:
|
| 986 |
raise ValueError("Local llm url cannot be None")
|
| 987 |
if base_url.split("/")[-1] != "v1":
|
| 988 |
-
self.base_url = join(base_url, "v1")
|
| 989 |
self.client = OpenAI(api_key="lm-studio", base_url=self.base_url)
|
| 990 |
self.model_name = model_name
|
|
|
|
| 24 |
from rag.nlp import is_english
|
| 25 |
from rag.utils import num_tokens_from_string
|
| 26 |
from groq import Groq
|
| 27 |
+
import os
|
| 28 |
import json
|
| 29 |
import requests
|
| 30 |
|
|
|
|
| 61 |
stream=True,
|
| 62 |
**gen_conf)
|
| 63 |
for resp in response:
|
| 64 |
+
if not resp.choices:continue
|
| 65 |
ans += resp.choices[0].delta.content
|
| 66 |
+
total_tokens = (
|
| 67 |
+
(
|
| 68 |
+
total_tokens
|
| 69 |
+
+ num_tokens_from_string(resp.choices[0].delta.content)
|
| 70 |
+
)
|
| 71 |
+
if not hasattr(resp, "usage")
|
| 72 |
+
else resp.usage["total_tokens"]
|
| 73 |
+
)
|
| 74 |
if resp.choices[0].finish_reason == "length":
|
| 75 |
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
| 76 |
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
|
|
|
| 93 |
if not base_url: base_url="https://api.moonshot.cn/v1"
|
| 94 |
super().__init__(key, model_name, base_url)
|
| 95 |
|
| 96 |
+
|
| 97 |
class XinferenceChat(Base):
|
| 98 |
def __init__(self, key=None, model_name="", base_url=""):
|
| 99 |
+
if not base_url:
|
| 100 |
+
raise ValueError("Local llm url cannot be None")
|
| 101 |
+
if base_url.split("/")[-1] != "v1":
|
| 102 |
+
self.base_url = os.path.join(base_url, "v1")
|
| 103 |
key = "xxx"
|
| 104 |
super().__init__(key, model_name, base_url)
|
| 105 |
|
|
|
|
| 362 |
|
| 363 |
class LocalAIChat(Base):
|
| 364 |
def __init__(self, key, model_name, base_url):
|
| 365 |
+
if not base_url:
|
| 366 |
+
raise ValueError("Local llm url cannot be None")
|
| 367 |
+
if base_url.split("/")[-1] != "v1":
|
| 368 |
+
self.base_url = os.path.join(base_url, "v1")
|
| 369 |
+
self.client = OpenAI(api_key="empty", base_url=self.base_url)
|
| 370 |
self.model_name = model_name.split("___")[0]
|
| 371 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
|
| 373 |
class LocalLLM(Base):
|
| 374 |
class RPCProxy:
|
|
|
|
| 839 |
## openrouter
|
| 840 |
class OpenRouterChat(Base):
|
| 841 |
def __init__(self, key, model_name, base_url="https://openrouter.ai/api/v1"):
|
| 842 |
+
if not base_url:
|
| 843 |
+
base_url = "https://openrouter.ai/api/v1"
|
| 844 |
+
super().__init__(key, model_name, base_url)
|
| 845 |
+
|
| 846 |
|
| 847 |
class StepFunChat(Base):
|
| 848 |
def __init__(self, key, model_name, base_url="https://api.stepfun.com/v1"):
|
|
|
|
| 852 |
|
| 853 |
|
| 854 |
class NvidiaChat(Base):
|
| 855 |
+
def __init__(self, key, model_name, base_url="https://integrate.api.nvidia.com/v1"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 856 |
if not base_url:
|
| 857 |
+
base_url = "https://integrate.api.nvidia.com/v1"
|
| 858 |
+
super().__init__(key, model_name, base_url)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 859 |
|
| 860 |
|
| 861 |
class LmStudioChat(Base):
|
| 862 |
def __init__(self, key, model_name, base_url):
|
|
|
|
|
|
|
| 863 |
if not base_url:
|
| 864 |
raise ValueError("Local llm url cannot be None")
|
| 865 |
if base_url.split("/")[-1] != "v1":
|
| 866 |
+
self.base_url = os.path.join(base_url, "v1")
|
| 867 |
self.client = OpenAI(api_key="lm-studio", base_url=self.base_url)
|
| 868 |
self.model_name = model_name
|
rag/llm/cv_model.py
CHANGED
|
@@ -378,7 +378,7 @@ class OllamaCV(Base):
|
|
| 378 |
def chat(self, system, history, gen_conf, image=""):
|
| 379 |
if system:
|
| 380 |
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
| 381 |
-
|
| 382 |
try:
|
| 383 |
for his in history:
|
| 384 |
if his["role"] == "user":
|
|
@@ -433,27 +433,16 @@ class OllamaCV(Base):
|
|
| 433 |
yield 0
|
| 434 |
|
| 435 |
|
| 436 |
-
class LocalAICV(
|
| 437 |
def __init__(self, key, model_name, base_url, lang="Chinese"):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 438 |
self.client = OpenAI(api_key="empty", base_url=base_url)
|
| 439 |
self.model_name = model_name.split("___")[0]
|
| 440 |
self.lang = lang
|
| 441 |
|
| 442 |
-
def describe(self, image, max_tokens=300):
|
| 443 |
-
b64 = self.image2base64(image)
|
| 444 |
-
prompt = self.prompt(b64)
|
| 445 |
-
for i in range(len(prompt)):
|
| 446 |
-
for c in prompt[i]["content"]:
|
| 447 |
-
if "text" in c:
|
| 448 |
-
c["type"] = "text"
|
| 449 |
-
|
| 450 |
-
res = self.client.chat.completions.create(
|
| 451 |
-
model=self.model_name,
|
| 452 |
-
messages=prompt,
|
| 453 |
-
max_tokens=max_tokens,
|
| 454 |
-
)
|
| 455 |
-
return res.choices[0].message.content.strip(), res.usage.total_tokens
|
| 456 |
-
|
| 457 |
|
| 458 |
class XinferenceCV(Base):
|
| 459 |
def __init__(self, key, model_name="", lang="Chinese", base_url=""):
|
|
@@ -549,60 +538,19 @@ class GeminiCV(Base):
|
|
| 549 |
yield response._chunks[-1].usage_metadata.total_token_count
|
| 550 |
|
| 551 |
|
| 552 |
-
class OpenRouterCV(
|
| 553 |
def __init__(
|
| 554 |
self,
|
| 555 |
key,
|
| 556 |
model_name,
|
| 557 |
lang="Chinese",
|
| 558 |
-
base_url="https://openrouter.ai/api/v1
|
| 559 |
):
|
|
|
|
|
|
|
|
|
|
| 560 |
self.model_name = model_name
|
| 561 |
self.lang = lang
|
| 562 |
-
self.base_url = "https://openrouter.ai/api/v1/chat/completions"
|
| 563 |
-
self.key = key
|
| 564 |
-
|
| 565 |
-
def describe(self, image, max_tokens=300):
|
| 566 |
-
b64 = self.image2base64(image)
|
| 567 |
-
response = requests.post(
|
| 568 |
-
url=self.base_url,
|
| 569 |
-
headers={
|
| 570 |
-
"Authorization": f"Bearer {self.key}",
|
| 571 |
-
},
|
| 572 |
-
data=json.dumps(
|
| 573 |
-
{
|
| 574 |
-
"model": self.model_name,
|
| 575 |
-
"messages": self.prompt(b64),
|
| 576 |
-
"max_tokens": max_tokens,
|
| 577 |
-
}
|
| 578 |
-
),
|
| 579 |
-
)
|
| 580 |
-
response = response.json()
|
| 581 |
-
return (
|
| 582 |
-
response["choices"][0]["message"]["content"].strip(),
|
| 583 |
-
response["usage"]["total_tokens"],
|
| 584 |
-
)
|
| 585 |
-
|
| 586 |
-
def prompt(self, b64):
|
| 587 |
-
return [
|
| 588 |
-
{
|
| 589 |
-
"role": "user",
|
| 590 |
-
"content": [
|
| 591 |
-
{
|
| 592 |
-
"type": "image_url",
|
| 593 |
-
"image_url": {"url": f"data:image/jpeg;base64,{b64}"},
|
| 594 |
-
},
|
| 595 |
-
{
|
| 596 |
-
"type": "text",
|
| 597 |
-
"text": (
|
| 598 |
-
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
|
| 599 |
-
if self.lang.lower() == "chinese"
|
| 600 |
-
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
|
| 601 |
-
),
|
| 602 |
-
},
|
| 603 |
-
],
|
| 604 |
-
}
|
| 605 |
-
]
|
| 606 |
|
| 607 |
|
| 608 |
class LocalCV(Base):
|
|
@@ -675,12 +623,12 @@ class NvidiaCV(Base):
|
|
| 675 |
]
|
| 676 |
|
| 677 |
|
| 678 |
-
class LmStudioCV(
|
| 679 |
def __init__(self, key, model_name, base_url, lang="Chinese"):
|
| 680 |
if not base_url:
|
| 681 |
raise ValueError("Local llm url cannot be None")
|
| 682 |
-
if base_url.split(
|
| 683 |
-
|
| 684 |
-
self.client = OpenAI(api_key="lm-studio", base_url=
|
| 685 |
self.model_name = model_name
|
| 686 |
self.lang = lang
|
|
|
|
| 378 |
def chat(self, system, history, gen_conf, image=""):
|
| 379 |
if system:
|
| 380 |
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
|
| 381 |
+
|
| 382 |
try:
|
| 383 |
for his in history:
|
| 384 |
if his["role"] == "user":
|
|
|
|
| 433 |
yield 0
|
| 434 |
|
| 435 |
|
| 436 |
+
class LocalAICV(GptV4):
|
| 437 |
def __init__(self, key, model_name, base_url, lang="Chinese"):
|
| 438 |
+
if not base_url:
|
| 439 |
+
raise ValueError("Local cv model url cannot be None")
|
| 440 |
+
if base_url.split("/")[-1] != "v1":
|
| 441 |
+
base_url = os.path.join(base_url, "v1")
|
| 442 |
self.client = OpenAI(api_key="empty", base_url=base_url)
|
| 443 |
self.model_name = model_name.split("___")[0]
|
| 444 |
self.lang = lang
|
| 445 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
|
| 447 |
class XinferenceCV(Base):
|
| 448 |
def __init__(self, key, model_name="", lang="Chinese", base_url=""):
|
|
|
|
| 538 |
yield response._chunks[-1].usage_metadata.total_token_count
|
| 539 |
|
| 540 |
|
| 541 |
+
class OpenRouterCV(GptV4):
|
| 542 |
def __init__(
|
| 543 |
self,
|
| 544 |
key,
|
| 545 |
model_name,
|
| 546 |
lang="Chinese",
|
| 547 |
+
base_url="https://openrouter.ai/api/v1",
|
| 548 |
):
|
| 549 |
+
if not base_url:
|
| 550 |
+
base_url = "https://openrouter.ai/api/v1"
|
| 551 |
+
self.client = OpenAI(api_key=key, base_url=base_url)
|
| 552 |
self.model_name = model_name
|
| 553 |
self.lang = lang
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 554 |
|
| 555 |
|
| 556 |
class LocalCV(Base):
|
|
|
|
| 623 |
]
|
| 624 |
|
| 625 |
|
| 626 |
+
class LmStudioCV(GptV4):
|
| 627 |
def __init__(self, key, model_name, base_url, lang="Chinese"):
|
| 628 |
if not base_url:
|
| 629 |
raise ValueError("Local llm url cannot be None")
|
| 630 |
+
if base_url.split("/")[-1] != "v1":
|
| 631 |
+
base_url = os.path.join(base_url, "v1")
|
| 632 |
+
self.client = OpenAI(api_key="lm-studio", base_url=base_url)
|
| 633 |
self.model_name = model_name
|
| 634 |
self.lang = lang
|
rag/llm/embedding_model.py
CHANGED
|
@@ -113,21 +113,24 @@ class OpenAIEmbed(Base):
|
|
| 113 |
|
| 114 |
class LocalAIEmbed(Base):
|
| 115 |
def __init__(self, key, model_name, base_url):
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
|
|
|
| 120 |
self.model_name = model_name.split("___")[0]
|
| 121 |
|
| 122 |
-
def encode(self, texts: list, batch_size=
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
|
|
|
| 127 |
|
| 128 |
def encode_queries(self, text):
|
| 129 |
-
|
| 130 |
-
return np.array(
|
|
|
|
| 131 |
|
| 132 |
class AzureEmbed(OpenAIEmbed):
|
| 133 |
def __init__(self, key, model_name, **kwargs):
|
|
@@ -502,7 +505,7 @@ class NvidiaEmbed(Base):
|
|
| 502 |
return np.array(embds[0]), cnt
|
| 503 |
|
| 504 |
|
| 505 |
-
class LmStudioEmbed(
|
| 506 |
def __init__(self, key, model_name, base_url):
|
| 507 |
if not base_url:
|
| 508 |
raise ValueError("Local llm url cannot be None")
|
|
@@ -510,14 +513,3 @@ class LmStudioEmbed(Base):
|
|
| 510 |
self.base_url = os.path.join(base_url, "v1")
|
| 511 |
self.client = OpenAI(api_key="lm-studio", base_url=self.base_url)
|
| 512 |
self.model_name = model_name
|
| 513 |
-
|
| 514 |
-
def encode(self, texts: list, batch_size=32):
|
| 515 |
-
res = self.client.embeddings.create(input=texts, model=self.model_name)
|
| 516 |
-
return (
|
| 517 |
-
np.array([d.embedding for d in res.data]),
|
| 518 |
-
1024,
|
| 519 |
-
) # local embedding for LmStudio donot count tokens
|
| 520 |
-
|
| 521 |
-
def encode_queries(self, text):
|
| 522 |
-
res = self.client.embeddings.create(text, model=self.model_name)
|
| 523 |
-
return np.array(res.data[0].embedding), 1024
|
|
|
|
| 113 |
|
| 114 |
class LocalAIEmbed(Base):
|
| 115 |
def __init__(self, key, model_name, base_url):
|
| 116 |
+
if not base_url:
|
| 117 |
+
raise ValueError("Local embedding model url cannot be None")
|
| 118 |
+
if base_url.split("/")[-1] != "v1":
|
| 119 |
+
base_url = os.path.join(base_url, "v1")
|
| 120 |
+
self.client = OpenAI(api_key="empty", base_url=base_url)
|
| 121 |
self.model_name = model_name.split("___")[0]
|
| 122 |
|
| 123 |
+
def encode(self, texts: list, batch_size=32):
|
| 124 |
+
res = self.client.embeddings.create(input=texts, model=self.model_name)
|
| 125 |
+
return (
|
| 126 |
+
np.array([d.embedding for d in res.data]),
|
| 127 |
+
1024,
|
| 128 |
+
) # local embedding for LmStudio donot count tokens
|
| 129 |
|
| 130 |
def encode_queries(self, text):
|
| 131 |
+
res = self.client.embeddings.create(text, model=self.model_name)
|
| 132 |
+
return np.array(res.data[0].embedding), 1024
|
| 133 |
+
|
| 134 |
|
| 135 |
class AzureEmbed(OpenAIEmbed):
|
| 136 |
def __init__(self, key, model_name, **kwargs):
|
|
|
|
| 505 |
return np.array(embds[0]), cnt
|
| 506 |
|
| 507 |
|
| 508 |
+
class LmStudioEmbed(LocalAIEmbed):
|
| 509 |
def __init__(self, key, model_name, base_url):
|
| 510 |
if not base_url:
|
| 511 |
raise ValueError("Local llm url cannot be None")
|
|
|
|
| 513 |
self.base_url = os.path.join(base_url, "v1")
|
| 514 |
self.client = OpenAI(api_key="lm-studio", base_url=self.base_url)
|
| 515 |
self.model_name = model_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|