import gradio as gr
from hashformers import TransformerWordSegmenter as WordSegmenter
import pandas as pd
article_string = "Author: Ruan Chaves Rodrigues. Read more about the Hashformers library."
app_title = "Hashtag segmentation"
app_description = """
Hashtag segmentation is the task of automatically adding spaces between the words on a hashtag.
This app uses the Hashformers library to suggest segmentations for hashtags.
Enter a hashtag or pick one from the examples below. The app will suggest the best segmentation for the hashtag.
"""
app_examples = [
    ["#helloworld"]
]
output_json_component_description = {"": ""}
model_dict = {
        "english": WordSegmenter(
        segmenter_model_name_or_path="gpt2",
        reranker_model_name_or_path="bert-base-uncased",
        segmenter_device="cpu",
        ),
        "english (fast)": WordSegmenter(
        segmenter_model_name_or_path="distilgpt2",
        reranker_model_name_or_path="distilbert-base-uncased",
        segmenter_device="cpu",
        ),
        "spanish": WordSegmenter(
        segmenter_model_name_or_path="mrm8488/spanish-gpt2",
        reranker_model_name_or_path="dccuchile/bert-base-spanish-wwm-cased",
        segmenter_device="cpu",
        ),
        "portuguese": WordSegmenter(
        segmenter_model_name_or_path="pierreguillou/gpt2-small-portuguese",
        reranker_model_name_or_path="neuralmind/bert-base-portuguese-cased",
        segmenter_device="cpu",
        ),
        "german": WordSegmenter(
        segmenter_model_name_or_path="dbmdz/german-gpt2",
        reranker_model_name_or_path="bert-base-german-cased",
        segmenter_device="cpu",
        ),
}
language_list = list(model_dict.keys())
def format_dataframe(df):
    if not isinstance(df, pd.DataFrame):
        return df
    df = df[["segmentation", "score"]]
    df["score"] = df["score"].apply(lambda x: 1/x)
    df["score"] = df["score"].apply(lambda x: round(x, 4))
    return df
def convert_to_score_dict(df):
    if not isinstance(df, pd.DataFrame):
        return {}
    df = df[["segmentation", "score"]]
    return df.set_index("segmentation").T.to_dict("records")[0]
def get_candidates_df(candidates, segmenter_score_dict, reranker_score_dict ):
    candidates_df = []
    for candidate in candidates:
        candidates_df.append(
            {
                "segmentation": candidate,
                "segmenter score": segmenter_score_dict.get(candidate, 0),
                "reranker score": reranker_score_dict.get(candidate, 0),
            })
    candidates_df = pd.DataFrame(candidates_df)
    return candidates_df
def parse_candidates(candidates):
    if not candidates:
        return []
    candidates = candidates.split(",")
    candidates = [c.strip() for c in candidates]
    return candidates
def predict(s1, language, candidates, use_reranker, topk, steps):
    hashtag_list = [s1]
    segmentation = model_dict[language].segment(hashtag_list, use_reranker=use_reranker, return_ranks=True, topk=topk, steps=steps)
    segmenter_df = format_dataframe(segmentation.segmenter_rank)
    reranker_df = format_dataframe(segmentation.reranker_rank)
    
    top_segmentation = segmentation.output[0]
    segmenter_score_dict = convert_to_score_dict(segmenter_df)
    reranker_score_dict = convert_to_score_dict(reranker_df)
    top_segmentation_df = get_candidates_df([top_segmentation], segmenter_score_dict, reranker_score_dict)
    
    candidates_list = parse_candidates(candidates)
    candidates_df = get_candidates_df(candidates_list, segmenter_score_dict, reranker_score_dict)
    output_df = pd.concat([top_segmentation_df, candidates_df], axis=0)
    if use_reranker:
        output_df = output_df.sort_values(by="reranker score", ascending=False)
    else:
        output_df = output_df.sort_values(by="segmenter score", ascending=False)
    output_df = output_df.drop_duplicates(subset="segmentation", keep="first")
    return top_segmentation, output_df
inputs = [
    gr.Textbox(label="Hashtag"),
    gr.Dropdown(language_list, label="Language", value="english (fast)"),
    gr.Textbox(label="Candidate segmentations"),
    gr.Checkbox(label="Use reranker", value=True),
    gr.Slider(0, 100, value=20, label="Advanced setting - Beamsearch top-k"),
    gr.Slider(0, 100, value=13, label="Advanced setting - Beamsearch steps")
]
outputs = [
 gr.Textbox(label="Suggested segmentation"),
 gr.DataFrame(label="Scores"),
]
gr.Interface(fn=predict, inputs=inputs, outputs=outputs, title=app_title, 
             description=app_description,
             examples=app_examples,
             article = article_string).launch()