Spaces:
Runtime error
Runtime error
Commit
Β·
cb57cce
1
Parent(s):
d987e13
fix code generation for pipeline textcat
Browse files
src/synthetic_dataset_generator/pipelines/textcat.py
CHANGED
|
@@ -171,6 +171,8 @@ def generate_pipeline_code(
|
|
| 171 |
temperature: float = 0.9,
|
| 172 |
) -> str:
|
| 173 |
labels = get_preprocess_labels(labels)
|
|
|
|
|
|
|
| 174 |
base_code = f"""
|
| 175 |
# Requirements: `pip install distilabel[hf-inference-endpoints]`
|
| 176 |
import os
|
|
@@ -192,15 +194,13 @@ with Pipeline(name="textcat") as pipeline:
|
|
| 192 |
task_generator = LoadDataFromDicts(data=[{{"task": TEXT_CLASSIFICATION_TASK}}])
|
| 193 |
|
| 194 |
textcat_generation = GenerateTextClassificationData(
|
| 195 |
-
llm=
|
| 196 |
-
|
| 197 |
base_url=BASE_URL,
|
| 198 |
api_key=os.environ["API_KEY"],
|
| 199 |
generation_kwargs={{
|
| 200 |
"temperature": {temperature},
|
| 201 |
"max_new_tokens": {MAX_NUM_TOKENS},
|
| 202 |
-
"do_sample": True,
|
| 203 |
-
"top_k": 50,
|
| 204 |
"top_p": 0.95,
|
| 205 |
}},
|
| 206 |
),
|
|
@@ -236,8 +236,8 @@ with Pipeline(name="textcat") as pipeline:
|
|
| 236 |
)
|
| 237 |
|
| 238 |
textcat_labeller = TextClassification(
|
| 239 |
-
llm=
|
| 240 |
-
|
| 241 |
base_url=BASE_URL,
|
| 242 |
api_key=os.environ["API_KEY"],
|
| 243 |
generation_kwargs={{
|
|
|
|
| 171 |
temperature: float = 0.9,
|
| 172 |
) -> str:
|
| 173 |
labels = get_preprocess_labels(labels)
|
| 174 |
+
MODEL_ARG = "model_id" if BASE_URL else "model"
|
| 175 |
+
MODEL_CLASS = "InferenceEndpointsLLM" if BASE_URL else "OpenAILLM"
|
| 176 |
base_code = f"""
|
| 177 |
# Requirements: `pip install distilabel[hf-inference-endpoints]`
|
| 178 |
import os
|
|
|
|
| 194 |
task_generator = LoadDataFromDicts(data=[{{"task": TEXT_CLASSIFICATION_TASK}}])
|
| 195 |
|
| 196 |
textcat_generation = GenerateTextClassificationData(
|
| 197 |
+
llm={MODEL_CLASS}(
|
| 198 |
+
{MODEL_ARG}=MODEL,
|
| 199 |
base_url=BASE_URL,
|
| 200 |
api_key=os.environ["API_KEY"],
|
| 201 |
generation_kwargs={{
|
| 202 |
"temperature": {temperature},
|
| 203 |
"max_new_tokens": {MAX_NUM_TOKENS},
|
|
|
|
|
|
|
| 204 |
"top_p": 0.95,
|
| 205 |
}},
|
| 206 |
),
|
|
|
|
| 236 |
)
|
| 237 |
|
| 238 |
textcat_labeller = TextClassification(
|
| 239 |
+
llm={MODEL_CLASS}(
|
| 240 |
+
{MODEL_ARG}=MODEL,
|
| 241 |
base_url=BASE_URL,
|
| 242 |
api_key=os.environ["API_KEY"],
|
| 243 |
generation_kwargs={{
|