Spaces:
Paused
Paused
import os | |
import re | |
from typing import Union | |
import pytest | |
from _pytest.monkeypatch import MonkeyPatch | |
from requests import Response | |
from requests.sessions import Session | |
from xinference_client.client.restful.restful_client import ( | |
Client, | |
RESTfulChatModelHandle, | |
RESTfulEmbeddingModelHandle, | |
RESTfulGenerateModelHandle, | |
RESTfulRerankModelHandle, | |
) | |
from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage | |
class MockXinferenceClass: | |
def get_chat_model(self: Client, model_uid: str) -> Union[RESTfulGenerateModelHandle, RESTfulChatModelHandle]: | |
if not re.match(r"https?:\/\/[^\s\/$.?#].[^\s]*$", self.base_url): | |
raise RuntimeError("404 Not Found") | |
if "generate" == model_uid: | |
return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={}) | |
if "chat" == model_uid: | |
return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={}) | |
if "embedding" == model_uid: | |
return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={}) | |
if "rerank" == model_uid: | |
return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={}) | |
raise RuntimeError("404 Not Found") | |
def get(self: Session, url: str, **kwargs): | |
response = Response() | |
if "v1/models/" in url: | |
# get model uid | |
model_uid = url.split("/")[-1] or "" | |
if not re.match( | |
r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", model_uid | |
) and model_uid not in {"generate", "chat", "embedding", "rerank"}: | |
response.status_code = 404 | |
response._content = b"{}" | |
return response | |
# check if url is valid | |
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", url): | |
response.status_code = 404 | |
response._content = b"{}" | |
return response | |
if model_uid in {"generate", "chat"}: | |
response.status_code = 200 | |
response._content = b"""{ | |
"model_type": "LLM", | |
"address": "127.0.0.1:43877", | |
"accelerators": [ | |
"0", | |
"1" | |
], | |
"model_name": "chatglm3-6b", | |
"model_lang": [ | |
"en" | |
], | |
"model_ability": [ | |
"generate", | |
"chat" | |
], | |
"model_description": "latest chatglm3", | |
"model_format": "pytorch", | |
"model_size_in_billions": 7, | |
"quantization": "none", | |
"model_hub": "huggingface", | |
"revision": null, | |
"context_length": 2048, | |
"replica": 1 | |
}""" | |
return response | |
elif model_uid == "embedding": | |
response.status_code = 200 | |
response._content = b"""{ | |
"model_type": "embedding", | |
"address": "127.0.0.1:43877", | |
"accelerators": [ | |
"0", | |
"1" | |
], | |
"model_name": "bge", | |
"model_lang": [ | |
"en" | |
], | |
"revision": null, | |
"max_tokens": 512 | |
}""" | |
return response | |
elif "v1/cluster/auth" in url: | |
response.status_code = 200 | |
response._content = b"""{ | |
"auth": true | |
}""" | |
return response | |
def _check_cluster_authenticated(self): | |
self._cluster_authed = True | |
def rerank( | |
self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool | |
) -> dict: | |
# check if self._model_uid is a valid uuid | |
if ( | |
not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid) | |
and self._model_uid != "rerank" | |
): | |
raise RuntimeError("404 Not Found") | |
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._base_url): | |
raise RuntimeError("404 Not Found") | |
if top_n is None: | |
top_n = 1 | |
return { | |
"results": [ | |
{"index": i, "document": doc, "relevance_score": 0.9} for i, doc in enumerate(documents[:top_n]) | |
] | |
} | |
def create_embedding(self: RESTfulGenerateModelHandle, input: Union[str, list[str]], **kwargs) -> dict: | |
# check if self._model_uid is a valid uuid | |
if ( | |
not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid) | |
and self._model_uid != "embedding" | |
): | |
raise RuntimeError("404 Not Found") | |
if isinstance(input, str): | |
input = [input] | |
ipt_len = len(input) | |
embedding = Embedding( | |
object="list", | |
model=self._model_uid, | |
data=[ | |
EmbeddingData(index=i, object="embedding", embedding=[1919.810 for _ in range(768)]) | |
for i in range(ipt_len) | |
], | |
usage=EmbeddingUsage(prompt_tokens=ipt_len, total_tokens=ipt_len), | |
) | |
return embedding | |
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" | |
def setup_xinference_mock(request, monkeypatch: MonkeyPatch): | |
if MOCK: | |
monkeypatch.setattr(Client, "get_model", MockXinferenceClass.get_chat_model) | |
monkeypatch.setattr(Client, "_check_cluster_authenticated", MockXinferenceClass._check_cluster_authenticated) | |
monkeypatch.setattr(Session, "get", MockXinferenceClass.get) | |
monkeypatch.setattr(RESTfulEmbeddingModelHandle, "create_embedding", MockXinferenceClass.create_embedding) | |
monkeypatch.setattr(RESTfulRerankModelHandle, "rerank", MockXinferenceClass.rerank) | |
yield | |
if MOCK: | |
monkeypatch.undo() | |