Update to adapt to sharegpt datasets with "assistant" rather than "gp… (#774)
Browse files* Update to adapt to sharegpt datasets with "assistant" rather than "gpt" as the machine answers.
* use a strict option for hanedling incorrect turn data
* chore: lint
---------
Co-authored-by: Wing Lian <[email protected]>
src/axolotl/prompt_strategies/sharegpt.py
CHANGED
|
@@ -24,7 +24,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|
| 24 |
)
|
| 25 |
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
| 26 |
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
| 27 |
-
|
| 28 |
ShareGPTPrompterV2(
|
| 29 |
conversation=conversation,
|
| 30 |
role_key_model=field_model,
|
|
@@ -34,6 +34,9 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|
| 34 |
cfg.train_on_inputs,
|
| 35 |
cfg.sequence_len,
|
| 36 |
)
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
def load_role(tokenizer, cfg):
|
|
@@ -59,8 +62,26 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
|
| 59 |
basic sharegpt strategy to grab conversations from the sample row
|
| 60 |
"""
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
def get_conversation_thread(self, prompt):
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
|
| 66 |
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
|
|
|
| 24 |
)
|
| 25 |
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
| 26 |
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
| 27 |
+
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
| 28 |
ShareGPTPrompterV2(
|
| 29 |
conversation=conversation,
|
| 30 |
role_key_model=field_model,
|
|
|
|
| 34 |
cfg.train_on_inputs,
|
| 35 |
cfg.sequence_len,
|
| 36 |
)
|
| 37 |
+
if ds_cfg and "strict" in ds_cfg:
|
| 38 |
+
strategy.strict = ds_cfg["strict"]
|
| 39 |
+
return strategy
|
| 40 |
|
| 41 |
|
| 42 |
def load_role(tokenizer, cfg):
|
|
|
|
| 62 |
basic sharegpt strategy to grab conversations from the sample row
|
| 63 |
"""
|
| 64 |
|
| 65 |
+
_strict = True
|
| 66 |
+
|
| 67 |
+
@property
|
| 68 |
+
def strict(self):
|
| 69 |
+
return self._strict
|
| 70 |
+
|
| 71 |
+
@strict.setter
|
| 72 |
+
def strict(self, strict):
|
| 73 |
+
self._strict = strict
|
| 74 |
+
|
| 75 |
def get_conversation_thread(self, prompt):
|
| 76 |
+
conversations = prompt["conversations"]
|
| 77 |
+
if self.strict:
|
| 78 |
+
return conversations
|
| 79 |
+
# remap roles - allow for assistant turn
|
| 80 |
+
role_map = {"human": "human", "assistant": "gpt", "gpt": "gpt"}
|
| 81 |
+
turns = [
|
| 82 |
+
{"from": role_map[t["from"]], "value": t["value"]} for t in conversations
|
| 83 |
+
]
|
| 84 |
+
return turns
|
| 85 |
|
| 86 |
|
| 87 |
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|