import torch from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList from .prompts import format_rag_prompt from .shared import generation_interrupt models = { "Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct", "Llama-3.2-1b-Instruct": "meta-llama/llama-3.2-1b-instruct", "Gemma-3-1b-it": "google/gemma-3-1b-it", } # List of model names for easy access model_names = list(models.keys()) # Custom stopping criteria that checks the interrupt flag class InterruptCriteria(StoppingCriteria): def __init__(self, interrupt_event): self.interrupt_event = interrupt_event def __call__(self, input_ids, scores, **kwargs): return self.interrupt_event.is_set() def generate_summaries(example, model_a_name, model_b_name): """ Generates summaries for the given example using the assigned models. """ if generation_interrupt.is_set(): return "", "" context_text = "" context_parts = [] if "full_contexts" in example: for ctx in example["full_contexts"]: if isinstance(ctx, dict) and "content" in ctx: context_parts.append(ctx["content"]) context_text = "\n---\n".join(context_parts) else: raise ValueError("No context found in the example.") question = example.get("question", "") if generation_interrupt.is_set(): return "", "" summary_a = run_inference(models[model_a_name], context_text, question) if generation_interrupt.is_set(): return summary_a, "" summary_b = run_inference(models[model_b_name], context_text, question) return summary_a, summary_b def run_inference(model_name, context, question): """ Run inference using the specified model. """ if generation_interrupt.is_set(): return "" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") try: tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", token=True) accepts_sys = ( "System role not supported" not in tokenizer.chat_template ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token if generation_interrupt.is_set(): return "" model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, attn_implementation="eager", token=True ).to(device) text_input = format_rag_prompt(question, context, accepts_sys) if generation_interrupt.is_set(): return "" actual_input = tokenizer.apply_chat_template( text_input, return_tensors="pt", tokenize=True, max_length=2048, add_generation_prompt=True, ).to(device) input_length = actual_input.shape[1] attention_mask = torch.ones_like(actual_input).to(device) if generation_interrupt.is_set(): return "" stopping_criteria = StoppingCriteriaList([InterruptCriteria(generation_interrupt)]) with torch.inference_mode(): outputs = model.generate( actual_input, attention_mask=attention_mask, max_new_tokens=512, pad_token_id=tokenizer.pad_token_id, stopping_criteria=stopping_criteria ) if generation_interrupt.is_set(): return "" result = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True) return result except Exception as e: print(f"Error in inference: {e}") return f"Error generating response: {str(e)[:100]}..." finally: if 'model' in locals(): del model if 'tokenizer' in locals(): del tokenizer if torch.cuda.is_available(): torch.cuda.empty_cache()