Spaces:
Running
Running
| import gradio as gr | |
| from transformers import pipeline | |
| from transformers import AutoTokenizer | |
| models = { | |
| "RUPunct-small": "RUPunct/RUPunct_small", | |
| "RUPunct-big": "RUPunct/RUPunct_big", | |
| "RUPunct-medium": "RUPunct/RUPunct_medium" | |
| } | |
| pipelines = {} | |
| for model_name, model_path in models.items(): | |
| tokenizer = AutoTokenizer.from_pretrained(model_path, strip_accents=False, add_prefix_space=True) | |
| pipelines[model_name] = pipeline("ner", model=model_path, tokenizer=tokenizer, aggregation_strategy="first") | |
| def process_token(token, label): | |
| if label == "LOWER_O": | |
| return token | |
| if label == "LOWER_PERIOD": | |
| return token + "." | |
| if label == "LOWER_COMMA": | |
| return token + "," | |
| if label == "LOWER_QUESTION": | |
| return token + "?" | |
| if label == "LOWER_TIRE": | |
| return token + "—" | |
| if label == "LOWER_DVOETOCHIE": | |
| return token + ":" | |
| if label == "LOWER_VOSKL": | |
| return token + "!" | |
| if label == "LOWER_PERIODCOMMA": | |
| return token + ";" | |
| if label == "LOWER_DEFIS": | |
| return token + "-" | |
| if label == "LOWER_MNOGOTOCHIE": | |
| return token + "..." | |
| if label == "LOWER_QUESTIONVOSKL": | |
| return token + "?!" | |
| if label == "UPPER_O": | |
| return token.capitalize() | |
| if label == "UPPER_PERIOD": | |
| return token.capitalize() + "." | |
| if label == "UPPER_COMMA": | |
| return token.capitalize() + "," | |
| if label == "UPPER_QUESTION": | |
| return token.capitalize() + "?" | |
| if label == "UPPER_TIRE": | |
| return token.capitalize() + " —" | |
| if label == "UPPER_DVOETOCHIE": | |
| return token.capitalize() + ":" | |
| if label == "UPPER_VOSKL": | |
| return token.capitalize() + "!" | |
| if label == "UPPER_PERIODCOMMA": | |
| return token.capitalize() + ";" | |
| if label == "UPPER_DEFIS": | |
| return token.capitalize() + "-" | |
| if label == "UPPER_MNOGOTOCHIE": | |
| return token.capitalize() + "..." | |
| if label == "UPPER_QUESTIONVOSKL": | |
| return token.capitalize() + "?!" | |
| if label == "UPPER_TOTAL_O": | |
| return token.upper() | |
| if label == "UPPER_TOTAL_PERIOD": | |
| return token.upper() + "." | |
| if label == "UPPER_TOTAL_COMMA": | |
| return token.upper() + "," | |
| if label == "UPPER_TOTAL_QUESTION": | |
| return token.upper() + "?" | |
| if label == "UPPER_TOTAL_TIRE": | |
| return token.upper() + " —" | |
| if label == "UPPER_TOTAL_DVOETOCHIE": | |
| return token.upper() + ":" | |
| if label == "UPPER_TOTAL_VOSKL": | |
| return token.upper() + "!" | |
| if label == "UPPER_TOTAL_PERIODCOMMA": | |
| return token.upper() + ";" | |
| if label == "UPPER_TOTAL_DEFIS": | |
| return token.upper() + "-" | |
| if label == "UPPER_TOTAL_MNOGOTOCHIE": | |
| return token.upper() + "..." | |
| if label == "UPPER_TOTAL_QUESTIONVOSKL": | |
| return token.upper() + "?!" | |
| def punctuate(input_text, model_name): | |
| classifier = pipelines[model_name] | |
| preds = classifier(input_text) | |
| output = "" | |
| for item in preds: | |
| if item["word"] == ".": | |
| item["entity_group"] = "O" | |
| output += " " + process_token(item['word'].strip(), item['entity_group']) | |
| return output.strip() | |
| iface = gr.Interface( | |
| fn=punctuate, | |
| inputs=[ | |
| gr.inputs.Textbox(lines=5, placeholder="Введите текст"), | |
| gr.inputs.Radio(list(models.keys()), label="Модель") | |
| ], | |
| outputs="text", | |
| title="RUPunct", | |
| description="Демо RUPunct - модели для автоматической расстановки знаков препинания в русском тексте.", | |
| ) | |
| iface.launch() |