Spaces:
Running
Running
| # π Masked Word Predictor | CPU-only HF Space | |
| import gradio as gr | |
| import pandas as pd | |
| from transformers import pipeline | |
| from transformers.pipelines.base import PipelineException | |
| # 1. Load the fill-mask pipeline once | |
| fill_mask = pipeline("fill-mask", model="distilroberta-base", device=-1) | |
| def predict_mask(sentence: str, top_k: int): | |
| # 2. Get the actual mask token (e.g. "<mask>") | |
| mask = fill_mask.tokenizer.mask_token | |
| # 3. Allow users to type [MASK] | |
| sentence = sentence.replace("[MASK]", mask) | |
| # 4. Validate presence of mask | |
| if mask not in sentence: | |
| return pd.DataFrame( | |
| [["Error: please include `[MASK]` in your sentence.", 0.0]], | |
| columns=["Sequence", "Score"] | |
| ) | |
| # 5. Run the pipeline safely | |
| try: | |
| preds = fill_mask(sentence, top_k=top_k) | |
| except PipelineException as e: | |
| return pd.DataFrame([[f"Error: {str(e)}", 0.0]], | |
| columns=["Sequence", "Score"]) | |
| # 6. Build a DataFrame from list-of-lists | |
| rows = [[p["sequence"], round(p["score"], 3)] for p in preds] | |
| return pd.DataFrame(rows, columns=["Sequence", "Score"]) | |
| with gr.Blocks(title="π Masked Word Predictor") as demo: | |
| gr.Markdown( | |
| "# π Masked Word Predictor\n" | |
| "Enter a sentence with one `[MASK]` token and see the top-K completions." | |
| ) | |
| with gr.Row(): | |
| sentence = gr.Textbox( | |
| lines=2, | |
| placeholder="e.g. The salonβs new color treatment is [MASK].", | |
| label="Input Sentence" | |
| ) | |
| top_k = gr.Slider( | |
| minimum=1, maximum=10, step=1, value=5, | |
| label="Top K Predictions" | |
| ) | |
| predict_btn = gr.Button("Predict π", variant="primary") | |
| results_df = gr.Dataframe( | |
| headers=["Sequence", "Score"], | |
| datatype=["str", "number"], | |
| wrap=True, | |
| interactive=False, | |
| label="Predictions" | |
| ) | |
| predict_btn.click( | |
| fn=predict_mask, | |
| inputs=[sentence, top_k], | |
| outputs=results_df | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0") | |