Spaces:
Running
Running
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering | |
import torch | |
# ------------------------------- | |
# Модель суммаризации | |
# ------------------------------- | |
sum_tokenizer = AutoTokenizer.from_pretrained("LaciaStudio/Lacia_sum_small_v1") | |
sum_model = AutoModelForSeq2SeqLM.from_pretrained("LaciaStudio/Lacia_sum_small_v1") | |
def summarize_document(file): | |
if file is None: | |
return "Файл не загружен." | |
# Открываем файл и читаем его содержимое | |
with open(file, "r", encoding="utf-8") as f: | |
text = f.read() | |
input_text = "summarize: " + text | |
inputs = sum_tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) | |
summary_ids = sum_model.generate(inputs["input_ids"], max_length=150, num_beams=4, early_stopping=True) | |
summary = sum_tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
return summary | |
# ------------------------------- | |
# Модель вопросов-ответов (Q&A) | |
# ------------------------------- | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
qa_tokenizer = AutoTokenizer.from_pretrained("LaciaStudio/Kaleidoscope_large_v1") | |
qa_model = AutoModelForQuestionAnswering.from_pretrained("LaciaStudio/Kaleidoscope_large_v1") | |
qa_model.to(device) | |
def answer_question(context, question): | |
inputs = qa_tokenizer(question, context, return_tensors="pt", truncation=True, max_length=384) | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
outputs = qa_model(**inputs) | |
start_index = torch.argmax(outputs.start_logits) | |
end_index = torch.argmax(outputs.end_logits) | |
answer_tokens = inputs["input_ids"][0][start_index:end_index + 1] | |
answer = qa_tokenizer.decode(answer_tokens, skip_special_tokens=True) | |
return answer | |
def answer_question_file(file, question): | |
if file is None: | |
return "Файл не загружен." | |
with open(file, "r", encoding="utf-8") as f: | |
context = f.read() | |
return answer_question(context, question) | |
def answer_question_text(context, question): | |
return answer_question(context, question) | |
# ------------------------------- | |
# Интерфейс Gradio | |
# ------------------------------- | |
with gr.Blocks() as demo: | |
gr.Markdown("# Интерфейс для суммаризации и вопросов-ответов") | |
with gr.Row(): | |
# Левая колонка – суммаризация | |
with gr.Column(): | |
gr.Markdown("## Суммаризация документа") | |
file_input_sum = gr.File(label="Прикрепить файл для суммаризации", file_count="single", type="file") | |
summarize_button = gr.Button("Суммаризировать") | |
summary_output = gr.Textbox(label="Суммаризация", lines=10) | |
summarize_button.click(fn=summarize_document, inputs=file_input_sum, outputs=summary_output) | |
# Правая колонка – Q&A с двумя вкладками | |
with gr.Column(): | |
gr.Markdown("## Вопрос-ответ по документу") | |
with gr.Tabs(): | |
with gr.Tab("Загрузить файл"): | |
file_input_qa = gr.File(label="Прикрепить файл с документом", file_count="single", type="file") | |
question_input_file = gr.Textbox(label="Введите вопрос", placeholder="Ваш вопрос здесь") | |
answer_button_file = gr.Button("Получить ответ") | |
answer_output_file = gr.Textbox(label="Ответ", lines=5) | |
answer_button_file.click(fn=answer_question_file, inputs=[file_input_qa, question_input_file], outputs=answer_output_file) | |
with gr.Tab("Ввести текст"): | |
context_input = gr.Textbox(label="Введите текст документа", lines=10, placeholder="Текст документа здесь") | |
question_input_text = gr.Textbox(label="Введите вопрос", placeholder="Ваш вопрос здесь") | |
answer_button_text = gr.Button("Получить ответ") | |
answer_output_text = gr.Textbox(label="Ответ", lines=5) | |
answer_button_text.click(fn=answer_question_text, inputs=[context_input, question_input_text], outputs=answer_output_text) | |
if __name__ == "__main__": | |
demo.launch() | |