Image Classification
Transformers
Safetensors
PyTorch
timm_wrapper
Not-For-All-Audiences
ccabrerafreepik commited on
Commit
0dfcd66
·
1 Parent(s): abdaeb9

Improved method

Browse files
Files changed (1) hide show
  1. README.md +6 -0
README.md CHANGED
@@ -165,6 +165,12 @@ def prediction(model, processor, img_batch: List[Image.Image], class_to_predict:
165
  """
166
  Predict if images meet or exceed a specific NSFW threshold
167
  """
 
 
 
 
 
 
168
  output = predict_batch_values(model, processor, img_batch)
169
  return [output[i][class_to_predict] >= threshold for i in range(len(output))]
170
 
 
165
  """
166
  Predict if images meet or exceed a specific NSFW threshold
167
  """
168
+ if class_to_predict not in ["low", "medium", "high"]:
169
+ raise ValueError("class_to_predict must be one of: low, medium, high")
170
+
171
+ if not 0 <= threshold <= 1:
172
+ raise ValueError("threshold must be between 0 and 1")
173
+
174
  output = predict_batch_values(model, processor, img_batch)
175
  return [output[i][class_to_predict] >= threshold for i in range(len(output))]
176