Spaces:
Sleeping
Sleeping
File size: 2,109 Bytes
40f633f ef48bcf a894b74 ef48bcf 40f633f ef48bcf a894b74 ef48bcf 40f633f ef48bcf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import copy
import enum
import os
from typing import List, Optional
import requests
import streamlit as st
http_session = requests.Session()
@enum.unique
class NeuralCategoryClassifierModel(enum.Enum):
keras_2_0 = "keras-2.0"
keras_sota_3_0 = "keras-sota-3-0"
keras_ingredient_ocr_3_0 = "keras-ingredient-ocr-3.0"
keras_baseline_3_0 = "keras-baseline-3.0"
keras_original_3_0 = "keras-original-3.0"
keras_product_name_only_3_0 = "keras-product-name-only-3.0"
LOCAL_DB = False
if LOCAL_DB:
ROBOTOFF_BASE_URL = "http://localhost:5500/api/v1"
else:
ROBOTOFF_BASE_URL = "https://robotoff.openfoodfacts.org/api/v1"
PREDICTION_URL = ROBOTOFF_BASE_URL + "/predict/category"
@st.cache
def get_predictions(barcode: str, model_name: str, threshold: Optional[float] = None):
data = {"barcode": barcode, "predictors": ["neural"], "neural_model_name": model_name}
if threshold is not None:
data["threshold"] = threshold
r = requests.post(PREDICTION_URL, json=data)
r.raise_for_status()
return r.json()["neural"]
def display_predictions(
barcode: str,
model_names: List[str],
threshold: Optional[float] = None,
):
debug_showed = False
for model_name in model_names:
response = get_predictions(barcode, model_name, threshold)
response = copy.deepcopy(response)
if "debug" in response:
if not debug_showed:
debug_showed = True
st.write(response["debug"])
response.pop("debug")
st.write(f"** {model_name} **")
st.write(response)
st.sidebar.title("Category Prediction Demo")
barcode = st.sidebar.text_input(
"Product barcode"
)
threshold = st.sidebar.number_input("Threshold", format="%f") or None
model_names = st.multiselect(
"Name of the model",
[x.name for x in NeuralCategoryClassifierModel],
default=NeuralCategoryClassifierModel.keras_sota_3_0.name,
)
if barcode:
barcode = barcode.strip()
display_predictions(
barcode=barcode,
threshold=threshold,
model_names=model_names,
)
|