Spaces:
Running
Running
Commit
·
e1e8013
1
Parent(s):
a704797
Improved input query filters
Browse files- app.py +53 -14
- data_filters.py +9 -0
app.py
CHANGED
|
@@ -21,6 +21,8 @@ from data_filters import (
|
|
| 21 |
restricted_patterns,
|
| 22 |
restricted_topics,
|
| 23 |
FINANCIAL_DATA_PATTERNS,
|
|
|
|
|
|
|
| 24 |
sensitive_terms,
|
| 25 |
FINANCIAL_TERMS,
|
| 26 |
)
|
|
@@ -37,8 +39,8 @@ os.makedirs("data", exist_ok=True)
|
|
| 37 |
# SLM: Microsoft PHI-2 model is loaded
|
| 38 |
# It does have higher memory and compute requirements compared to TinyLlama and Falcon
|
| 39 |
# But it gives the best results among the three
|
| 40 |
-
DEVICE = "cpu" # or cuda
|
| 41 |
-
|
| 42 |
# MODEL_NAME = "TinyLlama/TinyLlama_v1.1"
|
| 43 |
# MODEL_NAME = "tiiuae/falcon-rw-1b"
|
| 44 |
MODEL_NAME = "microsoft/phi-2"
|
|
@@ -55,7 +57,7 @@ if tokenizer.pad_token is None:
|
|
| 55 |
# Since the model is to be hosted on a cpu instance, we use float32
|
| 56 |
# For GPU, we can use float16 or bfloat16
|
| 57 |
model = AutoModelForCausalLM.from_pretrained(
|
| 58 |
-
MODEL_NAME, torch_dtype=torch.
|
| 59 |
).to(DEVICE)
|
| 60 |
model.eval()
|
| 61 |
# model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
|
|
@@ -234,25 +236,62 @@ def process_files(files, chunk_size=512):
|
|
| 234 |
pickle.dump(bm25_data, f)
|
| 235 |
return "Files processed successfully! You can now query."
|
| 236 |
|
|
|
|
| 237 |
def contains_financial_entities(query):
|
| 238 |
-
"""Check if
|
| 239 |
doc = nlp(query)
|
| 240 |
for ent in doc.ents:
|
| 241 |
if ent.label_ in FINANCIAL_ENTITY_LABELS:
|
| 242 |
return True
|
| 243 |
return False
|
| 244 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
# Input guardrail implementation
|
|
|
|
| 246 |
# Regex is used to filter queries related to sensitive topics
|
| 247 |
# Uses spaCy model's Named Entity Recognition to filter queries for personal details
|
| 248 |
# Uses cosine similarity with the embedded query and sensitive topic vectors
|
| 249 |
# to filter out queries violating confidential/security rules (additional)
|
| 250 |
def is_query_allowed(query):
|
| 251 |
"""Checks if the query violates security or confidentiality rules"""
|
|
|
|
|
|
|
| 252 |
for pattern in restricted_patterns:
|
| 253 |
if re.search(pattern, query.lower(), re.IGNORECASE):
|
| 254 |
return False, "This query requests sensitive or confidential information."
|
| 255 |
doc = nlp(query)
|
|
|
|
| 256 |
for ent in doc.ents:
|
| 257 |
if ent.label_ == "PERSON":
|
| 258 |
for token in ent.subtree:
|
|
@@ -265,6 +304,7 @@ def is_query_allowed(query):
|
|
| 265 |
topic_embeddings = embed_model.encode(
|
| 266 |
list(restricted_topics), normalize_embeddings=True
|
| 267 |
)
|
|
|
|
| 268 |
similarities = np.dot(topic_embeddings, query_embedding)
|
| 269 |
if np.max(similarities) > 0.85:
|
| 270 |
return False, "This query requests sensitive or confidential information."
|
|
@@ -368,8 +408,9 @@ def compute_response_confidence(
|
|
| 368 |
normalized_bm25 = 0.0
|
| 369 |
logger.info(
|
| 370 |
f"Faiss score: {normalized_faiss}, bm25: {normalized_bm25}, "
|
| 371 |
-
f"Mean Top Token + Entropy Avg: {model_conf_signal}"
|
| 372 |
)
|
|
|
|
| 373 |
confidence_score = (
|
| 374 |
lambda_faiss * normalized_faiss
|
| 375 |
+ model_conf_signal * lambda_conf
|
|
@@ -436,13 +477,10 @@ def query_model(
|
|
| 436 |
"You are a financial analyst. Answer financial queries concisely using only the numerical data "
|
| 437 |
"explicitly present in the provided financial context:\n\n"
|
| 438 |
f"{context}\n\n"
|
| 439 |
-
"
|
| 440 |
-
" Retain the original format of financial figures
|
| 441 |
-
"
|
| 442 |
-
"
|
| 443 |
-
"'No relevant financial data available.'"
|
| 444 |
-
" Provide exactly one answer in a single sentence."
|
| 445 |
-
" Do not generate explanations, additional text, or answer multiple queries."
|
| 446 |
f"\nQuery: {query}"
|
| 447 |
)
|
| 448 |
inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(DEVICE)
|
|
@@ -463,7 +501,8 @@ def query_model(
|
|
| 463 |
sequences = output["sequences"][0][input_len:]
|
| 464 |
execution_time = time.perf_counter() - start_time
|
| 465 |
logger.info(f"Query processed in {execution_time:.2f} seconds.")
|
| 466 |
-
|
|
|
|
| 467 |
token_probs = [torch.softmax(lp, dim=-1) for lp in log_probs]
|
| 468 |
# Extract top token probabilities for each step
|
| 469 |
token_confidences = [tp.max().item() for tp in token_probs]
|
|
@@ -487,7 +526,7 @@ def query_model(
|
|
| 487 |
final_out += f"Context: {context}\nQuery: {query}\n"
|
| 488 |
final_out += f"Response: {response}"
|
| 489 |
return (
|
| 490 |
-
|
| 491 |
f"Confidence: {confidence_score}%\nTime taken: {execution_time:.2f} seconds",
|
| 492 |
)
|
| 493 |
|
|
|
|
| 21 |
restricted_patterns,
|
| 22 |
restricted_topics,
|
| 23 |
FINANCIAL_DATA_PATTERNS,
|
| 24 |
+
FINANCIAL_ENTITY_LABELS,
|
| 25 |
+
GENERAL_KNOWLEDGE_PATTERNS,
|
| 26 |
sensitive_terms,
|
| 27 |
FINANCIAL_TERMS,
|
| 28 |
)
|
|
|
|
| 39 |
# SLM: Microsoft PHI-2 model is loaded
|
| 40 |
# It does have higher memory and compute requirements compared to TinyLlama and Falcon
|
| 41 |
# But it gives the best results among the three
|
| 42 |
+
# DEVICE = "cpu" # or cuda
|
| 43 |
+
DEVICE = "cuda" # or cuda
|
| 44 |
# MODEL_NAME = "TinyLlama/TinyLlama_v1.1"
|
| 45 |
# MODEL_NAME = "tiiuae/falcon-rw-1b"
|
| 46 |
MODEL_NAME = "microsoft/phi-2"
|
|
|
|
| 57 |
# Since the model is to be hosted on a cpu instance, we use float32
|
| 58 |
# For GPU, we can use float16 or bfloat16
|
| 59 |
model = AutoModelForCausalLM.from_pretrained(
|
| 60 |
+
MODEL_NAME, torch_dtype=torch.bfloat16, trust_remote_code=True
|
| 61 |
).to(DEVICE)
|
| 62 |
model.eval()
|
| 63 |
# model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
|
|
|
|
| 236 |
pickle.dump(bm25_data, f)
|
| 237 |
return "Files processed successfully! You can now query."
|
| 238 |
|
| 239 |
+
|
| 240 |
def contains_financial_entities(query):
|
| 241 |
+
"""Check if query contains financial entities"""
|
| 242 |
doc = nlp(query)
|
| 243 |
for ent in doc.ents:
|
| 244 |
if ent.label_ in FINANCIAL_ENTITY_LABELS:
|
| 245 |
return True
|
| 246 |
return False
|
| 247 |
|
| 248 |
+
|
| 249 |
+
def contains_geographical_entities(query):
|
| 250 |
+
"""Check if the query contains geographical entities"""
|
| 251 |
+
doc = nlp(query)
|
| 252 |
+
return any(ent.label_ == "GPE" for ent in doc.ents)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def contains_financial_terms(query):
|
| 256 |
+
"""Check if the query contains financial terms"""
|
| 257 |
+
return any(term in query.lower() for term in FINANCIAL_TERMS)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def is_general_knowledge_query(query):
|
| 261 |
+
"""Check if query contains general knowledge"""
|
| 262 |
+
query_lower = query.lower()
|
| 263 |
+
for pattern in GENERAL_KNOWLEDGE_PATTERNS:
|
| 264 |
+
if re.search(pattern, query_lower):
|
| 265 |
+
return True
|
| 266 |
+
return False
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def is_irrelevant_query(query):
|
| 270 |
+
"""Check if the query is not finance related"""
|
| 271 |
+
# If the query is general knowledge and not finance-related
|
| 272 |
+
if is_general_knowledge_query(query) and not contains_financial_terms(query):
|
| 273 |
+
return True
|
| 274 |
+
# If the query contains only geographical terms without financial entities
|
| 275 |
+
if contains_geographical_entities(query) and not contains_financial_entities(query):
|
| 276 |
+
return True
|
| 277 |
+
return False
|
| 278 |
+
|
| 279 |
+
|
| 280 |
# Input guardrail implementation
|
| 281 |
+
# NER + Regex + List of terms used to filter irrelevant queries
|
| 282 |
# Regex is used to filter queries related to sensitive topics
|
| 283 |
# Uses spaCy model's Named Entity Recognition to filter queries for personal details
|
| 284 |
# Uses cosine similarity with the embedded query and sensitive topic vectors
|
| 285 |
# to filter out queries violating confidential/security rules (additional)
|
| 286 |
def is_query_allowed(query):
|
| 287 |
"""Checks if the query violates security or confidentiality rules"""
|
| 288 |
+
if is_irrelevant_query(query):
|
| 289 |
+
return False, "Query is not finance-related. Please ask a financial question."
|
| 290 |
for pattern in restricted_patterns:
|
| 291 |
if re.search(pattern, query.lower(), re.IGNORECASE):
|
| 292 |
return False, "This query requests sensitive or confidential information."
|
| 293 |
doc = nlp(query)
|
| 294 |
+
# Check if there's a person entity and contains sensitive terms
|
| 295 |
for ent in doc.ents:
|
| 296 |
if ent.label_ == "PERSON":
|
| 297 |
for token in ent.subtree:
|
|
|
|
| 304 |
topic_embeddings = embed_model.encode(
|
| 305 |
list(restricted_topics), normalize_embeddings=True
|
| 306 |
)
|
| 307 |
+
# Check similarities between the restricted topics and the query
|
| 308 |
similarities = np.dot(topic_embeddings, query_embedding)
|
| 309 |
if np.max(similarities) > 0.85:
|
| 310 |
return False, "This query requests sensitive or confidential information."
|
|
|
|
| 408 |
normalized_bm25 = 0.0
|
| 409 |
logger.info(
|
| 410 |
f"Faiss score: {normalized_faiss}, bm25: {normalized_bm25}, "
|
| 411 |
+
f"Mean Top Token + 1-Entropy Avg: {model_conf_signal}"
|
| 412 |
)
|
| 413 |
+
# Weighted sum of all the normalized scores
|
| 414 |
confidence_score = (
|
| 415 |
lambda_faiss * normalized_faiss
|
| 416 |
+ model_conf_signal * lambda_conf
|
|
|
|
| 477 |
"You are a financial analyst. Answer financial queries concisely using only the numerical data "
|
| 478 |
"explicitly present in the provided financial context:\n\n"
|
| 479 |
f"{context}\n\n"
|
| 480 |
+
"Use only the given financial data—do not assume, infer, or generate missing values."
|
| 481 |
+
" Retain the original format of financial figures without conversion."
|
| 482 |
+
" If the requested information is unavailable, respond with 'No relevant financial data available.'"
|
| 483 |
+
" Provide a single-sentence answer without explanations, additional text, or multiple responses."
|
|
|
|
|
|
|
|
|
|
| 484 |
f"\nQuery: {query}"
|
| 485 |
)
|
| 486 |
inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(DEVICE)
|
|
|
|
| 501 |
sequences = output["sequences"][0][input_len:]
|
| 502 |
execution_time = time.perf_counter() - start_time
|
| 503 |
logger.info(f"Query processed in {execution_time:.2f} seconds.")
|
| 504 |
+
# Get the logits per generated token
|
| 505 |
+
log_probs = output["scores"]
|
| 506 |
token_probs = [torch.softmax(lp, dim=-1) for lp in log_probs]
|
| 507 |
# Extract top token probabilities for each step
|
| 508 |
token_confidences = [tp.max().item() for tp in token_probs]
|
|
|
|
| 526 |
final_out += f"Context: {context}\nQuery: {query}\n"
|
| 527 |
final_out += f"Response: {response}"
|
| 528 |
return (
|
| 529 |
+
final_out,
|
| 530 |
f"Confidence: {confidence_score}%\nTime taken: {execution_time:.2f} seconds",
|
| 531 |
)
|
| 532 |
|
data_filters.py
CHANGED
|
@@ -29,6 +29,15 @@ restricted_topics = {
|
|
| 29 |
"financial package",
|
| 30 |
}
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
sensitive_terms = {
|
| 33 |
"salary",
|
| 34 |
"compensation",
|
|
|
|
| 29 |
"financial package",
|
| 30 |
}
|
| 31 |
|
| 32 |
+
FINANCIAL_ENTITY_LABELS = {"MONEY", "PERCENT", "CARDINAL", "ORG"}
|
| 33 |
+
|
| 34 |
+
GENERAL_KNOWLEDGE_PATTERNS = [
|
| 35 |
+
r"\b(?:capital of|where is|who is|when did|what is|history of|define|meaning of|synonym of|antonym of|explain|how does|why is)\b",
|
| 36 |
+
r"\b(?:country|city|continent|leader|president|prime minister|language|currency|population|politics|war|anthem|flag|national animal|national bird|national flower|national sport|monarch|king|queen|ruler|army|military|constitution|government|laws|famous person|historical figure|famous landmark|ocean|mountain|river|lake|climate|weather|culture|tradition|festival|holiday|invention|discovery|science|technology|art|literature|music|religion|mythology|folklore|education|university|school|mathematics|physics|chemistry|biology|philosophy|astronomy|space|planet|star|galaxy|universe|health|medicine|disease|virus|bacteria|genetics|DNA|evolution|ecology|environment|pollution|wildlife|habitat|natural disaster|earthquake|volcano|tsunami|hurricane|storm|flood|drought)\b",
|
| 37 |
+
r"\b(?:[A-Z][a-z]+(?:'s)?\s+(?:capital|president|prime minister|national animal|national bird|national flower|national sport|anthem|flag|currency|language|leader|government|constitution|laws|monarch|king|queen|army|military|famous person|historical figure|landmark|river|ocean|mountain|religion|festival|holiday))\b",
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
sensitive_terms = {
|
| 42 |
"salary",
|
| 43 |
"compensation",
|