import os os.environ["MKL_THREADING_LAYER"] = "GNU" import spaces from peft import PeftModel import traceback import torch from transformers import ( pipeline, AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList, ) from .prompts import format_rag_prompt from .shared import generation_interrupt models = { "Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct", "Qwen2.5-3b-Instruct": "qwen/qwen2.5-3b-instruct", "Llama-3.2-1b-Instruct": "meta-llama/llama-3.2-1b-instruct", "Llama-3.2-3b-Instruct": "meta-llama/llama-3.2-3b-instruct", "Gemma-3-1b-it": "google/gemma-3-1b-it", "Gemma-3-4b-it": "google/gemma-3-4b-it", "Gemma-2-2b-it": "google/gemma-2-2b-it", "Phi-4-mini-instruct": "microsoft/phi-4-mini-instruct", "Cogito-v1-preview-llama-3b": "deepcogito/cogito-v1-preview-llama-3b", "IBM Granite-3.3-2b-instruct": "ibm-granite/granite-3.3-2b-instruct", # #"Bitnet-b1.58-2B4T": "microsoft/bitnet-b1.58-2B-4T", # #"MiniCPM3-RAG-LoRA": "openbmb/MiniCPM3-RAG-LoRA", "Qwen3-0.6b": "qwen/qwen3-0.6b", "Qwen3-1.7b": "qwen/qwen3-1.7b", "Qwen3-4b": "qwen/qwen3-4b", "SmolLM2-1.7b-Instruct": "HuggingFaceTB/SmolLM2-1.7B-Instruct", "EXAONE-3.5-2.4B-instruct": "LGAI-EXAONE/EXAONE-3.5-2.4B-Instruct", "OLMo-2-1B-Instruct": "allenai/OLMo-2-0425-1B-Instruct", "icecream-3b": "aizip-dev/icecream-3b", } tokenizer_cache = {} # List of model names for easy access model_names = list(models.keys()) # Custom stopping criteria that checks the interrupt flag class InterruptCriteria(StoppingCriteria): def __init__(self, interrupt_event): self.interrupt_event = interrupt_event def __call__(self, input_ids, scores, **kwargs): return self.interrupt_event.is_set() @spaces.GPU def generate_summaries(example, model_a_name, model_b_name): """ Generates summaries for the given example using the assigned models sequentially. """ if generation_interrupt.is_set(): return "", "" context_text = "" context_parts = [] if "full_contexts" in example and example["full_contexts"]: for i, ctx in enumerate(example["full_contexts"]): content = "" # Extract content from either dict or string if isinstance(ctx, dict) and "content" in ctx: content = ctx["content"] elif isinstance(ctx, str): content = ctx # Add document number if not already present if not content.strip().startswith("Document"): content = f"Document {i + 1}:\n{content}" context_parts.append(content) context_text = "\n\n".join(context_parts) else: # Provide a graceful fallback instead of raising an error print("Warning: No full context found in the example, using empty context") context_text = "" question = example.get("question", "") if generation_interrupt.is_set(): return "", "" # Run model A summary_a = run_inference(models[model_a_name], context_text, question) if generation_interrupt.is_set(): return summary_a, "" # Run model B summary_b = run_inference(models[model_b_name], context_text, question) return summary_a, summary_b @spaces.GPU def run_inference(model_name, context, question): """ Run inference using the specified model. Returns the generated text or empty string if interrupted. """ # Check interrupt at the beginning if generation_interrupt.is_set(): return "" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") result = "" tokenizer_kwargs = { "add_generation_prompt": True, } # make sure qwen3 doesn't use thinking generation_kwargs = { "max_new_tokens": 512, } if "qwen3" in model_name.lower(): print( f"Recognized {model_name} as a Qwen3 model. Setting enable_thinking=False." ) tokenizer_kwargs["enable_thinking"] = False try: print("REACHED HERE BEFORE tokenizer") if model_name in tokenizer_cache: tokenizer = tokenizer_cache[model_name] else: # Common arguments for tokenizer loading tokenizer_load_args = {"padding_side": "left", "token": True} actual_model_name_for_tokenizer = model_name if "icecream" in model_name.lower(): actual_model_name_for_tokenizer = "meta-llama/llama-3.2-3b-instruct" tokenizer = AutoTokenizer.from_pretrained(actual_model_name_for_tokenizer, **tokenizer_load_args) tokenizer_cache[model_name] = tokenizer accepts_sys = ( "System role not supported" not in tokenizer.chat_template if tokenizer.chat_template else False # Handle missing chat_template ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Check interrupt before loading the model if generation_interrupt.is_set(): return "" print("REACHED HERE BEFORE pipe") print(f"Loading model {model_name}...") if "icecream" not in model_name.lower(): pipe = pipeline( "text-generation", model=model_name, tokenizer=tokenizer, device_map="cuda", trust_remote_code=True, torch_dtype=torch.bfloat16, model_kwargs={ "attn_implementation": "eager", }, ) else: base_model = AutoModelForCausalLM.from_pretrained( "meta-llama/llama-3.2-3b-instruct", device_map="cuda", torch_dtype=torch.bfloat16, trust_remote_code=True, ) model = PeftModel.from_pretrained( base_model, "aizip-dev/icecream-3b", device_map="cuda", torch_dtype=torch.bfloat16, ) text_input = format_rag_prompt(question, context, accepts_sys) if "Gemma-3".lower() in model_name.lower(): print("REACHED HERE BEFORE GEN") result = pipe( text_input, max_new_tokens=512, generation_kwargs={"skip_special_tokens": True}, )[0]["generated_text"] result = result[-1]["content"] elif "icecream" in model_name.lower(): print("ICECREAM") model_inputs = tokenizer.apply_chat_template( text_input, tokenize=True, return_tensors="pt", return_dict=True, **tokenizer_kwargs, ) model_inputs = model_inputs.to(model.device) input_ids = model_inputs.input_ids attention_mask = model_inputs.attention_mask prompt_tokens_length = input_ids.shape[1] with torch.inference_mode(): # Check interrupt before generation if generation_interrupt.is_set(): return "" output_sequences = model.generate( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=512, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id # Addresses the warning ) generated_token_ids = output_sequences[0][prompt_tokens_length:] result = tokenizer.decode(generated_token_ids, skip_special_tokens=True) else: # For other models formatted = pipe.tokenizer.apply_chat_template( text_input, tokenize=False, **tokenizer_kwargs, ) input_length = len(formatted) # Check interrupt before generation outputs = pipe( formatted, max_new_tokens=512, generation_kwargs={"skip_special_tokens": True}, ) # print(outputs[0]['generated_text']) result = outputs[0]["generated_text"][input_length:] except Exception as e: print(f"Error in inference for {model_name}: {e}") print(traceback.format_exc()) result = f"Error generating response: {str(e)[:200]}..." finally: # Clean up resources if torch.cuda.is_available(): torch.cuda.empty_cache() return result