Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,819 Bytes
6c63a2d b8ee0a2 6c63a2d ddaff53 b8ee0a2 ddaff53 1db9e92 97629be ddaff53 a837a1c 533ae49 1db9e92 533ae49 8a142a6 ddaff53 1db9e92 8a142a6 97629be 8a142a6 97629be 8a142a6 1db9e92 ddaff53 8a142a6 ddaff53 8151596 8a142a6 8151596 ddaff53 8a142a6 1db9e92 97629be 1db9e92 97629be 8a142a6 ddaff53 b8ee0a2 97629be ddaff53 97629be ddaff53 97629be 1db9e92 97629be 6b26b26 ddaff53 6b26b26 ddaff53 1db9e92 6b26b26 1db9e92 6b26b26 b8ee0a2 6b26b26 b8ee0a2 97629be b8ee0a2 6b26b26 1db9e92 ddaff53 97629be b8ee0a2 97629be e6127a4 6b26b26 1db9e92 97629be 6b26b26 1db9e92 97629be 1db9e92 97629be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import os
os.environ['MKL_THREADING_LAYER'] = 'GNU'
import spaces
import torch
from transformers import pipeline, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
from .prompts import format_rag_prompt
from .shared import generation_interrupt
models = {
"Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct",
"Qwen2.5-3b-Instruct": "qwen/qwen-2.5-3b-instruct",
#"Llama-3.2-1b-Instruct": "meta-llama/llama-3.2-1b-instruct",
#"Llama-3.2-3b-Instruct": "meta-llama/llama-3.2-3b-instruct",
"Gemma-3-1b-it": "google/gemma-3-1b-it",
"Gemma-3-4b-it": "google/gemma-3-4b-it",
"Gemma-2-2b-it": "google/gemma-2-2b-it",
"Phi-4-mini-instruct": "microsoft/phi-4-mini-instruct",
}
# List of model names for easy access
model_names = list(models.keys())
# Custom stopping criteria that checks the interrupt flag
class InterruptCriteria(StoppingCriteria):
def __init__(self, interrupt_event):
self.interrupt_event = interrupt_event
def __call__(self, input_ids, scores, **kwargs):
return self.interrupt_event.is_set()
@spaces.GPU
def generate_summaries(example, model_a_name, model_b_name):
"""
Generates summaries for the given example using the assigned models sequentially.
"""
if generation_interrupt.is_set():
return "", ""
context_text = ""
context_parts = []
if "full_contexts" in example:
for ctx in example["full_contexts"]:
if isinstance(ctx, dict) and "content" in ctx:
context_parts.append(ctx["content"])
context_text = "\n---\n".join(context_parts)
else:
raise ValueError("No context found in the example.")
question = example.get("question", "")
if generation_interrupt.is_set():
return "", ""
# Run model A
summary_a = run_inference(models[model_a_name], context_text, question)
if generation_interrupt.is_set():
return summary_a, ""
# Run model B
summary_b = run_inference(models[model_b_name], context_text, question)
return summary_a, summary_b
@spaces.GPU
def run_inference(model_name, context, question):
"""
Run inference using the specified model.
Returns the generated text or empty string if interrupted.
"""
# Check interrupt at the beginning
if generation_interrupt.is_set():
return ""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
result = ""
try:
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", token=True)
accepts_sys = (
"System role not supported" not in tokenizer.chat_template
if tokenizer.chat_template else False # Handle missing chat_template
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Check interrupt before loading the model
if generation_interrupt.is_set():
return ""
pipe = pipeline(
"text-generation",
model=model_name,
tokenizer=tokenizer,
device_map='auto',
max_length=512,
do_sample=True,
temperature=0.6,
top_p=0.9,
)
text_input = format_rag_prompt(question, context, accepts_sys)
# Check interrupt before generation
if generation_interrupt.is_set():
return ""
outputs = pipe(text_input, max_new_tokens=512)
result = outputs[0]['generated_text'][-1]['content']
except Exception as e:
print(f"Error in inference for {model_name}: {e}")
result = f"Error generating response: {str(e)[:200]}..."
finally:
# Clean up resources
if torch.cuda.is_available():
torch.cuda.empty_cache()
return result |