File size: 4,667 Bytes
51da700
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108a050
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1c51ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio as gr
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
import pandas as pd
import plotly.express as px
import os

# --- 1. 模型加载 ---
# 负责同学: [填写负责这个模型的同学姓名,例如:张三]
# 注意:QuantFactory/Apollo2-7B-GGUF 模型通常不直接兼容 pipeline("text-generation", ...)
# 除非有额外的llama.cpp或特定的transformers加载配置。
# 为了演示和确保运行流畅,这里使用 gpt2-large 作为替代。
try:
    model1_name = "gpt2-large" # 替代 QuantFactory/Apollo2-7B-GGUF 以确保兼容性
    generator1 = pipeline("text-generation", model=model1_name, device=0 if torch.cuda.is_available() else -1)
    print(f"✅ 模型 1 (文本生成: {model1_name}) 加载成功!")
except Exception as e:
    print(f"❌ 模型 1 (文本生成: {model1_name}) 加载失败: {e}")
    generator1 = None

# 负责同学: [填写负责这个模型的同学姓名,例如:李四]
# deepset/roberta-base-squad2 是一个问答模型,需要 context
try:
    model2_name = "deepset/roberta-base-squad2"
    qa_model = pipeline("question-answering", model=model2_name, device=0 if torch.cuda.is_available() else -1)
    print(f"✅ 模型 2 (问答: {model2_name}) 加载成功!")
except Exception as e:
    print(f"❌ 模型 2 (问答: {model2_name}) 加载失败: {e}")
    qa_model = None
    # --- 2. 推理函数 ---
# 这个函数现在接受一个问题/提示词和一个上下文
def get_model_outputs(question_or_prompt, context, max_length=100):
    output_text_gen = "文本生成模型未加载或生成失败。"
    output_qa = "问答模型未加载或生成失败。"

    # 模型 1: 文本生成
    if generator1:
        try:
            # 文本生成模型将问题和上下文作为其prompt的一部分
            full_prompt_for_gen = f"{question_or_prompt}\nContext: {context}" if context else question_or_prompt
            gen_result = generator1(full_prompt_for_gen, max_new_tokens=max_length, num_return_sequences=1, truncation=True)
            output_text_gen = gen_result[0]['generated_text']
            # 清理:移除输入部分,只保留生成内容
            if output_text_gen.startswith(full_prompt_for_gen):
                output_text_gen = output_text_gen[len(full_prompt_for_gen):].strip()
        except Exception as e:
            output_text_gen = f"文本生成模型 ({model1_name}) 错误: {e}"

    # 模型 2: 问答
    if qa_model and context: # 问答模型必须有上下文
        try:
            qa_result = qa_model(question=question_or_prompt, context=context)
            output_qa = qa_result['answer']
        except Exception as e:
            output_qa = f"问答模型 ({model2_name}) 错误: {e}"
    elif qa_model and not context:
        output_qa = "问答模型需要提供上下文才能回答问题。"

    return output_text_gen, output_qa


# Arena 选项卡内容创建函数 (40分)
def create_arena_tab():
    with gr.Blocks() as arena_block:
        gr.Markdown("## ⚔️ Arena: 模型实时对比")
        gr.Markdown("在这里,您可以输入一个问题或提示词,并提供一段上下文。文本生成模型将根据问题和上下文生成文本,问答模型将从上下文中抽取答案。")

        with gr.Row():
            # 统一输入框 1: 问题/提示词
            question_input = gr.Textbox(label="问题/提示词:", placeholder="请输入您的问题或想让模型生成的提示词...", lines=3)
            # 统一输入框 2: 上下文 (主要用于问答模型)
            context_input = gr.Textbox(label="上下文 (Context):", placeholder="请输入问答模型需要从中抽取答案的上下文...", lines=5)

        with gr.Row():
            # 增加生成长度控制(主要针对文本生成模型)
            gen_length_slider = gr.Slider(minimum=20, maximum=300, value=100, step=10, label="文本生成最大长度")
            generate_btn = gr.Button("🚀 生成并对比")

        with gr.Row():
            # 模型 1 输出 (文本生成)
            output_text_gen = gr.Textbox(label=f"模型 1 (文本生成: {model1_name}) 输出:", interactive=False, lines=10)
            # 模型 2 输出 (问答)
            output_qa = gr.Textbox(label=f"模型 2 (问答: {model2_name}) 输出:", interactive=False, lines=10)

        # 绑定按钮点击事件到推理函数
        generate_btn.click(
            fn=get_model_outputs,
            inputs=[question_input, context_input, gen_length_slider],
            outputs=[output_text_gen, output_qa]
        )
    return arena_block