Spaces:
Running
Running
Update src/app/predict.py
Browse files- src/app/predict.py +4 -6
src/app/predict.py
CHANGED
@@ -8,16 +8,13 @@ from transformers import pipeline
|
|
8 |
|
9 |
|
10 |
# Load the zero-shot classification model
|
11 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
12 |
classifier = pipeline(
|
13 |
-
"zero-shot-classification",
|
14 |
-
model="MoritzLaurer/ModernBERT-large-zeroshot-v2.0",
|
15 |
-
device=device,
|
16 |
)
|
17 |
|
18 |
|
19 |
def ZeroShotTextClassification(
|
20 |
-
text_input: str, candidate_labels: str
|
21 |
) -> Dict[str, float]:
|
22 |
"""
|
23 |
Performs zero-shot classification on the given text input.
|
@@ -25,6 +22,7 @@ def ZeroShotTextClassification(
|
|
25 |
Args:
|
26 |
- text_input: The input text to classify.
|
27 |
- candidate_labels: A comma-separated string of candidate labels.
|
|
|
28 |
|
29 |
Returns:
|
30 |
Dictionary containing label-score pairs.
|
@@ -42,7 +40,7 @@ def ZeroShotTextClassification(
|
|
42 |
text_input,
|
43 |
labels,
|
44 |
hypothesis_template=hypothesis_template,
|
45 |
-
multi_label=
|
46 |
)
|
47 |
|
48 |
# Return the classification results
|
|
|
8 |
|
9 |
|
10 |
# Load the zero-shot classification model
|
|
|
11 |
classifier = pipeline(
|
12 |
+
"zero-shot-classification", model="MoritzLaurer/ModernBERT-large-zeroshot-v2.0",
|
|
|
|
|
13 |
)
|
14 |
|
15 |
|
16 |
def ZeroShotTextClassification(
|
17 |
+
text_input: str, candidate_labels: str, multi_label: bool
|
18 |
) -> Dict[str, float]:
|
19 |
"""
|
20 |
Performs zero-shot classification on the given text input.
|
|
|
22 |
Args:
|
23 |
- text_input: The input text to classify.
|
24 |
- candidate_labels: A comma-separated string of candidate labels.
|
25 |
+
- multi_label: A boolean indicating whether to allow the model to choose multiple classes.
|
26 |
|
27 |
Returns:
|
28 |
Dictionary containing label-score pairs.
|
|
|
40 |
text_input,
|
41 |
labels,
|
42 |
hypothesis_template=hypothesis_template,
|
43 |
+
multi_label=multi_label,
|
44 |
)
|
45 |
|
46 |
# Return the classification results
|