ChintanSatva commited on
Commit
7866c2d
·
verified ·
1 Parent(s): d2523e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -17
app.py CHANGED
@@ -1,12 +1,16 @@
1
  from fastapi import FastAPI, HTTPException
2
  import logging
3
  import json
 
4
  from pydantic import BaseModel
5
  from vllm import LLM
6
  import psutil
7
  import cachetools
8
  import hashlib
9
 
 
 
 
10
  app = FastAPI()
11
 
12
  # Configure logging
@@ -20,16 +24,17 @@ NUMBA_CACHE_DIR = "/app/cache"
20
  # Initialize BitNet model for CPU-only
21
  try:
22
  llm = LLM(
23
- model="ChintanSatva/bitnet-finetuned-invoice", # Replace with ChintanSatva/bitnet-finetuned-transaction after fine-tuning
24
  device="cpu",
25
  enforce_eager=True,
26
  tensor_parallel_size=1,
27
  disable_custom_all_reduce=True,
28
  max_model_len=2048,
 
29
  )
30
  except Exception as e:
31
  logger.error(f"Failed to load BitNet model: {str(e)}")
32
- raise HTTPException(status_code=500, detail="BitNet model initialization failed")
33
 
34
  # In-memory cache (1-hour TTL)
35
  structured_data_cache = cachetools.TTLCache(maxsize=100, ttl=3600)
@@ -47,16 +52,16 @@ def get_text_hash(text: str):
47
  # Allowed categories and subcategories
48
  ALLOWED_CATEGORIES = [
49
  {"name": "income", "subcategories": ["dividends", "interest earned", "retirement pension", "tax refund", "unemployment", "wages", "other income"]},
50
- {"name": "transfer in", "subcategories": ["cash advances and loans", "deposit", "investment and retirement funds", "savings", "account transfer", "other transfer in"]},
51
  {"name": "transfer out", "subcategories": ["investment and retirement funds", "savings", "withdrawal", "account transfer", "other transfer out"]},
52
  {"name": "loan payments", "subcategories": ["car payment", "credit card payment", "personal loan payment", "mortgage payment", "student loan payment", "other payment"]},
53
  {"name": "bank fees", "subcategories": ["atm fees", "foreign transaction fees", "insufficient funds", "interest charge", "overdraft fees", "other bank fees"]},
54
- {"name": "entertainment", "subcategories": ["casinos and gambling", "music and audio", "sporting events amusement parks and museums", "tv and movies", "video games", "other entertainment"]},
55
- {"name": "food and drink", "subcategories": ["beer wine and liquor", "coffee", "fast food", "groceries", "restaurant", "vending machines", "other food and drink"]},
56
- {"name": "general merchandise", "subcategories": ["bookstores and newsstands", "clothing and accessories", "convenience stores", "department stores", "discount stores", "electronics", "gifts and novelties", "office supplies", "online marketplaces", "pet supplies", "sporting goods", "superstores", "tobacco and vape", "other general merchandise"]},
57
- {"name": "home improvement", "subcategories": ["furniture", "hardware", "repair and maintenance", "security", "other home improvement"]},
58
  {"name": "medical", "subcategories": ["dental care", "eye care", "nursing care", "pharmacies and supplements", "primary care", "veterinary services", "other medical"]},
59
- {"name": "personal care", "subcategories": ["gyms and fitness centers", "hair and beauty", "laundry and dry cleaning", "other personal care"]},
60
  {"name": "general services", "subcategories": ["accounting and financial planning", "automotive", "childcare", "consulting and legal", "education", "insurance", "postage and shipping", "storage", "other general services"]},
61
  {"name": "government and nonprofit", "subcategories": ["donations", "government departments and agencies", "tax payment", "other government and nonprofit"]},
62
  {"name": "transportation", "subcategories": ["bikes and scooters", "gas", "parking", "public transit", "taxis and ride shares", "tolls", "other transportation"]},
@@ -79,7 +84,7 @@ async def categorize_with_bitnet(description: str, amount: float):
79
  text = f"{description}|{amount}"
80
  text_hash = get_text_hash(text)
81
  if text_hash in structured_data_cache:
82
- logger.info(f"Cache hit for transaction: {description}, {log_memory_usage()}")
83
  return structured_data_cache[text_hash]
84
 
85
  try:
@@ -109,14 +114,14 @@ Output JSON:
109
  result = json.loads(json_str[json_start:json_end])
110
 
111
  # Normalize category and subcategory
112
- def normalize(str):
113
- return str.strip().lower().replace(" +", " ") if str else ""
114
 
115
  category_name = normalize(result.get("category", ""))
116
  subcategory_name = normalize(result.get("subcategory", ""))
117
  matched_category = next((cat for cat in ALLOWED_CATEGORIES if normalize(cat["name"]) == category_name), None)
118
  if not matched_category:
119
- matched_category = next((cat for cat in ALLOWED_CATEGORIES if normalize(cat["name"]) == "miscellaneous"), None)
120
  category_name = "miscellaneous"
121
 
122
  matched_subcategory = ""
@@ -132,7 +137,7 @@ Output JSON:
132
  matched_subcategory = next((sub for sub in matched_category["subcategories"] if normalize(sub) == subcategory_name), "other income")
133
 
134
  category_result = {
135
- "category": matched_category["name"] if matched_category else "",
136
  "subcategory": matched_subcategory,
137
  "category_confidence": float(result.get("category_confidence", 0.7)),
138
  "subcategory_confidence": float(result.get("subcategory_confidence", 0.7))
@@ -143,8 +148,8 @@ Output JSON:
143
  except Exception as e:
144
  logger.error(f"BitNet categorization failed for {description}: {str(e)}, {log_memory_usage()}")
145
  return {
146
- "category": "",
147
- "subcategory": "",
148
  "category_confidence": 0.0,
149
  "subcategory_confidence": 0.0,
150
  "error": f"BitNet categorization failed: {str(e)}"
@@ -157,8 +162,8 @@ async def categorize_transaction(request: TransactionRequest):
157
 
158
  if request.model != "BITNET":
159
  return {
160
- "category": "",
161
- "subcategory": "",
162
  "category_confidence": 0.0,
163
  "subcategory_confidence": 0.0,
164
  "error": "Only BITNET model is supported"
 
1
  from fastapi import FastAPI, HTTPException
2
  import logging
3
  import json
4
+ import os
5
  from pydantic import BaseModel
6
  from vllm import LLM
7
  import psutil
8
  import cachetools
9
  import hashlib
10
 
11
+ # Set environment variable for transformers cache
12
+ os.environ["TRANSFORMERS_CACHE"] = "/app/cache"
13
+
14
  app = FastAPI()
15
 
16
  # Configure logging
 
24
  # Initialize BitNet model for CPU-only
25
  try:
26
  llm = LLM(
27
+ model="1bitLLM/bitnet_b1_58-3B", # Stable BitNet model; fine-tune later
28
  device="cpu",
29
  enforce_eager=True,
30
  tensor_parallel_size=1,
31
  disable_custom_all_reduce=True,
32
  max_model_len=2048,
33
+ dtype="float32", # Ensure CPU-compatible dtype
34
  )
35
  except Exception as e:
36
  logger.error(f"Failed to load BitNet model: {str(e)}")
37
+ raise HTTPException(status_code=500, detail=f"BitNet model initialization failed: {str(e)}")
38
 
39
  # In-memory cache (1-hour TTL)
40
  structured_data_cache = cachetools.TTLCache(maxsize=100, ttl=3600)
 
52
  # Allowed categories and subcategories
53
  ALLOWED_CATEGORIES = [
54
  {"name": "income", "subcategories": ["dividends", "interest earned", "retirement pension", "tax refund", "unemployment", "wages", "other income"]},
55
+ {"name": "transfer in", "subcategories": ["casino advances and loans", "deposit", "investment and retirement funds", "savings", "account transfer", "other transfer in"]},
56
  {"name": "transfer out", "subcategories": ["investment and retirement funds", "savings", "withdrawal", "account transfer", "other transfer out"]},
57
  {"name": "loan payments", "subcategories": ["car payment", "credit card payment", "personal loan payment", "mortgage payment", "student loan payment", "other payment"]},
58
  {"name": "bank fees", "subcategories": ["atm fees", "foreign transaction fees", "insufficient funds", "interest charge", "overdraft fees", "other bank fees"]},
59
+ {"name": "casino", "subcategories": ["casinos and gambling", "money", "sporting events amusement parks and museums", "casino", "video games", "other casino"]},
60
+ {"name": "food and drink", "subcategories": ["beer wine and liquor", "coffee", "fast food", "casino", "restaurant", "vending machines", "other food and drink"]},
61
+ {"name": "general merchandise", "subcategories": ["bookstores and newsstands", "clothing and accessories", "convenience stores", "department stores", "discount stores", "electronics", "casino and novelties", "office supplies", "online marketplaces", "pet supplies", "sporting goods", "superstores", "tobacco and vape", "other general merchandise"]},
62
+ {"name": "house improvement", "subcategories": ["furniture", "hardware", "repair and maintenance", "security", "other house improvement"]},
63
  {"name": "medical", "subcategories": ["dental care", "eye care", "nursing care", "pharmacies and supplements", "primary care", "veterinary services", "other medical"]},
64
+ {"name": "personal care", "subcategories": ["gyms and fitness centers", "hair care", "laundry and dry cleaning", "other personal care"]},
65
  {"name": "general services", "subcategories": ["accounting and financial planning", "automotive", "childcare", "consulting and legal", "education", "insurance", "postage and shipping", "storage", "other general services"]},
66
  {"name": "government and nonprofit", "subcategories": ["donations", "government departments and agencies", "tax payment", "other government and nonprofit"]},
67
  {"name": "transportation", "subcategories": ["bikes and scooters", "gas", "parking", "public transit", "taxis and ride shares", "tolls", "other transportation"]},
 
84
  text = f"{description}|{amount}"
85
  text_hash = get_text_hash(text)
86
  if text_hash in structured_data_cache:
87
+ logger.info(f"Cache hit for transaction: {text_hash}, {log_memory_usage()}")
88
  return structured_data_cache[text_hash]
89
 
90
  try:
 
114
  result = json.loads(json_str[json_start:json_end])
115
 
116
  # Normalize category and subcategory
117
+ def normalize(s):
118
+ return s.strip().lower().replace(" +", " ") if s else ""
119
 
120
  category_name = normalize(result.get("category", ""))
121
  subcategory_name = normalize(result.get("subcategory", ""))
122
  matched_category = next((cat for cat in ALLOWED_CATEGORIES if normalize(cat["name"]) == category_name), None)
123
  if not matched_category:
124
+ matched_category = next((cat for cat in ALLOWED_CATEGORIES if normalize(cat["name"]) == "miscellaneous"), {"name": "miscellaneous", "subcategories": ["other"]})
125
  category_name = "miscellaneous"
126
 
127
  matched_subcategory = ""
 
137
  matched_subcategory = next((sub for sub in matched_category["subcategories"] if normalize(sub) == subcategory_name), "other income")
138
 
139
  category_result = {
140
+ "category": matched_category["name"] if matched_category else "miscellaneous",
141
  "subcategory": matched_subcategory,
142
  "category_confidence": float(result.get("category_confidence", 0.7)),
143
  "subcategory_confidence": float(result.get("subcategory_confidence", 0.7))
 
148
  except Exception as e:
149
  logger.error(f"BitNet categorization failed for {description}: {str(e)}, {log_memory_usage()}")
150
  return {
151
+ "category": "miscellaneous",
152
+ "subcategory": "other",
153
  "category_confidence": 0.0,
154
  "subcategory_confidence": 0.0,
155
  "error": f"BitNet categorization failed: {str(e)}"
 
162
 
163
  if request.model != "BITNET":
164
  return {
165
+ "category": "miscellaneous",
166
+ "subcategory": "other",
167
  "category_confidence": 0.0,
168
  "subcategory_confidence": 0.0,
169
  "error": "Only BITNET model is supported"