Commit
·
0dfcd66
1
Parent(s):
abdaeb9
Improved method
Browse files
README.md
CHANGED
@@ -165,6 +165,12 @@ def prediction(model, processor, img_batch: List[Image.Image], class_to_predict:
|
|
165 |
"""
|
166 |
Predict if images meet or exceed a specific NSFW threshold
|
167 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
output = predict_batch_values(model, processor, img_batch)
|
169 |
return [output[i][class_to_predict] >= threshold for i in range(len(output))]
|
170 |
|
|
|
165 |
"""
|
166 |
Predict if images meet or exceed a specific NSFW threshold
|
167 |
"""
|
168 |
+
if class_to_predict not in ["low", "medium", "high"]:
|
169 |
+
raise ValueError("class_to_predict must be one of: low, medium, high")
|
170 |
+
|
171 |
+
if not 0 <= threshold <= 1:
|
172 |
+
raise ValueError("threshold must be between 0 and 1")
|
173 |
+
|
174 |
output = predict_batch_values(model, processor, img_batch)
|
175 |
return [output[i][class_to_predict] >= threshold for i in range(len(output))]
|
176 |
|