import gradio as gr
from huggingface_hub import hf_hub_download
import json
import tensorflow as tf
import numpy as np
# Load models
MODELS = [
    ("Bayes Enron1 spam", BAYES := "bayes-enron1-spam"),
    ("NN Enron1 spam", NN := "nn-enron1-spam"),
    ("GISTy Enron1 spam", LLM := "gisty-enron1-spam"),
]
model_probs_path = hf_hub_download(repo_id="tbitai/bayes-enron1-spam", filename="probs.json")
with open(model_probs_path) as f:
    model_probs = json.load(f)
nn_model_path = hf_hub_download(repo_id="tbitai/nn-enron1-spam", filename="nn-enron1-spam.keras")
nn_model = tf.keras.models.load_model(nn_model_path)
llm_model_path = hf_hub_download(repo_id="tbitai/gisty-enron1-spam", filename="gisty-enron1-spam.keras")
llm_model = tf.keras.models.load_model(llm_model_path)
# Sentence Transformers should be imported after Keras models, in order to prevent it from setting Keras to legacy.
from sentence_transformers import SentenceTransformer
st_model = SentenceTransformer("avsolatorio/GIST-large-Embedding-v0")
# Utils for Bayes
UNK = '[UNK]'
def tokenize(text):
    return tf.keras.preprocessing.text.text_to_word_sequence(text)
def combine(probs):
    if any(p == 0 for p in probs):
        return 0
    prod = np.prod(probs)
    neg_prod = np.prod([1 - p for p in probs])
    if prod + neg_prod == 0:  # Still possible due to floating point arithmetic
        return 0.5  # Assume that prod and neg_prod are equally small
    return prod / (prod + neg_prod)
def get_interesting_probs(probs, intr_threshold):
    return sorted(probs,
                  key=lambda p: abs(p - 0.5),
                  reverse=True)[:intr_threshold]
DEFAULT_INTR_THRESHOLD = 15
def unbias(p):
    return (2 * p) / (p + 1)
# Predict functions
def predict_bayes(text, intr_threshold, unbiased=False):
    words = tokenize(text)
    probs = []
    for w in words:
        try:
            p = model_probs[w]
            if unbiased:
                p = unbias(p)
        except KeyError:
            p = model_probs[UNK]
        probs.append(p)
    interesting_probs = get_interesting_probs(probs, intr_threshold)
    return combine(interesting_probs)
def predict_nn(text):
    return float(nn_model(np.array([text]))[0][0].numpy())
def predict_llm(text):
    embedding = st_model.encode(text)
    return float(llm_model(np.array([embedding]))[0][0].numpy())
def predict(model, input_txt, unbiased, intr_threshold):
    if model == BAYES:
        return predict_bayes(input_txt, unbiased=unbiased, intr_threshold=intr_threshold)
    elif model == NN:
        return predict_nn(input_txt)
    elif model == LLM:
        return predict_llm(input_txt)
# UI
demo = gr.Interface(
    theme=gr.themes.Origin(  # Gradio 4-like
        primary_hue="yellow",
    ),
    fn=predict,
    inputs=[
        gr.Dropdown(choices=MODELS, value=BAYES, label="Model",
                    # FIXME: Font size should be smaller by default. Remove workaround when fixed in Gradio: https://github.com/gradio-app/gradio/issues/9642
                    info="Learn more about the models [here](https://huggingface.co/collections/tbitai/bayes-or-spam-6700033fa145e298ec849249)"),
        gr.TextArea(label="Email"),
    ],
    additional_inputs_accordion=gr.Accordion("Additional configuration for Bayes", open=False),
    additional_inputs=[
        gr.Checkbox(label="Unbias", info="Correct Graham's bias?"),
        gr.Slider(minimum=1, maximum=DEFAULT_INTR_THRESHOLD + 5, step=1, value=DEFAULT_INTR_THRESHOLD, 
                  label="Interestingness threshold", 
                  info=f"How many of the most interesting words to select in the probability calculation? ({DEFAULT_INTR_THRESHOLD} for Graham)"),
    ],
    outputs=[gr.Number(label="Spam probability")],
    title="Bayes or Spam?",
    description="Choose your model, and predict if your email is a spam! 📨",
    examples=[
        [NN, "Enron actuals for June 26, 2000", None, None],
        [BAYES, "Stop the aging clock\nNerissa", True, DEFAULT_INTR_THRESHOLD],
    ],
    article="This is a demo of the models in the [Bayes or Spam?](https://github.com/tbitai/bayes-or-spam) project.",
)
if __name__ == "__main__":
    demo.launch()