File size: 4,047 Bytes
ddaff53
1db9e92
ddaff53
1db9e92
ddaff53
 
 
665e5a3
1db9e92
8a142a6
 
 
ddaff53
 
1db9e92
 
 
 
 
 
 
8a142a6
 
 
 
 
1db9e92
 
ddaff53
8a142a6
ddaff53
8151596
 
8a142a6
 
 
 
8151596
ddaff53
8a142a6
 
1db9e92
 
 
ddaff53
1db9e92
 
 
 
ddaff53
8a142a6
ddaff53
 
 
 
 
1db9e92
 
 
ddaff53
 
1db9e92
 
 
 
 
ddaff53
1db9e92
 
 
 
 
 
 
 
 
ddaff53
1db9e92
ddaff53
1db9e92
 
 
 
 
 
 
 
 
 
ddaff53
1db9e92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddaff53
1db9e92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
from .prompts import format_rag_prompt
from .shared import generation_interrupt

models = {
    "Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct",
    "Llama-3.2-1b-Instruct": "meta-llama/llama-3.2-1b-instruct",
    "Gemma-3-1b-it": "google/gemma-3-1b-it",
}

# List of model names for easy access
model_names = list(models.keys())

# Custom stopping criteria that checks the interrupt flag
class InterruptCriteria(StoppingCriteria):
    def __init__(self, interrupt_event):
        self.interrupt_event = interrupt_event
        
    def __call__(self, input_ids, scores, **kwargs):
        return self.interrupt_event.is_set()

def generate_summaries(example, model_a_name, model_b_name):
    """
    Generates summaries for the given example using the assigned models.
    """
    if generation_interrupt.is_set():
        return "", ""

    context_text = ""
    context_parts = []
    if "full_contexts" in example:
        for ctx in example["full_contexts"]:
            if isinstance(ctx, dict) and "content" in ctx:
                context_parts.append(ctx["content"])
        context_text = "\n---\n".join(context_parts)
    else:
        raise ValueError("No context found in the example.")

    question = example.get("question", "")

    if generation_interrupt.is_set():
        return "", ""

    summary_a = run_inference(models[model_a_name], context_text, question)
    
    if generation_interrupt.is_set():
        return summary_a, ""
        
    summary_b = run_inference(models[model_b_name], context_text, question)
    return summary_a, summary_b

def run_inference(model_name, context, question):
    """
    Run inference using the specified model.
    """
    if generation_interrupt.is_set():
        return ""
        
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    try:
        tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", token=True)
        accepts_sys = (
            "System role not supported" not in tokenizer.chat_template
        )
        
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            
        if generation_interrupt.is_set():
            return ""
            
        model = AutoModelForCausalLM.from_pretrained(
            model_name, torch_dtype=torch.bfloat16, attn_implementation="eager", token=True
        ).to(device)

        text_input = format_rag_prompt(question, context, accepts_sys)

        if generation_interrupt.is_set():
            return ""
            
        actual_input = tokenizer.apply_chat_template(
            text_input,
            return_tensors="pt",
            tokenize=True,
            max_length=2048,
            add_generation_prompt=True,
        ).to(device)

        input_length = actual_input.shape[1]
        attention_mask = torch.ones_like(actual_input).to(device)
        
        if generation_interrupt.is_set():
            return ""
            
        stopping_criteria = StoppingCriteriaList([InterruptCriteria(generation_interrupt)])
        
        with torch.inference_mode():
            outputs = model.generate(
                actual_input,
                attention_mask=attention_mask,
                max_new_tokens=512,  
                pad_token_id=tokenizer.pad_token_id,
                stopping_criteria=stopping_criteria
            )

        if generation_interrupt.is_set():
            return ""
            
        result = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
        
        return result
        
    except Exception as e:
        print(f"Error in inference: {e}")
        return f"Error generating response: {str(e)[:100]}..."
        
    finally:
        if 'model' in locals():
            del model
        if 'tokenizer' in locals():
            del tokenizer
        if torch.cuda.is_available():
            torch.cuda.empty_cache()