Spaces:
Sleeping
Sleeping
"""Generator component for the RAG system.""" | |
from typing import List, Dict | |
import torch | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForCausalLM, | |
LogitsProcessor, | |
LogitsProcessorList | |
) | |
class FinancialContextProcessor(LogitsProcessor): | |
"""Custom logits processor for financial context.""" | |
def __init__(self, financial_constraints: Dict): | |
self.constraints = financial_constraints | |
def __call__(self, input_ids: torch.LongTensor, | |
scores: torch.FloatTensor) -> torch.FloatTensor: | |
# Apply financial domain constraints | |
# This is a placeholder for actual constraints | |
return scores | |
class RAGGenerator: | |
def __init__(self, config: Dict): | |
"""Initialize the generator.""" | |
self.model_name = "gpt2" # Can be configured based on needs | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
self.model = AutoModelForCausalLM.from_pretrained(self.model_name) | |
self.max_length = 512 | |
def prepare_context(self, retrieved_docs: List[Dict]) -> str: | |
"""Prepare context from retrieved documents.""" | |
context = "" | |
for doc in retrieved_docs: | |
context += f"{doc['document']['text']}\n" | |
return context.strip() | |
def generate(self, query: str, retrieved_docs: List[Dict], | |
financial_constraints: Dict = None) -> str: | |
"""Generate text based on query and retrieved documents.""" | |
context = self.prepare_context(retrieved_docs) | |
prompt = f"Context: {context}\nQuery: {query}\nResponse:" | |
# Prepare logits processors | |
processors = LogitsProcessorList() | |
if financial_constraints: | |
processors.append(FinancialContextProcessor(financial_constraints)) | |
# Generate response | |
inputs = self.tokenizer(prompt, return_tensors="pt") | |
outputs = self.model.generate( | |
inputs.input_ids, | |
max_length=self.max_length, | |
num_return_sequences=1, | |
logits_processor=processors, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9 | |
) | |
return self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |