File size: 4,038 Bytes
ddaff53
 
 
 
8a142a6
 
ddaff53
 
 
 
 
 
 
 
 
dab8aab
665e5a3
 
 
 
ddaff53
8a142a6
 
 
ddaff53
 
8a142a6
 
 
 
 
ddaff53
8a142a6
 
ddaff53
8151596
 
8a142a6
 
 
 
8151596
ddaff53
8a142a6
 
 
 
 
ddaff53
 
8a142a6
ddaff53
 
 
 
 
 
 
 
665e5a3
ddaff53
 
 
 
 
 
 
 
 
665e5a3
ddaff53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8151596
ddaff53
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from .prompts import format_rag_prompt

# --- Dummy Model Summaries ---
# Define functions that simulate model summary generation
# models = {
#     "Model Alpha": lambda context, question, answerable: f"Alpha Summary: Based on the context for '{question[:20]}...', it appears the question is {'answerable' if answerable else 'unanswerable'}.",
#     "Model Beta": lambda context, question, answerable: f"Beta Summary: Regarding '{question[:20]}...', the provided documents {'allow' if answerable else 'do not allow'} for a conclusive answer based on the text.",
#     "Model Gamma": lambda context, question, answerable: f"Gamma Summary: For the question '{question[:20]}...', I {'can' if answerable else 'cannot'} provide a specific answer from the given text snippets.",
#     "Model Delta (Refusal Specialist)": lambda context, question, answerable: f"Delta Summary: The context for '{question[:20]}...' is {'sufficient' if answerable else 'insufficient'} to formulate a direct response. Therefore, I must refuse."
# }

models = {
    "Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct",
    "Qwen2.5-3b-Instruct": "qwen/qwen2.5-3b-instruct", # remove gated for now
    "Llama-3.2-3b-Instruct": "meta-llama/llama-3.2-3b-instruct",
    "Llama-3.2-1b-Instruct": "meta-llama/llama-3.2-1b-instruct",
    "Gemma-3-1b-it" : "google/gemma-3-1b-it",
    #"Bitnet-b1.58-2B-4T": "microsoft/bitnet-b1.58-2B-4T",
    #TODO add more models
}

# List of model names for easy access
model_names = list(models.keys())


def generate_summaries(example, model_a_name, model_b_name):
    """
    Generates summaries for the given example using the assigned models.
    """

    # Create a plain text version of the contexts for the models
    context_text = ""
    context_parts = []
    if "full_contexts" in example:
        for ctx in example["full_contexts"]:
            if isinstance(ctx, dict) and "content" in ctx:
                context_parts.append(ctx["content"])
        context_text = "\n---\n".join(context_parts)
    else:
        raise ValueError("No context found in the example.")

    # Pass 'Answerable' status to models (they might use it)
    answerable = example.get("Answerable", True)
    question = example.get("question", "")

    # Call the dummy model functions
    summary_a = run_inference(models[model_a_name], context_text, question)
    summary_b = run_inference(models[model_b_name], context_text, question)
    return summary_a, summary_b

def run_inference(model_name, context, question):
    """
    Run inference using the specified model.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load the model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", token=True)
    accepts_sys = (
        "System role not supported" not in tokenizer.chat_template
    )  # Workaround for Gemma
    
    # Set padding token if not set
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        
    model = AutoModelForCausalLM.from_pretrained(
        model_name, torch_dtype=torch.bfloat16, attn_implementation="eager", token=True
    ).to(device)

    text_input = format_rag_prompt(question, context, accepts_sys)

    # Tokenize the input
    actual_input = tokenizer.apply_chat_template(
        text_input,
        return_tensors="pt",
        tokenize=True,
        max_length=2048,
        add_generation_prompt=True,
    ).to(device)

    input_length = actual_input.shape[1]
    attention_mask = torch.ones_like(actual_input).to(device)
    
    # Generate output
    with torch.inference_mode():
        outputs = model.generate(
            actual_input,
            attention_mask=attention_mask,
            max_new_tokens=512,  
            pad_token_id=tokenizer.pad_token_id,
        )

    # Decode the output
    result = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)

    return result