Upload folder using huggingface_hub
Browse files- api.py +10 -8
- artifact.py +11 -2
- dataset.py +0 -1
- dataset_utils.py +8 -5
- inference.py +86 -29
- metric.py +0 -1
- metrics.py +76 -32
- operators.py +47 -0
- serializers.py +1 -6
- struct_data_operators.py +21 -1
- tool_calling.py +0 -119
- type_utils.py +15 -3
- types.py +12 -6
- version.py +1 -1
api.py
CHANGED
@@ -37,12 +37,11 @@ def short_hex_hash(value, length=8):
|
|
37 |
return h[:length]
|
38 |
|
39 |
|
40 |
-
def _get_recipe_from_query(dataset_query: str) -> DatasetRecipe:
|
41 |
-
dataset_query = dataset_query.replace("sys_prompt", "instruction")
|
42 |
try:
|
43 |
-
dataset_stream, _ = fetch_artifact(dataset_query)
|
44 |
except:
|
45 |
-
dataset_stream = get_dataset_artifact(dataset_query)
|
46 |
return dataset_stream
|
47 |
|
48 |
|
@@ -82,14 +81,15 @@ def load_recipe(dataset_query: Optional[str] = None, **kwargs) -> DatasetRecipe:
|
|
82 |
if isinstance(dataset_query, (DatasetRecipe, Benchmark)):
|
83 |
return dataset_query
|
84 |
|
85 |
-
_verify_dataset_args(dataset_query, kwargs)
|
86 |
-
|
87 |
if dataset_query:
|
88 |
-
recipe = _get_recipe_from_query(dataset_query)
|
89 |
|
90 |
-
|
91 |
recipe = _get_recipe_from_dict(kwargs)
|
92 |
|
|
|
|
|
|
|
93 |
return recipe
|
94 |
|
95 |
|
@@ -187,6 +187,8 @@ def load_dataset(
|
|
187 |
Alternatively, dataset is loaded from a provided card based on explicitly
|
188 |
given parameters.
|
189 |
|
|
|
|
|
190 |
Args:
|
191 |
dataset_query (str, optional):
|
192 |
A string query which specifies a dataset to load from
|
|
|
37 |
return h[:length]
|
38 |
|
39 |
|
40 |
+
def _get_recipe_from_query(dataset_query: str, overwrite_kwargs: Optional[Dict[str, Any]]=None) -> DatasetRecipe:
|
|
|
41 |
try:
|
42 |
+
dataset_stream, _ = fetch_artifact(dataset_query, overwrite_kwargs=overwrite_kwargs)
|
43 |
except:
|
44 |
+
dataset_stream = get_dataset_artifact(dataset_query, overwrite_kwargs=overwrite_kwargs)
|
45 |
return dataset_stream
|
46 |
|
47 |
|
|
|
81 |
if isinstance(dataset_query, (DatasetRecipe, Benchmark)):
|
82 |
return dataset_query
|
83 |
|
|
|
|
|
84 |
if dataset_query:
|
85 |
+
recipe = _get_recipe_from_query(dataset_query, kwargs)
|
86 |
|
87 |
+
elif kwargs:
|
88 |
recipe = _get_recipe_from_dict(kwargs)
|
89 |
|
90 |
+
else:
|
91 |
+
raise UnitxtError("Specify either dataset recipe string artifact name or recipe args.")
|
92 |
+
|
93 |
return recipe
|
94 |
|
95 |
|
|
|
187 |
Alternatively, dataset is loaded from a provided card based on explicitly
|
188 |
given parameters.
|
189 |
|
190 |
+
If both are given, then the textual recipe is loaded with the key word args overriding the textual recipe args.
|
191 |
+
|
192 |
Args:
|
193 |
dataset_query (str, optional):
|
194 |
A string query which specifies a dataset to load from
|
artifact.py
CHANGED
@@ -22,7 +22,7 @@ from .parsing_utils import (
|
|
22 |
separate_inside_and_outside_square_brackets,
|
23 |
)
|
24 |
from .settings_utils import get_constants, get_settings
|
25 |
-
from .text_utils import camel_to_snake_case, is_camel_case
|
26 |
from .type_utils import isoftype, issubtype
|
27 |
from .utils import (
|
28 |
artifacts_json_cache,
|
@@ -369,6 +369,10 @@ class Artifact(Dataclass):
|
|
369 |
data = self.to_dict()
|
370 |
return json_dump(data)
|
371 |
|
|
|
|
|
|
|
|
|
372 |
def serialize(self):
|
373 |
if self.__id__ is not None:
|
374 |
return self.__id__
|
@@ -528,7 +532,7 @@ class UnitxtArtifactNotFoundError(UnitxtError):
|
|
528 |
super().__init__(msg)
|
529 |
|
530 |
|
531 |
-
def fetch_artifact(artifact_rep) -> Tuple[Artifact, Union[AbstractCatalog, None]]:
|
532 |
"""Loads an artifict from one of possible representations.
|
533 |
|
534 |
(1) If artifact representation is already an Artifact object, return it.
|
@@ -553,6 +557,11 @@ def fetch_artifact(artifact_rep) -> Tuple[Artifact, Union[AbstractCatalog, None]
|
|
553 |
name, _ = separate_inside_and_outside_square_brackets(artifact_rep)
|
554 |
if is_name_legal_for_catalog(name):
|
555 |
catalog, artifact_rep, args = get_catalog_name_and_args(name=artifact_rep)
|
|
|
|
|
|
|
|
|
|
|
556 |
artifact_to_return = catalog.get_with_overwrite(
|
557 |
artifact_rep, overwrite_args=args
|
558 |
)
|
|
|
22 |
separate_inside_and_outside_square_brackets,
|
23 |
)
|
24 |
from .settings_utils import get_constants, get_settings
|
25 |
+
from .text_utils import camel_to_snake_case, is_camel_case, print_dict_as_yaml
|
26 |
from .type_utils import isoftype, issubtype
|
27 |
from .utils import (
|
28 |
artifacts_json_cache,
|
|
|
369 |
data = self.to_dict()
|
370 |
return json_dump(data)
|
371 |
|
372 |
+
def to_yaml(self):
|
373 |
+
data = self.to_dict()
|
374 |
+
return print_dict_as_yaml(data)
|
375 |
+
|
376 |
def serialize(self):
|
377 |
if self.__id__ is not None:
|
378 |
return self.__id__
|
|
|
532 |
super().__init__(msg)
|
533 |
|
534 |
|
535 |
+
def fetch_artifact(artifact_rep, overwrite_kwargs: Optional[Dict[str, Any]]=None) -> Tuple[Artifact, Union[AbstractCatalog, None]]:
|
536 |
"""Loads an artifict from one of possible representations.
|
537 |
|
538 |
(1) If artifact representation is already an Artifact object, return it.
|
|
|
557 |
name, _ = separate_inside_and_outside_square_brackets(artifact_rep)
|
558 |
if is_name_legal_for_catalog(name):
|
559 |
catalog, artifact_rep, args = get_catalog_name_and_args(name=artifact_rep)
|
560 |
+
if overwrite_kwargs is not None:
|
561 |
+
if args is None:
|
562 |
+
args = overwrite_kwargs
|
563 |
+
else:
|
564 |
+
args.update(overwrite_kwargs)
|
565 |
artifact_to_return = catalog.get_with_overwrite(
|
566 |
artifact_rep, overwrite_args=args
|
567 |
)
|
dataset.py
CHANGED
@@ -68,7 +68,6 @@ from .system_prompts import __file__ as _
|
|
68 |
from .task import __file__ as _
|
69 |
from .templates import __file__ as _
|
70 |
from .text_utils import __file__ as _
|
71 |
-
from .tool_calling import __file__ as _
|
72 |
from .type_utils import __file__ as _
|
73 |
from .types import __file__ as _
|
74 |
from .utils import __file__ as _
|
|
|
68 |
from .task import __file__ as _
|
69 |
from .templates import __file__ as _
|
70 |
from .text_utils import __file__ as _
|
|
|
71 |
from .type_utils import __file__ as _
|
72 |
from .types import __file__ as _
|
73 |
from .utils import __file__ as _
|
dataset_utils.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
from json.decoder import JSONDecodeError
|
|
|
2 |
|
3 |
from .artifact import Artifact, UnitxtArtifactNotFoundError, fetch_artifact
|
4 |
from .logging_utils import get_logger
|
@@ -11,19 +12,19 @@ logger = get_logger()
|
|
11 |
settings = get_settings()
|
12 |
|
13 |
|
14 |
-
def fetch(artifact_name):
|
15 |
try:
|
16 |
-
artifact, _ = fetch_artifact(artifact_name)
|
17 |
return artifact
|
18 |
except (UnitxtArtifactNotFoundError, JSONDecodeError):
|
19 |
return None
|
20 |
|
21 |
|
22 |
-
def parse(query: str):
|
23 |
return parse_key_equals_value_string_to_dict(query)
|
24 |
|
25 |
|
26 |
-
def get_dataset_artifact(dataset):
|
27 |
if isinstance(dataset, DatasetRecipe):
|
28 |
return dataset
|
29 |
assert isinstance(
|
@@ -31,10 +32,12 @@ def get_dataset_artifact(dataset):
|
|
31 |
), "dataset should be string description of recipe, or recipe object."
|
32 |
_reset_env_local_catalogs()
|
33 |
register_all_artifacts()
|
34 |
-
recipe = fetch(dataset)
|
35 |
if recipe is None:
|
36 |
args = parse(dataset)
|
37 |
if "__type__" not in args:
|
38 |
args["__type__"] = settings.default_recipe
|
|
|
|
|
39 |
recipe = Artifact.from_dict(args)
|
40 |
return recipe
|
|
|
1 |
from json.decoder import JSONDecodeError
|
2 |
+
from typing import Any, Dict, Optional
|
3 |
|
4 |
from .artifact import Artifact, UnitxtArtifactNotFoundError, fetch_artifact
|
5 |
from .logging_utils import get_logger
|
|
|
12 |
settings = get_settings()
|
13 |
|
14 |
|
15 |
+
def fetch(artifact_name: str, overwrite_kwargs: Optional[Dict[str, Any]]=None):
|
16 |
try:
|
17 |
+
artifact, _ = fetch_artifact(artifact_name, overwrite_kwargs=overwrite_kwargs)
|
18 |
return artifact
|
19 |
except (UnitxtArtifactNotFoundError, JSONDecodeError):
|
20 |
return None
|
21 |
|
22 |
|
23 |
+
def parse(query: str) -> dict:
|
24 |
return parse_key_equals_value_string_to_dict(query)
|
25 |
|
26 |
|
27 |
+
def get_dataset_artifact(dataset, overwrite_kwargs: Optional[Dict[str, Any]]=None):
|
28 |
if isinstance(dataset, DatasetRecipe):
|
29 |
return dataset
|
30 |
assert isinstance(
|
|
|
32 |
), "dataset should be string description of recipe, or recipe object."
|
33 |
_reset_env_local_catalogs()
|
34 |
register_all_artifacts()
|
35 |
+
recipe = fetch(dataset, overwrite_kwargs=overwrite_kwargs)
|
36 |
if recipe is None:
|
37 |
args = parse(dataset)
|
38 |
if "__type__" not in args:
|
39 |
args["__type__"] = settings.default_recipe
|
40 |
+
if overwrite_kwargs is not None:
|
41 |
+
args.update(overwrite_kwargs)
|
42 |
recipe = Artifact.from_dict(args)
|
43 |
return recipe
|
inference.py
CHANGED
@@ -344,6 +344,8 @@ class InferenceEngine(Artifact):
|
|
344 |
|
345 |
def to_tools(self, instance):
|
346 |
task_data = instance.get("task_data")
|
|
|
|
|
347 |
if isinstance(task_data, str):
|
348 |
task_data = json.loads(task_data)
|
349 |
if "__tools__" in task_data:
|
@@ -445,6 +447,8 @@ class HFInferenceEngineBase(
|
|
445 |
model: Any = InternalField(default=None, name="Inference object")
|
446 |
processor: Any = InternalField(default=None, name="Input processor (tokenizer)")
|
447 |
|
|
|
|
|
448 |
_requirements_list = {
|
449 |
"transformers": "Install huggingface package using 'pip install --upgrade transformers",
|
450 |
"torch": "Install torch, go on PyTorch website for mode details.",
|
@@ -655,8 +659,6 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
|
|
655 |
truncation: bool = True
|
656 |
padding_side: str = "left" # for decoder only models
|
657 |
|
658 |
-
chat_kwargs_dict: dict = {}
|
659 |
-
|
660 |
def _init_processor(self):
|
661 |
from transformers import AutoTokenizer
|
662 |
|
@@ -712,10 +714,9 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
|
|
712 |
trust_remote_code=True,
|
713 |
**model_args,
|
714 |
)
|
715 |
-
if self.device_map is None:
|
716 |
-
self.model.to(self.device)
|
717 |
|
718 |
def prepare_inputs(self, data: Iterable) -> Mapping:
|
|
|
719 |
if isinstance(data[0], list):
|
720 |
data = self.processor.apply_chat_template(
|
721 |
data,
|
@@ -723,6 +724,7 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
|
|
723 |
add_generation_prompt=True,
|
724 |
**self.chat_kwargs_dict,
|
725 |
)
|
|
|
726 |
|
727 |
if self.processor.pad_token is None:
|
728 |
self.processor.pad_token_id = self.model.config.eos_token_id[0]
|
@@ -733,6 +735,8 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
|
|
733 |
padding=self.padding,
|
734 |
truncation=self.truncation,
|
735 |
padding_side=self.padding_side,
|
|
|
|
|
736 |
).to(self.device or self.device_map)
|
737 |
|
738 |
def _infer_fn(
|
@@ -755,13 +759,14 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
|
|
755 |
"""
|
756 |
all_final_outputs = [] # List to store results from all batches
|
757 |
|
758 |
-
for
|
759 |
-
|
760 |
desc=f"Running inference in batches of {self.batch_size}",
|
|
|
761 |
):
|
|
|
762 |
# Get the current batch
|
763 |
-
|
764 |
-
batch_sources = [instance["source"] for instance in batch_data]
|
765 |
|
766 |
# --- Process the current batch ---
|
767 |
# 1. Tokenize inputs for the batch
|
@@ -800,7 +805,7 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
|
|
800 |
j
|
801 |
], # Output for the j-th item in the batch
|
802 |
output_tokens=len(string_tokens_batch[j]),
|
803 |
-
inp=
|
804 |
inp_tokens=len(tokenized_inputs.encodings[j].tokens)
|
805 |
if tokenized_inputs.encodings is not None
|
806 |
else None,
|
@@ -1840,15 +1845,26 @@ class OpenAiInferenceEngine(
|
|
1840 |
@run_with_imap
|
1841 |
def _get_chat_completion(self, instance, return_meta_data):
|
1842 |
import openai
|
1843 |
-
|
1844 |
messages = self.to_messages(instance)
|
1845 |
try:
|
1846 |
response = self.client.chat.completions.create(
|
1847 |
messages=messages,
|
|
|
1848 |
model=self.get_client_model_name(),
|
1849 |
**self._get_completion_kwargs(),
|
|
|
1850 |
)
|
1851 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1852 |
return self.get_return_object(prediction, response, return_meta_data)
|
1853 |
# catch in case of content_filtering failure
|
1854 |
except openai.BadRequestError as e:
|
@@ -2742,14 +2758,37 @@ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin):
|
|
2742 |
# images as SDK allows sending only one image per message.
|
2743 |
return [messages]
|
2744 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2745 |
def _handle_async_requests(
|
2746 |
self,
|
2747 |
-
|
2748 |
params: Dict[str, Any],
|
2749 |
) -> List[Dict[str, Any]]:
|
2750 |
async def handle_async_requests(start_idx, end_idx):
|
2751 |
coroutines = [
|
2752 |
-
self._model.achat(
|
|
|
|
|
|
|
|
|
|
|
2753 |
for idx in range(start_idx, end_idx)
|
2754 |
]
|
2755 |
batch_results = await asyncio.gather(*coroutines)
|
@@ -2758,10 +2797,10 @@ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin):
|
|
2758 |
loop = asyncio.get_event_loop()
|
2759 |
results = []
|
2760 |
|
2761 |
-
for batch_idx in range(0, len(
|
2762 |
batch_results = loop.run_until_complete(
|
2763 |
handle_async_requests(
|
2764 |
-
batch_idx, min(batch_idx + self.concurrency_limit, len(
|
2765 |
)
|
2766 |
)
|
2767 |
results.extend(batch_results)
|
@@ -2783,25 +2822,43 @@ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin):
|
|
2783 |
output_type = "message"
|
2784 |
params["logprobs"] = False
|
2785 |
|
2786 |
-
|
2787 |
-
|
|
|
|
|
|
|
|
|
2788 |
for i in range(len(dataset))
|
2789 |
for message in self.to_messages(dataset[i])
|
2790 |
]
|
2791 |
|
2792 |
-
|
2793 |
-
[msg[1] for msg in indexed_messages], params
|
2794 |
-
)
|
2795 |
|
2796 |
-
|
2797 |
-
|
2798 |
-
|
2799 |
-
|
2800 |
-
|
2801 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2802 |
)
|
2803 |
-
|
2804 |
-
|
2805 |
|
2806 |
def get_return_object(self, predict_result, result, input_text, return_meta_data):
|
2807 |
if return_meta_data:
|
@@ -3439,7 +3496,7 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
|
|
3439 |
"aws": LiteLLMInferenceEngine,
|
3440 |
"ollama": OllamaInferenceEngine,
|
3441 |
"bam": IbmGenAiInferenceEngine,
|
3442 |
-
"watsonx-sdk":
|
3443 |
"rits": RITSInferenceEngine,
|
3444 |
"azure": LiteLLMInferenceEngine,
|
3445 |
"vertex-ai": LiteLLMInferenceEngine,
|
|
|
344 |
|
345 |
def to_tools(self, instance):
|
346 |
task_data = instance.get("task_data")
|
347 |
+
if task_data is None:
|
348 |
+
return None
|
349 |
if isinstance(task_data, str):
|
350 |
task_data = json.loads(task_data)
|
351 |
if "__tools__" in task_data:
|
|
|
447 |
model: Any = InternalField(default=None, name="Inference object")
|
448 |
processor: Any = InternalField(default=None, name="Input processor (tokenizer)")
|
449 |
|
450 |
+
chat_kwargs_dict: dict = {}
|
451 |
+
|
452 |
_requirements_list = {
|
453 |
"transformers": "Install huggingface package using 'pip install --upgrade transformers",
|
454 |
"torch": "Install torch, go on PyTorch website for mode details.",
|
|
|
659 |
truncation: bool = True
|
660 |
padding_side: str = "left" # for decoder only models
|
661 |
|
|
|
|
|
662 |
def _init_processor(self):
|
663 |
from transformers import AutoTokenizer
|
664 |
|
|
|
714 |
trust_remote_code=True,
|
715 |
**model_args,
|
716 |
)
|
|
|
|
|
717 |
|
718 |
def prepare_inputs(self, data: Iterable) -> Mapping:
|
719 |
+
tokenizer_kargs = {}
|
720 |
if isinstance(data[0], list):
|
721 |
data = self.processor.apply_chat_template(
|
722 |
data,
|
|
|
724 |
add_generation_prompt=True,
|
725 |
**self.chat_kwargs_dict,
|
726 |
)
|
727 |
+
tokenizer_kargs["add_special_tokens"] = False
|
728 |
|
729 |
if self.processor.pad_token is None:
|
730 |
self.processor.pad_token_id = self.model.config.eos_token_id[0]
|
|
|
735 |
padding=self.padding,
|
736 |
truncation=self.truncation,
|
737 |
padding_side=self.padding_side,
|
738 |
+
**tokenizer_kargs
|
739 |
+
|
740 |
).to(self.device or self.device_map)
|
741 |
|
742 |
def _infer_fn(
|
|
|
759 |
"""
|
760 |
all_final_outputs = [] # List to store results from all batches
|
761 |
|
762 |
+
for batch in tqdm(
|
763 |
+
batched(dataset, self.batch_size),
|
764 |
desc=f"Running inference in batches of {self.batch_size}",
|
765 |
+
total=len(dataset) // self.batch_size,
|
766 |
):
|
767 |
+
|
768 |
# Get the current batch
|
769 |
+
batch_sources = [instance["source"] for instance in batch]
|
|
|
770 |
|
771 |
# --- Process the current batch ---
|
772 |
# 1. Tokenize inputs for the batch
|
|
|
805 |
j
|
806 |
], # Output for the j-th item in the batch
|
807 |
output_tokens=len(string_tokens_batch[j]),
|
808 |
+
inp=batch[j]["source"], # Original input for the j-th item
|
809 |
inp_tokens=len(tokenized_inputs.encodings[j].tokens)
|
810 |
if tokenized_inputs.encodings is not None
|
811 |
else None,
|
|
|
1845 |
@run_with_imap
|
1846 |
def _get_chat_completion(self, instance, return_meta_data):
|
1847 |
import openai
|
1848 |
+
tools = self.to_tools(instance)
|
1849 |
messages = self.to_messages(instance)
|
1850 |
try:
|
1851 |
response = self.client.chat.completions.create(
|
1852 |
messages=messages,
|
1853 |
+
tools=tools,
|
1854 |
model=self.get_client_model_name(),
|
1855 |
**self._get_completion_kwargs(),
|
1856 |
+
# tool_choice="auto"
|
1857 |
)
|
1858 |
+
|
1859 |
+
if tools is None:
|
1860 |
+
prediction = response.choices[0].message.content
|
1861 |
+
else:
|
1862 |
+
try:
|
1863 |
+
func_call = response.choices[0].message.tool_calls[0].function
|
1864 |
+
prediction = f'{{"name": "{func_call.name}", "arguments": {func_call.arguments}}}'
|
1865 |
+
except:
|
1866 |
+
prediction = response.choices[0].message.content or ""
|
1867 |
+
|
1868 |
return self.get_return_object(prediction, response, return_meta_data)
|
1869 |
# catch in case of content_filtering failure
|
1870 |
except openai.BadRequestError as e:
|
|
|
2758 |
# images as SDK allows sending only one image per message.
|
2759 |
return [messages]
|
2760 |
|
2761 |
+
def to_tools(
|
2762 |
+
self,
|
2763 |
+
instance: Dict[str, Any]
|
2764 |
+
) -> Dict[str, Union[Optional[List[Dict[str, str]]], Optional[Dict[str, str]]]]:
|
2765 |
+
"""watsonx.ai chat also allows specifying which tools models must use."""
|
2766 |
+
task_data = instance.get("task_data")
|
2767 |
+
if task_data is None:
|
2768 |
+
return {"tools": None, "tool_choice": None}
|
2769 |
+
|
2770 |
+
if isinstance(task_data, str):
|
2771 |
+
task_data = json.loads(task_data)
|
2772 |
+
if "__tools__" in task_data:
|
2773 |
+
tools: List[Dict[str, str]] = task_data["__tools__"]
|
2774 |
+
tool_choice: Optional[Dict[str, str]] = task_data.get("__tool_choice__")
|
2775 |
+
return {"tools": tools, "tool_choice": tool_choice}
|
2776 |
+
|
2777 |
+
return {"tools": None, "tool_choice": None}
|
2778 |
+
|
2779 |
def _handle_async_requests(
|
2780 |
self,
|
2781 |
+
data: List[Dict[str, Any]],
|
2782 |
params: Dict[str, Any],
|
2783 |
) -> List[Dict[str, Any]]:
|
2784 |
async def handle_async_requests(start_idx, end_idx):
|
2785 |
coroutines = [
|
2786 |
+
self._model.achat(
|
2787 |
+
messages=data[idx]["msg"],
|
2788 |
+
params=params,
|
2789 |
+
tools=data[idx]["tools"]["tools"],
|
2790 |
+
tool_choice=data[idx]["tools"]["tool_choice"],
|
2791 |
+
)
|
2792 |
for idx in range(start_idx, end_idx)
|
2793 |
]
|
2794 |
batch_results = await asyncio.gather(*coroutines)
|
|
|
2797 |
loop = asyncio.get_event_loop()
|
2798 |
results = []
|
2799 |
|
2800 |
+
for batch_idx in range(0, len(data), self.concurrency_limit):
|
2801 |
batch_results = loop.run_until_complete(
|
2802 |
handle_async_requests(
|
2803 |
+
batch_idx, min(batch_idx + self.concurrency_limit, len(data))
|
2804 |
)
|
2805 |
)
|
2806 |
results.extend(batch_results)
|
|
|
2822 |
output_type = "message"
|
2823 |
params["logprobs"] = False
|
2824 |
|
2825 |
+
data = [
|
2826 |
+
{
|
2827 |
+
"idx": i,
|
2828 |
+
"msg": message,
|
2829 |
+
"tools": self.to_tools(dataset[i]),
|
2830 |
+
}
|
2831 |
for i in range(len(dataset))
|
2832 |
for message in self.to_messages(dataset[i])
|
2833 |
]
|
2834 |
|
2835 |
+
responses = self._handle_async_requests(data, params)
|
|
|
|
|
2836 |
|
2837 |
+
results = []
|
2838 |
+
for inp, response in zip(data, responses):
|
2839 |
+
idx = inp["idx"]
|
2840 |
+
tool_call = data[idx]["tools"]["tools"] is not None
|
2841 |
+
|
2842 |
+
output = response["choices"][0][output_type]
|
2843 |
+
if tool_call:
|
2844 |
+
if "tool_calls" in output:
|
2845 |
+
func = output["tool_calls"][0]["function"]
|
2846 |
+
prediction = f'{{"name": "{func["name"]}", "arguments": {func["arguments"]}}}'
|
2847 |
+
else:
|
2848 |
+
prediction = output["content"]
|
2849 |
+
else:
|
2850 |
+
prediction = output["content"]
|
2851 |
+
|
2852 |
+
results.append(
|
2853 |
+
self.get_return_object(
|
2854 |
+
prediction,
|
2855 |
+
response,
|
2856 |
+
str(inp),
|
2857 |
+
return_meta_data,
|
2858 |
+
)
|
2859 |
)
|
2860 |
+
|
2861 |
+
return results
|
2862 |
|
2863 |
def get_return_object(self, predict_result, result, input_text, return_meta_data):
|
2864 |
if return_meta_data:
|
|
|
3496 |
"aws": LiteLLMInferenceEngine,
|
3497 |
"ollama": OllamaInferenceEngine,
|
3498 |
"bam": IbmGenAiInferenceEngine,
|
3499 |
+
"watsonx-sdk": WMLInferenceEngineChat,
|
3500 |
"rits": RITSInferenceEngine,
|
3501 |
"azure": LiteLLMInferenceEngine,
|
3502 |
"vertex-ai": LiteLLMInferenceEngine,
|
metric.py
CHANGED
@@ -65,7 +65,6 @@ from .system_prompts import __file__ as _
|
|
65 |
from .task import __file__ as _
|
66 |
from .templates import __file__ as _
|
67 |
from .text_utils import __file__ as _
|
68 |
-
from .tool_calling import __file__ as _
|
69 |
from .type_utils import __file__ as _
|
70 |
from .types import __file__ as _
|
71 |
from .utils import __file__ as _
|
|
|
65 |
from .task import __file__ as _
|
66 |
from .templates import __file__ as _
|
67 |
from .text_utils import __file__ as _
|
|
|
68 |
from .type_utils import __file__ as _
|
69 |
from .types import __file__ as _
|
70 |
from .utils import __file__ as _
|
metrics.py
CHANGED
@@ -63,7 +63,6 @@ from .operators import ArtifactFetcherMixin, Copy, Set
|
|
63 |
from .random_utils import get_seed
|
64 |
from .settings_utils import get_settings
|
65 |
from .stream import MultiStream, Stream
|
66 |
-
from .tool_calling import convert_chat_api_format_to_tool
|
67 |
from .type_utils import Type, isoftype, parse_type_string, to_type_string
|
68 |
from .types import ToolCall
|
69 |
from .utils import deep_copy, recursive_copy, retry_connection_with_exponential_backoff
|
@@ -789,74 +788,92 @@ class F1Fast(MapReduceMetric[str, Tuple[int, int]]):
|
|
789 |
return result
|
790 |
|
791 |
class ToolCallingMetric(ReductionInstanceMetric[str, Dict[str, float]]):
|
|
|
792 |
main_score = "exact_match"
|
793 |
reduction = MeanReduction()
|
794 |
prediction_type = ToolCall
|
|
|
|
|
|
|
|
|
|
|
|
|
795 |
|
796 |
def map(
|
797 |
self, prediction: ToolCall, references: List[ToolCall], task_data: Dict[str, Any]
|
798 |
) -> Dict[str, float]:
|
799 |
|
800 |
-
|
801 |
exact_match = float(
|
802 |
-
|
803 |
)
|
804 |
|
805 |
-
|
806 |
str(prediction["name"]) in [str(reference["name"]) for reference in references]
|
807 |
)
|
808 |
|
809 |
-
|
810 |
for reference in references:
|
811 |
-
if len(
|
|
|
|
|
|
|
|
|
|
|
812 |
|
|
|
|
|
|
|
813 |
score = len(set(prediction["arguments"]).intersection(set(reference["arguments"]))) / len(set(prediction["arguments"]))
|
814 |
-
|
815 |
score = 1.0
|
816 |
-
|
817 |
-
|
|
|
|
|
|
|
818 |
|
|
|
819 |
|
820 |
-
parameter_values = 0.0
|
821 |
for reference in references:
|
822 |
value_matches = 0
|
|
|
823 |
for key, val in prediction["arguments"].items():
|
824 |
try:
|
825 |
-
|
|
|
|
|
826 |
value_matches += 1
|
827 |
except:
|
828 |
pass
|
829 |
|
830 |
if len(prediction["arguments"]) > 0:
|
831 |
-
|
832 |
score = value_matches / len(prediction["arguments"])
|
833 |
else:
|
834 |
score = 1.0
|
835 |
-
if score >
|
836 |
-
|
837 |
|
|
|
838 |
for tool in task_data["__tools__"]:
|
839 |
-
tool
|
840 |
-
|
841 |
-
for param in tool["parameters"]:
|
842 |
-
tool_params_types[param["name"]] = param["type"]
|
843 |
-
correct_parameters_types = 0
|
844 |
-
for key, value in prediction["arguments"].items():
|
845 |
-
typing_type = tool_params_types.get(key, Any)
|
846 |
-
if isoftype(value, typing_type):
|
847 |
-
correct_parameters_types += 1
|
848 |
-
if len(prediction["arguments"]) > 0:
|
849 |
-
parameters_types = correct_parameters_types / len(prediction["arguments"])
|
850 |
-
else:
|
851 |
-
parameters_types = 1.0
|
852 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
853 |
|
854 |
return {
|
855 |
self.main_score: exact_match,
|
856 |
-
"
|
857 |
-
"
|
858 |
-
"
|
859 |
-
"
|
|
|
860 |
}
|
861 |
|
862 |
|
@@ -3499,7 +3516,7 @@ class CustomF1(GlobalMetric):
|
|
3499 |
class KeyValueExtraction(GlobalMetric):
|
3500 |
prediction_type = Dict[str, str]
|
3501 |
metric: Metric
|
3502 |
-
single_reference_per_prediction =
|
3503 |
main_score = ""
|
3504 |
|
3505 |
def prepare(self):
|
@@ -3575,6 +3592,33 @@ class KeyValueExtraction(GlobalMetric):
|
|
3575 |
|
3576 |
return result
|
3577 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3578 |
|
3579 |
class NER(CustomF1):
|
3580 |
"""F1 Metrics that receives as input a list of (Entity,EntityType) pairs."""
|
|
|
63 |
from .random_utils import get_seed
|
64 |
from .settings_utils import get_settings
|
65 |
from .stream import MultiStream, Stream
|
|
|
66 |
from .type_utils import Type, isoftype, parse_type_string, to_type_string
|
67 |
from .types import ToolCall
|
68 |
from .utils import deep_copy, recursive_copy, retry_connection_with_exponential_backoff
|
|
|
788 |
return result
|
789 |
|
790 |
class ToolCallingMetric(ReductionInstanceMetric[str, Dict[str, float]]):
|
791 |
+
"""Compares each predicted tool call with list of references tool call."""
|
792 |
main_score = "exact_match"
|
793 |
reduction = MeanReduction()
|
794 |
prediction_type = ToolCall
|
795 |
+
_requirements_list = ["jsonschema-rs"]
|
796 |
+
|
797 |
+
def prepare(self):
|
798 |
+
super().prepare()
|
799 |
+
import jsonschema_rs
|
800 |
+
self._schema = jsonschema_rs
|
801 |
|
802 |
def map(
|
803 |
self, prediction: ToolCall, references: List[ToolCall], task_data: Dict[str, Any]
|
804 |
) -> Dict[str, float]:
|
805 |
|
|
|
806 |
exact_match = float(
|
807 |
+
json.dumps(prediction, sort_keys=True) in [json.dumps(reference, sort_keys=True) for reference in references]
|
808 |
)
|
809 |
|
810 |
+
tool_name_accuracy = float(
|
811 |
str(prediction["name"]) in [str(reference["name"]) for reference in references]
|
812 |
)
|
813 |
|
814 |
+
argument_name_recall = 0.0
|
815 |
for reference in references:
|
816 |
+
if len(reference["arguments"]) > 0:
|
817 |
+
score = len(set(prediction["arguments"]).intersection(set(reference["arguments"]))) / len(set(reference["arguments"]))
|
818 |
+
else:
|
819 |
+
score = 1.0
|
820 |
+
if score > argument_name_recall:
|
821 |
+
argument_name_recall = score
|
822 |
|
823 |
+
argument_name_precision = 0.0
|
824 |
+
for reference in references:
|
825 |
+
if len(prediction["arguments"]) > 0:
|
826 |
score = len(set(prediction["arguments"]).intersection(set(reference["arguments"]))) / len(set(prediction["arguments"]))
|
827 |
+
elif len(reference["arguments"]) == 0:
|
828 |
score = 1.0
|
829 |
+
else:
|
830 |
+
score = 0.0
|
831 |
+
if score > argument_name_precision:
|
832 |
+
argument_name_precision = score
|
833 |
+
|
834 |
|
835 |
+
argument_value_precision = 0.0
|
836 |
|
|
|
837 |
for reference in references:
|
838 |
value_matches = 0
|
839 |
+
|
840 |
for key, val in prediction["arguments"].items():
|
841 |
try:
|
842 |
+
predicted = json.dumps(val, sort_keys=True)
|
843 |
+
target = json.dumps(reference["arguments"][key], sort_keys=True)
|
844 |
+
if predicted == target:
|
845 |
value_matches += 1
|
846 |
except:
|
847 |
pass
|
848 |
|
849 |
if len(prediction["arguments"]) > 0:
|
|
|
850 |
score = value_matches / len(prediction["arguments"])
|
851 |
else:
|
852 |
score = 1.0
|
853 |
+
if score > argument_value_precision:
|
854 |
+
argument_value_precision = score
|
855 |
|
856 |
+
parameters = None
|
857 |
for tool in task_data["__tools__"]:
|
858 |
+
if tool["function"]["name"] == prediction["name"]:
|
859 |
+
parameters = tool["function"]["parameters"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
860 |
|
861 |
+
if parameters is None:
|
862 |
+
argument_schema_validation = 0.0
|
863 |
+
else:
|
864 |
+
try:
|
865 |
+
self._schema.validate(parameters, prediction["arguments"], )
|
866 |
+
argument_schema_validation = 1.0
|
867 |
+
except self._schema.ValidationError:
|
868 |
+
argument_schema_validation = 0.0
|
869 |
|
870 |
return {
|
871 |
self.main_score: exact_match,
|
872 |
+
"tool_name_accuracy": tool_name_accuracy,
|
873 |
+
"argument_name_recall": argument_name_recall,
|
874 |
+
"argument_name_precision": argument_name_precision,
|
875 |
+
"argument_value_precision": argument_value_precision,
|
876 |
+
"argument_schema_validation": argument_schema_validation,
|
877 |
}
|
878 |
|
879 |
|
|
|
3516 |
class KeyValueExtraction(GlobalMetric):
|
3517 |
prediction_type = Dict[str, str]
|
3518 |
metric: Metric
|
3519 |
+
single_reference_per_prediction = False
|
3520 |
main_score = ""
|
3521 |
|
3522 |
def prepare(self):
|
|
|
3592 |
|
3593 |
return result
|
3594 |
|
3595 |
+
class ToolCallKeyValueExtraction(KeyValueExtraction):
|
3596 |
+
prediction_type = ToolCall
|
3597 |
+
|
3598 |
+
def flatten_dict(self,nested_dict, parent_key="", sep="."):
|
3599 |
+
flat_dict = {}
|
3600 |
+
for k, v in nested_dict.items():
|
3601 |
+
new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
3602 |
+
if isinstance(v, list):
|
3603 |
+
for e in v:
|
3604 |
+
if isinstance(e,dict):
|
3605 |
+
flat_dict.update(self.flatten_dict(e, new_key, sep=sep))
|
3606 |
+
elif isinstance(v, dict):
|
3607 |
+
flat_dict.update(self.flatten_dict(v, new_key, sep=sep))
|
3608 |
+
else:
|
3609 |
+
flat_dict[new_key] = v
|
3610 |
+
return flat_dict
|
3611 |
+
|
3612 |
+
def compute(
|
3613 |
+
self,
|
3614 |
+
references: List[List[ToolCall]],
|
3615 |
+
predictions: List[ToolCall],
|
3616 |
+
task_data: List[Dict],
|
3617 |
+
) -> dict:
|
3618 |
+
return super().compute([[ self.flatten_dict(r) for r in ref ] for ref in references],
|
3619 |
+
[ self.flatten_dict(p) for p in predictions],task_data)
|
3620 |
+
|
3621 |
+
|
3622 |
|
3623 |
class NER(CustomF1):
|
3624 |
"""F1 Metrics that receives as input a list of (Entity,EntityType) pairs."""
|
operators.py
CHANGED
@@ -283,6 +283,53 @@ class Set(InstanceOperator):
|
|
283 |
dict_set(instance, key, value)
|
284 |
return instance
|
285 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
286 |
|
287 |
@deprecation(version="2.0.0", alternative=Set)
|
288 |
class AddFields(Set):
|
|
|
283 |
dict_set(instance, key, value)
|
284 |
return instance
|
285 |
|
286 |
+
def recursive_key_value_replace(data, target_key, value_map, value_remove=None):
|
287 |
+
"""Recursively traverses a data structure (dicts and lists), replaces values of target_key using value_map, and removes values listed in value_remove.
|
288 |
+
|
289 |
+
Args:
|
290 |
+
data: The data structure (dict or list) to traverse.
|
291 |
+
target_key: The specific key whose value needs to be checked and replaced or removed.
|
292 |
+
value_map: A dictionary mapping old values to new values.
|
293 |
+
value_remove: A list of values to completely remove if found as values of target_key.
|
294 |
+
|
295 |
+
Returns:
|
296 |
+
The modified data structure. Modification is done in-place.
|
297 |
+
"""
|
298 |
+
if value_remove is None:
|
299 |
+
value_remove = []
|
300 |
+
|
301 |
+
if isinstance(data, dict):
|
302 |
+
keys_to_delete = []
|
303 |
+
for key, value in data.items():
|
304 |
+
if key == target_key:
|
305 |
+
if isinstance(value, list):
|
306 |
+
data[key] = [
|
307 |
+
value_map.get(item, item)
|
308 |
+
for item in value
|
309 |
+
if not isinstance(item, dict) and item not in value_remove
|
310 |
+
]
|
311 |
+
elif isinstance(value, dict):
|
312 |
+
pass # Skip or handle dict values if needed
|
313 |
+
elif value in value_remove:
|
314 |
+
keys_to_delete.append(key)
|
315 |
+
elif value in value_map:
|
316 |
+
data[key] = value_map[value]
|
317 |
+
else:
|
318 |
+
recursive_key_value_replace(value, target_key, value_map, value_remove)
|
319 |
+
for key in keys_to_delete:
|
320 |
+
del data[key]
|
321 |
+
elif isinstance(data, list):
|
322 |
+
for item in data:
|
323 |
+
recursive_key_value_replace(item, target_key, value_map, value_remove)
|
324 |
+
return data
|
325 |
+
|
326 |
+
class RecursiveReplace(InstanceOperator):
|
327 |
+
key: str
|
328 |
+
map_values: dict
|
329 |
+
remove_values: Optional[list] = None
|
330 |
+
|
331 |
+
def process(self, instance: Dict[str, Any], stream_name: Optional[str] = None) -> Dict[str, Any]:
|
332 |
+
return recursive_key_value_replace(instance, self.key, self.map_values, self.remove_values)
|
333 |
|
334 |
@deprecation(version="2.0.0", alternative=Set)
|
335 |
class AddFields(Set):
|
serializers.py
CHANGED
@@ -7,7 +7,6 @@ from typing import Any, Dict, List, Union
|
|
7 |
from .dataclass import AbstractField, Field
|
8 |
from .operators import InstanceFieldOperator
|
9 |
from .settings_utils import get_constants
|
10 |
-
from .tool_calling import convert_to_chat_api_format
|
11 |
from .type_utils import isoftype, to_type_string
|
12 |
from .types import (
|
13 |
Dialog,
|
@@ -168,24 +167,20 @@ class MultiDocumentSerializer(DocumentSerializer):
|
|
168 |
class ToolsSerializer(SingleTypeSerializer):
|
169 |
|
170 |
serialized_type = List[Tool]
|
171 |
-
_requirements_list: List[str] = ["pydantic"]
|
172 |
|
173 |
def serialize(self, value: List[Tool], instance: Dict[str, Any]) -> str:
|
174 |
if "__tools__" not in instance:
|
175 |
instance["__tools__"] = []
|
176 |
tool = []
|
177 |
for tool in value:
|
178 |
-
chat_api_tool = convert_to_chat_api_format(tool=tool)
|
179 |
instance["__tools__"].append(
|
180 |
-
|
181 |
)
|
182 |
-
tool["parameters"] = chat_api_tool["function"]["parameters"]
|
183 |
return json.dumps(instance["__tools__"], indent=4)
|
184 |
|
185 |
class ToolCallSerializer(SingleTypeSerializer):
|
186 |
|
187 |
serialized_type = ToolCall
|
188 |
-
_requirements_list: List[str] = ["pydantic"]
|
189 |
|
190 |
def serialize(self, value: ToolCall, instance: Dict[str, Any]) -> str:
|
191 |
return json.dumps(value)
|
|
|
7 |
from .dataclass import AbstractField, Field
|
8 |
from .operators import InstanceFieldOperator
|
9 |
from .settings_utils import get_constants
|
|
|
10 |
from .type_utils import isoftype, to_type_string
|
11 |
from .types import (
|
12 |
Dialog,
|
|
|
167 |
class ToolsSerializer(SingleTypeSerializer):
|
168 |
|
169 |
serialized_type = List[Tool]
|
|
|
170 |
|
171 |
def serialize(self, value: List[Tool], instance: Dict[str, Any]) -> str:
|
172 |
if "__tools__" not in instance:
|
173 |
instance["__tools__"] = []
|
174 |
tool = []
|
175 |
for tool in value:
|
|
|
176 |
instance["__tools__"].append(
|
177 |
+
{"type": "function", "function": tool}
|
178 |
)
|
|
|
179 |
return json.dumps(instance["__tools__"], indent=4)
|
180 |
|
181 |
class ToolCallSerializer(SingleTypeSerializer):
|
182 |
|
183 |
serialized_type = ToolCall
|
|
|
184 |
|
185 |
def serialize(self, value: ToolCall, instance: Dict[str, Any]) -> str:
|
186 |
return json.dumps(value)
|
struct_data_operators.py
CHANGED
@@ -43,7 +43,7 @@ from .operators import FieldOperator, InstanceOperator
|
|
43 |
from .random_utils import new_random_generator
|
44 |
from .serializers import ImageSerializer, TableSerializer
|
45 |
from .type_utils import isoftype
|
46 |
-
from .types import Table
|
47 |
from .utils import recursive_copy
|
48 |
|
49 |
|
@@ -754,6 +754,26 @@ class LoadJson(FieldOperator):
|
|
754 |
return json.loads(value, strict=False)
|
755 |
|
756 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
757 |
class DumpJson(FieldOperator):
|
758 |
def process_value(self, value: str) -> str:
|
759 |
return json.dumps(value)
|
|
|
43 |
from .random_utils import new_random_generator
|
44 |
from .serializers import ImageSerializer, TableSerializer
|
45 |
from .type_utils import isoftype
|
46 |
+
from .types import Table, ToolCall
|
47 |
from .utils import recursive_copy
|
48 |
|
49 |
|
|
|
754 |
return json.loads(value, strict=False)
|
755 |
|
756 |
|
757 |
+
class ToolCallPostProcessor(FieldOperator):
|
758 |
+
failure_value: Any = None
|
759 |
+
allow_failure: bool = False
|
760 |
+
def process_value(self, value: str) -> ToolCall:
|
761 |
+
if self.allow_failure:
|
762 |
+
try:
|
763 |
+
result = json.loads(value)
|
764 |
+
except json.JSONDecodeError:
|
765 |
+
return self.failure_value
|
766 |
+
else:
|
767 |
+
result = json.loads(value, strict=False)
|
768 |
+
if isoftype(result, List[ToolCall]):
|
769 |
+
if len(result) > 1:
|
770 |
+
UnitxtWarning(f"More than one tool returned from model: {result}" )
|
771 |
+
return self.failure_value
|
772 |
+
return result[0]
|
773 |
+
if not isoftype(result, ToolCall):
|
774 |
+
return self.failure_value
|
775 |
+
return result
|
776 |
+
|
777 |
class DumpJson(FieldOperator):
|
778 |
def process_value(self, value: str) -> str:
|
779 |
return json.dumps(value)
|
tool_calling.py
DELETED
@@ -1,119 +0,0 @@
|
|
1 |
-
from typing import Any, Dict, List, Type
|
2 |
-
|
3 |
-
from .operators import FieldOperator
|
4 |
-
from .types import Parameter, Tool
|
5 |
-
|
6 |
-
|
7 |
-
def convert_to_chat_api_format(tool: Tool) -> Dict[str, Any]:
|
8 |
-
|
9 |
-
from pydantic import create_model
|
10 |
-
|
11 |
-
field_definitions = {}
|
12 |
-
for param in tool["parameters"]:
|
13 |
-
param_name = param["name"]
|
14 |
-
param_type = param.get("type", Any)
|
15 |
-
field_definitions[param_name] = (param_type, ...) # ... means required in Pydantic
|
16 |
-
|
17 |
-
model = create_model(f"{tool['name']}Params", **field_definitions)
|
18 |
-
|
19 |
-
schema = model.model_json_schema()
|
20 |
-
|
21 |
-
return {
|
22 |
-
"type": "function",
|
23 |
-
"function": {
|
24 |
-
"name": tool["name"],
|
25 |
-
"description": tool["description"],
|
26 |
-
"parameters": schema
|
27 |
-
}
|
28 |
-
}
|
29 |
-
|
30 |
-
|
31 |
-
def convert_chat_api_format_to_tool(chat_api_tool: Dict[str, Any]) -> Tool:
|
32 |
-
"""Convert a Chat API formatted tool back to the original Tool structure.
|
33 |
-
|
34 |
-
Args:
|
35 |
-
chat_api_tool: A dictionary representing a tool in Chat API format
|
36 |
-
|
37 |
-
Returns:
|
38 |
-
A Tool dictionary with name, description, and parameters
|
39 |
-
"""
|
40 |
-
# Extract function information
|
41 |
-
function_info = chat_api_tool.get("function", {})
|
42 |
-
name = function_info.get("name", chat_api_tool.get("name", ""))
|
43 |
-
description = function_info.get("description", chat_api_tool.get("description", ""))
|
44 |
-
|
45 |
-
# Extract parameters from schema
|
46 |
-
parameters: List[Parameter] = []
|
47 |
-
schema = function_info.get("parameters", chat_api_tool.get("parameters", ""))
|
48 |
-
properties = schema.get("properties", {})
|
49 |
-
|
50 |
-
for param_name, param_schema in properties.items():
|
51 |
-
# Map JSON schema type to Python type
|
52 |
-
param_type = json_schema_to_python_type(param_schema)
|
53 |
-
|
54 |
-
parameter: Parameter = {
|
55 |
-
"name": param_name,
|
56 |
-
"type": param_type
|
57 |
-
}
|
58 |
-
parameters.append(parameter)
|
59 |
-
|
60 |
-
# Construct and return the Tool
|
61 |
-
tool: Tool = {
|
62 |
-
"name": name,
|
63 |
-
"description": description,
|
64 |
-
"parameters": parameters
|
65 |
-
}
|
66 |
-
|
67 |
-
return tool
|
68 |
-
|
69 |
-
def json_schema_to_python_type(schema: Dict[str, Any]) -> Type:
|
70 |
-
"""Convert JSON schema type to Python type."""
|
71 |
-
from typing import Any, Dict, List, Union
|
72 |
-
|
73 |
-
schema_type = schema.get("type")
|
74 |
-
|
75 |
-
# Handle simple types
|
76 |
-
simple_types = {
|
77 |
-
"string": str,
|
78 |
-
"integer": int,
|
79 |
-
"number": float,
|
80 |
-
"boolean": bool,
|
81 |
-
"null": type(None)
|
82 |
-
}
|
83 |
-
|
84 |
-
if schema_type in simple_types:
|
85 |
-
return simple_types[schema_type]
|
86 |
-
|
87 |
-
# Handle arrays
|
88 |
-
if schema_type == "array":
|
89 |
-
items = schema.get("items", {})
|
90 |
-
if not items:
|
91 |
-
return List[Any]
|
92 |
-
|
93 |
-
item_type = json_schema_to_python_type(items)
|
94 |
-
return List[item_type]
|
95 |
-
|
96 |
-
# Handle objects
|
97 |
-
if schema_type == "object":
|
98 |
-
return Dict[str, Any]
|
99 |
-
|
100 |
-
# Handle unions with anyOf/oneOf
|
101 |
-
if "anyOf" in schema or "oneOf" in schema:
|
102 |
-
union_schemas = schema.get("anyOf", []) or schema.get("oneOf", [])
|
103 |
-
union_types = [json_schema_to_python_type(s) for s in union_schemas]
|
104 |
-
# Use Union for Python 3.9+ or create Union using typing module
|
105 |
-
return Union[tuple(union_types)] if union_types else Any
|
106 |
-
|
107 |
-
# Handle references (simplified)
|
108 |
-
if "$ref" in schema:
|
109 |
-
# In a real implementation, you'd resolve references
|
110 |
-
return Any
|
111 |
-
|
112 |
-
# Default to Any for unrecognized schema types
|
113 |
-
return Any
|
114 |
-
|
115 |
-
|
116 |
-
class ToTool(FieldOperator):
|
117 |
-
|
118 |
-
def process_value(self, value: Dict[str, Any]) -> Tool:
|
119 |
-
return convert_chat_api_format_to_tool(value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
type_utils.py
CHANGED
@@ -27,7 +27,7 @@ _registered_types = {
|
|
27 |
def register_type(new_type):
|
28 |
assert is_new_type(new_type) or is_typed_dict(
|
29 |
new_type
|
30 |
-
), "Can register only typing.NewType or typing.TypedDict"
|
31 |
_registered_types[new_type.__name__] = new_type
|
32 |
|
33 |
|
@@ -489,6 +489,9 @@ def isoftype(object, typing_type):
|
|
489 |
if not is_type(typing_type):
|
490 |
raise UnsupportedTypeError(typing_type)
|
491 |
|
|
|
|
|
|
|
492 |
if typing_type is typing.Type:
|
493 |
return is_type(object)
|
494 |
|
@@ -1066,9 +1069,18 @@ def verify_required_schema(
|
|
1066 |
f"{class_name} description: {description}"
|
1067 |
) from e
|
1068 |
|
1069 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1070 |
raise ValueError(
|
1071 |
-
f"Passed value
|
1072 |
f"of required type: ({to_type_string(data_type)}) in {class_name} ('{id}').\n"
|
1073 |
f"{class_name} description: {description}"
|
1074 |
)
|
|
|
27 |
def register_type(new_type):
|
28 |
assert is_new_type(new_type) or is_typed_dict(
|
29 |
new_type
|
30 |
+
) or hasattr(new_type, "__verify_type__"), "Can register only typing.NewType or typing.TypedDict or object with __verify_type__ class function"
|
31 |
_registered_types[new_type.__name__] = new_type
|
32 |
|
33 |
|
|
|
489 |
if not is_type(typing_type):
|
490 |
raise UnsupportedTypeError(typing_type)
|
491 |
|
492 |
+
if hasattr(typing_type, "__verify_type__"):
|
493 |
+
return typing_type.__verify_type__(object)
|
494 |
+
|
495 |
if typing_type is typing.Type:
|
496 |
return is_type(object)
|
497 |
|
|
|
1069 |
f"{class_name} description: {description}"
|
1070 |
) from e
|
1071 |
|
1072 |
+
try:
|
1073 |
+
valid = isoftype(value, data_type)
|
1074 |
+
except Exception as e:
|
1075 |
+
raise ValueError(
|
1076 |
+
f"Passed value {value} of field '{field_name}' is not "
|
1077 |
+
f"of required type: ({to_type_string(data_type)}) in {class_name} ('{id}').\n"
|
1078 |
+
f"{class_name} description: {description}\nReason:\n{e}"
|
1079 |
+
) from e
|
1080 |
+
|
1081 |
+
if not valid:
|
1082 |
raise ValueError(
|
1083 |
+
f"Passed value {value} of field '{field_name}' is not "
|
1084 |
f"of required type: ({to_type_string(data_type)}) in {class_name} ('{id}').\n"
|
1085 |
f"{class_name} description: {description}"
|
1086 |
)
|
types.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from typing import Any, Dict, List, Literal, NewType, Optional,
|
2 |
|
3 |
from .type_utils import register_type
|
4 |
|
@@ -51,14 +51,20 @@ class SQLDatabase(TypedDict):
|
|
51 |
dbms: Optional[str]
|
52 |
data: Optional[Dict[str, Dict]]
|
53 |
|
54 |
-
class
|
55 |
-
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
class Tool(TypedDict):
|
59 |
name: str
|
60 |
description: str
|
61 |
-
parameters:
|
62 |
|
63 |
class ToolCall(TypedDict):
|
64 |
name: str
|
@@ -76,7 +82,7 @@ register_type(Document)
|
|
76 |
register_type(MultiDocument)
|
77 |
register_type(RagResponse)
|
78 |
register_type(SQLDatabase)
|
79 |
-
register_type(Parameter)
|
80 |
register_type(Tool)
|
|
|
81 |
register_type(ToolCall)
|
82 |
|
|
|
1 |
+
from typing import Any, Dict, List, Literal, NewType, Optional, TypedDict, Union
|
2 |
|
3 |
from .type_utils import register_type
|
4 |
|
|
|
51 |
dbms: Optional[str]
|
52 |
data: Optional[Dict[str, Dict]]
|
53 |
|
54 |
+
class JsonSchema:
|
55 |
+
|
56 |
+
@classmethod
|
57 |
+
def __verify_type__(cls, object):
|
58 |
+
if not isinstance(object, dict):
|
59 |
+
return False
|
60 |
+
import jsonschema_rs
|
61 |
+
jsonschema_rs.meta.validate(object)
|
62 |
+
return True
|
63 |
|
64 |
class Tool(TypedDict):
|
65 |
name: str
|
66 |
description: str
|
67 |
+
parameters: JsonSchema
|
68 |
|
69 |
class ToolCall(TypedDict):
|
70 |
name: str
|
|
|
82 |
register_type(MultiDocument)
|
83 |
register_type(RagResponse)
|
84 |
register_type(SQLDatabase)
|
|
|
85 |
register_type(Tool)
|
86 |
+
register_type(JsonSchema)
|
87 |
register_type(ToolCall)
|
88 |
|
version.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
version = "1.
|
|
|
1 |
+
version = "1.23.0"
|