Reduce simple dataset generation time
Browse files
src/distilabel_dataset_generator/pipelines/sft.py
CHANGED
|
@@ -190,31 +190,73 @@ if __name__ == "__main__":
|
|
| 190 |
def get_pipeline(num_turns, num_rows, system_prompt):
|
| 191 |
input_mappings = _get_output_mappings(num_turns)
|
| 192 |
output_mappings = input_mappings
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
|
| 219 |
|
| 220 |
def get_prompt_generation_step():
|
|
|
|
| 190 |
def get_pipeline(num_turns, num_rows, system_prompt):
|
| 191 |
input_mappings = _get_output_mappings(num_turns)
|
| 192 |
output_mappings = input_mappings
|
| 193 |
+
if num_turns == 1:
|
| 194 |
+
with Pipeline(name="sft") as pipeline:
|
| 195 |
+
magpie = MagpieGenerator(
|
| 196 |
+
llm=InferenceEndpointsLLM(
|
| 197 |
+
model_id=MODEL,
|
| 198 |
+
tokenizer_id=MODEL,
|
| 199 |
+
api_key=os.environ["HF_TOKEN"],
|
| 200 |
+
magpie_pre_query_template="llama3",
|
| 201 |
+
generation_kwargs={
|
| 202 |
+
"temperature": 0.8, # it's the best value for Llama 3.1 70B Instruct
|
| 203 |
+
"do_sample": True,
|
| 204 |
+
"max_new_tokens": 512,
|
| 205 |
+
"stop_sequences": _STOP_SEQUENCES,
|
| 206 |
+
},
|
| 207 |
+
),
|
| 208 |
+
batch_size=2,
|
| 209 |
+
n_turns=num_turns,
|
| 210 |
+
num_rows=num_rows,
|
| 211 |
+
system_prompt=system_prompt,
|
| 212 |
+
output_mappings=output_mappings,
|
| 213 |
+
only_instructions=True
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
generate_response = TextGeneration(
|
| 217 |
+
llm=InferenceEndpointsLLM(
|
| 218 |
+
model_id=MODEL,
|
| 219 |
+
tokenizer_id=MODEL,
|
| 220 |
+
generation_kwargs={
|
| 221 |
+
"temperature": 0.8,
|
| 222 |
+
"max_new_tokens": 1024
|
| 223 |
+
},
|
| 224 |
+
)
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
keep_columns = KeepColumns(
|
| 228 |
+
columns=list(output_mappings.values()) + ["model_name"],
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
magpie.connect(generate_response)
|
| 232 |
+
generate_response.connect(keep_columns)
|
| 233 |
+
return pipeline
|
| 234 |
+
else:
|
| 235 |
+
with Pipeline(name="sft") as pipeline:
|
| 236 |
+
magpie = MagpieGenerator(
|
| 237 |
+
llm=InferenceEndpointsLLM(
|
| 238 |
+
model_id=MODEL,
|
| 239 |
+
tokenizer_id=MODEL,
|
| 240 |
+
api_key=os.environ["HF_TOKEN"],
|
| 241 |
+
magpie_pre_query_template="llama3",
|
| 242 |
+
generation_kwargs={
|
| 243 |
+
"temperature": 0.8, # it's the best value for Llama 3.1 70B Instruct
|
| 244 |
+
"do_sample": True,
|
| 245 |
+
"max_new_tokens": 2048,
|
| 246 |
+
"stop_sequences": _STOP_SEQUENCES,
|
| 247 |
+
},
|
| 248 |
+
),
|
| 249 |
+
batch_size=2,
|
| 250 |
+
n_turns=num_turns,
|
| 251 |
+
num_rows=num_rows,
|
| 252 |
+
system_prompt=system_prompt,
|
| 253 |
+
output_mappings=output_mappings,
|
| 254 |
+
)
|
| 255 |
+
keep_columns = KeepColumns(
|
| 256 |
+
columns=list(output_mappings.values()) + ["model_name"],
|
| 257 |
+
)
|
| 258 |
+
magpie.connect(keep_columns)
|
| 259 |
+
return pipeline
|
| 260 |
|
| 261 |
|
| 262 |
def get_prompt_generation_step():
|