import os os.environ['MKL_THREADING_LAYER'] = 'GNU' import torch from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList from .prompts import format_rag_prompt from .shared import generation_interrupt import threading import queue import time # Added for sleep from vllm import LLM, SamplingParams models = { "Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct", "Llama-3.2-1b-Instruct": "meta-llama/llama-3.2-1b-instruct", "Gemma-3-1b-it": "google/gemma-3-1b-it", } # List of model names for easy access model_names = list(models.keys()) # Custom stopping criteria that checks the interrupt flag class InterruptCriteria(StoppingCriteria): def __init__(self, interrupt_event): self.interrupt_event = interrupt_event def __call__(self, input_ids, scores, **kwargs): return self.interrupt_event.is_set() def generate_summaries(example, model_a_name, model_b_name): """ Generates summaries for the given example using the assigned models. """ if generation_interrupt.is_set(): return "", "" context_text = "" context_parts = [] if "full_contexts" in example: for ctx in example["full_contexts"]: if isinstance(ctx, dict) and "content" in ctx: context_parts.append(ctx["content"]) context_text = "\n---\n".join(context_parts) else: raise ValueError("No context found in the example.") question = example.get("question", "") if generation_interrupt.is_set(): return "", "" # Use a queue to get results from threads result_queue_a = queue.Queue() thread_a = threading.Thread(target=run_inference, args=(models[model_a_name], context_text, question, result_queue_a)) thread_a.start() summary_a = "" while thread_a.is_alive(): if generation_interrupt.is_set(): print(f"Interrupting model A ({model_a_name})...") # The InterruptCriteria within the thread will handle stopping generate # We return early from the main control flow. thread_a.join(timeout=1.0) # Give thread a moment to potentially stop return "", "" try: summary_a = result_queue_a.get(timeout=0.1) # Check queue periodically break # Got result except queue.Empty: continue # Still running, check interrupt again # If thread finished but we didn't get a result (e.g., interrupted just before putting in queue) if not summary_a and not result_queue_a.empty(): summary_a = result_queue_a.get_nowait() elif not summary_a and generation_interrupt.is_set(): # Check interrupt again if thread finished quickly return "", "" if generation_interrupt.is_set(): # Check between models return summary_a, "" # --- Model B --- result_queue_b = queue.Queue() thread_b = threading.Thread(target=run_inference, args=(models[model_b_name], context_text, question, result_queue_b)) thread_b.start() summary_b = "" while thread_b.is_alive(): if generation_interrupt.is_set(): print(f"Interrupting model B ({model_b_name})...") thread_b.join(timeout=1.0) return summary_a, "" # Return summary_a obtained so far try: summary_b = result_queue_b.get(timeout=0.1) break except queue.Empty: continue if not summary_b and not result_queue_b.empty(): summary_b = result_queue_b.get_nowait() elif not summary_b and generation_interrupt.is_set(): return summary_a, "" return summary_a, summary_b # Modified run_inference to run in a thread and use a queue for results def run_inference(model_name, context, question, result_queue): """ Run inference using the specified model. Designed to be run in a thread. Puts the result or an error string into the result_queue. """ # Check interrupt at the very beginning of the thread if generation_interrupt.is_set(): result_queue.put("") return device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = None tokenizer = None result = "" try: tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", token=True) accepts_sys = ( "System role not supported" not in tokenizer.chat_template if tokenizer.chat_template else False # Handle missing chat_template ) # if tokenizer.pad_token is None: # tokenizer.pad_token = tokenizer.eos_token # # Check interrupt before loading the model # if generation_interrupt.is_set(): # result_queue.put("") # return # model = AutoModelForCausalLM.from_pretrained( # model_name, torch_dtype=torch.bfloat16, attn_implementation="eager", token=True # ).to(device) # model.eval() # Set model to evaluation mode text_input = format_rag_prompt(question, context, accepts_sys) # # Check interrupt before tokenization/template application # if generation_interrupt.is_set(): # result_queue.put("") # return # actual_input = tokenizer.apply_chat_template( # text_input, # return_tensors="pt", # tokenize=True, # # Consider reducing max_length if context/question is very long # # max_length=tokenizer.model_max_length, # Use model's max length # # truncation=True, # Ensure truncation if needed # max_length=2048, # Keep original max_length for now # add_generation_prompt=True, # ).to(device) # # Ensure input does not exceed model max length after adding generation prompt # # This check might be redundant if tokenizer handles it, but good for safety # # if actual_input.shape[1] > tokenizer.model_max_length: # # # Handle too long input - maybe truncate manually or raise error # # print(f"Warning: Input length {actual_input.shape[1]} exceeds model max length {tokenizer.model_max_length}") # # # Simple truncation (might lose important info): # # # actual_input = actual_input[:, -tokenizer.model_max_length:] # input_length = actual_input.shape[1] # attention_mask = torch.ones_like(actual_input).to(device) # # Check interrupt before generation # if generation_interrupt.is_set(): # result_queue.put("") # return # stopping_criteria = StoppingCriteriaList([InterruptCriteria(generation_interrupt)]) # with torch.inference_mode(): # outputs = model.generate( # actual_input, # attention_mask=attention_mask, # max_new_tokens=512, # pad_token_id=tokenizer.pad_token_id, # stopping_criteria=stopping_criteria, # do_sample=True, # Consider adding sampling parameters if needed # temperature=0.6, # top_p=0.9, # ) # # Check interrupt immediately after generation finishes or stops # if generation_interrupt.is_set(): # result = "" # Discard potentially partial result if interrupted # else: # # Decode the generated tokens, excluding the input tokens # result = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True) llm = LLM(model_name, dtype=torch.bfloat16, hf_token=True, enforce_eager=True) params = SamplingParams( max_tokens=512, ) # Check interrupt before generation if generation_interrupt.is_set(): result_queue.put("") return # Generate the response outputs = llm.chat( text_input, sampling_params=params, # stopping_criteria=StoppingCriteriaList([InterruptCriteria(generation_interrupt)]), ) # Check interrupt immediately after generation finishes or stops result_queue.put(outputs[0].outputs[0].text) except Exception as e: print(f"Error in inference thread for {model_name}: {e}") # Put error message in queue for the main thread to handle/display result_queue.put(f"Error generating response: {str(e)[:200]}...") finally: # Clean up resources within the thread del model del tokenizer del text_input del outputs if torch.cuda.is_available(): torch.cuda.empty_cache()