sitammeur commited on
Commit
7747686
·
verified ·
1 Parent(s): 9b22410

Update src/app/predict.py

Browse files
Files changed (1) hide show
  1. src/app/predict.py +4 -6
src/app/predict.py CHANGED
@@ -8,16 +8,13 @@ from transformers import pipeline
8
 
9
 
10
  # Load the zero-shot classification model
11
- device = "cuda" if torch.cuda.is_available() else "cpu"
12
  classifier = pipeline(
13
- "zero-shot-classification",
14
- model="MoritzLaurer/ModernBERT-large-zeroshot-v2.0",
15
- device=device,
16
  )
17
 
18
 
19
  def ZeroShotTextClassification(
20
- text_input: str, candidate_labels: str
21
  ) -> Dict[str, float]:
22
  """
23
  Performs zero-shot classification on the given text input.
@@ -25,6 +22,7 @@ def ZeroShotTextClassification(
25
  Args:
26
  - text_input: The input text to classify.
27
  - candidate_labels: A comma-separated string of candidate labels.
 
28
 
29
  Returns:
30
  Dictionary containing label-score pairs.
@@ -42,7 +40,7 @@ def ZeroShotTextClassification(
42
  text_input,
43
  labels,
44
  hypothesis_template=hypothesis_template,
45
- multi_label=False,
46
  )
47
 
48
  # Return the classification results
 
8
 
9
 
10
  # Load the zero-shot classification model
 
11
  classifier = pipeline(
12
+ "zero-shot-classification", model="MoritzLaurer/ModernBERT-large-zeroshot-v2.0",
 
 
13
  )
14
 
15
 
16
  def ZeroShotTextClassification(
17
+ text_input: str, candidate_labels: str, multi_label: bool
18
  ) -> Dict[str, float]:
19
  """
20
  Performs zero-shot classification on the given text input.
 
22
  Args:
23
  - text_input: The input text to classify.
24
  - candidate_labels: A comma-separated string of candidate labels.
25
+ - multi_label: A boolean indicating whether to allow the model to choose multiple classes.
26
 
27
  Returns:
28
  Dictionary containing label-score pairs.
 
40
  text_input,
41
  labels,
42
  hypothesis_template=hypothesis_template,
43
+ multi_label=multi_label,
44
  )
45
 
46
  # Return the classification results