File size: 2,109 Bytes
40f633f
ef48bcf
 
a894b74
ef48bcf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40f633f
ef48bcf
 
 
 
 
 
 
 
 
 
 
a894b74
ef48bcf
 
 
 
 
40f633f
ef48bcf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import copy
import enum
import os
from typing import List, Optional

import requests
import streamlit as st


http_session = requests.Session()

@enum.unique
class NeuralCategoryClassifierModel(enum.Enum):
    keras_2_0 = "keras-2.0"
    keras_sota_3_0 = "keras-sota-3-0"
    keras_ingredient_ocr_3_0 = "keras-ingredient-ocr-3.0"
    keras_baseline_3_0 = "keras-baseline-3.0"
    keras_original_3_0 = "keras-original-3.0"
    keras_product_name_only_3_0 = "keras-product-name-only-3.0"


LOCAL_DB = False

if LOCAL_DB:
    ROBOTOFF_BASE_URL = "http://localhost:5500/api/v1"
else:
    ROBOTOFF_BASE_URL = "https://robotoff.openfoodfacts.org/api/v1"

PREDICTION_URL = ROBOTOFF_BASE_URL + "/predict/category"


@st.cache
def get_predictions(barcode: str, model_name: str, threshold: Optional[float] = None):
    data = {"barcode": barcode, "predictors": ["neural"], "neural_model_name": model_name}
    if threshold is not None:
        data["threshold"] = threshold

    r = requests.post(PREDICTION_URL, json=data)
    r.raise_for_status()
    return r.json()["neural"]

def display_predictions(
    barcode: str,
    model_names: List[str],
    threshold: Optional[float] = None,
):
    debug_showed = False
    for model_name in model_names:
        response = get_predictions(barcode, model_name, threshold)
        response = copy.deepcopy(response) 
        if "debug" in response:
            if not debug_showed:
                debug_showed = True
                st.write(response["debug"])
            response.pop("debug")
        st.write(f"** {model_name} **")
        st.write(response)



st.sidebar.title("Category Prediction Demo")
barcode = st.sidebar.text_input(
    "Product barcode"
)
threshold = st.sidebar.number_input("Threshold", format="%f") or None
model_names = st.multiselect(
    "Name of the model",
    [x.name for x in NeuralCategoryClassifierModel],
    default=NeuralCategoryClassifierModel.keras_sota_3_0.name,
)

if barcode:
    barcode = barcode.strip()
    display_predictions(
        barcode=barcode,
        threshold=threshold,
        model_names=model_names,
    )