Upload folder using huggingface_hub
Browse files- api.py +21 -13
- artifact.py +29 -21
- augmentors.py +2 -1
- dataclass.py +8 -1
- dialog_operators.py +7 -5
- dict_utils.py +20 -15
- formats.py +20 -32
- image_operators.py +10 -7
- inference.py +107 -102
- llm_as_judge.py +40 -12
- loaders.py +26 -11
- metrics.py +172 -35
- operators.py +123 -105
- span_lableing_operators.py +22 -16
- struct_data_operators.py +125 -86
- task.py +16 -12
- templates.py +17 -8
- type_utils.py +13 -12
- utils.py +2 -2
- version.py +1 -1
api.py
CHANGED
|
@@ -93,31 +93,39 @@ def load_dataset(
|
|
| 93 |
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
|
| 94 |
"""Loads dataset.
|
| 95 |
|
| 96 |
-
If the 'dataset_query' argument is provided, then dataset is loaded from a card
|
| 97 |
-
catalog based on parameters specified in the query.
|
| 98 |
-
|
|
|
|
|
|
|
| 99 |
|
| 100 |
Args:
|
| 101 |
dataset_query (str, optional): A string query which specifies a dataset to load from local catalog or name of specific recipe or benchmark in the catalog.
|
| 102 |
-
|
| 103 |
-
|
| 104 |
streaming (bool, False): When True yields the data as Unitxt streams dictionary
|
|
|
|
| 105 |
split (str, optional): The split of the data to load
|
|
|
|
| 106 |
disable_cache (str, optional): Disable caching process of the data
|
|
|
|
| 107 |
**kwargs: Arguments used to load dataset from provided card, which is not present in local catalog.
|
| 108 |
|
| 109 |
Returns:
|
| 110 |
DatasetDict
|
| 111 |
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
-
card = TaskCard(...)
|
| 118 |
-
template = Template(...)
|
| 119 |
-
loader_limit = 10
|
| 120 |
-
dataset = load_dataset(card=card, template=template, loader_limit=loader_limit)
|
| 121 |
"""
|
| 122 |
recipe = load_recipe(dataset_query, **kwargs)
|
| 123 |
|
|
|
|
| 93 |
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
|
| 94 |
"""Loads dataset.
|
| 95 |
|
| 96 |
+
If the 'dataset_query' argument is provided, then dataset is loaded from a card
|
| 97 |
+
in local catalog based on parameters specified in the query.
|
| 98 |
+
|
| 99 |
+
Alternatively, dataset is loaded from a provided card based on explicitly
|
| 100 |
+
given parameters.
|
| 101 |
|
| 102 |
Args:
|
| 103 |
dataset_query (str, optional): A string query which specifies a dataset to load from local catalog or name of specific recipe or benchmark in the catalog.
|
| 104 |
+
For example: ``"card=cards.wnli,template=templates.classification.multi_class.relation.default".``
|
| 105 |
+
|
| 106 |
streaming (bool, False): When True yields the data as Unitxt streams dictionary
|
| 107 |
+
|
| 108 |
split (str, optional): The split of the data to load
|
| 109 |
+
|
| 110 |
disable_cache (str, optional): Disable caching process of the data
|
| 111 |
+
|
| 112 |
**kwargs: Arguments used to load dataset from provided card, which is not present in local catalog.
|
| 113 |
|
| 114 |
Returns:
|
| 115 |
DatasetDict
|
| 116 |
|
| 117 |
+
Example:
|
| 118 |
+
.. code-block:: python
|
| 119 |
+
|
| 120 |
+
dataset = load_dataset(
|
| 121 |
+
dataset_query="card=cards.stsb,template=templates.regression.two_texts.simple,max_train_instances=5"
|
| 122 |
+
) # card must be present in local catalog
|
| 123 |
+
|
| 124 |
+
card = TaskCard(...)
|
| 125 |
+
template = Template(...)
|
| 126 |
+
loader_limit = 10
|
| 127 |
+
dataset = load_dataset(card=card, template=template, loader_limit=loader_limit)
|
| 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
"""
|
| 130 |
recipe = load_recipe(dataset_query, **kwargs)
|
| 131 |
|
artifact.py
CHANGED
|
@@ -89,16 +89,18 @@ class Catalogs:
|
|
| 89 |
self.catalogs = []
|
| 90 |
|
| 91 |
|
| 92 |
-
def
|
| 93 |
-
if
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
|
|
|
|
|
|
| 102 |
|
| 103 |
|
| 104 |
def get_closest_artifact_type(type):
|
|
@@ -150,8 +152,12 @@ class Artifact(Dataclass):
|
|
| 150 |
)
|
| 151 |
|
| 152 |
@classmethod
|
| 153 |
-
def is_artifact_dict(cls,
|
| 154 |
-
return isinstance(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
|
| 156 |
@classmethod
|
| 157 |
def verify_artifact_dict(cls, d):
|
|
@@ -292,7 +298,7 @@ class Artifact(Dataclass):
|
|
| 292 |
field.type, Union[Artifact, List[Artifact], Dict[str, Artifact]]
|
| 293 |
):
|
| 294 |
value = getattr(self, field.name)
|
| 295 |
-
value =
|
| 296 |
setattr(self, field.name, value)
|
| 297 |
|
| 298 |
self.verify_data_classification_policy()
|
|
@@ -343,15 +349,18 @@ class Artifact(Dataclass):
|
|
| 343 |
|
| 344 |
Args:
|
| 345 |
instance (Dict[str, Any]): data which should contain its allowed data
|
| 346 |
-
|
|
|
|
| 347 |
name (Optional[str]): name of artifact which should be used to retrieve
|
| 348 |
-
|
| 349 |
-
|
| 350 |
|
| 351 |
Returns:
|
| 352 |
Dict[str, Any]: unchanged instance.
|
| 353 |
|
| 354 |
Examples:
|
|
|
|
|
|
|
| 355 |
instance = {"x": "some_text", "data_classification_policy": ["pii"]}
|
| 356 |
|
| 357 |
# Will raise an error as "pii" is not included policy
|
|
@@ -574,11 +583,10 @@ def reset_artifacts_json_cache():
|
|
| 574 |
artifacts_json_cache.cache_clear()
|
| 575 |
|
| 576 |
|
| 577 |
-
def maybe_recover_artifact(
|
| 578 |
-
if
|
| 579 |
-
return verbosed_fetch_artifact(
|
| 580 |
-
|
| 581 |
-
return artifact
|
| 582 |
|
| 583 |
|
| 584 |
def register_all_artifacts(path):
|
|
|
|
| 89 |
self.catalogs = []
|
| 90 |
|
| 91 |
|
| 92 |
+
def maybe_recover_artifacts_structure(obj):
|
| 93 |
+
if Artifact.is_possible_identifier(obj):
|
| 94 |
+
return verbosed_fetch_artifact(obj)
|
| 95 |
+
if isinstance(obj, dict):
|
| 96 |
+
for key, value in obj.items():
|
| 97 |
+
obj[key] = maybe_recover_artifact(value)
|
| 98 |
+
return obj
|
| 99 |
+
if isinstance(obj, list):
|
| 100 |
+
for i in range(len(obj)):
|
| 101 |
+
obj[i] = maybe_recover_artifact(obj[i])
|
| 102 |
+
return obj
|
| 103 |
+
return obj
|
| 104 |
|
| 105 |
|
| 106 |
def get_closest_artifact_type(type):
|
|
|
|
| 152 |
)
|
| 153 |
|
| 154 |
@classmethod
|
| 155 |
+
def is_artifact_dict(cls, obj):
|
| 156 |
+
return isinstance(obj, dict) and "__type__" in obj
|
| 157 |
+
|
| 158 |
+
@classmethod
|
| 159 |
+
def is_possible_identifier(cls, obj):
|
| 160 |
+
return isinstance(obj, str) or cls.is_artifact_dict(obj)
|
| 161 |
|
| 162 |
@classmethod
|
| 163 |
def verify_artifact_dict(cls, d):
|
|
|
|
| 298 |
field.type, Union[Artifact, List[Artifact], Dict[str, Artifact]]
|
| 299 |
):
|
| 300 |
value = getattr(self, field.name)
|
| 301 |
+
value = maybe_recover_artifacts_structure(value)
|
| 302 |
setattr(self, field.name, value)
|
| 303 |
|
| 304 |
self.verify_data_classification_policy()
|
|
|
|
| 349 |
|
| 350 |
Args:
|
| 351 |
instance (Dict[str, Any]): data which should contain its allowed data
|
| 352 |
+
classification policies under key 'data_classification_policy'.
|
| 353 |
+
|
| 354 |
name (Optional[str]): name of artifact which should be used to retrieve
|
| 355 |
+
data classification from env. If not specified, then either ``__id__`` or
|
| 356 |
+
``__class__.__name__``, are used instead, respectively.
|
| 357 |
|
| 358 |
Returns:
|
| 359 |
Dict[str, Any]: unchanged instance.
|
| 360 |
|
| 361 |
Examples:
|
| 362 |
+
.. code-block:: python
|
| 363 |
+
|
| 364 |
instance = {"x": "some_text", "data_classification_policy": ["pii"]}
|
| 365 |
|
| 366 |
# Will raise an error as "pii" is not included policy
|
|
|
|
| 583 |
artifacts_json_cache.cache_clear()
|
| 584 |
|
| 585 |
|
| 586 |
+
def maybe_recover_artifact(obj):
|
| 587 |
+
if Artifact.is_possible_identifier(obj):
|
| 588 |
+
return verbosed_fetch_artifact(obj)
|
| 589 |
+
return obj
|
|
|
|
| 590 |
|
| 591 |
|
| 592 |
def register_all_artifacts(path):
|
augmentors.py
CHANGED
|
@@ -97,7 +97,8 @@ class AugmentPrefixSuffix(TextAugmentor):
|
|
| 97 |
To prepend the input with a prefix made of 4 ``\n``-s or ``\t``-s, employ
|
| 98 |
``AugmentPrefixSuffix(augment_model_input=True, prefixes=['\n','\t'], prefix_len=4, suffixes = None)``.
|
| 99 |
|
| 100 |
-
To append the input with a suffix made of 3 ``\n``-s or ``\t``-s, with ``\n`` being preferred over ``\t``,
|
|
|
|
| 101 |
``AugmentPrefixSuffix(augment_model_input=True, suffixes={'\n':2,'\t':1}, suffix_len=3, prefixes = None)``
|
| 102 |
which will append ``\n``-s twice as often as ``\t``-s.
|
| 103 |
|
|
|
|
| 97 |
To prepend the input with a prefix made of 4 ``\n``-s or ``\t``-s, employ
|
| 98 |
``AugmentPrefixSuffix(augment_model_input=True, prefixes=['\n','\t'], prefix_len=4, suffixes = None)``.
|
| 99 |
|
| 100 |
+
To append the input with a suffix made of 3 ``\n``-s or ``\t``-s, with ``\n`` being preferred over ``\t``,
|
| 101 |
+
at 2:1 ratio, employ
|
| 102 |
``AugmentPrefixSuffix(augment_model_input=True, suffixes={'\n':2,'\t':1}, suffix_len=3, prefixes = None)``
|
| 103 |
which will append ``\n``-s twice as often as ``\t``-s.
|
| 104 |
|
dataclass.py
CHANGED
|
@@ -533,6 +533,13 @@ class Dataclass(metaclass=DataclassMeta):
|
|
| 533 |
if keep_empty or value is not None
|
| 534 |
}
|
| 535 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 536 |
def __repr__(self) -> str:
|
| 537 |
"""String representation."""
|
| 538 |
-
return f"{self.__class__.__name__}({', '.join([f'{
|
|
|
|
| 533 |
if keep_empty or value is not None
|
| 534 |
}
|
| 535 |
|
| 536 |
+
def get_repr_dict(self):
|
| 537 |
+
result = {}
|
| 538 |
+
for field in fields(self):
|
| 539 |
+
if not field.internal:
|
| 540 |
+
result[field.name] = getattr(self, field.name)
|
| 541 |
+
return result
|
| 542 |
+
|
| 543 |
def __repr__(self) -> str:
|
| 544 |
"""String representation."""
|
| 545 |
+
return f"{self.__class__.__name__}({', '.join([f'{key}={val!r}' for key, val in self.get_repr_dict().items()])})"
|
dialog_operators.py
CHANGED
|
@@ -5,11 +5,13 @@ text that can be fed to the model.
|
|
| 5 |
|
| 6 |
The format of the dialog is:
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
| 13 |
"""
|
| 14 |
from typing import Any, Dict, List, Optional
|
| 15 |
|
|
|
|
| 5 |
|
| 6 |
The format of the dialog is:
|
| 7 |
|
| 8 |
+
.. code-block:: text
|
| 9 |
+
|
| 10 |
+
dialog = [
|
| 11 |
+
{"user": "hello", "system": "hi"},
|
| 12 |
+
{"user": "kkk", "system": ""},
|
| 13 |
+
{"user": "kkk", "system": ""},
|
| 14 |
+
]
|
| 15 |
"""
|
| 16 |
from typing import Any, Dict, List, Optional
|
| 17 |
|
dict_utils.py
CHANGED
|
@@ -24,29 +24,32 @@ def is_wildcard(string):
|
|
| 24 |
# formal definition of qpath syntax by which a query is specified:
|
| 25 |
# qpath -> A (/A)*
|
| 26 |
# A -> name | * | non-neg-int
|
| 27 |
-
# name ->
|
| 28 |
-
# *
|
| 29 |
#
|
| 30 |
-
#
|
| 31 |
-
#
|
| 32 |
-
#
|
| 33 |
-
# (
|
| 34 |
-
#
|
| 35 |
-
#
|
|
|
|
|
|
|
|
|
|
| 36 |
# and hence no path in dic matches query qpath. (E.g., when el is a list, A must match indx, and its
|
| 37 |
# int value should be smaller than len(el) in order for the path in dic leading to element el[A] to match pref/A)
|
| 38 |
-
# (3) Denoting as in (2), now with A == *
|
| 39 |
# {el[0], el[1], .. , el[len(el)-1]} is said to be lead to by a path matching pref/*
|
| 40 |
# and when el is a dict, each and every element in the set {el[k] for k being a key in el} is said to be lead
|
| 41 |
# to by a path matching pref/*
|
| 42 |
#
|
| 43 |
# An element el lead to by path p that matches qpath as a whole is thus either a list member (when indx.match the last
|
| 44 |
-
# component of p
|
| 45 |
-
# of el
|
| 46 |
#
|
| 47 |
# Thus, for a query with no *, dic contains at most one element the path to which matches the query.
|
| 48 |
# If there is such one in dic - the function (either dict_get, dict_set, or dict_delete) operates on
|
| 49 |
-
# that element according to its arguments, other than not_exist_ok
|
| 50 |
# If there is not any such element in dic - the function throws or does not throw an exception, depending
|
| 51 |
# on flag not_exist_ok.
|
| 52 |
# For a query with *, there could be up to as many as there are values to match the *
|
|
@@ -54,9 +57,9 @@ def is_wildcard(string):
|
|
| 54 |
# for more than one * in the query -- this effect multiplies)
|
| 55 |
# Each of the three functions below (dict_get, dict_set, dict_delete) applies the requested
|
| 56 |
# operation (read, set, or delete) to each and every element el in dic, the path to which matches the query in whole,
|
| 57 |
-
# and reads a value from, or sets a new value to, or pops
|
| 58 |
#
|
| 59 |
-
# If no path in dic matches the query, then
|
| 60 |
# but if not_exist_ok=True, the function returns a default value (dict_get) or does nothing (dict_delete)
|
| 61 |
# or generates all the needed missing suffixes (dict_set, see details below).
|
| 62 |
#
|
|
@@ -444,7 +447,9 @@ def dict_get(
|
|
| 444 |
)
|
| 445 |
if len(components) > 1:
|
| 446 |
try:
|
| 447 |
-
success, values = get_values(
|
|
|
|
|
|
|
| 448 |
if success:
|
| 449 |
return values
|
| 450 |
except Exception as e:
|
|
|
|
| 24 |
# formal definition of qpath syntax by which a query is specified:
|
| 25 |
# qpath -> A (/A)*
|
| 26 |
# A -> name | * | non-neg-int
|
| 27 |
+
# name -> a string satisfying is_name above.
|
| 28 |
+
# * -> ALL members (each and every) of a list or a dictionary element in the input dictionary,
|
| 29 |
#
|
| 30 |
+
# A path p in dictionary dic, leading to element (aka subfield) el, is said to match query qpath
|
| 31 |
+
# (alternatively said: query qpath matches path p in dic),
|
| 32 |
+
# if the following recursively defined condition is satisfied:
|
| 33 |
+
# (1) the prefix of length 0 of qpath (i.e., pref = "") matches the empty path in dic, the path leading to the whole of dic.
|
| 34 |
+
# (2) Denoting by el the element in dic lead to by the path in dic that matches the prefix pref of qpath
|
| 35 |
+
# (el must be a list or dictionary, since led to by a path matching a prefix of qpath, and not the whole of qpath),
|
| 36 |
+
# and by A (as the definition above) the component, DIFFERENT from *, in qpath, that follows pref, then the element
|
| 37 |
+
# lead to by the path in dic matching query pref/A is el[A]. If el[A] is missing from dic, then no path in dic matches
|
| 38 |
+
# pref/A, that is either a longer prefix of qpath, or the whole of qpath,
|
| 39 |
# and hence no path in dic matches query qpath. (E.g., when el is a list, A must match indx, and its
|
| 40 |
# int value should be smaller than len(el) in order for the path in dic leading to element el[A] to match pref/A)
|
| 41 |
+
# (3) Denoting as in (2), now with A == * , then when el is a list, each and every element in the set:
|
| 42 |
# {el[0], el[1], .. , el[len(el)-1]} is said to be lead to by a path matching pref/*
|
| 43 |
# and when el is a dict, each and every element in the set {el[k] for k being a key in el} is said to be lead
|
| 44 |
# to by a path matching pref/*
|
| 45 |
#
|
| 46 |
# An element el lead to by path p that matches qpath as a whole is thus either a list member (when indx.match the last
|
| 47 |
+
# component of p) or a dictionary item (the key of which equals the last component of p). The value
|
| 48 |
+
# of el is returned (dic_get) or el is popped (dic_delete) or el's value is replaced by a new value (dic_set).
|
| 49 |
#
|
| 50 |
# Thus, for a query with no *, dic contains at most one element the path to which matches the query.
|
| 51 |
# If there is such one in dic - the function (either dict_get, dict_set, or dict_delete) operates on
|
| 52 |
+
# that element according to its arguments, other than not_exist_ok.
|
| 53 |
# If there is not any such element in dic - the function throws or does not throw an exception, depending
|
| 54 |
# on flag not_exist_ok.
|
| 55 |
# For a query with *, there could be up to as many as there are values to match the *
|
|
|
|
| 57 |
# for more than one * in the query -- this effect multiplies)
|
| 58 |
# Each of the three functions below (dict_get, dict_set, dict_delete) applies the requested
|
| 59 |
# operation (read, set, or delete) to each and every element el in dic, the path to which matches the query in whole,
|
| 60 |
+
# and reads a value from, or sets a new value to, or pops el out from dic.
|
| 61 |
#
|
| 62 |
+
# If no path in dic matches the query, then if not_exist_ok=False, the function throws an exception;
|
| 63 |
# but if not_exist_ok=True, the function returns a default value (dict_get) or does nothing (dict_delete)
|
| 64 |
# or generates all the needed missing suffixes (dict_set, see details below).
|
| 65 |
#
|
|
|
|
| 447 |
)
|
| 448 |
if len(components) > 1:
|
| 449 |
try:
|
| 450 |
+
success, values = get_values(
|
| 451 |
+
dic, components, -1 * len(components), allow_int_index=allow_int_index
|
| 452 |
+
)
|
| 453 |
if success:
|
| 454 |
return values
|
| 455 |
except Exception as e:
|
formats.py
CHANGED
|
@@ -28,26 +28,26 @@ class Format(InstanceOperator):
|
|
| 28 |
def apply_capital_new_line_notation(text: str) -> str:
|
| 29 |
r"""Transforms a given string by applying the Capital New Line Notation.
|
| 30 |
|
| 31 |
-
The Capital New Line Notation (\N) is designed to manage newline behavior in a string efficiently.
|
| 32 |
-
This custom notation aims to consolidate multiple newline characters (\n) into a single newline under
|
| 33 |
specific conditions, with tailored handling based on whether there's preceding text. The function
|
| 34 |
distinguishes between two primary scenarios:
|
| 35 |
|
| 36 |
-
1. If there's text (referred to as a prefix) followed by any number of
|
| 37 |
-
more
|
| 38 |
newlines and notation characters into a single newline when there's preceding text.
|
| 39 |
-
|
| 40 |
-
|
|
|
|
| 41 |
applicable when the notation should not introduce any newlines due to the absence of preceding text.
|
| 42 |
|
| 43 |
Args:
|
| 44 |
-
text (str): The input string to be transformed, potentially containing the Capital New Line Notation
|
| 45 |
-
(\N) mixed with actual newline characters (\n).
|
| 46 |
|
| 47 |
Returns:
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
|
| 52 |
Examples:
|
| 53 |
>>> apply_capital_new_line_notation("Hello World\\n\\n\N")
|
|
@@ -131,27 +131,26 @@ class BaseFormat(Format):
|
|
| 131 |
class SystemFormat(BaseFormat):
|
| 132 |
r"""Generates the whole input to the model, from constant strings that are given as args, and from values found in specified fields of the instance.
|
| 133 |
|
| 134 |
-
Important: formats can use '\N' notations that means new-line if no new-line before and no empty string before.
|
| 135 |
|
| 136 |
SystemFormat expects the input instance to contain:
|
| 137 |
1. A field named "system_prompt" whose value is a string (potentially empty) that delivers a task-independent opening text.
|
| 138 |
2. A field named "source" whose value is a string verbalizing the original values in the instance (as read
|
| 139 |
from the source dataset), in the context of the underlying task.
|
| 140 |
3. A field named "instruction" that contains a (non-None) string.
|
| 141 |
-
4. A field named with the value in arg 'demos_field'
|
| 142 |
and "target", representing a single demo.
|
| 143 |
5. A field named "target_prefix" that contains a string to prefix the target in each demo, and to end the whole generated prompt
|
| 144 |
|
| 145 |
-
SystemFormat formats the above fields into a single string to be
|
| 146 |
-
field "source" of the instance. Formatting is driven by two args: 'demo_format' and 'model_input_format'
|
| 147 |
SystemFormat also pops fields "system_prompt", "instruction", "target_prefix", and the field containing the demos out from the input instance.
|
| 148 |
|
| 149 |
Args:
|
| 150 |
demos_field (str): the name of the field that contains the demos, being a list of dicts, each with "source" and "target" keys
|
| 151 |
demo_format (str): formatting string for a single demo, combining fields "source" and "target"
|
| 152 |
-
model_input_format (str) overall product format, combining instruction and source (as read from fields "instruction"
|
| 153 |
-
|
| 154 |
-
format_args: Dict[str,str]: additional format args to be used when formatting the different format strings
|
| 155 |
|
| 156 |
Example:
|
| 157 |
when input instance:
|
|
@@ -423,24 +422,13 @@ class ChatAPIFormat(BaseFormat):
|
|
| 423 |
class HFSystemFormat(ChatAPIFormat):
|
| 424 |
r"""Formats the complete input for the model using the HuggingFace chat template of a given model.
|
| 425 |
|
| 426 |
-
HFSystemFormat
|
| 427 |
-
1. A field named "system_prompt" whose value is a string (potentially empty) that delivers a task-independent opening text.
|
| 428 |
-
2. A field named "source" whose value is a string verbalizing the original values in the instance (as read
|
| 429 |
-
from the source dataset), in the context of the underlying task.
|
| 430 |
-
3. A field named "instruction" that contains a (non-None) string.
|
| 431 |
-
4. A field named with the value in arg 'demos_field', containing a list of dicts, each dict with fields "source"
|
| 432 |
-
and "target", representing a single demo.
|
| 433 |
-
5. A field named "target_prefix" that contains a string to prefix the target in each demo, and to end the whole generated prompt.
|
| 434 |
-
|
| 435 |
-
SystemFormat formats the above fields into a single string to be inputted to the model. This string overwrites
|
| 436 |
field "source" of the instance.
|
| 437 |
|
| 438 |
Example:
|
| 439 |
-
HFSystemFormat(model_name="HuggingFaceH4/zephyr-7b-beta")
|
| 440 |
-
|
| 441 |
-
Uses the template defined the in tokenizer_config.json of the model:
|
| 442 |
|
| 443 |
-
"chat_template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
|
| 444 |
|
| 445 |
See more details in https://huggingface.co/docs/transformers/main/en/chat_templating
|
| 446 |
|
|
|
|
| 28 |
def apply_capital_new_line_notation(text: str) -> str:
|
| 29 |
r"""Transforms a given string by applying the Capital New Line Notation.
|
| 30 |
|
| 31 |
+
The Capital New Line Notation ``(\N)`` is designed to manage newline behavior in a string efficiently.
|
| 32 |
+
This custom notation aims to consolidate multiple newline characters ``(\n)`` into a single newline under
|
| 33 |
specific conditions, with tailored handling based on whether there's preceding text. The function
|
| 34 |
distinguishes between two primary scenarios:
|
| 35 |
|
| 36 |
+
1. If there's text (referred to as a prefix) followed by any number of ``\n`` characters and then one or
|
| 37 |
+
more ``\N``, the entire sequence is replaced with a single ``\n``. This effectively simplifies multiple
|
| 38 |
newlines and notation characters into a single newline when there's preceding text.
|
| 39 |
+
|
| 40 |
+
2. If the string starts with ``\n`` characters followed by ``\N`` without any text before this sequence, or if
|
| 41 |
+
``\N`` is at the very beginning of the string, the sequence is completely removed. This case is
|
| 42 |
applicable when the notation should not introduce any newlines due to the absence of preceding text.
|
| 43 |
|
| 44 |
Args:
|
| 45 |
+
text (str): The input string to be transformed, potentially containing the Capital New Line Notation ``(\N)`` mixed with actual newline characters ``(\n)``.
|
|
|
|
| 46 |
|
| 47 |
Returns:
|
| 48 |
+
The string after applying the Capital New Line Notation rules, which either consolidates multiple
|
| 49 |
+
newlines and notation characters into a single newline when text precedes them, or removes the
|
| 50 |
+
notation and any preceding newlines entirely if no text is present before the notation.
|
| 51 |
|
| 52 |
Examples:
|
| 53 |
>>> apply_capital_new_line_notation("Hello World\\n\\n\N")
|
|
|
|
| 131 |
class SystemFormat(BaseFormat):
|
| 132 |
r"""Generates the whole input to the model, from constant strings that are given as args, and from values found in specified fields of the instance.
|
| 133 |
|
| 134 |
+
Important: formats can use ``'\N'`` notations that means new-line if no new-line before and no empty string before.
|
| 135 |
|
| 136 |
SystemFormat expects the input instance to contain:
|
| 137 |
1. A field named "system_prompt" whose value is a string (potentially empty) that delivers a task-independent opening text.
|
| 138 |
2. A field named "source" whose value is a string verbalizing the original values in the instance (as read
|
| 139 |
from the source dataset), in the context of the underlying task.
|
| 140 |
3. A field named "instruction" that contains a (non-None) string.
|
| 141 |
+
4. A field named with the value in arg ``'demos_field'``, containing a list of dicts, each dict with fields "source"
|
| 142 |
and "target", representing a single demo.
|
| 143 |
5. A field named "target_prefix" that contains a string to prefix the target in each demo, and to end the whole generated prompt
|
| 144 |
|
| 145 |
+
SystemFormat formats the above fields into a single string to be input to the model. This string overwrites
|
| 146 |
+
field "source" of the instance. Formatting is driven by two args: ``'demo_format'`` and ``'model_input_format'``.
|
| 147 |
SystemFormat also pops fields "system_prompt", "instruction", "target_prefix", and the field containing the demos out from the input instance.
|
| 148 |
|
| 149 |
Args:
|
| 150 |
demos_field (str): the name of the field that contains the demos, being a list of dicts, each with "source" and "target" keys
|
| 151 |
demo_format (str): formatting string for a single demo, combining fields "source" and "target"
|
| 152 |
+
model_input_format (str): overall product format, combining instruction and source (as read from fields "instruction" and "source" of the input instance), together with demos (as formatted into one string)
|
| 153 |
+
format_args (Dict[str,str]): additional format args to be used when formatting the different format strings
|
|
|
|
| 154 |
|
| 155 |
Example:
|
| 156 |
when input instance:
|
|
|
|
| 422 |
class HFSystemFormat(ChatAPIFormat):
|
| 423 |
r"""Formats the complete input for the model using the HuggingFace chat template of a given model.
|
| 424 |
|
| 425 |
+
HFSystemFormat formats instance fields into a single string to be inputted to the model. This string overwrites
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 426 |
field "source" of the instance.
|
| 427 |
|
| 428 |
Example:
|
| 429 |
+
``HFSystemFormat(model_name="HuggingFaceH4/zephyr-7b-beta")`` Uses the template defined the in tokenizer_config.json of the model:
|
|
|
|
|
|
|
| 430 |
|
| 431 |
+
``"chat_template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"``
|
| 432 |
|
| 433 |
See more details in https://huggingface.co/docs/transformers/main/en/chat_templating
|
| 434 |
|
image_operators.py
CHANGED
|
@@ -167,12 +167,14 @@ class GridLines(ImageAugmentor):
|
|
| 167 |
"""A class that overlays a fixed number of evenly spaced horizontal and vertical lines on an image.
|
| 168 |
|
| 169 |
Attributes:
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
|
|
|
|
|
|
| 173 |
|
| 174 |
Methods:
|
| 175 |
-
|
| 176 |
"""
|
| 177 |
|
| 178 |
num_lines: int = 128
|
|
@@ -207,11 +209,12 @@ class PixelNoise(ImageAugmentor):
|
|
| 207 |
"""A class that overlays a mask of randomly colored nxn squares across an image based on a specified noise rate.
|
| 208 |
|
| 209 |
Attributes:
|
| 210 |
-
|
| 211 |
-
|
|
|
|
| 212 |
|
| 213 |
Methods:
|
| 214 |
-
|
| 215 |
"""
|
| 216 |
|
| 217 |
square_size: int = 1
|
|
|
|
| 167 |
"""A class that overlays a fixed number of evenly spaced horizontal and vertical lines on an image.
|
| 168 |
|
| 169 |
Attributes:
|
| 170 |
+
num_lines (int): The number of horizontal and vertical lines to add.
|
| 171 |
+
|
| 172 |
+
line_thickness (int): Thickness of each line in pixels.
|
| 173 |
+
|
| 174 |
+
line_color (Tuple[int, int, int]): RGB color of the grid lines.
|
| 175 |
|
| 176 |
Methods:
|
| 177 |
+
process_image(image): Adds grid lines to the provided image and returns the modified image.
|
| 178 |
"""
|
| 179 |
|
| 180 |
num_lines: int = 128
|
|
|
|
| 209 |
"""A class that overlays a mask of randomly colored nxn squares across an image based on a specified noise rate.
|
| 210 |
|
| 211 |
Attributes:
|
| 212 |
+
square_size (int): Size of each square in pixels.
|
| 213 |
+
|
| 214 |
+
noise_rate (float): Proportion of the image that should be affected by noise (0 to 1).
|
| 215 |
|
| 216 |
Methods:
|
| 217 |
+
process_image(image): Adds the random square mask to the provided image and returns the modified image.
|
| 218 |
"""
|
| 219 |
|
| 220 |
square_size: int = 1
|
inference.py
CHANGED
|
@@ -23,7 +23,7 @@ from typing import (
|
|
| 23 |
Union,
|
| 24 |
)
|
| 25 |
|
| 26 |
-
from datasets import DatasetDict
|
| 27 |
from tqdm import tqdm, trange
|
| 28 |
from tqdm.asyncio import tqdm_asyncio
|
| 29 |
|
|
@@ -70,21 +70,26 @@ class TextGenerationInferenceOutput:
|
|
| 70 |
"""Contains the prediction results and metadata for the inference.
|
| 71 |
|
| 72 |
Args:
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
-
|
| 82 |
-
output_tokens (int) : number of output tokens to the model.
|
| 83 |
-
stop_reason (str): stop reason for text generation, for example "eos" (end of string).
|
| 84 |
-
seed (int): seed used by the model during generation.
|
| 85 |
-
input_text (str): input to the model.
|
| 86 |
-
model_name (str): the model_name as kept in the InferenceEngine.
|
| 87 |
-
inference_type (str): The label stating the type of the InferenceEngine.
|
| 88 |
"""
|
| 89 |
|
| 90 |
prediction: Union[str, List[Dict[str, Any]]]
|
|
@@ -103,7 +108,7 @@ class InferenceEngine(Artifact):
|
|
| 103 |
@abc.abstractmethod
|
| 104 |
def _infer(
|
| 105 |
self,
|
| 106 |
-
dataset: Union[List[Dict[str, Any]],
|
| 107 |
return_meta_data: bool = False,
|
| 108 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 109 |
"""Perform inference on the input dataset.
|
|
@@ -126,7 +131,7 @@ class InferenceEngine(Artifact):
|
|
| 126 |
|
| 127 |
def infer(
|
| 128 |
self,
|
| 129 |
-
dataset: Union[List[Dict[str, Any]],
|
| 130 |
return_meta_data: bool = False,
|
| 131 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 132 |
"""Verifies instances of a dataset and perform inference on the input dataset.
|
|
@@ -134,6 +139,10 @@ class InferenceEngine(Artifact):
|
|
| 134 |
If return_meta_data - returns a list of TextGenerationInferenceOutput, else returns a list of the string
|
| 135 |
predictions.
|
| 136 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
if return_meta_data and not hasattr(self, "get_return_object"):
|
| 138 |
raise NotImplementedError(
|
| 139 |
f"Inference engine {self.__class__.__name__} does not support return_meta_data as it "
|
|
@@ -147,7 +156,7 @@ class InferenceEngine(Artifact):
|
|
| 147 |
|
| 148 |
def _mock_infer(
|
| 149 |
self,
|
| 150 |
-
dataset: Union[List[Dict[str, Any]],
|
| 151 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 152 |
return [str(instance["source"]) for instance in dataset]
|
| 153 |
|
|
@@ -198,7 +207,7 @@ class LogProbInferenceEngine(abc.ABC, Artifact):
|
|
| 198 |
@abc.abstractmethod
|
| 199 |
def _infer_log_probs(
|
| 200 |
self,
|
| 201 |
-
dataset: Union[List[Dict[str, Any]],
|
| 202 |
return_meta_data: bool = False,
|
| 203 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
| 204 |
"""Perform inference on the input dataset that returns log probs.
|
|
@@ -211,7 +220,7 @@ class LogProbInferenceEngine(abc.ABC, Artifact):
|
|
| 211 |
|
| 212 |
def infer_log_probs(
|
| 213 |
self,
|
| 214 |
-
dataset: Union[List[Dict[str, Any]],
|
| 215 |
return_meta_data: bool = False,
|
| 216 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
| 217 |
"""Verifies instances of a dataset and performs inference that returns log probabilities of top tokens.
|
|
@@ -446,7 +455,7 @@ class HFInferenceEngineBase(
|
|
| 446 |
|
| 447 |
def infer(
|
| 448 |
self,
|
| 449 |
-
dataset: Union[List[Dict[str, Any]],
|
| 450 |
return_meta_data: bool = False,
|
| 451 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 452 |
if not self._is_loaded():
|
|
@@ -456,14 +465,14 @@ class HFInferenceEngineBase(
|
|
| 456 |
@abc.abstractmethod
|
| 457 |
def _infer(
|
| 458 |
self,
|
| 459 |
-
dataset: Union[List[Dict[str, Any]],
|
| 460 |
return_meta_data: bool = False,
|
| 461 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 462 |
raise NotImplementedError
|
| 463 |
|
| 464 |
def infer_log_probs(
|
| 465 |
self,
|
| 466 |
-
dataset: Union[List[Dict[str, Any]],
|
| 467 |
return_meta_data: bool = False,
|
| 468 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
| 469 |
if not self._is_loaded():
|
|
@@ -473,7 +482,7 @@ class HFInferenceEngineBase(
|
|
| 473 |
@abc.abstractmethod
|
| 474 |
def _infer_log_probs(
|
| 475 |
self,
|
| 476 |
-
dataset: Union[List[Dict[str, Any]],
|
| 477 |
return_meta_data: bool = False,
|
| 478 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
| 479 |
raise NotImplementedError
|
|
@@ -524,7 +533,7 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
|
|
| 524 |
|
| 525 |
def _infer_fn(
|
| 526 |
self,
|
| 527 |
-
dataset: Union[List[Dict[str, Any]],
|
| 528 |
return_meta_data: bool,
|
| 529 |
return_logprobs: bool,
|
| 530 |
) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]:
|
|
@@ -565,7 +574,7 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
|
|
| 565 |
|
| 566 |
def _infer(
|
| 567 |
self,
|
| 568 |
-
dataset: Union[List[Dict[str, Any]],
|
| 569 |
return_meta_data: bool = False,
|
| 570 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 571 |
self.verify_not_chat_api(dataset)
|
|
@@ -573,7 +582,7 @@ class HFAutoModelInferenceEngine(HFInferenceEngineBase):
|
|
| 573 |
|
| 574 |
def _infer_log_probs(
|
| 575 |
self,
|
| 576 |
-
dataset: Union[List[Dict[str, Any]],
|
| 577 |
return_meta_data: bool = False,
|
| 578 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
| 579 |
self.verify_not_chat_api(dataset)
|
|
@@ -647,7 +656,7 @@ class HFLlavaInferenceEngine(HFInferenceEngineBase):
|
|
| 647 |
|
| 648 |
def _infer_fn(
|
| 649 |
self,
|
| 650 |
-
dataset: Union[List[Dict[str, Any]],
|
| 651 |
return_meta_data: bool,
|
| 652 |
return_logprobs: bool,
|
| 653 |
) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]:
|
|
@@ -681,14 +690,14 @@ class HFLlavaInferenceEngine(HFInferenceEngineBase):
|
|
| 681 |
|
| 682 |
def _infer(
|
| 683 |
self,
|
| 684 |
-
dataset: Union[List[Dict[str, Any]],
|
| 685 |
return_meta_data: bool = False,
|
| 686 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 687 |
return self._infer_fn(dataset, return_meta_data, False)
|
| 688 |
|
| 689 |
def _infer_log_probs(
|
| 690 |
self,
|
| 691 |
-
dataset: Union[List[Dict[str, Any]],
|
| 692 |
return_meta_data: bool = False,
|
| 693 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
| 694 |
return self._infer_fn(dataset, return_meta_data, True)
|
|
@@ -879,7 +888,7 @@ class HFPipelineBasedInferenceEngine(
|
|
| 879 |
|
| 880 |
def _infer(
|
| 881 |
self,
|
| 882 |
-
dataset: Union[List[Dict[str, Any]],
|
| 883 |
return_meta_data: bool = False,
|
| 884 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 885 |
if not self._is_loaded():
|
|
@@ -933,13 +942,13 @@ class MockInferenceEngine(InferenceEngine, LogProbInferenceEngine):
|
|
| 933 |
|
| 934 |
def _mock_infer(
|
| 935 |
self,
|
| 936 |
-
dataset: Union[List[Dict[str, Any]],
|
| 937 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 938 |
return [self.default_inference_value for _ in dataset]
|
| 939 |
|
| 940 |
def _infer(
|
| 941 |
self,
|
| 942 |
-
dataset: Union[List[Dict[str, Any]],
|
| 943 |
return_meta_data: bool = False,
|
| 944 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 945 |
return [
|
|
@@ -951,7 +960,7 @@ class MockInferenceEngine(InferenceEngine, LogProbInferenceEngine):
|
|
| 951 |
|
| 952 |
def _infer_log_probs(
|
| 953 |
self,
|
| 954 |
-
dataset: Union[List[Dict[str, Any]],
|
| 955 |
return_meta_data: bool = False,
|
| 956 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
| 957 |
return [
|
|
@@ -1047,14 +1056,14 @@ class GenericInferenceEngine(
|
|
| 1047 |
|
| 1048 |
def _infer(
|
| 1049 |
self,
|
| 1050 |
-
dataset: Union[List[Dict[str, Any]],
|
| 1051 |
return_meta_data: bool = False,
|
| 1052 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 1053 |
return self.engine._infer(dataset)
|
| 1054 |
|
| 1055 |
def _infer_log_probs(
|
| 1056 |
self,
|
| 1057 |
-
dataset: Union[List[Dict[str, Any]],
|
| 1058 |
return_meta_data: bool = False,
|
| 1059 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 1060 |
if not isinstance(self.engine, LogProbInferenceEngine):
|
|
@@ -1082,7 +1091,7 @@ class OllamaInferenceEngine(
|
|
| 1082 |
|
| 1083 |
def _infer(
|
| 1084 |
self,
|
| 1085 |
-
dataset: Union[List[Dict[str, Any]],
|
| 1086 |
return_meta_data: bool = False,
|
| 1087 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 1088 |
import ollama
|
|
@@ -1250,7 +1259,7 @@ class IbmGenAiInferenceEngine(
|
|
| 1250 |
|
| 1251 |
def _infer(
|
| 1252 |
self,
|
| 1253 |
-
dataset: Union[List[Dict[str, Any]],
|
| 1254 |
return_meta_data: bool = False,
|
| 1255 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 1256 |
from genai.schema import TextGenerationParameters, TextGenerationResult
|
|
@@ -1279,7 +1288,7 @@ class IbmGenAiInferenceEngine(
|
|
| 1279 |
|
| 1280 |
def _infer_log_probs(
|
| 1281 |
self,
|
| 1282 |
-
dataset: Union[List[Dict[str, Any]],
|
| 1283 |
return_meta_data: bool = False,
|
| 1284 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
| 1285 |
from genai.schema import TextGenerationParameters, TextGenerationResult
|
|
@@ -1507,7 +1516,7 @@ class OpenAiInferenceEngine(
|
|
| 1507 |
|
| 1508 |
def _infer(
|
| 1509 |
self,
|
| 1510 |
-
dataset: Union[List[Dict[str, Any]],
|
| 1511 |
return_meta_data: bool = False,
|
| 1512 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 1513 |
outputs = []
|
|
@@ -1527,22 +1536,14 @@ class OpenAiInferenceEngine(
|
|
| 1527 |
|
| 1528 |
def _infer_log_probs(
|
| 1529 |
self,
|
| 1530 |
-
dataset: Union[List[Dict[str, Any]],
|
| 1531 |
return_meta_data: bool = False,
|
| 1532 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
| 1533 |
outputs = []
|
| 1534 |
for instance in tqdm(dataset, desc="Inferring with openAI API"):
|
|
|
|
| 1535 |
response = self.client.chat.completions.create(
|
| 1536 |
-
messages=
|
| 1537 |
-
# {
|
| 1538 |
-
# "role": "system",
|
| 1539 |
-
# "content": self.system_prompt,
|
| 1540 |
-
# },
|
| 1541 |
-
{
|
| 1542 |
-
"role": "user",
|
| 1543 |
-
"content": instance["source"],
|
| 1544 |
-
}
|
| 1545 |
-
],
|
| 1546 |
model=self.model_name,
|
| 1547 |
**self._get_completion_kwargs(),
|
| 1548 |
)
|
|
@@ -1681,7 +1682,7 @@ class TogetherAiInferenceEngine(
|
|
| 1681 |
|
| 1682 |
def _infer(
|
| 1683 |
self,
|
| 1684 |
-
dataset: Union[List[Dict[str, Any]],
|
| 1685 |
return_meta_data: bool = False,
|
| 1686 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 1687 |
from together.types.models import ModelType
|
|
@@ -1943,7 +1944,7 @@ class WMLInferenceEngineBase(
|
|
| 1943 |
@abc.abstractmethod
|
| 1944 |
def _send_requests(
|
| 1945 |
self,
|
| 1946 |
-
dataset: Union[List[Dict[str, Any]],
|
| 1947 |
return_logprobs: bool,
|
| 1948 |
return_meta_data: bool,
|
| 1949 |
) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]:
|
|
@@ -1955,7 +1956,7 @@ class WMLInferenceEngineBase(
|
|
| 1955 |
|
| 1956 |
def _infer(
|
| 1957 |
self,
|
| 1958 |
-
dataset: Union[List[Dict[str, Any]],
|
| 1959 |
return_meta_data: bool = False,
|
| 1960 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 1961 |
if self._model is None:
|
|
@@ -1969,7 +1970,7 @@ class WMLInferenceEngineBase(
|
|
| 1969 |
|
| 1970 |
def _infer_log_probs(
|
| 1971 |
self,
|
| 1972 |
-
dataset: Union[List[Dict[str, Any]],
|
| 1973 |
return_meta_data: bool = False,
|
| 1974 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
| 1975 |
if self._model is None:
|
|
@@ -2050,27 +2051,29 @@ class WMLInferenceEngineGeneration(WMLInferenceEngineBase, WMLGenerationParamsMi
|
|
| 2050 |
|
| 2051 |
Attributes:
|
| 2052 |
concurrency_limit (int): Number of concurrent requests sent to a model. Default is 10,
|
| 2053 |
-
|
| 2054 |
|
| 2055 |
Examples:
|
| 2056 |
-
|
| 2057 |
|
| 2058 |
-
|
| 2059 |
-
"url": "some_url", "project_id": "some_id", "api_key": "some_key"
|
| 2060 |
-
}
|
| 2061 |
-
model_name = "google/flan-t5-xxl"
|
| 2062 |
-
wml_inference = WMLInferenceEngineGeneration(
|
| 2063 |
-
credentials=wml_credentials,
|
| 2064 |
-
model_name=model_name,
|
| 2065 |
-
data_classification_policy=["public"],
|
| 2066 |
-
top_p=0.5,
|
| 2067 |
-
random_seed=123,
|
| 2068 |
-
)
|
| 2069 |
|
| 2070 |
-
|
| 2071 |
-
|
| 2072 |
-
|
| 2073 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2074 |
"""
|
| 2075 |
|
| 2076 |
concurrency_limit: int = 10
|
|
@@ -2112,7 +2115,7 @@ class WMLInferenceEngineGeneration(WMLInferenceEngineBase, WMLGenerationParamsMi
|
|
| 2112 |
|
| 2113 |
def _send_requests(
|
| 2114 |
self,
|
| 2115 |
-
dataset: Union[List[Dict[str, Any]],
|
| 2116 |
return_logprobs: bool,
|
| 2117 |
return_meta_data: bool,
|
| 2118 |
) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]:
|
|
@@ -2178,31 +2181,33 @@ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin):
|
|
| 2178 |
|
| 2179 |
Attributes:
|
| 2180 |
image_encoder (EncodeImageToString, optional): operator which encodes images in
|
| 2181 |
-
|
| 2182 |
-
|
| 2183 |
|
| 2184 |
Example:
|
| 2185 |
-
|
| 2186 |
-
from .image_operators
|
| 2187 |
|
| 2188 |
-
|
|
|
|
| 2189 |
|
| 2190 |
-
|
| 2191 |
-
"url": "some_url", "project_id": "some_id", "api_key": "some_key"
|
| 2192 |
-
}
|
| 2193 |
-
model_name = "meta-llama/llama-3-2-11b-vision-instruct"
|
| 2194 |
-
wml_inference = WMLInferenceEngineChat(
|
| 2195 |
-
credentials=wml_credentials,
|
| 2196 |
-
model_name=model_name,
|
| 2197 |
-
image_encoder=image_encoder,
|
| 2198 |
-
data_classification_policy=["public"],
|
| 2199 |
-
max_tokens=1024,
|
| 2200 |
-
)
|
| 2201 |
|
| 2202 |
-
|
| 2203 |
-
|
| 2204 |
-
|
| 2205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2206 |
"""
|
| 2207 |
|
| 2208 |
image_encoder: Optional[EncodeImageToString] = None
|
|
@@ -2303,7 +2308,7 @@ class WMLInferenceEngineChat(WMLInferenceEngineBase, WMLChatParamsMixin):
|
|
| 2303 |
|
| 2304 |
def _send_requests(
|
| 2305 |
self,
|
| 2306 |
-
dataset: Union[List[Dict[str, Any]],
|
| 2307 |
return_logprobs: bool,
|
| 2308 |
return_meta_data: bool,
|
| 2309 |
) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]:
|
|
@@ -2428,7 +2433,7 @@ class LMMSEvalInferenceEngine(LMMSEvalBaseInferenceEngine):
|
|
| 2428 |
|
| 2429 |
def _infer(
|
| 2430 |
self,
|
| 2431 |
-
dataset: Union[List[Dict[str, Any]],
|
| 2432 |
return_meta_data: bool = False,
|
| 2433 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 2434 |
if not self._is_loaded():
|
|
@@ -2500,7 +2505,7 @@ class LMMSEvalLoglikelihoodInferenceEngine(LMMSEvalBaseInferenceEngine):
|
|
| 2500 |
|
| 2501 |
def _infer(
|
| 2502 |
self,
|
| 2503 |
-
dataset: Union[List[Dict[str, Any]],
|
| 2504 |
return_meta_data: bool = False,
|
| 2505 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 2506 |
if not self._is_loaded():
|
|
@@ -2555,7 +2560,7 @@ class VLLMInferenceEngine(
|
|
| 2555 |
|
| 2556 |
def _infer(
|
| 2557 |
self,
|
| 2558 |
-
dataset: Union[List[Dict[str, Any]],
|
| 2559 |
return_meta_data: bool = False,
|
| 2560 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 2561 |
inputs = []
|
|
@@ -2681,7 +2686,7 @@ class LiteLLMInferenceEngine(
|
|
| 2681 |
|
| 2682 |
def _infer(
|
| 2683 |
self,
|
| 2684 |
-
dataset: Union[List[Dict[str, Any]],
|
| 2685 |
return_meta_data: bool = False,
|
| 2686 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 2687 |
"""Main inference entry point."""
|
|
@@ -2735,8 +2740,8 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
|
|
| 2735 |
"granite-3-8b-instruct": "ibm/granite-3-8b-instruct",
|
| 2736 |
},
|
| 2737 |
"together-ai": {
|
| 2738 |
-
"llama-3-8b-instruct": "together_ai/
|
| 2739 |
-
"llama-3-70b-instruct": "together_ai/
|
| 2740 |
"llama-3-2-1b-instruct": "together_ai/togethercomputer/llama-3-2-1b-instruct",
|
| 2741 |
},
|
| 2742 |
"aws": {
|
|
@@ -2812,7 +2817,7 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
|
|
| 2812 |
|
| 2813 |
def _infer(
|
| 2814 |
self,
|
| 2815 |
-
dataset: Union[List[Dict[str, Any]],
|
| 2816 |
return_meta_data: bool = False,
|
| 2817 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 2818 |
return self.engine._infer(dataset, return_meta_data)
|
|
@@ -2898,7 +2903,7 @@ class HFOptionSelectingInferenceEngine(InferenceEngine):
|
|
| 2898 |
|
| 2899 |
def _infer(
|
| 2900 |
self,
|
| 2901 |
-
dataset: Union[List[Dict[str, Any]],
|
| 2902 |
return_meta_data: bool = False,
|
| 2903 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 2904 |
inputs = []
|
|
|
|
| 23 |
Union,
|
| 24 |
)
|
| 25 |
|
| 26 |
+
from datasets import Dataset, DatasetDict
|
| 27 |
from tqdm import tqdm, trange
|
| 28 |
from tqdm.asyncio import tqdm_asyncio
|
| 29 |
|
|
|
|
| 70 |
"""Contains the prediction results and metadata for the inference.
|
| 71 |
|
| 72 |
Args:
|
| 73 |
+
prediction (Union[str, List[Dict[str, Any]]]): If this is the result of an _infer call, the string predicted by the model.
|
| 74 |
+
| If this is the results of an _infer_log_probs call, a list of dictionaries. The i'th dictionary represents
|
| 75 |
+
the i'th token in the response. The entry "top_tokens" in the dictionary holds a sorted list of the top tokens
|
| 76 |
+
for this position and their probabilities.
|
| 77 |
+
| For example: ``[ {.. "top_tokens": [ {"text": "a", 'logprob': }, {"text": "b", 'logprob': } ....]},
|
| 78 |
+
{.. "top_tokens": [ {"text": "c", 'logprob': }, {"text": "d", 'logprob': } ....]} ]``
|
| 79 |
+
|
| 80 |
+
input_tokens (int) : number of input tokens to the model.
|
| 81 |
+
|
| 82 |
+
output_tokens (int) : number of output tokens to the model.
|
| 83 |
+
|
| 84 |
+
stop_reason (str): stop reason for text generation, for example "eos" (end of string).
|
| 85 |
+
|
| 86 |
+
seed (int): seed used by the model during generation.
|
| 87 |
+
|
| 88 |
+
input_text (str): input to the model.
|
| 89 |
+
|
| 90 |
+
model_name (str): the model_name as kept in the InferenceEngine.
|
| 91 |
|
| 92 |
+
inference_type (str): The label stating the type of the InferenceEngine.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
"""
|
| 94 |
|
| 95 |
prediction: Union[str, List[Dict[str, Any]]]
|
|
|
|
| 108 |
@abc.abstractmethod
|
| 109 |
def _infer(
|
| 110 |
self,
|
| 111 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 112 |
return_meta_data: bool = False,
|
| 113 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 114 |
"""Perform inference on the input dataset.
|
|
|
|
| 131 |
|
| 132 |
def infer(
|
| 133 |
self,
|
| 134 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 135 |
return_meta_data: bool = False,
|
| 136 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 137 |
"""Verifies instances of a dataset and perform inference on the input dataset.
|
|
|
|
| 139 |
If return_meta_data - returns a list of TextGenerationInferenceOutput, else returns a list of the string
|
| 140 |
predictions.
|
| 141 |
"""
|
| 142 |
+
if not isoftype(dataset, Union[List[Dict[str, Any]], Dataset]):
|
| 143 |
+
raise Exception(
|
| 144 |
+
"Dataset passed to infer() is not list of dictionaries or Huggingface Dataset"
|
| 145 |
+
)
|
| 146 |
if return_meta_data and not hasattr(self, "get_return_object"):
|
| 147 |
raise NotImplementedError(
|
| 148 |
f"Inference engine {self.__class__.__name__} does not support return_meta_data as it "
|
|
|
|
| 156 |
|
| 157 |
def _mock_infer(
|
| 158 |
self,
|
| 159 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 160 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 161 |
return [str(instance["source"]) for instance in dataset]
|
| 162 |
|
|
|
|
| 207 |
@abc.abstractmethod
|
| 208 |
def _infer_log_probs(
|
| 209 |
self,
|
| 210 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 211 |
return_meta_data: bool = False,
|
| 212 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
| 213 |
"""Perform inference on the input dataset that returns log probs.
|
|
|
|
| 220 |
|
| 221 |
def infer_log_probs(
|
| 222 |
self,
|
| 223 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 224 |
return_meta_data: bool = False,
|
| 225 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
| 226 |
"""Verifies instances of a dataset and performs inference that returns log probabilities of top tokens.
|
|
|
|
| 455 |
|
| 456 |
def infer(
|
| 457 |
self,
|
| 458 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 459 |
return_meta_data: bool = False,
|
| 460 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 461 |
if not self._is_loaded():
|
|
|
|
| 465 |
@abc.abstractmethod
|
| 466 |
def _infer(
|
| 467 |
self,
|
| 468 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 469 |
return_meta_data: bool = False,
|
| 470 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 471 |
raise NotImplementedError
|
| 472 |
|
| 473 |
def infer_log_probs(
|
| 474 |
self,
|
| 475 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 476 |
return_meta_data: bool = False,
|
| 477 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
| 478 |
if not self._is_loaded():
|
|
|
|
| 482 |
@abc.abstractmethod
|
| 483 |
def _infer_log_probs(
|
| 484 |
self,
|
| 485 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 486 |
return_meta_data: bool = False,
|
| 487 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
| 488 |
raise NotImplementedError
|
|
|
|
| 533 |
|
| 534 |
def _infer_fn(
|
| 535 |
self,
|
| 536 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 537 |
return_meta_data: bool,
|
| 538 |
return_logprobs: bool,
|
| 539 |
) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]:
|
|
|
|
| 574 |
|
| 575 |
def _infer(
|
| 576 |
self,
|
| 577 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 578 |
return_meta_data: bool = False,
|
| 579 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 580 |
self.verify_not_chat_api(dataset)
|
|
|
|
| 582 |
|
| 583 |
def _infer_log_probs(
|
| 584 |
self,
|
| 585 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 586 |
return_meta_data: bool = False,
|
| 587 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
| 588 |
self.verify_not_chat_api(dataset)
|
|
|
|
| 656 |
|
| 657 |
def _infer_fn(
|
| 658 |
self,
|
| 659 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 660 |
return_meta_data: bool,
|
| 661 |
return_logprobs: bool,
|
| 662 |
) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]:
|
|
|
|
| 690 |
|
| 691 |
def _infer(
|
| 692 |
self,
|
| 693 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 694 |
return_meta_data: bool = False,
|
| 695 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 696 |
return self._infer_fn(dataset, return_meta_data, False)
|
| 697 |
|
| 698 |
def _infer_log_probs(
|
| 699 |
self,
|
| 700 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 701 |
return_meta_data: bool = False,
|
| 702 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
| 703 |
return self._infer_fn(dataset, return_meta_data, True)
|
|
|
|
| 888 |
|
| 889 |
def _infer(
|
| 890 |
self,
|
| 891 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 892 |
return_meta_data: bool = False,
|
| 893 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 894 |
if not self._is_loaded():
|
|
|
|
| 942 |
|
| 943 |
def _mock_infer(
|
| 944 |
self,
|
| 945 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 946 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 947 |
return [self.default_inference_value for _ in dataset]
|
| 948 |
|
| 949 |
def _infer(
|
| 950 |
self,
|
| 951 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 952 |
return_meta_data: bool = False,
|
| 953 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 954 |
return [
|
|
|
|
| 960 |
|
| 961 |
def _infer_log_probs(
|
| 962 |
self,
|
| 963 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 964 |
return_meta_data: bool = False,
|
| 965 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
| 966 |
return [
|
|
|
|
| 1056 |
|
| 1057 |
def _infer(
|
| 1058 |
self,
|
| 1059 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 1060 |
return_meta_data: bool = False,
|
| 1061 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 1062 |
return self.engine._infer(dataset)
|
| 1063 |
|
| 1064 |
def _infer_log_probs(
|
| 1065 |
self,
|
| 1066 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 1067 |
return_meta_data: bool = False,
|
| 1068 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 1069 |
if not isinstance(self.engine, LogProbInferenceEngine):
|
|
|
|
| 1091 |
|
| 1092 |
def _infer(
|
| 1093 |
self,
|
| 1094 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 1095 |
return_meta_data: bool = False,
|
| 1096 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 1097 |
import ollama
|
|
|
|
| 1259 |
|
| 1260 |
def _infer(
|
| 1261 |
self,
|
| 1262 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 1263 |
return_meta_data: bool = False,
|
| 1264 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 1265 |
from genai.schema import TextGenerationParameters, TextGenerationResult
|
|
|
|
| 1288 |
|
| 1289 |
def _infer_log_probs(
|
| 1290 |
self,
|
| 1291 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 1292 |
return_meta_data: bool = False,
|
| 1293 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
| 1294 |
from genai.schema import TextGenerationParameters, TextGenerationResult
|
|
|
|
| 1516 |
|
| 1517 |
def _infer(
|
| 1518 |
self,
|
| 1519 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 1520 |
return_meta_data: bool = False,
|
| 1521 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 1522 |
outputs = []
|
|
|
|
| 1536 |
|
| 1537 |
def _infer_log_probs(
|
| 1538 |
self,
|
| 1539 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 1540 |
return_meta_data: bool = False,
|
| 1541 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
| 1542 |
outputs = []
|
| 1543 |
for instance in tqdm(dataset, desc="Inferring with openAI API"):
|
| 1544 |
+
messages = self.to_messages(instance)
|
| 1545 |
response = self.client.chat.completions.create(
|
| 1546 |
+
messages=messages,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1547 |
model=self.model_name,
|
| 1548 |
**self._get_completion_kwargs(),
|
| 1549 |
)
|
|
|
|
| 1682 |
|
| 1683 |
def _infer(
|
| 1684 |
self,
|
| 1685 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 1686 |
return_meta_data: bool = False,
|
| 1687 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 1688 |
from together.types.models import ModelType
|
|
|
|
| 1944 |
@abc.abstractmethod
|
| 1945 |
def _send_requests(
|
| 1946 |
self,
|
| 1947 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 1948 |
return_logprobs: bool,
|
| 1949 |
return_meta_data: bool,
|
| 1950 |
) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]:
|
|
|
|
| 1956 |
|
| 1957 |
def _infer(
|
| 1958 |
self,
|
| 1959 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 1960 |
return_meta_data: bool = False,
|
| 1961 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 1962 |
if self._model is None:
|
|
|
|
| 1970 |
|
| 1971 |
def _infer_log_probs(
|
| 1972 |
self,
|
| 1973 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 1974 |
return_meta_data: bool = False,
|
| 1975 |
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
| 1976 |
if self._model is None:
|
|
|
|
| 2051 |
|
| 2052 |
Attributes:
|
| 2053 |
concurrency_limit (int): Number of concurrent requests sent to a model. Default is 10,
|
| 2054 |
+
which is also the maximum value.
|
| 2055 |
|
| 2056 |
Examples:
|
| 2057 |
+
.. code-block:: python
|
| 2058 |
|
| 2059 |
+
from .api import load_dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2060 |
|
| 2061 |
+
wml_credentials = {
|
| 2062 |
+
"url": "some_url", "project_id": "some_id", "api_key": "some_key"
|
| 2063 |
+
}
|
| 2064 |
+
model_name = "google/flan-t5-xxl"
|
| 2065 |
+
wml_inference = WMLInferenceEngineGeneration(
|
| 2066 |
+
credentials=wml_credentials,
|
| 2067 |
+
model_name=model_name,
|
| 2068 |
+
data_classification_policy=["public"],
|
| 2069 |
+
top_p=0.5,
|
| 2070 |
+
random_seed=123,
|
| 2071 |
+
)
|
| 2072 |
+
|
| 2073 |
+
dataset = load_dataset(
|
| 2074 |
+
dataset_query="card=cards.argument_topic,template_card_index=0,loader_limit=5"
|
| 2075 |
+
)
|
| 2076 |
+
results = wml_inference.infer(dataset["test"])
|
| 2077 |
"""
|
| 2078 |
|
| 2079 |
concurrency_limit: int = 10
|
|
|
|
| 2115 |
|
| 2116 |
def _send_requests(
|
| 2117 |
self,
|
| 2118 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 2119 |
return_logprobs: bool,
|
| 2120 |
return_meta_data: bool,
|
| 2121 |
) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]:
|
|
|
|
| 2181 |
|
| 2182 |
Attributes:
|
| 2183 |
image_encoder (EncodeImageToString, optional): operator which encodes images in
|
| 2184 |
+
given format to base64 strings required by service. You should specify it when
|
| 2185 |
+
you are using images in your inputs.
|
| 2186 |
|
| 2187 |
Example:
|
| 2188 |
+
.. code-block:: python
|
|
|
|
| 2189 |
|
| 2190 |
+
from .api import load_dataset
|
| 2191 |
+
from .image_operators
|
| 2192 |
|
| 2193 |
+
image_encoder = EncodeImageToString(image_format="JPEG")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2194 |
|
| 2195 |
+
wml_credentials = {
|
| 2196 |
+
"url": "some_url", "project_id": "some_id", "api_key": "some_key"
|
| 2197 |
+
}
|
| 2198 |
+
model_name = "meta-llama/llama-3-2-11b-vision-instruct"
|
| 2199 |
+
wml_inference = WMLInferenceEngineChat(
|
| 2200 |
+
credentials=wml_credentials,
|
| 2201 |
+
model_name=model_name,
|
| 2202 |
+
image_encoder=image_encoder,
|
| 2203 |
+
data_classification_policy=["public"],
|
| 2204 |
+
max_tokens=1024,
|
| 2205 |
+
)
|
| 2206 |
+
|
| 2207 |
+
dataset = load_dataset(
|
| 2208 |
+
dataset_query="card=cards.doc_vqa.en,template=templates.qa.with_context.with_type,loader_limit=30"
|
| 2209 |
+
)
|
| 2210 |
+
results = wml_inference.infer(dataset["test"])
|
| 2211 |
"""
|
| 2212 |
|
| 2213 |
image_encoder: Optional[EncodeImageToString] = None
|
|
|
|
| 2308 |
|
| 2309 |
def _send_requests(
|
| 2310 |
self,
|
| 2311 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 2312 |
return_logprobs: bool,
|
| 2313 |
return_meta_data: bool,
|
| 2314 |
) -> Union[List[str], List[Dict], List[TextGenerationInferenceOutput]]:
|
|
|
|
| 2433 |
|
| 2434 |
def _infer(
|
| 2435 |
self,
|
| 2436 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 2437 |
return_meta_data: bool = False,
|
| 2438 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 2439 |
if not self._is_loaded():
|
|
|
|
| 2505 |
|
| 2506 |
def _infer(
|
| 2507 |
self,
|
| 2508 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 2509 |
return_meta_data: bool = False,
|
| 2510 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 2511 |
if not self._is_loaded():
|
|
|
|
| 2560 |
|
| 2561 |
def _infer(
|
| 2562 |
self,
|
| 2563 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 2564 |
return_meta_data: bool = False,
|
| 2565 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 2566 |
inputs = []
|
|
|
|
| 2686 |
|
| 2687 |
def _infer(
|
| 2688 |
self,
|
| 2689 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 2690 |
return_meta_data: bool = False,
|
| 2691 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 2692 |
"""Main inference entry point."""
|
|
|
|
| 2740 |
"granite-3-8b-instruct": "ibm/granite-3-8b-instruct",
|
| 2741 |
},
|
| 2742 |
"together-ai": {
|
| 2743 |
+
"llama-3-8b-instruct": "together_ai/meta-llama/Llama-3-8b-chat-hf",
|
| 2744 |
+
"llama-3-70b-instruct": "together_ai/meta-llama/Llama-3-70b-chat-hf",
|
| 2745 |
"llama-3-2-1b-instruct": "together_ai/togethercomputer/llama-3-2-1b-instruct",
|
| 2746 |
},
|
| 2747 |
"aws": {
|
|
|
|
| 2817 |
|
| 2818 |
def _infer(
|
| 2819 |
self,
|
| 2820 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 2821 |
return_meta_data: bool = False,
|
| 2822 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 2823 |
return self.engine._infer(dataset, return_meta_data)
|
|
|
|
| 2903 |
|
| 2904 |
def _infer(
|
| 2905 |
self,
|
| 2906 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
| 2907 |
return_meta_data: bool = False,
|
| 2908 |
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
| 2909 |
inputs = []
|
llm_as_judge.py
CHANGED
|
@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Literal, Optional
|
|
| 4 |
|
| 5 |
from .api import infer
|
| 6 |
from .dataclass import Field
|
| 7 |
-
from .formats import Format, SystemFormat
|
| 8 |
from .inference import InferenceEngine, LogProbInferenceEngine, OpenAiInferenceEngine
|
| 9 |
from .metrics import BulkInstanceMetric
|
| 10 |
from .operator import SequentialOperator
|
|
@@ -65,12 +65,17 @@ class LLMAsJudgeBase(BulkInstanceMetric, ArtifactFetcherMixin):
|
|
| 65 |
)
|
| 66 |
|
| 67 |
if isinstance(self.inference_model, OpenAiInferenceEngine):
|
| 68 |
-
if self.format and type(self.format) is not
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
if self.system_prompt and type(self.system_prompt) is not EmptySystemPrompt:
|
| 75 |
raise ValueError(
|
| 76 |
"Error in 'LLMAsJudge' metric. Inference model 'OpenAiInferenceEngine' does "
|
|
@@ -132,16 +137,24 @@ class LLMAsJudge(LLMAsJudgeBase):
|
|
| 132 |
|
| 133 |
Attributes:
|
| 134 |
main_score (str): The main score label used for evaluation.
|
|
|
|
| 135 |
task (Literal["rating.single_turn","rating.single_turn_with_reference",
|
| 136 |
"pairwise_comparative_rating.single_turn"]): The type of task the llm as judge runs.
|
| 137 |
-
|
|
|
|
| 138 |
template (Template): The template used when generating inputs for the judge llm.
|
|
|
|
| 139 |
format (Format): The format used when generating inputs for judge llm.
|
|
|
|
| 140 |
system_prompt (SystemPrompt): The system prompt used when generating inputs for judge llm.
|
|
|
|
| 141 |
strip_system_prompt_and_format_from_inputs (bool): Whether to strip the system prompt and formatting from the
|
| 142 |
-
|
|
|
|
| 143 |
inference_model (InferenceEngine): The module that creates the inference of the judge llm.
|
|
|
|
| 144 |
reduction_map (dict): A dictionary specifying the reduction method for the metric.
|
|
|
|
| 145 |
batch_size (int): The size of the bulk.
|
| 146 |
"""
|
| 147 |
|
|
@@ -318,22 +331,34 @@ class TaskBasedLLMasJudge(LLMAsJudgeBase):
|
|
| 318 |
|
| 319 |
Attributes:
|
| 320 |
main_score (str): The main score label used for evaluation.
|
|
|
|
| 321 |
task (str): The type of task the llm as judge runs.
|
| 322 |
This defines the output and input format of the judge model.
|
|
|
|
| 323 |
template (Template): The template used when generating inputs for the judge llm.
|
|
|
|
| 324 |
format (Format): The format used when generating inputs for judge llm.
|
|
|
|
| 325 |
system_prompt (SystemPrompt): The system prompt used when generating inputs for judge llm.
|
|
|
|
| 326 |
strip_system_prompt_and_format_from_inputs (bool): Whether to strip the system prompt and formatting from the
|
| 327 |
-
|
|
|
|
| 328 |
inference_model (InferenceEngine): The module that creates the inference of the judge llm.
|
|
|
|
| 329 |
reduction_map (dict): A dictionary specifying the reduction method for the metric.
|
|
|
|
| 330 |
batch_size (int): The size of the bulk.
|
|
|
|
| 331 |
infer_log_probs(bool): whether to perform the inference using logprobs. If true, the template's
|
| 332 |
post-processing must support the logprobs output.
|
|
|
|
| 333 |
judge_to_generator_fields_mapping (Dict[str, str]): optional mapping between the names of the fields in the generator task and the
|
| 334 |
judge task. For example, if the generator task uses "reference_answers" and the judge task expect "ground_truth",
|
| 335 |
include {"ground_truth": "reference_answers"} in this dictionary.
|
| 336 |
-
|
|
|
|
|
|
|
| 337 |
include_meta_data (bool): whether to include the inference per-instance metadata in the returned results.
|
| 338 |
|
| 339 |
"""
|
|
@@ -384,7 +409,10 @@ class TaskBasedLLMasJudge(LLMAsJudgeBase):
|
|
| 384 |
# if format is not directly set in constructor, choose according to the inference model
|
| 385 |
def set_format_for_inference_engine(self):
|
| 386 |
model_name = self.inference_model.get_engine_id()
|
| 387 |
-
|
|
|
|
|
|
|
|
|
|
| 388 |
format_name = "formats.llama3_instruct"
|
| 389 |
else:
|
| 390 |
format_name = "formats.empty"
|
|
|
|
| 4 |
|
| 5 |
from .api import infer
|
| 6 |
from .dataclass import Field
|
| 7 |
+
from .formats import ChatAPIFormat, Format, SystemFormat
|
| 8 |
from .inference import InferenceEngine, LogProbInferenceEngine, OpenAiInferenceEngine
|
| 9 |
from .metrics import BulkInstanceMetric
|
| 10 |
from .operator import SequentialOperator
|
|
|
|
| 65 |
)
|
| 66 |
|
| 67 |
if isinstance(self.inference_model, OpenAiInferenceEngine):
|
| 68 |
+
if self.format and type(self.format) is not ChatAPIFormat:
|
| 69 |
+
if not (
|
| 70 |
+
type(self.format) is SystemFormat
|
| 71 |
+
and self.format.__id__ == "formats.empty"
|
| 72 |
+
):
|
| 73 |
+
raise ValueError(
|
| 74 |
+
"Error in 'LLMAsJudge' metric. Inference model 'OpenAiInferenceEngine' does "
|
| 75 |
+
"not support formatting. Please remove the format definition from the recipe,"
|
| 76 |
+
"or set the format to either 'formats.empty' or 'formats.chat_api'"
|
| 77 |
+
" (OpenAi Chat API take care of the formatting automatically)."
|
| 78 |
+
)
|
| 79 |
if self.system_prompt and type(self.system_prompt) is not EmptySystemPrompt:
|
| 80 |
raise ValueError(
|
| 81 |
"Error in 'LLMAsJudge' metric. Inference model 'OpenAiInferenceEngine' does "
|
|
|
|
| 137 |
|
| 138 |
Attributes:
|
| 139 |
main_score (str): The main score label used for evaluation.
|
| 140 |
+
|
| 141 |
task (Literal["rating.single_turn","rating.single_turn_with_reference",
|
| 142 |
"pairwise_comparative_rating.single_turn"]): The type of task the llm as judge runs.
|
| 143 |
+
This defines the output and input format of the judge model.
|
| 144 |
+
|
| 145 |
template (Template): The template used when generating inputs for the judge llm.
|
| 146 |
+
|
| 147 |
format (Format): The format used when generating inputs for judge llm.
|
| 148 |
+
|
| 149 |
system_prompt (SystemPrompt): The system prompt used when generating inputs for judge llm.
|
| 150 |
+
|
| 151 |
strip_system_prompt_and_format_from_inputs (bool): Whether to strip the system prompt and formatting from the
|
| 152 |
+
inputs that the models that is being judges received, when they are inserted to the llm-as-judge prompt.
|
| 153 |
+
|
| 154 |
inference_model (InferenceEngine): The module that creates the inference of the judge llm.
|
| 155 |
+
|
| 156 |
reduction_map (dict): A dictionary specifying the reduction method for the metric.
|
| 157 |
+
|
| 158 |
batch_size (int): The size of the bulk.
|
| 159 |
"""
|
| 160 |
|
|
|
|
| 331 |
|
| 332 |
Attributes:
|
| 333 |
main_score (str): The main score label used for evaluation.
|
| 334 |
+
|
| 335 |
task (str): The type of task the llm as judge runs.
|
| 336 |
This defines the output and input format of the judge model.
|
| 337 |
+
|
| 338 |
template (Template): The template used when generating inputs for the judge llm.
|
| 339 |
+
|
| 340 |
format (Format): The format used when generating inputs for judge llm.
|
| 341 |
+
|
| 342 |
system_prompt (SystemPrompt): The system prompt used when generating inputs for judge llm.
|
| 343 |
+
|
| 344 |
strip_system_prompt_and_format_from_inputs (bool): Whether to strip the system prompt and formatting from the
|
| 345 |
+
inputs that the models that is being judges received, when they are inserted to the llm-as-judge prompt.
|
| 346 |
+
|
| 347 |
inference_model (InferenceEngine): The module that creates the inference of the judge llm.
|
| 348 |
+
|
| 349 |
reduction_map (dict): A dictionary specifying the reduction method for the metric.
|
| 350 |
+
|
| 351 |
batch_size (int): The size of the bulk.
|
| 352 |
+
|
| 353 |
infer_log_probs(bool): whether to perform the inference using logprobs. If true, the template's
|
| 354 |
post-processing must support the logprobs output.
|
| 355 |
+
|
| 356 |
judge_to_generator_fields_mapping (Dict[str, str]): optional mapping between the names of the fields in the generator task and the
|
| 357 |
judge task. For example, if the generator task uses "reference_answers" and the judge task expect "ground_truth",
|
| 358 |
include {"ground_truth": "reference_answers"} in this dictionary.
|
| 359 |
+
|
| 360 |
+
prediction_field (str): if indicated, and prediction exist, copy prediction to this field name in task_data.
|
| 361 |
+
|
| 362 |
include_meta_data (bool): whether to include the inference per-instance metadata in the returned results.
|
| 363 |
|
| 364 |
"""
|
|
|
|
| 409 |
# if format is not directly set in constructor, choose according to the inference model
|
| 410 |
def set_format_for_inference_engine(self):
|
| 411 |
model_name = self.inference_model.get_engine_id()
|
| 412 |
+
# TODO : better format resolution to support more chat_api options
|
| 413 |
+
if "rits" in model_name:
|
| 414 |
+
format_name = "formats.chat_api"
|
| 415 |
+
elif re.search("llama.?3.*instruct", model_name):
|
| 416 |
format_name = "formats.llama3_instruct"
|
| 417 |
else:
|
| 418 |
format_name = "formats.empty"
|
loaders.py
CHANGED
|
@@ -162,14 +162,22 @@ class LoadHF(Loader):
|
|
| 162 |
|
| 163 |
Args:
|
| 164 |
path: The path or identifier of the dataset on the HuggingFace Hub.
|
|
|
|
| 165 |
name: An optional dataset name.
|
|
|
|
| 166 |
data_dir: Optional directory to store downloaded data.
|
|
|
|
| 167 |
split: Optional specification of which split to load.
|
|
|
|
| 168 |
data_files: Optional specification of particular data files to load.
|
|
|
|
| 169 |
revision: Optional. The revision of the dataset. Often the commit id. Use in case you want to set the dataset version.
|
| 170 |
-
|
|
|
|
|
|
|
| 171 |
filtering_lambda: A lambda function for filtering the data after loading.
|
| 172 |
-
|
|
|
|
| 173 |
|
| 174 |
Example:
|
| 175 |
Loading glue's mrpc dataset
|
|
@@ -355,7 +363,9 @@ class LoadCSV(Loader):
|
|
| 355 |
file_path, nrows=self.get_limit(), sep=self.sep
|
| 356 |
).to_dict("records")
|
| 357 |
else:
|
| 358 |
-
iterables[split_name] = pd.read_csv(file_path).to_dict(
|
|
|
|
|
|
|
| 359 |
return iterables
|
| 360 |
|
| 361 |
|
|
@@ -733,19 +743,24 @@ class LoadFromHFSpace(LoadHF):
|
|
| 733 |
|
| 734 |
Args:
|
| 735 |
space_name (str): Name of the HuggingFace Space to be accessed.
|
|
|
|
| 736 |
data_files (str | Sequence[str] | Mapping[str, str | Sequence[str]]): Relative
|
| 737 |
-
|
| 738 |
-
|
| 739 |
-
|
|
|
|
| 740 |
path (str, optional): Absolute path to a directory where data should be downloaded.
|
|
|
|
| 741 |
revision (str, optional): ID of a Git branch or commit to be used. By default, it is
|
| 742 |
-
|
| 743 |
-
|
|
|
|
| 744 |
use_token (bool, optional): Whether a token is used for authentication when accessing
|
| 745 |
-
|
| 746 |
-
|
|
|
|
| 747 |
token_env (str, optional): Key of an env variable which value will be used for
|
| 748 |
-
|
| 749 |
|
| 750 |
Example:
|
| 751 |
Loading from a HuggingFace Space
|
|
|
|
| 162 |
|
| 163 |
Args:
|
| 164 |
path: The path or identifier of the dataset on the HuggingFace Hub.
|
| 165 |
+
|
| 166 |
name: An optional dataset name.
|
| 167 |
+
|
| 168 |
data_dir: Optional directory to store downloaded data.
|
| 169 |
+
|
| 170 |
split: Optional specification of which split to load.
|
| 171 |
+
|
| 172 |
data_files: Optional specification of particular data files to load.
|
| 173 |
+
|
| 174 |
revision: Optional. The revision of the dataset. Often the commit id. Use in case you want to set the dataset version.
|
| 175 |
+
|
| 176 |
+
streaming (bool): indicating if streaming should be used.
|
| 177 |
+
|
| 178 |
filtering_lambda: A lambda function for filtering the data after loading.
|
| 179 |
+
|
| 180 |
+
num_proc (int): Optional integer to specify the number of processes to use for parallel dataset loading.
|
| 181 |
|
| 182 |
Example:
|
| 183 |
Loading glue's mrpc dataset
|
|
|
|
| 363 |
file_path, nrows=self.get_limit(), sep=self.sep
|
| 364 |
).to_dict("records")
|
| 365 |
else:
|
| 366 |
+
iterables[split_name] = pd.read_csv(file_path, sep=self.sep).to_dict(
|
| 367 |
+
"records"
|
| 368 |
+
)
|
| 369 |
return iterables
|
| 370 |
|
| 371 |
|
|
|
|
| 743 |
|
| 744 |
Args:
|
| 745 |
space_name (str): Name of the HuggingFace Space to be accessed.
|
| 746 |
+
|
| 747 |
data_files (str | Sequence[str] | Mapping[str, str | Sequence[str]]): Relative
|
| 748 |
+
paths to files within a given repository. If given as a mapping, paths should
|
| 749 |
+
be values, while keys should represent the type of respective files
|
| 750 |
+
(training, testing etc.).
|
| 751 |
+
|
| 752 |
path (str, optional): Absolute path to a directory where data should be downloaded.
|
| 753 |
+
|
| 754 |
revision (str, optional): ID of a Git branch or commit to be used. By default, it is
|
| 755 |
+
set to None, thus data is downloaded from the main branch of the accessed
|
| 756 |
+
repository.
|
| 757 |
+
|
| 758 |
use_token (bool, optional): Whether a token is used for authentication when accessing
|
| 759 |
+
the HuggingFace Space. If necessary, the token is read from the HuggingFace
|
| 760 |
+
config folder.
|
| 761 |
+
|
| 762 |
token_env (str, optional): Key of an env variable which value will be used for
|
| 763 |
+
authentication when accessing the HuggingFace Space - if necessary.
|
| 764 |
|
| 765 |
Example:
|
| 766 |
Loading from a HuggingFace Space
|
metrics.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import ast
|
| 2 |
import json
|
|
|
|
| 3 |
import os
|
| 4 |
import re
|
| 5 |
import string
|
|
@@ -27,7 +28,11 @@ from .dataclass import (
|
|
| 27 |
)
|
| 28 |
from .deprecation_utils import deprecation
|
| 29 |
from .error_utils import Documentation, UnitxtWarning
|
| 30 |
-
from .inference import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
from .logging_utils import get_logger
|
| 32 |
from .metric_utils import InstanceInput, MetricRequest, MetricResponse
|
| 33 |
from .operator import (
|
|
@@ -960,11 +965,13 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
|
|
| 960 |
"""Class for metrics for which a global score can be calculated by aggregating the instance scores (possibly with additional instance inputs).
|
| 961 |
|
| 962 |
InstanceMetric currently allows two reductions:
|
|
|
|
| 963 |
1. 'mean', which calculates the mean of instance scores,
|
| 964 |
2. 'group_mean', which first applies an aggregation function specified in the reduction_map
|
| 965 |
-
|
| 966 |
-
|
| 967 |
-
|
|
|
|
| 968 |
"""
|
| 969 |
|
| 970 |
n_resamples: int = OptionalField(
|
|
@@ -1489,13 +1496,17 @@ class StringContainmentRatio(InstanceMetric):
|
|
| 1489 |
|
| 1490 |
Attributes:
|
| 1491 |
field: The field from the task_data that contains the values to be checked for containment.
|
| 1492 |
-
|
| 1493 |
-
|
| 1494 |
-
|
| 1495 |
-
|
| 1496 |
-
|
| 1497 |
-
|
| 1498 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1499 |
"""
|
| 1500 |
|
| 1501 |
reduction_map = {"mean": ["string_containment"]}
|
|
@@ -2776,8 +2787,8 @@ class BertScore(HuggingfaceBulkMetric):
|
|
| 2776 |
|
| 2777 |
|
| 2778 |
class SentenceBert(BulkInstanceMetric):
|
| 2779 |
-
|
| 2780 |
-
|
| 2781 |
batch_size: int = 32
|
| 2782 |
|
| 2783 |
model_name: str
|
|
@@ -2823,12 +2834,12 @@ class SentenceBert(BulkInstanceMetric):
|
|
| 2823 |
refs_group_emb = refs_emb[ref_group_bounds[0] : ref_group_bounds[1]]
|
| 2824 |
scores.append(self.util.cos_sim(pred_emb, refs_group_emb).max().item())
|
| 2825 |
|
| 2826 |
-
return [{
|
| 2827 |
|
| 2828 |
|
| 2829 |
class Reward(BulkInstanceMetric):
|
| 2830 |
-
|
| 2831 |
-
|
| 2832 |
batch_size: int = 32
|
| 2833 |
|
| 2834 |
model_name: str
|
|
@@ -2864,12 +2875,15 @@ class Reward(BulkInstanceMetric):
|
|
| 2864 |
|
| 2865 |
# compute the metric
|
| 2866 |
# add function_to_apply="none" to disable sigmoid
|
| 2867 |
-
|
|
|
|
|
|
|
|
|
|
| 2868 |
|
| 2869 |
|
| 2870 |
class Detector(BulkInstanceMetric):
|
| 2871 |
-
|
| 2872 |
-
|
| 2873 |
batch_size: int = 32
|
| 2874 |
|
| 2875 |
prediction_type = str
|
|
@@ -2896,7 +2910,10 @@ class Detector(BulkInstanceMetric):
|
|
| 2896 |
) -> List[Dict[str, Any]]:
|
| 2897 |
# compute the metric
|
| 2898 |
# add function_to_apply="none" to disable sigmoid
|
| 2899 |
-
|
|
|
|
|
|
|
|
|
|
| 2900 |
|
| 2901 |
|
| 2902 |
class RegardMetric(GlobalMetric):
|
|
@@ -3537,13 +3554,13 @@ class Perplexity(BulkInstanceMetric):
|
|
| 3537 |
|
| 3538 |
|
| 3539 |
class FaithfulnessHHEM(BulkInstanceMetric):
|
| 3540 |
-
|
| 3541 |
-
main_score = "score"
|
| 3542 |
batch_size: int = 2
|
| 3543 |
model_name: str = "vectara/hallucination_evaluation_model"
|
| 3544 |
prediction_type = str
|
| 3545 |
single_reference_per_prediction = True
|
| 3546 |
max_context_words = 4096
|
|
|
|
| 3547 |
|
| 3548 |
_requirements_list: List[str] = ["transformers", "torch"]
|
| 3549 |
|
|
@@ -3587,7 +3604,7 @@ class FaithfulnessHHEM(BulkInstanceMetric):
|
|
| 3587 |
for input_batch in tqdm(input_batches, "input batch"):
|
| 3588 |
batch_scores = self.model.predict(input_batch).cpu().tolist()
|
| 3589 |
scores.extend(batch_scores)
|
| 3590 |
-
return [{
|
| 3591 |
|
| 3592 |
|
| 3593 |
class Squad(HuggingfaceMetric):
|
|
@@ -4019,18 +4036,21 @@ def performance_drop_rate(
|
|
| 4019 |
def interpret_effect_size(x: float):
|
| 4020 |
"""Return a string rule-of-thumb interpretation of an effect size value, as defined by Cohen/Sawilowsky.
|
| 4021 |
|
| 4022 |
-
See https://en.wikipedia.org/wiki/Effect_size
|
| 4023 |
-
Cohen, Jacob (1988). Statistical Power Analysis for the Behavioral Sciences; and
|
| 4024 |
-
Sawilowsky, S (2009). "New effect size rules of thumb". Journal of Modern Applied Statistical Methods. 8 (2): 467-474.
|
| 4025 |
|
| 4026 |
Value has interpretation of
|
| 4027 |
-
|
| 4028 |
-
-
|
| 4029 |
-
|
| 4030 |
-
|
| 4031 |
-
|
| 4032 |
-
|
| 4033 |
-
|
|
|
|
|
|
|
|
|
|
| 4034 |
|
| 4035 |
Args:
|
| 4036 |
x: float effect size value
|
|
@@ -4066,7 +4086,7 @@ def normalized_cohens_h(
|
|
| 4066 |
"""Cohen's h effect size between two proportions, normalized to interval [-1,1].
|
| 4067 |
|
| 4068 |
Allows for change-type metric when the baseline is 0 (percentage change, and thus PDR, is undefined)
|
| 4069 |
-
https://en.wikipedia.org/wiki/Cohen%27s_h
|
| 4070 |
|
| 4071 |
Cohen's h effect size metric between two proportions p2 and p1 is 2 * (arcsin(sqrt(p2)) - arcsin(sqrt(p1))).
|
| 4072 |
h in -pi, pi, with +/-pi representing the largest increase/decrease (p1=0, p2=1), or (p1=1, p2=0).
|
|
@@ -4077,6 +4097,9 @@ def normalized_cohens_h(
|
|
| 4077 |
Interpretation: the original unscaled Cohen's h can be interpreted according to function interpret_effect_size
|
| 4078 |
|
| 4079 |
Thus, the rule of interpreting the effect of the normalized value is to use the same thresholds divided by pi
|
|
|
|
|
|
|
|
|
|
| 4080 |
- essentially 0 if |norm h| < 0.0031831
|
| 4081 |
- very small if 0.0031831 <= |norm h| < 0.06366198
|
| 4082 |
- small difference if 0.06366198 <= |norm h| < 0.15915494
|
|
@@ -4084,12 +4107,17 @@ def normalized_cohens_h(
|
|
| 4084 |
- a large difference if 0.25464791 <= |norm h| < 0.38197186
|
| 4085 |
- a very large difference if 0.38197186 <= |norm h| < 0.63661977
|
| 4086 |
- a huge difference if 0.63661977 <= |norm h|
|
|
|
|
| 4087 |
Args:
|
| 4088 |
subgroup_scores_dict: dict where keys are subgroup types and values are lists of instance scores.
|
|
|
|
| 4089 |
control_subgroup_types: list of subgroup types (potential keys of subgroup_scores_dict) that are the control (baseline) group
|
|
|
|
| 4090 |
comparison_subgroup_types: list of subgroup types (potential keys of subgroup_scores_dict) that are the group
|
| 4091 |
-
|
|
|
|
| 4092 |
interpret: boolean, whether to interpret the significance of the score or not
|
|
|
|
| 4093 |
Returns:
|
| 4094 |
float score between -1 and 1, and a string interpretation if interpret=True
|
| 4095 |
"""
|
|
@@ -5118,3 +5146,112 @@ class PredictionLength(InstanceMetric):
|
|
| 5118 |
task_data: List[Dict],
|
| 5119 |
) -> dict:
|
| 5120 |
return {self.main_score: [len(prediction)], "score_name": self.main_score}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import ast
|
| 2 |
import json
|
| 3 |
+
import math
|
| 4 |
import os
|
| 5 |
import re
|
| 6 |
import string
|
|
|
|
| 28 |
)
|
| 29 |
from .deprecation_utils import deprecation
|
| 30 |
from .error_utils import Documentation, UnitxtWarning
|
| 31 |
+
from .inference import (
|
| 32 |
+
HFPipelineBasedInferenceEngine,
|
| 33 |
+
InferenceEngine,
|
| 34 |
+
WMLInferenceEngineGeneration,
|
| 35 |
+
)
|
| 36 |
from .logging_utils import get_logger
|
| 37 |
from .metric_utils import InstanceInput, MetricRequest, MetricResponse
|
| 38 |
from .operator import (
|
|
|
|
| 965 |
"""Class for metrics for which a global score can be calculated by aggregating the instance scores (possibly with additional instance inputs).
|
| 966 |
|
| 967 |
InstanceMetric currently allows two reductions:
|
| 968 |
+
|
| 969 |
1. 'mean', which calculates the mean of instance scores,
|
| 970 |
2. 'group_mean', which first applies an aggregation function specified in the reduction_map
|
| 971 |
+
to instance scores grouped by the field grouping_field (which must not be None), and returns the mean
|
| 972 |
+
of the group scores; if grouping_field is None, grouping is disabled.
|
| 973 |
+
See _validate_group_mean_reduction for formatting instructions.
|
| 974 |
+
|
| 975 |
"""
|
| 976 |
|
| 977 |
n_resamples: int = OptionalField(
|
|
|
|
| 1496 |
|
| 1497 |
Attributes:
|
| 1498 |
field: The field from the task_data that contains the values to be checked for containment.
|
| 1499 |
+
|
| 1500 |
+
Example task that contains this metric:
|
| 1501 |
+
|
| 1502 |
+
.. code-block:: python
|
| 1503 |
+
|
| 1504 |
+
Task(
|
| 1505 |
+
input_fields={"question": str},
|
| 1506 |
+
reference_fields={"entities": str},
|
| 1507 |
+
prediction_type=str,
|
| 1508 |
+
metrics=["string_containment_ratio[field=entities]"],
|
| 1509 |
+
)
|
| 1510 |
"""
|
| 1511 |
|
| 1512 |
reduction_map = {"mean": ["string_containment"]}
|
|
|
|
| 2787 |
|
| 2788 |
|
| 2789 |
class SentenceBert(BulkInstanceMetric):
|
| 2790 |
+
main_score = "sbert_score"
|
| 2791 |
+
reduction_map = {"mean": [main_score]}
|
| 2792 |
batch_size: int = 32
|
| 2793 |
|
| 2794 |
model_name: str
|
|
|
|
| 2834 |
refs_group_emb = refs_emb[ref_group_bounds[0] : ref_group_bounds[1]]
|
| 2835 |
scores.append(self.util.cos_sim(pred_emb, refs_group_emb).max().item())
|
| 2836 |
|
| 2837 |
+
return [{self.main_score: score} for score in scores]
|
| 2838 |
|
| 2839 |
|
| 2840 |
class Reward(BulkInstanceMetric):
|
| 2841 |
+
main_score = "reward_score"
|
| 2842 |
+
reduction_map = {"mean": [main_score]}
|
| 2843 |
batch_size: int = 32
|
| 2844 |
|
| 2845 |
model_name: str
|
|
|
|
| 2875 |
|
| 2876 |
# compute the metric
|
| 2877 |
# add function_to_apply="none" to disable sigmoid
|
| 2878 |
+
results = self.pipe(inputs, batch_size=self.batch_size)
|
| 2879 |
+
for result in results:
|
| 2880 |
+
result[self.main_score] = result["score"]
|
| 2881 |
+
return results
|
| 2882 |
|
| 2883 |
|
| 2884 |
class Detector(BulkInstanceMetric):
|
| 2885 |
+
main_score = "detector_score"
|
| 2886 |
+
reduction_map = {"mean": [main_score]}
|
| 2887 |
batch_size: int = 32
|
| 2888 |
|
| 2889 |
prediction_type = str
|
|
|
|
| 2910 |
) -> List[Dict[str, Any]]:
|
| 2911 |
# compute the metric
|
| 2912 |
# add function_to_apply="none" to disable sigmoid
|
| 2913 |
+
results = self.pipe(predictions, batch_size=self.batch_size)
|
| 2914 |
+
for result in results:
|
| 2915 |
+
result[self.main_score] = result["score"]
|
| 2916 |
+
return results
|
| 2917 |
|
| 2918 |
|
| 2919 |
class RegardMetric(GlobalMetric):
|
|
|
|
| 3554 |
|
| 3555 |
|
| 3556 |
class FaithfulnessHHEM(BulkInstanceMetric):
|
| 3557 |
+
main_score = "hhem_score"
|
|
|
|
| 3558 |
batch_size: int = 2
|
| 3559 |
model_name: str = "vectara/hallucination_evaluation_model"
|
| 3560 |
prediction_type = str
|
| 3561 |
single_reference_per_prediction = True
|
| 3562 |
max_context_words = 4096
|
| 3563 |
+
reduction_map = {"mean": [main_score]}
|
| 3564 |
|
| 3565 |
_requirements_list: List[str] = ["transformers", "torch"]
|
| 3566 |
|
|
|
|
| 3604 |
for input_batch in tqdm(input_batches, "input batch"):
|
| 3605 |
batch_scores = self.model.predict(input_batch).cpu().tolist()
|
| 3606 |
scores.extend(batch_scores)
|
| 3607 |
+
return [{self.main_score: score} for score in scores]
|
| 3608 |
|
| 3609 |
|
| 3610 |
class Squad(HuggingfaceMetric):
|
|
|
|
| 4036 |
def interpret_effect_size(x: float):
|
| 4037 |
"""Return a string rule-of-thumb interpretation of an effect size value, as defined by Cohen/Sawilowsky.
|
| 4038 |
|
| 4039 |
+
| See `Effect size <https://en.wikipedia.org/wiki/Effect_size>`_
|
| 4040 |
+
| Cohen, Jacob (1988). Statistical Power Analysis for the Behavioral Sciences; and
|
| 4041 |
+
| Sawilowsky, S (2009). "New effect size rules of thumb". Journal of Modern Applied Statistical Methods. 8 (2): 467-474.
|
| 4042 |
|
| 4043 |
Value has interpretation of
|
| 4044 |
+
|
| 4045 |
+
.. code-block:: text
|
| 4046 |
+
|
| 4047 |
+
- essentially 0 if |x| < 0.01
|
| 4048 |
+
- very small if 0.01 <= |x| < 0.2
|
| 4049 |
+
- small difference if 0.2 <= |x| < 0.5
|
| 4050 |
+
- a medium difference if 0.5 <= |x| < 0.8
|
| 4051 |
+
- a large difference if 0.8 <= |x| < 1.2
|
| 4052 |
+
- a very large difference if 1.2 <= |x| < 2.0
|
| 4053 |
+
- a huge difference if 2.0 <= |x|
|
| 4054 |
|
| 4055 |
Args:
|
| 4056 |
x: float effect size value
|
|
|
|
| 4086 |
"""Cohen's h effect size between two proportions, normalized to interval [-1,1].
|
| 4087 |
|
| 4088 |
Allows for change-type metric when the baseline is 0 (percentage change, and thus PDR, is undefined)
|
| 4089 |
+
`Conhen's h <https://en.wikipedia.org/wiki/Cohen%27s_h>`_
|
| 4090 |
|
| 4091 |
Cohen's h effect size metric between two proportions p2 and p1 is 2 * (arcsin(sqrt(p2)) - arcsin(sqrt(p1))).
|
| 4092 |
h in -pi, pi, with +/-pi representing the largest increase/decrease (p1=0, p2=1), or (p1=1, p2=0).
|
|
|
|
| 4097 |
Interpretation: the original unscaled Cohen's h can be interpreted according to function interpret_effect_size
|
| 4098 |
|
| 4099 |
Thus, the rule of interpreting the effect of the normalized value is to use the same thresholds divided by pi
|
| 4100 |
+
|
| 4101 |
+
.. code-block:: text
|
| 4102 |
+
|
| 4103 |
- essentially 0 if |norm h| < 0.0031831
|
| 4104 |
- very small if 0.0031831 <= |norm h| < 0.06366198
|
| 4105 |
- small difference if 0.06366198 <= |norm h| < 0.15915494
|
|
|
|
| 4107 |
- a large difference if 0.25464791 <= |norm h| < 0.38197186
|
| 4108 |
- a very large difference if 0.38197186 <= |norm h| < 0.63661977
|
| 4109 |
- a huge difference if 0.63661977 <= |norm h|
|
| 4110 |
+
|
| 4111 |
Args:
|
| 4112 |
subgroup_scores_dict: dict where keys are subgroup types and values are lists of instance scores.
|
| 4113 |
+
|
| 4114 |
control_subgroup_types: list of subgroup types (potential keys of subgroup_scores_dict) that are the control (baseline) group
|
| 4115 |
+
|
| 4116 |
comparison_subgroup_types: list of subgroup types (potential keys of subgroup_scores_dict) that are the group
|
| 4117 |
+
to be compared to the control group.
|
| 4118 |
+
|
| 4119 |
interpret: boolean, whether to interpret the significance of the score or not
|
| 4120 |
+
|
| 4121 |
Returns:
|
| 4122 |
float score between -1 and 1, and a string interpretation if interpret=True
|
| 4123 |
"""
|
|
|
|
| 5146 |
task_data: List[Dict],
|
| 5147 |
) -> dict:
|
| 5148 |
return {self.main_score: [len(prediction)], "score_name": self.main_score}
|
| 5149 |
+
|
| 5150 |
+
|
| 5151 |
+
class GraniteGuardianWMLMetric(InstanceMetric):
|
| 5152 |
+
"""Return metric for different kinds of "risk" from the Granite-3.0 Guardian model."""
|
| 5153 |
+
|
| 5154 |
+
main_score = "granite_guardian"
|
| 5155 |
+
reduction_map: Dict[str, List[str]] = None
|
| 5156 |
+
prediction_type = float
|
| 5157 |
+
|
| 5158 |
+
model_name: str = "ibm/granite-guardian-3-8b"
|
| 5159 |
+
hf_model_name: str = "ibm-granite/granite-guardian-3.0-8b"
|
| 5160 |
+
safe_token = "No"
|
| 5161 |
+
unsafe_token = "Yes"
|
| 5162 |
+
|
| 5163 |
+
inference_engine: WMLInferenceEngineGeneration = None
|
| 5164 |
+
generation_params: Dict = None
|
| 5165 |
+
risk_name: str = None
|
| 5166 |
+
|
| 5167 |
+
_requirements_list: List[str] = ["ibm_watsonx_ai", "torch", "transformers"]
|
| 5168 |
+
|
| 5169 |
+
def prepare(self):
|
| 5170 |
+
self.reduction_map = {"mean": [self.main_score]}
|
| 5171 |
+
|
| 5172 |
+
def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> dict:
|
| 5173 |
+
from transformers import AutoTokenizer
|
| 5174 |
+
|
| 5175 |
+
if not hasattr(self, "_tokenizer") or self._tokenizer is None:
|
| 5176 |
+
self._tokenizer = AutoTokenizer.from_pretrained(self.hf_model_name)
|
| 5177 |
+
self.inference_engine = WMLInferenceEngineGeneration(
|
| 5178 |
+
model_name=self.model_name,
|
| 5179 |
+
)
|
| 5180 |
+
self.inference_engine._load_model()
|
| 5181 |
+
self.model = self.inference_engine._model
|
| 5182 |
+
self.generation_params = self.inference_engine._set_logprobs_params({})
|
| 5183 |
+
|
| 5184 |
+
messages = self.process_input_fields(task_data)
|
| 5185 |
+
guardian_config = {"risk_name": self.risk_name}
|
| 5186 |
+
processed_input = self._tokenizer.apply_chat_template(
|
| 5187 |
+
messages,
|
| 5188 |
+
guardian_config=guardian_config,
|
| 5189 |
+
tokenize=False,
|
| 5190 |
+
add_generation_prompt=True,
|
| 5191 |
+
)
|
| 5192 |
+
|
| 5193 |
+
result = self.model.generate(
|
| 5194 |
+
prompt=[processed_input],
|
| 5195 |
+
params=self.generation_params,
|
| 5196 |
+
)
|
| 5197 |
+
generated_tokens_list = result[0]["results"][0]["generated_tokens"]
|
| 5198 |
+
label, prob_of_risk = self.parse_output(generated_tokens_list)
|
| 5199 |
+
score = 1 - prob_of_risk if label is not None else np.nan
|
| 5200 |
+
return {self.main_score: score}
|
| 5201 |
+
|
| 5202 |
+
def process_input_fields(self, task_data):
|
| 5203 |
+
if self.risk_name == "groundedness":
|
| 5204 |
+
messages = [
|
| 5205 |
+
{"role": "context", "content": "\n".join(task_data["contexts"])},
|
| 5206 |
+
{"role": "assistant", "content": task_data["answer"]},
|
| 5207 |
+
]
|
| 5208 |
+
elif self.risk_name == "answer_relevance":
|
| 5209 |
+
messages = [
|
| 5210 |
+
{"role": "user", "content": task_data["question"]},
|
| 5211 |
+
{"role": "assistant", "content": task_data["answer"]},
|
| 5212 |
+
]
|
| 5213 |
+
elif self.risk_name == "context_relevance":
|
| 5214 |
+
messages = [
|
| 5215 |
+
{"role": "user", "content": task_data["question"]},
|
| 5216 |
+
{"role": "context", "content": "\n".join(task_data["contexts"])},
|
| 5217 |
+
]
|
| 5218 |
+
else:
|
| 5219 |
+
raise NotImplementedError()
|
| 5220 |
+
|
| 5221 |
+
return messages
|
| 5222 |
+
|
| 5223 |
+
def parse_output(self, generated_tokens_list):
|
| 5224 |
+
top_tokens_list = [
|
| 5225 |
+
generated_tokens["top_tokens"] for generated_tokens in generated_tokens_list
|
| 5226 |
+
]
|
| 5227 |
+
prob = self.get_probabilities(top_tokens_list)
|
| 5228 |
+
prob_of_risk = prob[1]
|
| 5229 |
+
|
| 5230 |
+
res = next(iter(generated_tokens_list))["text"].strip()
|
| 5231 |
+
|
| 5232 |
+
if self.unsafe_token.lower() == res.lower():
|
| 5233 |
+
label = self.unsafe_token
|
| 5234 |
+
elif self.safe_token.lower() == res.lower():
|
| 5235 |
+
label = self.safe_token
|
| 5236 |
+
else:
|
| 5237 |
+
label = None
|
| 5238 |
+
|
| 5239 |
+
return label, prob_of_risk
|
| 5240 |
+
|
| 5241 |
+
def get_probabilities(self, top_tokens_list):
|
| 5242 |
+
import torch
|
| 5243 |
+
|
| 5244 |
+
safe_token_prob = 1e-50
|
| 5245 |
+
unsafe_token_prob = 1e-50
|
| 5246 |
+
|
| 5247 |
+
for top_tokens in top_tokens_list:
|
| 5248 |
+
for token in top_tokens:
|
| 5249 |
+
if token["text"].strip().lower() == self.safe_token.lower():
|
| 5250 |
+
safe_token_prob += math.exp(token["logprob"])
|
| 5251 |
+
if token["text"].strip().lower() == self.unsafe_token.lower():
|
| 5252 |
+
unsafe_token_prob += math.exp(token["logprob"])
|
| 5253 |
+
|
| 5254 |
+
return torch.softmax(
|
| 5255 |
+
torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]),
|
| 5256 |
+
dim=0,
|
| 5257 |
+
).numpy()
|
operators.py
CHANGED
|
@@ -137,34 +137,39 @@ class MapInstanceValues(InstanceOperator):
|
|
| 137 |
|
| 138 |
Attributes:
|
| 139 |
mappers (Dict[str, Dict[str, Any]]): The mappers to use for mapping instance values.
|
| 140 |
-
|
| 141 |
-
|
|
|
|
|
|
|
|
|
|
| 142 |
strict (bool): If True, the mapping is applied strictly. That means if a value
|
| 143 |
-
|
| 144 |
-
|
|
|
|
| 145 |
process_every_value (bool): If True, all fields to be mapped should be lists, and the mapping
|
| 146 |
-
|
| 147 |
-
|
| 148 |
|
| 149 |
Examples:
|
| 150 |
-
MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}})
|
| 151 |
-
replaces
|
| 152 |
-
instance {"a":
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
|
|
|
| 168 |
"""
|
| 169 |
|
| 170 |
mappers: Dict[str, Dict[str, str]]
|
|
@@ -234,27 +239,25 @@ class FlattenInstances(InstanceOperator):
|
|
| 234 |
|
| 235 |
|
| 236 |
class Set(InstanceOperator):
|
| 237 |
-
"""
|
| 238 |
|
| 239 |
Args:
|
| 240 |
-
fields (Dict[str, object]): The fields to add to each instance.
|
| 241 |
-
|
| 242 |
use_deepcopy (bool) : Deep copy the input value to avoid later modifications
|
| 243 |
|
| 244 |
Examples:
|
| 245 |
-
#
|
| 246 |
-
Set(fields={"classes": ["positive","negatives"]})
|
| 247 |
|
| 248 |
-
#
|
| 249 |
-
Set(fields={"span/start": 0}
|
| 250 |
|
| 251 |
-
#
|
| 252 |
-
Set(fields={"classes": ["positive","negatives"], apply_to_stream=["train"]})
|
| 253 |
|
| 254 |
-
#
|
| 255 |
-
|
| 256 |
-
Set(fields={"classes": alist}), use_deepcopy=True)
|
| 257 |
-
# if now alist is modified, still the instances remain intact.
|
| 258 |
"""
|
| 259 |
|
| 260 |
fields: Dict[str, object]
|
|
@@ -333,22 +336,26 @@ class InstanceFieldOperator(InstanceOperator):
|
|
| 333 |
|
| 334 |
Args:
|
| 335 |
field (Optional[str]): The field to process, if only a single one is passed. Defaults to None
|
|
|
|
| 336 |
to_field (Optional[str]): Field name to save result into, if only one field is processed, if None is passed the
|
| 337 |
-
|
|
|
|
| 338 |
field_to_field (Optional[Union[List[List[str]], Dict[str, str]]]): Mapping from names of fields to process,
|
| 339 |
-
|
| 340 |
-
|
| 341 |
is mapped to the field.
|
| 342 |
-
|
| 343 |
-
in the (outer) List. But when the type of argument
|
| 344 |
order. The end result might depend on that order if either (1) two different fields are mapped to the same
|
| 345 |
to_field, or (2) a field shows both as a key and as a value in different mappings.
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
|
| 350 |
-
Note: if 'field' and 'to_field' (or both members of a pair in 'field_to_field') are equal (or share a common
|
| 351 |
-
prefix if 'field' and 'to_field' contain a /), then the result of the operation is saved within 'field'
|
| 352 |
"""
|
| 353 |
|
| 354 |
field: Optional[str] = None
|
|
@@ -577,17 +584,18 @@ class Apply(InstanceOperator):
|
|
| 577 |
Args:
|
| 578 |
function (str): name of function.
|
| 579 |
to_field (str): the field to store the result
|
| 580 |
-
|
|
|
|
| 581 |
|
| 582 |
Examples:
|
| 583 |
-
Store in field "b" the uppercase string of the value in field "a"
|
| 584 |
-
Apply("a", function=str.upper, to_field="b")
|
| 585 |
|
| 586 |
-
Dump the json representation of field "t" and store back in the same field
|
| 587 |
-
Apply("t", function=json.dumps, to_field="t")
|
| 588 |
|
| 589 |
-
Set the time in a field 'b'
|
| 590 |
-
Apply(function=time.time, to_field="b")
|
| 591 |
|
| 592 |
"""
|
| 593 |
|
|
@@ -667,14 +675,13 @@ class ListFieldValues(InstanceOperator):
|
|
| 667 |
|
| 668 |
|
| 669 |
class ZipFieldValues(InstanceOperator):
|
| 670 |
-
"""Zips values of multiple fields in a given instance, similar to list(zip(*fields))
|
| 671 |
|
| 672 |
The value in each of the specified 'fields' is assumed to be a list. The lists from all 'fields'
|
| 673 |
are zipped, and stored into 'to_field'.
|
| 674 |
|
| 675 |
-
If 'longest'=False, the length of the zipped result is determined by the shortest input value.
|
| 676 |
-
If 'longest'=
|
| 677 |
-
inputs with None -s.
|
| 678 |
|
| 679 |
"""
|
| 680 |
|
|
@@ -706,11 +713,11 @@ class ZipFieldValues(InstanceOperator):
|
|
| 706 |
class InterleaveListsToDialogOperator(InstanceOperator):
|
| 707 |
"""Interleaves two lists, one of user dialog turns and one of assistant dialog turns, into a single list of tuples, alternating between "user" and "assistant".
|
| 708 |
|
| 709 |
-
|
| 710 |
-
|
| 711 |
|
| 712 |
The user turns and assistant turns field are specified in the arguments.
|
| 713 |
-
|
| 714 |
|
| 715 |
"""
|
| 716 |
|
|
@@ -854,13 +861,13 @@ class Copy(FieldOperator):
|
|
| 854 |
|
| 855 |
Examples:
|
| 856 |
An input instance {"a": 2, "b": 3}, when processed by
|
| 857 |
-
Copy(field_to_field={"a": "b"}
|
| 858 |
would yield {"a": 2, "b": 2}, and when processed by
|
| 859 |
-
Copy(field_to_field={"a": "c"} would yield
|
| 860 |
{"a": 2, "b": 3, "c": 2}
|
| 861 |
|
| 862 |
with field names containing / , we can also copy inside the field:
|
| 863 |
-
Copy(field="a/0",to_field="a")
|
| 864 |
would process instance {"a": [1, 3]} into {"a": 1}
|
| 865 |
|
| 866 |
|
|
@@ -930,32 +937,41 @@ class CastFields(InstanceOperator):
|
|
| 930 |
"""Casts specified fields to specified types.
|
| 931 |
|
| 932 |
Args:
|
| 933 |
-
use_nested_query (bool): Whether to cast nested fields, expressed in dpath. Defaults to False.
|
| 934 |
fields (Dict[str, str]): A dictionary mapping field names to the names of the types to cast the fields to.
|
| 935 |
-
|
|
|
|
| 936 |
defaults (Dict[str, object]): A dictionary mapping field names to default values for cases of casting failure.
|
|
|
|
| 937 |
process_every_value (bool): If true, all fields involved must contain lists, and each value in the list is then casted. Defaults to False.
|
| 938 |
|
| 939 |
-
|
| 940 |
-
|
| 941 |
-
|
| 942 |
-
|
| 943 |
-
|
| 944 |
-
|
| 945 |
-
|
| 946 |
-
|
| 947 |
-
|
|
|
|
|
|
|
| 948 |
|
| 949 |
"""
|
| 950 |
|
| 951 |
fields: Dict[str, str] = field(default_factory=dict)
|
| 952 |
failure_defaults: Dict[str, object] = field(default_factory=dict)
|
| 953 |
-
use_nested_query: bool =
|
| 954 |
process_every_value: bool = False
|
| 955 |
|
| 956 |
def prepare(self):
|
| 957 |
self.types = {"int": int, "float": float, "str": str, "bool": bool}
|
| 958 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 959 |
def _cast_single(self, value, type, field):
|
| 960 |
try:
|
| 961 |
return self.types[type](value)
|
|
@@ -1093,18 +1109,18 @@ class FilterByCondition(StreamOperator):
|
|
| 1093 |
|
| 1094 |
Args:
|
| 1095 |
values (Dict[str, Any]): Field names and respective Values that instances must match according the condition, to be included in the output.
|
|
|
|
| 1096 |
condition: the name of the desired condition operator between the specified (sub) field's value and the provided constant value. Supported conditions are ("gt", "ge", "lt", "le", "ne", "eq", "in","not in")
|
|
|
|
| 1097 |
error_on_filtered_all (bool, optional): If True, raises an error if all instances are filtered out. Defaults to True.
|
| 1098 |
|
| 1099 |
Examples:
|
| 1100 |
-
FilterByCondition(values = {"a":4}, condition = "gt") will yield only instances where field "a" contains a value
|
| 1101 |
-
FilterByCondition(values = {"a":4}, condition = "le") will yield only instances where "a"<=4
|
| 1102 |
-
FilterByCondition(values = {"a":[4,8]}, condition = "in") will yield only instances where "a" is 4 or 8
|
| 1103 |
-
FilterByCondition(values = {"a":[4,8]}, condition = "not in") will yield only instances where "a" different from 4 or 8
|
| 1104 |
-
FilterByCondition(values = {"a/b":[4,8]}, condition = "not in") will yield only instances where "a" is
|
| 1105 |
-
|
| 1106 |
-
FilterByCondition(values = {"a[2]":4}, condition = "le") will yield only instances where "a" is a list whose 3-rd
|
| 1107 |
-
element is <= 4
|
| 1108 |
|
| 1109 |
|
| 1110 |
"""
|
|
@@ -1805,14 +1821,14 @@ class EncodeLabels(InstanceOperator):
|
|
| 1805 |
Args:
|
| 1806 |
fields (List[str]): The fields to encode together.
|
| 1807 |
|
| 1808 |
-
Example:
|
| 1809 |
-
EncodeLabels(fields = ["a", "b/*"])
|
| 1810 |
-
on input stream = [{"a": "red", "b": ["red", "blue"], "c":"bread"},
|
| 1811 |
-
{"a": "blue", "b": ["green"], "c":"water"}] will yield the
|
| 1812 |
-
output stream = [{'a': 0, 'b': [0, 1], 'c': 'bread'}, {'a': 1, 'b': [2], 'c': 'water'}]
|
| 1813 |
|
| 1814 |
-
Note:
|
| 1815 |
-
input 'fields' with the appendix "/*" as in the above example.
|
| 1816 |
|
| 1817 |
"""
|
| 1818 |
|
|
@@ -2132,21 +2148,23 @@ class CollateInstances(StreamOperator):
|
|
| 2132 |
batch_size (int)
|
| 2133 |
|
| 2134 |
Example:
|
| 2135 |
-
|
| 2136 |
-
|
| 2137 |
-
Given inputs = [
|
| 2138 |
-
{"a": 1, "b": 2},
|
| 2139 |
-
{"a": 2, "b": 2},
|
| 2140 |
-
{"a": 3, "b": 2},
|
| 2141 |
-
{"a": 4, "b": 2},
|
| 2142 |
-
{"a": 5, "b": 2}
|
| 2143 |
-
]
|
| 2144 |
|
| 2145 |
-
|
| 2146 |
-
|
| 2147 |
-
|
| 2148 |
-
|
| 2149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2150 |
|
| 2151 |
|
| 2152 |
"""
|
|
|
|
| 137 |
|
| 138 |
Attributes:
|
| 139 |
mappers (Dict[str, Dict[str, Any]]): The mappers to use for mapping instance values.
|
| 140 |
+
Keys are the names of the fields to undergo mapping, and values are dictionaries
|
| 141 |
+
that define the mapping from old values to new values.
|
| 142 |
+
Note that mapped values are defined by their string representation, so mapped values
|
| 143 |
+
are converted to strings before being looked up in the mappers.
|
| 144 |
+
|
| 145 |
strict (bool): If True, the mapping is applied strictly. That means if a value
|
| 146 |
+
does not exist in the mapper, it will raise a KeyError. If False, values
|
| 147 |
+
that are not present in the mapper are kept as they are.
|
| 148 |
+
|
| 149 |
process_every_value (bool): If True, all fields to be mapped should be lists, and the mapping
|
| 150 |
+
is to be applied to their individual elements. If False, mapping is only applied to a field
|
| 151 |
+
containing a single value.
|
| 152 |
|
| 153 |
Examples:
|
| 154 |
+
``MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}})``
|
| 155 |
+
replaces ``"1"`` with ``"hi"`` and ``"2"`` with ``"bye"`` in field ``"a"`` in all instances of all streams:
|
| 156 |
+
instance ``{"a": 1, "b": 2}`` becomes ``{"a": "hi", "b": 2}``. Note that the value of ``"b"`` remained intact,
|
| 157 |
+
since field-name ``"b"`` does not participate in the mappers, and that ``1`` was casted to ``"1"`` before looked
|
| 158 |
+
up in the mapper of ``"a"``.
|
| 159 |
+
|
| 160 |
+
``MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}}, process_every_value=True)``:
|
| 161 |
+
Assuming field ``"a"`` is a list of values, potentially including ``"1"``-s and ``"2"``-s, this replaces
|
| 162 |
+
each such ``"1"`` with ``"hi"`` and ``"2"`` -- with ``"bye"`` in all instances of all streams:
|
| 163 |
+
instance ``{"a": ["1", "2"], "b": 2}`` becomes ``{"a": ["hi", "bye"], "b": 2}``.
|
| 164 |
+
|
| 165 |
+
``MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}}, strict=True)``:
|
| 166 |
+
To ensure that all values of field ``"a"`` are mapped in every instance, use ``strict=True``.
|
| 167 |
+
Input instance ``{"a":"3", "b": 2}`` will raise an exception per the above call,
|
| 168 |
+
because ``"3"`` is not a key in the mapper of ``"a"``.
|
| 169 |
+
|
| 170 |
+
``MapInstanceValues(mappers={"a": {str([1,2,3,4]): "All", str([]): "None"}}, strict=True)``
|
| 171 |
+
replaces a list ``[1,2,3,4]`` with the string ``"All"`` and an empty list by string ``"None"``.
|
| 172 |
+
|
| 173 |
"""
|
| 174 |
|
| 175 |
mappers: Dict[str, Dict[str, str]]
|
|
|
|
| 239 |
|
| 240 |
|
| 241 |
class Set(InstanceOperator):
|
| 242 |
+
"""Sets specified fields in each instance, in a given stream or all streams (default), with specified values. If fields exist, updates them, if do not exist -- adds them.
|
| 243 |
|
| 244 |
Args:
|
| 245 |
+
fields (Dict[str, object]): The fields to add to each instance. Use '/' to access inner fields
|
| 246 |
+
|
| 247 |
use_deepcopy (bool) : Deep copy the input value to avoid later modifications
|
| 248 |
|
| 249 |
Examples:
|
| 250 |
+
# Set a value of a list consisting of "positive" and "negative" do field "classes" to each and every instance of all streams
|
| 251 |
+
``Set(fields={"classes": ["positive","negatives"]})``
|
| 252 |
|
| 253 |
+
# In each and every instance of all streams, field "span" is to become a dictionary containing a field "start", in which the value 0 is to be set
|
| 254 |
+
``Set(fields={"span/start": 0}``
|
| 255 |
|
| 256 |
+
# In all instances of stream "train" only, Set field "classes" to have the value of a list consisting of "positive" and "negative"
|
| 257 |
+
``Set(fields={"classes": ["positive","negatives"], apply_to_stream=["train"]})``
|
| 258 |
|
| 259 |
+
# Set field "classes" to have the value of a given list, preventing modification of original list from changing the instance.
|
| 260 |
+
``Set(fields={"classes": alist}), use_deepcopy=True)`` if now alist is modified, still the instances remain intact.
|
|
|
|
|
|
|
| 261 |
"""
|
| 262 |
|
| 263 |
fields: Dict[str, object]
|
|
|
|
| 336 |
|
| 337 |
Args:
|
| 338 |
field (Optional[str]): The field to process, if only a single one is passed. Defaults to None
|
| 339 |
+
|
| 340 |
to_field (Optional[str]): Field name to save result into, if only one field is processed, if None is passed the
|
| 341 |
+
operation would happen in-place and its result would replace the value of ``field``. Defaults to None
|
| 342 |
+
|
| 343 |
field_to_field (Optional[Union[List[List[str]], Dict[str, str]]]): Mapping from names of fields to process,
|
| 344 |
+
to names of fields to save the results into. Inner List, if used, should be of length 2.
|
| 345 |
+
| A field is processed by feeding its value into method ``process_value`` and storing the result in ``to_field`` that
|
| 346 |
is mapped to the field.
|
| 347 |
+
| When the type of argument ``field_to_field`` is List, the order by which the fields are processed is their order
|
| 348 |
+
in the (outer) List. But when the type of argument ``field_to_field`` is Dict, there is no uniquely determined
|
| 349 |
order. The end result might depend on that order if either (1) two different fields are mapped to the same
|
| 350 |
to_field, or (2) a field shows both as a key and as a value in different mappings.
|
| 351 |
+
| The operator throws an AssertionError in either of these cases.
|
| 352 |
+
| field_to_field defaults to None
|
| 353 |
+
|
| 354 |
+
process_every_value (bool): Processes the values in a list instead of the list as a value, similar to python's ``*var``. Defaults to False
|
| 355 |
+
|
| 356 |
+
Note: if ``field`` and ``to_field`` (or both members of a pair in ``field_to_field`` ) are equal (or share a common
|
| 357 |
+
prefix if ``field`` and ``to_field`` contain a / ), then the result of the operation is saved within ``field`` .
|
| 358 |
|
|
|
|
|
|
|
| 359 |
"""
|
| 360 |
|
| 361 |
field: Optional[str] = None
|
|
|
|
| 584 |
Args:
|
| 585 |
function (str): name of function.
|
| 586 |
to_field (str): the field to store the result
|
| 587 |
+
|
| 588 |
+
any additional arguments are field names whose values will be passed directly to the function specified
|
| 589 |
|
| 590 |
Examples:
|
| 591 |
+
Store in field "b" the uppercase string of the value in field "a":
|
| 592 |
+
``Apply("a", function=str.upper, to_field="b")``
|
| 593 |
|
| 594 |
+
Dump the json representation of field "t" and store back in the same field:
|
| 595 |
+
``Apply("t", function=json.dumps, to_field="t")``
|
| 596 |
|
| 597 |
+
Set the time in a field 'b':
|
| 598 |
+
``Apply(function=time.time, to_field="b")``
|
| 599 |
|
| 600 |
"""
|
| 601 |
|
|
|
|
| 675 |
|
| 676 |
|
| 677 |
class ZipFieldValues(InstanceOperator):
|
| 678 |
+
"""Zips values of multiple fields in a given instance, similar to ``list(zip(*fields))``.
|
| 679 |
|
| 680 |
The value in each of the specified 'fields' is assumed to be a list. The lists from all 'fields'
|
| 681 |
are zipped, and stored into 'to_field'.
|
| 682 |
|
| 683 |
+
| If 'longest'=False, the length of the zipped result is determined by the shortest input value.
|
| 684 |
+
| If 'longest'=True, the length of the zipped result is determined by the longest input, padding shorter inputs with None-s.
|
|
|
|
| 685 |
|
| 686 |
"""
|
| 687 |
|
|
|
|
| 713 |
class InterleaveListsToDialogOperator(InstanceOperator):
|
| 714 |
"""Interleaves two lists, one of user dialog turns and one of assistant dialog turns, into a single list of tuples, alternating between "user" and "assistant".
|
| 715 |
|
| 716 |
+
The list of tuples if of format (role, turn_content), where the role label is specified by
|
| 717 |
+
the 'user_role_label' and 'assistant_role_label' fields (default to "user" and "assistant").
|
| 718 |
|
| 719 |
The user turns and assistant turns field are specified in the arguments.
|
| 720 |
+
The value of each of the 'fields' is assumed to be a list.
|
| 721 |
|
| 722 |
"""
|
| 723 |
|
|
|
|
| 861 |
|
| 862 |
Examples:
|
| 863 |
An input instance {"a": 2, "b": 3}, when processed by
|
| 864 |
+
``Copy(field_to_field={"a": "b"})``
|
| 865 |
would yield {"a": 2, "b": 2}, and when processed by
|
| 866 |
+
``Copy(field_to_field={"a": "c"})`` would yield
|
| 867 |
{"a": 2, "b": 3, "c": 2}
|
| 868 |
|
| 869 |
with field names containing / , we can also copy inside the field:
|
| 870 |
+
``Copy(field="a/0",to_field="a")``
|
| 871 |
would process instance {"a": [1, 3]} into {"a": 1}
|
| 872 |
|
| 873 |
|
|
|
|
| 937 |
"""Casts specified fields to specified types.
|
| 938 |
|
| 939 |
Args:
|
|
|
|
| 940 |
fields (Dict[str, str]): A dictionary mapping field names to the names of the types to cast the fields to.
|
| 941 |
+
e.g: "int", "str", "float", "bool". Basic names of types
|
| 942 |
+
|
| 943 |
defaults (Dict[str, object]): A dictionary mapping field names to default values for cases of casting failure.
|
| 944 |
+
|
| 945 |
process_every_value (bool): If true, all fields involved must contain lists, and each value in the list is then casted. Defaults to False.
|
| 946 |
|
| 947 |
+
Example:
|
| 948 |
+
.. code-block:: python
|
| 949 |
+
|
| 950 |
+
CastFields(
|
| 951 |
+
fields={"a/d": "float", "b": "int"},
|
| 952 |
+
failure_defaults={"a/d": 0.0, "b": 0},
|
| 953 |
+
process_every_value=True,
|
| 954 |
+
)
|
| 955 |
+
|
| 956 |
+
would process the input instance: ``{"a": {"d": ["half", "0.6", 1, 12]}, "b": ["2"]}``
|
| 957 |
+
into ``{"a": {"d": [0.0, 0.6, 1.0, 12.0]}, "b": [2]}``.
|
| 958 |
|
| 959 |
"""
|
| 960 |
|
| 961 |
fields: Dict[str, str] = field(default_factory=dict)
|
| 962 |
failure_defaults: Dict[str, object] = field(default_factory=dict)
|
| 963 |
+
use_nested_query: bool = None # deprecated field
|
| 964 |
process_every_value: bool = False
|
| 965 |
|
| 966 |
def prepare(self):
|
| 967 |
self.types = {"int": int, "float": float, "str": str, "bool": bool}
|
| 968 |
|
| 969 |
+
def verify(self):
|
| 970 |
+
super().verify()
|
| 971 |
+
if self.use_nested_query is not None:
|
| 972 |
+
depr_message = "Field 'use_nested_query' is deprecated. From now on, default behavior is compatible to use_nested_query=True. Please remove this field from your code."
|
| 973 |
+
warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
|
| 974 |
+
|
| 975 |
def _cast_single(self, value, type, field):
|
| 976 |
try:
|
| 977 |
return self.types[type](value)
|
|
|
|
| 1109 |
|
| 1110 |
Args:
|
| 1111 |
values (Dict[str, Any]): Field names and respective Values that instances must match according the condition, to be included in the output.
|
| 1112 |
+
|
| 1113 |
condition: the name of the desired condition operator between the specified (sub) field's value and the provided constant value. Supported conditions are ("gt", "ge", "lt", "le", "ne", "eq", "in","not in")
|
| 1114 |
+
|
| 1115 |
error_on_filtered_all (bool, optional): If True, raises an error if all instances are filtered out. Defaults to True.
|
| 1116 |
|
| 1117 |
Examples:
|
| 1118 |
+
| ``FilterByCondition(values = {"a":4}, condition = "gt")`` will yield only instances where field ``"a"`` contains a value ``> 4``
|
| 1119 |
+
| ``FilterByCondition(values = {"a":4}, condition = "le")`` will yield only instances where ``"a"<=4``
|
| 1120 |
+
| ``FilterByCondition(values = {"a":[4,8]}, condition = "in")`` will yield only instances where ``"a"`` is ``4`` or ``8``
|
| 1121 |
+
| ``FilterByCondition(values = {"a":[4,8]}, condition = "not in")`` will yield only instances where ``"a"`` is different from ``4`` or ``8``
|
| 1122 |
+
| ``FilterByCondition(values = {"a/b":[4,8]}, condition = "not in")`` will yield only instances where ``"a"`` is a dict in which key ``"b"`` is mapped to a value that is neither ``4`` nor ``8``
|
| 1123 |
+
| ``FilterByCondition(values = {"a[2]":4}, condition = "le")`` will yield only instances where "a" is a list whose 3-rd element is ``<= 4``
|
|
|
|
|
|
|
| 1124 |
|
| 1125 |
|
| 1126 |
"""
|
|
|
|
| 1821 |
Args:
|
| 1822 |
fields (List[str]): The fields to encode together.
|
| 1823 |
|
| 1824 |
+
Example:
|
| 1825 |
+
applying ``EncodeLabels(fields = ["a", "b/*"])``
|
| 1826 |
+
on input stream = ``[{"a": "red", "b": ["red", "blue"], "c":"bread"},
|
| 1827 |
+
{"a": "blue", "b": ["green"], "c":"water"}]`` will yield the
|
| 1828 |
+
output stream = ``[{'a': 0, 'b': [0, 1], 'c': 'bread'}, {'a': 1, 'b': [2], 'c': 'water'}]``
|
| 1829 |
|
| 1830 |
+
Note: dict_utils are applied here, and hence, fields that are lists, should be included in
|
| 1831 |
+
input 'fields' with the appendix ``"/*"`` as in the above example.
|
| 1832 |
|
| 1833 |
"""
|
| 1834 |
|
|
|
|
| 2148 |
batch_size (int)
|
| 2149 |
|
| 2150 |
Example:
|
| 2151 |
+
.. code-block:: text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2152 |
|
| 2153 |
+
CollateInstances(batch_size=2)
|
| 2154 |
+
|
| 2155 |
+
Given inputs = [
|
| 2156 |
+
{"a": 1, "b": 2},
|
| 2157 |
+
{"a": 2, "b": 2},
|
| 2158 |
+
{"a": 3, "b": 2},
|
| 2159 |
+
{"a": 4, "b": 2},
|
| 2160 |
+
{"a": 5, "b": 2}
|
| 2161 |
+
]
|
| 2162 |
+
|
| 2163 |
+
Returns targets = [
|
| 2164 |
+
{"a": [1,2], "b": [2,2]},
|
| 2165 |
+
{"a": [3,4], "b": [2,2]},
|
| 2166 |
+
{"a": [5], "b": [2]},
|
| 2167 |
+
]
|
| 2168 |
|
| 2169 |
|
| 2170 |
"""
|
span_lableing_operators.py
CHANGED
|
@@ -8,29 +8,35 @@ class IobExtractor(InstanceOperator):
|
|
| 8 |
|
| 9 |
Attributes:
|
| 10 |
labels (List[str]): A list of entity type labels, e.g., ["Person", "Organization", "Location"].
|
|
|
|
| 11 |
begin_labels (List[str]): A list of labels indicating the beginning of an entity, e.g., ["B-PER", "B-ORG", "B-LOC"].
|
|
|
|
| 12 |
inside_labels (List[str]): A list of labels indicating the continuation of an entity, e.g., ["I-PER", "I-ORG", "I-LOC"].
|
|
|
|
| 13 |
outside_label (str): The label indicating tokens outside of any entity, typically "O".
|
| 14 |
|
| 15 |
The extraction process identifies spans of text corresponding to entities and labels them according to their entity type. Each span is annotated with a start and end character offset, the entity text, and the corresponding label.
|
| 16 |
|
|
|
|
|
|
|
| 17 |
Example of instantiation and usage:
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
|
|
|
| 34 |
|
| 35 |
For more details on the IOB tagging convention, see: https://en.wikipedia.org/wiki/Inside-outside-beginning_(tagging)
|
| 36 |
|
|
|
|
| 8 |
|
| 9 |
Attributes:
|
| 10 |
labels (List[str]): A list of entity type labels, e.g., ["Person", "Organization", "Location"].
|
| 11 |
+
|
| 12 |
begin_labels (List[str]): A list of labels indicating the beginning of an entity, e.g., ["B-PER", "B-ORG", "B-LOC"].
|
| 13 |
+
|
| 14 |
inside_labels (List[str]): A list of labels indicating the continuation of an entity, e.g., ["I-PER", "I-ORG", "I-LOC"].
|
| 15 |
+
|
| 16 |
outside_label (str): The label indicating tokens outside of any entity, typically "O".
|
| 17 |
|
| 18 |
The extraction process identifies spans of text corresponding to entities and labels them according to their entity type. Each span is annotated with a start and end character offset, the entity text, and the corresponding label.
|
| 19 |
|
| 20 |
+
|
| 21 |
+
|
| 22 |
Example of instantiation and usage:
|
| 23 |
+
|
| 24 |
+
.. code-block:: python
|
| 25 |
+
|
| 26 |
+
operator = IobExtractor(
|
| 27 |
+
labels=["Person", "Organization", "Location"],
|
| 28 |
+
begin_labels=["B-PER", "B-ORG", "B-LOC"],
|
| 29 |
+
inside_labels=["I-PER", "I-ORG", "I-LOC"],
|
| 30 |
+
outside_label="O",
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
instance = {
|
| 34 |
+
"labels": ["B-PER", "I-PER", "O", "B-ORG", "I-ORG"],
|
| 35 |
+
"tokens": ["John", "Doe", "works", "at", "OpenAI"]
|
| 36 |
+
}
|
| 37 |
+
processed_instance = operator.process(instance)
|
| 38 |
+
print(processed_instance["spans"])
|
| 39 |
+
# Output: [{'start': 0, 'end': 8, 'text': 'John Doe', 'label': 'Person'}, ...]
|
| 40 |
|
| 41 |
For more details on the IOB tagging convention, see: https://en.wikipedia.org/wiki/Inside-outside-beginning_(tagging)
|
| 42 |
|
struct_data_operators.py
CHANGED
|
@@ -2,17 +2,25 @@
|
|
| 2 |
|
| 3 |
These operators are specialized in handling structured data like tables.
|
| 4 |
For tables, expected input format is:
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
For triples, expected input format is:
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
For key-value pairs, expected input format is:
|
| 14 |
-
|
| 15 |
-
|
|
|
|
|
|
|
| 16 |
"""
|
| 17 |
|
| 18 |
import json
|
|
@@ -148,11 +156,15 @@ class SerializeTableAsMarkdown(SerializeTable):
|
|
| 148 |
|
| 149 |
Markdown table format is used in GitHub code primarily.
|
| 150 |
Format:
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
"""
|
| 157 |
|
| 158 |
# main method that serializes a table.
|
|
@@ -192,11 +204,14 @@ class SerializeTableAsDFLoader(SerializeTable):
|
|
| 192 |
|
| 193 |
Pandas dataframe based code snippet format serializer.
|
| 194 |
Format(Sample):
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
|
|
|
|
|
|
|
|
|
| 200 |
"""
|
| 201 |
|
| 202 |
# main method that serializes a table.
|
|
@@ -234,11 +249,14 @@ class SerializeTableAsJson(SerializeTable):
|
|
| 234 |
|
| 235 |
Json format based serializer.
|
| 236 |
Format(Sample):
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
|
|
|
|
|
|
|
|
|
| 242 |
"""
|
| 243 |
|
| 244 |
# main method that serializes a table.
|
|
@@ -264,15 +282,18 @@ class SerializeTableAsHTML(SerializeTable):
|
|
| 264 |
|
| 265 |
HTML table format used for rendering tables in web pages.
|
| 266 |
Format(Sample):
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
|
|
|
|
|
|
|
|
|
| 276 |
"""
|
| 277 |
|
| 278 |
# main method that serializes a table.
|
|
@@ -404,7 +425,7 @@ class TruncateTableRows(FieldOperator):
|
|
| 404 |
"""Limits table rows to specified limit by removing excess rows via random selection.
|
| 405 |
|
| 406 |
Args:
|
| 407 |
-
rows_to_keep (int)
|
| 408 |
"""
|
| 409 |
|
| 410 |
rows_to_keep: int = 10
|
|
@@ -563,16 +584,19 @@ class ListToKeyValPairs(InstanceOperator):
|
|
| 563 |
class ConvertTableColNamesToSequential(FieldOperator):
|
| 564 |
"""Replaces actual table column names with static sequential names like col_0, col_1,...
|
| 565 |
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
|
|
|
|
|
|
|
|
|
| 576 |
"""
|
| 577 |
|
| 578 |
def process_value(self, table: Any) -> Any:
|
|
@@ -595,17 +619,19 @@ class ConvertTableColNamesToSequential(FieldOperator):
|
|
| 595 |
class ShuffleTableRows(TypeDependentAugmentor):
|
| 596 |
"""Shuffles the input table rows randomly.
|
| 597 |
|
| 598 |
-
|
| 599 |
-
{
|
| 600 |
-
"header": ["name", "age"],
|
| 601 |
-
"rows": [["Alex", 26], ["Raj", 34], ["Donald", 39]],
|
| 602 |
-
}
|
| 603 |
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 609 |
"""
|
| 610 |
|
| 611 |
augmented_type = Table
|
|
@@ -619,17 +645,19 @@ class ShuffleTableRows(TypeDependentAugmentor):
|
|
| 619 |
class ShuffleTableColumns(TypeDependentAugmentor):
|
| 620 |
"""Shuffles the table columns randomly.
|
| 621 |
|
| 622 |
-
|
| 623 |
-
{
|
| 624 |
-
"header": ["name", "age"],
|
| 625 |
-
"rows": [["Alex", 26], ["Raj", 34], ["Donald", 39]],
|
| 626 |
-
}
|
| 627 |
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 633 |
"""
|
| 634 |
|
| 635 |
augmented_type = Table
|
|
@@ -662,11 +690,14 @@ class DumpJson(FieldOperator):
|
|
| 662 |
class MapHTMLTableToJSON(FieldOperator):
|
| 663 |
"""Converts HTML table format to the basic one (JSON).
|
| 664 |
|
| 665 |
-
JSON format
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
|
|
|
|
|
|
|
|
|
| 670 |
"""
|
| 671 |
|
| 672 |
_requirements_list = ["bs4"]
|
|
@@ -701,11 +732,14 @@ class MapHTMLTableToJSON(FieldOperator):
|
|
| 701 |
class MapTableListsToStdTableJSON(FieldOperator):
|
| 702 |
"""Converts lists table format to the basic one (JSON).
|
| 703 |
|
| 704 |
-
JSON format
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
|
|
|
|
|
|
|
|
|
|
| 709 |
"""
|
| 710 |
|
| 711 |
def process_value(self, table: Any) -> Any:
|
|
@@ -755,17 +789,20 @@ class ConstructTableFromRowsCols(InstanceOperator):
|
|
| 755 |
class TransposeTable(TypeDependentAugmentor):
|
| 756 |
"""Transpose a table.
|
| 757 |
|
| 758 |
-
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
|
| 762 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 763 |
|
| 764 |
-
Sample Output:
|
| 765 |
-
{
|
| 766 |
-
"header": [" ", "0", "1", "2"],
|
| 767 |
-
"rows": [["name", "Alice", "Raj", "Donald"], ["age", 26, 34, 39], ["sex", "F", "M", "M"]],
|
| 768 |
-
}
|
| 769 |
"""
|
| 770 |
|
| 771 |
augmented_type = Table
|
|
@@ -791,8 +828,9 @@ class DuplicateTableRows(TypeDependentAugmentor):
|
|
| 791 |
"""Duplicates specific rows of a table for the given number of times.
|
| 792 |
|
| 793 |
Args:
|
| 794 |
-
row_indices (List[int])
|
| 795 |
-
|
|
|
|
| 796 |
"""
|
| 797 |
|
| 798 |
augmented_type = Table
|
|
@@ -823,8 +861,9 @@ class DuplicateTableColumns(TypeDependentAugmentor):
|
|
| 823 |
"""Duplicates specific columns of a table for the given number of times.
|
| 824 |
|
| 825 |
Args:
|
| 826 |
-
column_indices (List[int])
|
| 827 |
-
|
|
|
|
| 828 |
"""
|
| 829 |
|
| 830 |
augmented_type = Table
|
|
|
|
| 2 |
|
| 3 |
These operators are specialized in handling structured data like tables.
|
| 4 |
For tables, expected input format is:
|
| 5 |
+
|
| 6 |
+
.. code-block:: text
|
| 7 |
+
|
| 8 |
+
{
|
| 9 |
+
"header": ["col1", "col2"],
|
| 10 |
+
"rows": [["row11", "row12"], ["row21", "row22"], ["row31", "row32"]]
|
| 11 |
+
}
|
| 12 |
|
| 13 |
For triples, expected input format is:
|
| 14 |
+
|
| 15 |
+
.. code-block:: text
|
| 16 |
+
|
| 17 |
+
[[ "subject1", "relation1", "object1" ], [ "subject1", "relation2", "object2"]]
|
| 18 |
|
| 19 |
For key-value pairs, expected input format is:
|
| 20 |
+
|
| 21 |
+
.. code-block:: text
|
| 22 |
+
|
| 23 |
+
{"key1": "value1", "key2": value2, "key3": "value3"}
|
| 24 |
"""
|
| 25 |
|
| 26 |
import json
|
|
|
|
| 156 |
|
| 157 |
Markdown table format is used in GitHub code primarily.
|
| 158 |
Format:
|
| 159 |
+
|
| 160 |
+
.. code-block:: text
|
| 161 |
+
|
| 162 |
+
|col1|col2|col3|
|
| 163 |
+
|---|---|---|
|
| 164 |
+
|A|4|1|
|
| 165 |
+
|I|2|1|
|
| 166 |
+
...
|
| 167 |
+
|
| 168 |
"""
|
| 169 |
|
| 170 |
# main method that serializes a table.
|
|
|
|
| 204 |
|
| 205 |
Pandas dataframe based code snippet format serializer.
|
| 206 |
Format(Sample):
|
| 207 |
+
|
| 208 |
+
.. code-block:: python
|
| 209 |
+
|
| 210 |
+
pd.DataFrame({
|
| 211 |
+
"name" : ["Alex", "Diana", "Donald"],
|
| 212 |
+
"age" : [26, 34, 39]
|
| 213 |
+
},
|
| 214 |
+
index=[0,1,2])
|
| 215 |
"""
|
| 216 |
|
| 217 |
# main method that serializes a table.
|
|
|
|
| 249 |
|
| 250 |
Json format based serializer.
|
| 251 |
Format(Sample):
|
| 252 |
+
|
| 253 |
+
.. code-block:: json
|
| 254 |
+
|
| 255 |
+
{
|
| 256 |
+
"0":{"name":"Alex","age":26},
|
| 257 |
+
"1":{"name":"Diana","age":34},
|
| 258 |
+
"2":{"name":"Donald","age":39}
|
| 259 |
+
}
|
| 260 |
"""
|
| 261 |
|
| 262 |
# main method that serializes a table.
|
|
|
|
| 282 |
|
| 283 |
HTML table format used for rendering tables in web pages.
|
| 284 |
Format(Sample):
|
| 285 |
+
|
| 286 |
+
.. code-block:: html
|
| 287 |
+
|
| 288 |
+
<table>
|
| 289 |
+
<thead>
|
| 290 |
+
<tr><th>name</th><th>age</th><th>sex</th></tr>
|
| 291 |
+
</thead>
|
| 292 |
+
<tbody>
|
| 293 |
+
<tr><td>Alice</td><td>26</td><td>F</td></tr>
|
| 294 |
+
<tr><td>Raj</td><td>34</td><td>M</td></tr>
|
| 295 |
+
</tbody>
|
| 296 |
+
</table>
|
| 297 |
"""
|
| 298 |
|
| 299 |
# main method that serializes a table.
|
|
|
|
| 425 |
"""Limits table rows to specified limit by removing excess rows via random selection.
|
| 426 |
|
| 427 |
Args:
|
| 428 |
+
rows_to_keep (int): number of rows to keep.
|
| 429 |
"""
|
| 430 |
|
| 431 |
rows_to_keep: int = 10
|
|
|
|
| 584 |
class ConvertTableColNamesToSequential(FieldOperator):
|
| 585 |
"""Replaces actual table column names with static sequential names like col_0, col_1,...
|
| 586 |
|
| 587 |
+
.. code-block:: text
|
| 588 |
+
|
| 589 |
+
Sample input:
|
| 590 |
+
{
|
| 591 |
+
"header": ["name", "age"],
|
| 592 |
+
"rows": [["Alex", 21], ["Donald", 34]]
|
| 593 |
+
}
|
| 594 |
+
|
| 595 |
+
Sample output:
|
| 596 |
+
{
|
| 597 |
+
"header": ["col_0", "col_1"],
|
| 598 |
+
"rows": [["Alex", 21], ["Donald", 34]]
|
| 599 |
+
}
|
| 600 |
"""
|
| 601 |
|
| 602 |
def process_value(self, table: Any) -> Any:
|
|
|
|
| 619 |
class ShuffleTableRows(TypeDependentAugmentor):
|
| 620 |
"""Shuffles the input table rows randomly.
|
| 621 |
|
| 622 |
+
.. code-block:: text
|
|
|
|
|
|
|
|
|
|
|
|
|
| 623 |
|
| 624 |
+
Sample Input:
|
| 625 |
+
{
|
| 626 |
+
"header": ["name", "age"],
|
| 627 |
+
"rows": [["Alex", 26], ["Raj", 34], ["Donald", 39]],
|
| 628 |
+
}
|
| 629 |
+
|
| 630 |
+
Sample Output:
|
| 631 |
+
{
|
| 632 |
+
"header": ["name", "age"],
|
| 633 |
+
"rows": [["Donald", 39], ["Raj", 34], ["Alex", 26]],
|
| 634 |
+
}
|
| 635 |
"""
|
| 636 |
|
| 637 |
augmented_type = Table
|
|
|
|
| 645 |
class ShuffleTableColumns(TypeDependentAugmentor):
|
| 646 |
"""Shuffles the table columns randomly.
|
| 647 |
|
| 648 |
+
.. code-block:: text
|
|
|
|
|
|
|
|
|
|
|
|
|
| 649 |
|
| 650 |
+
Sample Input:
|
| 651 |
+
{
|
| 652 |
+
"header": ["name", "age"],
|
| 653 |
+
"rows": [["Alex", 26], ["Raj", 34], ["Donald", 39]],
|
| 654 |
+
}
|
| 655 |
+
|
| 656 |
+
Sample Output:
|
| 657 |
+
{
|
| 658 |
+
"header": ["age", "name"],
|
| 659 |
+
"rows": [[26, "Alex"], [34, "Raj"], [39, "Donald"]],
|
| 660 |
+
}
|
| 661 |
"""
|
| 662 |
|
| 663 |
augmented_type = Table
|
|
|
|
| 690 |
class MapHTMLTableToJSON(FieldOperator):
|
| 691 |
"""Converts HTML table format to the basic one (JSON).
|
| 692 |
|
| 693 |
+
JSON format:
|
| 694 |
+
|
| 695 |
+
.. code-block:: json
|
| 696 |
+
|
| 697 |
+
{
|
| 698 |
+
"header": ["col1", "col2"],
|
| 699 |
+
"rows": [["row11", "row12"], ["row21", "row22"], ["row31", "row32"]]
|
| 700 |
+
}
|
| 701 |
"""
|
| 702 |
|
| 703 |
_requirements_list = ["bs4"]
|
|
|
|
| 732 |
class MapTableListsToStdTableJSON(FieldOperator):
|
| 733 |
"""Converts lists table format to the basic one (JSON).
|
| 734 |
|
| 735 |
+
JSON format:
|
| 736 |
+
|
| 737 |
+
.. code-block:: json
|
| 738 |
+
|
| 739 |
+
{
|
| 740 |
+
"header": ["col1", "col2"],
|
| 741 |
+
"rows": [["row11", "row12"], ["row21", "row22"], ["row31", "row32"]]
|
| 742 |
+
}
|
| 743 |
"""
|
| 744 |
|
| 745 |
def process_value(self, table: Any) -> Any:
|
|
|
|
| 789 |
class TransposeTable(TypeDependentAugmentor):
|
| 790 |
"""Transpose a table.
|
| 791 |
|
| 792 |
+
.. code-block:: text
|
| 793 |
+
|
| 794 |
+
Sample Input:
|
| 795 |
+
{
|
| 796 |
+
"header": ["name", "age", "sex"],
|
| 797 |
+
"rows": [["Alice", 26, "F"], ["Raj", 34, "M"], ["Donald", 39, "M"]],
|
| 798 |
+
}
|
| 799 |
+
|
| 800 |
+
Sample Output:
|
| 801 |
+
{
|
| 802 |
+
"header": [" ", "0", "1", "2"],
|
| 803 |
+
"rows": [["name", "Alice", "Raj", "Donald"], ["age", 26, 34, 39], ["sex", "F", "M", "M"]],
|
| 804 |
+
}
|
| 805 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 806 |
"""
|
| 807 |
|
| 808 |
augmented_type = Table
|
|
|
|
| 828 |
"""Duplicates specific rows of a table for the given number of times.
|
| 829 |
|
| 830 |
Args:
|
| 831 |
+
row_indices (List[int]): rows to be duplicated
|
| 832 |
+
|
| 833 |
+
times(int): each row to be duplicated is to show that many times
|
| 834 |
"""
|
| 835 |
|
| 836 |
augmented_type = Table
|
|
|
|
| 861 |
"""Duplicates specific columns of a table for the given number of times.
|
| 862 |
|
| 863 |
Args:
|
| 864 |
+
column_indices (List[int]): columns to be duplicated
|
| 865 |
+
|
| 866 |
+
times(int): each column to be duplicated is to show that many times
|
| 867 |
"""
|
| 868 |
|
| 869 |
augmented_type = Table
|
task.py
CHANGED
|
@@ -41,24 +41,28 @@ class Task(InstanceOperator, ArtifactFetcherMixin):
|
|
| 41 |
|
| 42 |
Attributes:
|
| 43 |
input_fields (Union[Dict[str, str], List[str]]):
|
| 44 |
-
|
| 45 |
-
|
|
|
|
| 46 |
reference_fields (Union[Dict[str, str], List[str]]):
|
| 47 |
-
|
| 48 |
-
|
|
|
|
| 49 |
metrics (List[str]): List of names of metrics to be used in the task.
|
|
|
|
| 50 |
prediction_type (Optional[str]):
|
| 51 |
-
|
| 52 |
-
|
|
|
|
| 53 |
defaults (Optional[Dict[str, Any]]):
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
|
| 58 |
The output instance contains three fields:
|
| 59 |
-
"input_fields" whose value is a sub-dictionary of the input instance, consisting of all the fields listed in Arg 'input_fields'.
|
| 60 |
-
"reference_fields" -- for the fields listed in Arg "reference_fields".
|
| 61 |
-
"metrics" -- to contain the value of Arg 'metrics'
|
| 62 |
"""
|
| 63 |
|
| 64 |
input_fields: Optional[Union[Dict[str, Type], Dict[str, str], List[str]]] = None
|
|
|
|
| 41 |
|
| 42 |
Attributes:
|
| 43 |
input_fields (Union[Dict[str, str], List[str]]):
|
| 44 |
+
Dictionary with string names of instance input fields and types of respective values.
|
| 45 |
+
In case a list is passed, each type will be assumed to be Any.
|
| 46 |
+
|
| 47 |
reference_fields (Union[Dict[str, str], List[str]]):
|
| 48 |
+
Dictionary with string names of instance output fields and types of respective values.
|
| 49 |
+
In case a list is passed, each type will be assumed to be Any.
|
| 50 |
+
|
| 51 |
metrics (List[str]): List of names of metrics to be used in the task.
|
| 52 |
+
|
| 53 |
prediction_type (Optional[str]):
|
| 54 |
+
Need to be consistent with all used metrics. Defaults to None, which means that it will
|
| 55 |
+
be set to Any.
|
| 56 |
+
|
| 57 |
defaults (Optional[Dict[str, Any]]):
|
| 58 |
+
An optional dictionary with default values for chosen input/output keys. Needs to be
|
| 59 |
+
consistent with names and types provided in 'input_fields' and/or 'output_fields' arguments.
|
| 60 |
+
Will not overwrite values if already provided in a given instance.
|
| 61 |
|
| 62 |
The output instance contains three fields:
|
| 63 |
+
1. "input_fields" whose value is a sub-dictionary of the input instance, consisting of all the fields listed in Arg 'input_fields'.
|
| 64 |
+
2. "reference_fields" -- for the fields listed in Arg "reference_fields".
|
| 65 |
+
3. "metrics" -- to contain the value of Arg 'metrics'
|
| 66 |
"""
|
| 67 |
|
| 68 |
input_fields: Optional[Union[Dict[str, Type], Dict[str, str], List[str]]] = None
|
templates.py
CHANGED
|
@@ -308,19 +308,25 @@ class PairwiseChoiceTemplate(InputOutputTemplate):
|
|
| 308 |
|
| 309 |
Args:
|
| 310 |
choice_a_field (str): The field which contains choice_a value
|
|
|
|
| 311 |
choice_b_field (str): The field which contains choice_b value
|
|
|
|
| 312 |
answer_field (str): The field which contains the answer value.
|
| 313 |
-
|
|
|
|
| 314 |
choice_a_label (str): The label of choice A answer as it is verbalized in the template.
|
|
|
|
| 315 |
choice_b_label (str): The label of choice B answer as it is verbalized in the template.
|
|
|
|
| 316 |
choice_tie_label (str): The label of a tie answer as it should be verbalized in the template.
|
|
|
|
| 317 |
shuffle (bool): whether to shuffle the choices or not. This is done to take into account position bias.
|
| 318 |
|
| 319 |
shuffle: 50% of the time:
|
| 320 |
-
1
|
| 321 |
-
2
|
| 322 |
-
|
| 323 |
-
|
| 324 |
|
| 325 |
"""
|
| 326 |
|
|
@@ -433,14 +439,17 @@ class PairwiseComparativeRatingTemplate(InputOutputTemplate):
|
|
| 433 |
|
| 434 |
Args:
|
| 435 |
choice_a_field (str): The field which contains choice_a value
|
|
|
|
| 436 |
choice_b_field (str): The field which contains choice_b value
|
|
|
|
| 437 |
answer_field (str): The field which contains the answer value. The value should be an int.
|
| 438 |
-
|
|
|
|
| 439 |
shuffle (bool): whether to shuffle the choices or not. This is done to take into account position bias.
|
| 440 |
|
| 441 |
shuffle: 50% of the time:
|
| 442 |
-
|
| 443 |
-
|
| 444 |
|
| 445 |
"""
|
| 446 |
|
|
|
|
| 308 |
|
| 309 |
Args:
|
| 310 |
choice_a_field (str): The field which contains choice_a value
|
| 311 |
+
|
| 312 |
choice_b_field (str): The field which contains choice_b value
|
| 313 |
+
|
| 314 |
answer_field (str): The field which contains the answer value.
|
| 315 |
+
Should be of type Literal["choice_1", "choice_2", "tie"]
|
| 316 |
+
|
| 317 |
choice_a_label (str): The label of choice A answer as it is verbalized in the template.
|
| 318 |
+
|
| 319 |
choice_b_label (str): The label of choice B answer as it is verbalized in the template.
|
| 320 |
+
|
| 321 |
choice_tie_label (str): The label of a tie answer as it should be verbalized in the template.
|
| 322 |
+
|
| 323 |
shuffle (bool): whether to shuffle the choices or not. This is done to take into account position bias.
|
| 324 |
|
| 325 |
shuffle: 50% of the time:
|
| 326 |
+
1. The values of choice_a_field and choice_b_field will be swapped.
|
| 327 |
+
2. If the values of answer_field is choice_a_label, set it to choice_b_label.
|
| 328 |
+
| Else if the values of answer_field is choice_b_label, set it to choice_a_label.
|
| 329 |
+
| Else if the value of answer_field is choice_tie_label, do nothing.
|
| 330 |
|
| 331 |
"""
|
| 332 |
|
|
|
|
| 439 |
|
| 440 |
Args:
|
| 441 |
choice_a_field (str): The field which contains choice_a value
|
| 442 |
+
|
| 443 |
choice_b_field (str): The field which contains choice_b value
|
| 444 |
+
|
| 445 |
answer_field (str): The field which contains the answer value. The value should be an int.
|
| 446 |
+
Positive for preferring choice_a, and negative for preferring choice_b
|
| 447 |
+
|
| 448 |
shuffle (bool): whether to shuffle the choices or not. This is done to take into account position bias.
|
| 449 |
|
| 450 |
shuffle: 50% of the time:
|
| 451 |
+
| 1) The values of choice_a_field and choice_b_field will be swapped.
|
| 452 |
+
| 2) Replace the values of answer_field with its mapped value according to the reverse_preference_map Dict.
|
| 453 |
|
| 454 |
"""
|
| 455 |
|
type_utils.py
CHANGED
|
@@ -307,21 +307,22 @@ def infer_type_string(obj: typing.Any) -> str:
|
|
| 307 |
obj:Any
|
| 308 |
|
| 309 |
Returns:
|
| 310 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
|
| 312 |
-
formal definition of the returned string:
|
| 313 |
-
Type -> basic | List[Type] | Dict[Type, Type] | Union[Type (, Type)* | Tuple[Type (,Type)*]
|
| 314 |
-
basic -> bool,str,int,float,Any
|
| 315 |
-
no spaces at all.
|
| 316 |
|
| 317 |
Examples:
|
| 318 |
-
infer_type_string({"how_much": 7}) returns "Dict[str,int]"
|
| 319 |
-
infer_type_string([1, 2]) returns "List[int]"
|
| 320 |
-
infer_type_string([]) returns "List[Any]") no contents to list to indicate any type
|
| 321 |
-
infer_type_string([[], [7]]) returns "List[List[int]]" type of parent list indicated
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
infer_type_string([[], 7, True]) returns "List[Union[List[Any],int]]"
|
|
|
|
| 325 |
|
| 326 |
"""
|
| 327 |
|
|
|
|
| 307 |
obj:Any
|
| 308 |
|
| 309 |
Returns:
|
| 310 |
+
a string representation of the type of the object. e.g. ``"str"``, ``"List[int]"``, ``"Dict[str, Any]"``
|
| 311 |
+
|
| 312 |
+
| formal definition of the returned string:
|
| 313 |
+
| Type -> basic | List[Type] | Dict[Type, Type] | Union[Type(, Type)*] | Tuple[Type(, Type)*]
|
| 314 |
+
| basic -> ``bool`` | ``str`` | ``int`` | ``float`` | ``Any``
|
| 315 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
|
| 317 |
Examples:
|
| 318 |
+
| ``infer_type_string({"how_much": 7})`` returns ``"Dict[str,int]"``
|
| 319 |
+
| ``infer_type_string([1, 2])`` returns ``"List[int]"``
|
| 320 |
+
| ``infer_type_string([])`` returns ``"List[Any]")`` no contents to list to indicate any type
|
| 321 |
+
| ``infer_type_string([[], [7]])`` returns ``"List[List[int]]"`` type of parent list indicated
|
| 322 |
+
by the type of the non-empty child list. The empty child list is indeed, by default, also of
|
| 323 |
+
that type of the non-empty child.
|
| 324 |
+
| ``infer_type_string([[], 7, True])`` returns ``"List[Union[List[Any],int]]"``
|
| 325 |
+
because ``bool`` is also an ``int``
|
| 326 |
|
| 327 |
"""
|
| 328 |
|
utils.py
CHANGED
|
@@ -32,8 +32,8 @@ class LRUCache:
|
|
| 32 |
|
| 33 |
Attributes:
|
| 34 |
max_size (int): The maximum number of items to store in the cache.
|
| 35 |
-
|
| 36 |
-
|
| 37 |
"""
|
| 38 |
|
| 39 |
def __init__(self, max_size=10):
|
|
|
|
| 32 |
|
| 33 |
Attributes:
|
| 34 |
max_size (int): The maximum number of items to store in the cache.
|
| 35 |
+
Items exceeding this limit are automatically removed based on least
|
| 36 |
+
recent usage.
|
| 37 |
"""
|
| 38 |
|
| 39 |
def __init__(self, max_size=10):
|
version.py
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
version = "1.15.
|
|
|
|
| 1 |
+
version = "1.15.10"
|