Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from .prompts import format_rag_prompt | |
| # --- Dummy Model Summaries --- | |
| # Define functions that simulate model summary generation | |
| # models = { | |
| # "Model Alpha": lambda context, question, answerable: f"Alpha Summary: Based on the context for '{question[:20]}...', it appears the question is {'answerable' if answerable else 'unanswerable'}.", | |
| # "Model Beta": lambda context, question, answerable: f"Beta Summary: Regarding '{question[:20]}...', the provided documents {'allow' if answerable else 'do not allow'} for a conclusive answer based on the text.", | |
| # "Model Gamma": lambda context, question, answerable: f"Gamma Summary: For the question '{question[:20]}...', I {'can' if answerable else 'cannot'} provide a specific answer from the given text snippets.", | |
| # "Model Delta (Refusal Specialist)": lambda context, question, answerable: f"Delta Summary: The context for '{question[:20]}...' is {'sufficient' if answerable else 'insufficient'} to formulate a direct response. Therefore, I must refuse." | |
| # } | |
| models = { | |
| "Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct", | |
| "Qwen2.5-3b-Instruct": "qwen/qwen2.5-3b-instruct", # remove gated for now | |
| "Llama-3.2-3b-Instruct": "meta-llama/llama-3.2-3b-instruct", | |
| "Llama-3.2-1b-Instruct": "meta-llama/llama-3.2-1b-instruct", | |
| "Gemma-3-1b-it" : "google/gemma-3-1b-it", | |
| #"Bitnet-b1.58-2B-4T": "microsoft/bitnet-b1.58-2B-4T", | |
| #TODO add more models | |
| } | |
| # List of model names for easy access | |
| model_names = list(models.keys()) | |
| def generate_summaries(example, model_a_name, model_b_name): | |
| """ | |
| Generates summaries for the given example using the assigned models. | |
| """ | |
| # Create a plain text version of the contexts for the models | |
| context_text = "" | |
| context_parts = [] | |
| if "full_contexts" in example: | |
| for ctx in example["full_contexts"]: | |
| if isinstance(ctx, dict) and "content" in ctx: | |
| context_parts.append(ctx["content"]) | |
| context_text = "\n---\n".join(context_parts) | |
| else: | |
| raise ValueError("No context found in the example.") | |
| # Pass 'Answerable' status to models (they might use it) | |
| answerable = example.get("Answerable", True) | |
| question = example.get("question", "") | |
| # Call the dummy model functions | |
| summary_a = run_inference(models[model_a_name], context_text, question) | |
| summary_b = run_inference(models[model_b_name], context_text, question) | |
| return summary_a, summary_b | |
| def run_inference(model_name, context, question): | |
| """ | |
| Run inference using the specified model. | |
| """ | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load the model and tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", token=True) | |
| accepts_sys = ( | |
| "System role not supported" not in tokenizer.chat_template | |
| ) # Workaround for Gemma | |
| # Set padding token if not set | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, torch_dtype=torch.bfloat16, attn_implementation="eager", token=True | |
| ).to(device) | |
| text_input = format_rag_prompt(question, context, accepts_sys) | |
| # Tokenize the input | |
| actual_input = tokenizer.apply_chat_template( | |
| text_input, | |
| return_tensors="pt", | |
| tokenize=True, | |
| max_length=2048, | |
| add_generation_prompt=True, | |
| ).to(device) | |
| input_length = actual_input.shape[1] | |
| attention_mask = torch.ones_like(actual_input).to(device) | |
| # Generate output | |
| with torch.inference_mode(): | |
| outputs = model.generate( | |
| actual_input, | |
| attention_mask=attention_mask, | |
| max_new_tokens=512, | |
| pad_token_id=tokenizer.pad_token_id, | |
| ) | |
| # Decode the output | |
| result = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True) | |
| return result | |