File size: 4,610 Bytes
bd5a4f4
828a989
 
bd5a4f4
828a989
 
 
 
 
bd5a4f4
828a989
 
 
 
 
 
 
 
 
 
 
bd5a4f4
828a989
 
 
 
 
 
 
bd5a4f4
828a989
 
 
 
 
 
 
 
 
bd5a4f4
828a989
 
 
 
 
 
bd5a4f4
828a989
 
bd5a4f4
828a989
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd5a4f4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering
import torch

# -------------------------------
# Модель суммаризации
# -------------------------------
sum_tokenizer = AutoTokenizer.from_pretrained("LaciaStudio/Lacia_sum_small_v1")
sum_model = AutoModelForSeq2SeqLM.from_pretrained("LaciaStudio/Lacia_sum_small_v1")

def summarize_document(file):
    if file is None:
        return "Файл не загружен."
    # Открываем файл и читаем его содержимое
    with open(file, "r", encoding="utf-8") as f:
        text = f.read()
    input_text = "summarize: " + text
    inputs = sum_tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
    summary_ids = sum_model.generate(inputs["input_ids"], max_length=150, num_beams=4, early_stopping=True)
    summary = sum_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    return summary

# -------------------------------
# Модель вопросов-ответов (Q&A)
# -------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
qa_tokenizer = AutoTokenizer.from_pretrained("LaciaStudio/Kaleidoscope_large_v1")
qa_model = AutoModelForQuestionAnswering.from_pretrained("LaciaStudio/Kaleidoscope_large_v1")
qa_model.to(device)

def answer_question(context, question):
    inputs = qa_tokenizer(question, context, return_tensors="pt", truncation=True, max_length=384)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    outputs = qa_model(**inputs)
    start_index = torch.argmax(outputs.start_logits)
    end_index = torch.argmax(outputs.end_logits)
    answer_tokens = inputs["input_ids"][0][start_index:end_index + 1]
    answer = qa_tokenizer.decode(answer_tokens, skip_special_tokens=True)
    return answer

def answer_question_file(file, question):
    if file is None:
        return "Файл не загружен."
    with open(file, "r", encoding="utf-8") as f:
        context = f.read()
    return answer_question(context, question)

def answer_question_text(context, question):
    return answer_question(context, question)

# -------------------------------
# Интерфейс Gradio
# -------------------------------
with gr.Blocks() as demo:
    gr.Markdown("# Интерфейс для суммаризации и вопросов-ответов")
    with gr.Row():
        # Левая колонка – суммаризация
        with gr.Column():
            gr.Markdown("## Суммаризация документа")
            file_input_sum = gr.File(label="Прикрепить файл для суммаризации", file_count="single", type="file")
            summarize_button = gr.Button("Суммаризировать")
            summary_output = gr.Textbox(label="Суммаризация", lines=10)
            summarize_button.click(fn=summarize_document, inputs=file_input_sum, outputs=summary_output)
        
        # Правая колонка – Q&A с двумя вкладками
        with gr.Column():
            gr.Markdown("## Вопрос-ответ по документу")
            with gr.Tabs():
                with gr.Tab("Загрузить файл"):
                    file_input_qa = gr.File(label="Прикрепить файл с документом", file_count="single", type="file")
                    question_input_file = gr.Textbox(label="Введите вопрос", placeholder="Ваш вопрос здесь")
                    answer_button_file = gr.Button("Получить ответ")
                    answer_output_file = gr.Textbox(label="Ответ", lines=5)
                    answer_button_file.click(fn=answer_question_file, inputs=[file_input_qa, question_input_file], outputs=answer_output_file)
                with gr.Tab("Ввести текст"):
                    context_input = gr.Textbox(label="Введите текст документа", lines=10, placeholder="Текст документа здесь")
                    question_input_text = gr.Textbox(label="Введите вопрос", placeholder="Ваш вопрос здесь")
                    answer_button_text = gr.Button("Получить ответ")
                    answer_output_text = gr.Textbox(label="Ответ", lines=5)
                    answer_button_text.click(fn=answer_question_text, inputs=[context_input, question_input_text], outputs=answer_output_text)
                    
if __name__ == "__main__":
    demo.launch()