SLM-RAG-Arena / utils /models.py
oliver-aizip's picture
prepare for zeroGPU
b8ee0a2
raw
history blame
9.1 kB
import os
os.environ['MKL_THREADING_LAYER'] = 'GNU'
import spaces
import torch
from transformers import pipeline, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
from .prompts import format_rag_prompt
from .shared import generation_interrupt
import threading
import queue
import time # Added for sleep
from vllm import LLM, SamplingParams
models = {
"Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct",
"Llama-3.2-1b-Instruct": "meta-llama/llama-3.2-1b-instruct",
"Gemma-3-1b-it": "google/gemma-3-1b-it",
}
# List of model names for easy access
model_names = list(models.keys())
# Custom stopping criteria that checks the interrupt flag
class InterruptCriteria(StoppingCriteria):
def __init__(self, interrupt_event):
self.interrupt_event = interrupt_event
def __call__(self, input_ids, scores, **kwargs):
return self.interrupt_event.is_set()
def generate_summaries(example, model_a_name, model_b_name):
"""
Generates summaries for the given example using the assigned models.
"""
if generation_interrupt.is_set():
return "", ""
context_text = ""
context_parts = []
if "full_contexts" in example:
for ctx in example["full_contexts"]:
if isinstance(ctx, dict) and "content" in ctx:
context_parts.append(ctx["content"])
context_text = "\n---\n".join(context_parts)
else:
raise ValueError("No context found in the example.")
question = example.get("question", "")
if generation_interrupt.is_set():
return "", ""
# Use a queue to get results from threads
result_queue_a = queue.Queue()
thread_a = threading.Thread(target=run_inference, args=(models[model_a_name], context_text, question, result_queue_a))
thread_a.start()
summary_a = ""
while thread_a.is_alive():
if generation_interrupt.is_set():
print(f"Interrupting model A ({model_a_name})...")
# The InterruptCriteria within the thread will handle stopping generate
# We return early from the main control flow.
thread_a.join(timeout=1.0) # Give thread a moment to potentially stop
return "", ""
try:
summary_a = result_queue_a.get(timeout=0.1) # Check queue periodically
break # Got result
except queue.Empty:
continue # Still running, check interrupt again
# If thread finished but we didn't get a result (e.g., interrupted just before putting in queue)
if not summary_a and not result_queue_a.empty():
summary_a = result_queue_a.get_nowait()
elif not summary_a and generation_interrupt.is_set(): # Check interrupt again if thread finished quickly
return "", ""
if generation_interrupt.is_set(): # Check between models
return summary_a, ""
# --- Model B ---
result_queue_b = queue.Queue()
thread_b = threading.Thread(target=run_inference, args=(models[model_b_name], context_text, question, result_queue_b))
thread_b.start()
summary_b = ""
while thread_b.is_alive():
if generation_interrupt.is_set():
print(f"Interrupting model B ({model_b_name})...")
thread_b.join(timeout=1.0)
return summary_a, "" # Return summary_a obtained so far
try:
summary_b = result_queue_b.get(timeout=0.1)
break
except queue.Empty:
continue
if not summary_b and not result_queue_b.empty():
summary_b = result_queue_b.get_nowait()
elif not summary_b and generation_interrupt.is_set():
return summary_a, ""
return summary_a, summary_b
# Modified run_inference to run in a thread and use a queue for results
@spaces.GPU
def run_inference(model_name, context, question, result_queue):
"""
Run inference using the specified model. Designed to be run in a thread.
Puts the result or an error string into the result_queue.
"""
# Check interrupt at the very beginning of the thread
if generation_interrupt.is_set():
result_queue.put("")
return
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = None
tokenizer = None
result = ""
try:
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", token=True)
accepts_sys = (
"System role not supported" not in tokenizer.chat_template
if tokenizer.chat_template else False # Handle missing chat_template
)
outputs = ""
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Check interrupt before loading the model
if generation_interrupt.is_set():
result_queue.put("")
return
pipe = pipeline(
"text-generation",
model=model_name,
tokenizer=tokenizer,
device_map='auto',
max_length=512,
do_sample=True,
temperature=0.6,
top_p=0.9,
)
# model = AutoModelForCausalLM.from_pretrained(
# model_name, torch_dtype=torch.bfloat16, attn_implementation="eager", token=True
# ).to(device)
# model.eval() # Set model to evaluation mode
text_input = format_rag_prompt(question, context, accepts_sys)
# Check interrupt before tokenization/template application
if generation_interrupt.is_set():
result_queue.put("")
return
# actual_input = tokenizer.apply_chat_template(
# text_input,
# return_tensors="pt",
# tokenize=True,
# # Consider reducing max_length if context/question is very long
# # max_length=tokenizer.model_max_length, # Use model's max length
# # truncation=True, # Ensure truncation if needed
# max_length=2048, # Keep original max_length for now
# add_generation_prompt=True,
# ).to(device)
output = pipe(text_input, max_new_tokens=512)
result = output[0]['generated_text'][-1]['content']
# # Ensure input does not exceed model max length after adding generation prompt
# # This check might be redundant if tokenizer handles it, but good for safety
# # if actual_input.shape[1] > tokenizer.model_max_length:
# # # Handle too long input - maybe truncate manually or raise error
# # print(f"Warning: Input length {actual_input.shape[1]} exceeds model max length {tokenizer.model_max_length}")
# # # Simple truncation (might lose important info):
# # # actual_input = actual_input[:, -tokenizer.model_max_length:]
# input_length = actual_input.shape[1]
# attention_mask = torch.ones_like(actual_input).to(device)
# # Check interrupt before generation
# if generation_interrupt.is_set():
# result_queue.put("")
# return
# stopping_criteria = StoppingCriteriaList([InterruptCriteria(generation_interrupt)])
# with torch.inference_mode():
# outputs = model.generate(
# actual_input,
# attention_mask=attention_mask,
# max_new_tokens=512,
# pad_token_id=tokenizer.pad_token_id,
# stopping_criteria=stopping_criteria,
# do_sample=True, # Consider adding sampling parameters if needed
# temperature=0.6,
# top_p=0.9,
# )
# # Check interrupt immediately after generation finishes or stops
# if generation_interrupt.is_set():
# result = "" # Discard potentially partial result if interrupted
# else:
# # Decode the generated tokens, excluding the input tokens
# result = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
# llm = LLM(model_name, dtype=torch.bfloat16, hf_token=True, enforce_eager=True, device="cpu")
# params = SamplingParams(
# max_tokens=512,
# )
# # Check interrupt before generation
# if generation_interrupt.is_set():
# result_queue.put("")
# return
# # Generate the response
# outputs = llm.chat(
# text_input,
# sampling_params=params,
# # stopping_criteria=StoppingCriteriaList([InterruptCriteria(generation_interrupt)]),
# )
# # Check interrupt immediately after generation finishes or stops
result_queue.put(result)
except Exception as e:
print(f"Error in inference thread for {model_name}: {e}")
# Put error message in queue for the main thread to handle/display
result_queue.put(f"Error generating response: {str(e)[:200]}...")
finally:
# Clean up resources within the thread
del model
del tokenizer
del text_input
del outputs
if torch.cuda.is_available():
torch.cuda.empty_cache()