Update app.py
Browse files
app.py
CHANGED
@@ -1,429 +1,168 @@
|
|
1 |
-
from fastapi import FastAPI,
|
2 |
-
import pytesseract
|
3 |
-
import cv2
|
4 |
-
import os
|
5 |
-
from PIL import Image
|
6 |
-
import json
|
7 |
-
import unicodedata
|
8 |
-
from pdf2image import convert_from_bytes
|
9 |
-
from pypdf import PdfReader
|
10 |
-
import numpy as np
|
11 |
-
from typing import List
|
12 |
-
import io
|
13 |
import logging
|
14 |
-
import
|
15 |
-
import
|
|
|
16 |
import psutil
|
17 |
import cachetools
|
18 |
import hashlib
|
19 |
|
20 |
-
app = FastAPI(
|
21 |
|
22 |
# Configure logging
|
23 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
24 |
logger = logging.getLogger(__name__)
|
25 |
|
26 |
-
# Set
|
27 |
-
|
|
|
28 |
|
29 |
-
# Initialize
|
30 |
-
llm = None
|
31 |
try:
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
max_length=512)
|
41 |
-
logger.info("Lightweight text generation model loaded successfully")
|
42 |
except Exception as e:
|
43 |
-
logger.error(f"Failed to load
|
44 |
-
|
45 |
|
46 |
-
# In-memory
|
47 |
-
raw_text_cache = cachetools.TTLCache(maxsize=100, ttl=3600)
|
48 |
structured_data_cache = cachetools.TTLCache(maxsize=100, ttl=3600)
|
49 |
|
50 |
def log_memory_usage():
|
51 |
"""Log current memory usage."""
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
"""
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
"""
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
"""Process a single PDF page with OCR."""
|
95 |
-
start_time = time.time()
|
96 |
-
logger.info(f"Starting OCR for PDF page {page_idx}")
|
97 |
-
try:
|
98 |
-
img_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
99 |
-
gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
|
100 |
-
|
101 |
-
# Preprocess image for better OCR
|
102 |
-
gray = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]
|
103 |
-
|
104 |
-
img_pil = Image.fromarray(gray)
|
105 |
-
custom_config = r'--oem 3 --psm 6 -l eng'
|
106 |
-
page_text = pytesseract.image_to_string(img_pil, config=custom_config)
|
107 |
-
|
108 |
-
logger.info(f"Completed OCR for PDF page {page_idx}, took {time.time() - start_time:.2f} seconds")
|
109 |
-
return page_text + "\n"
|
110 |
-
except Exception as e:
|
111 |
-
logger.error(f"OCR failed for PDF page {page_idx}: {str(e)}")
|
112 |
-
return ""
|
113 |
-
|
114 |
-
def rule_based_extraction(raw_text: str):
|
115 |
-
"""Rule-based fallback extraction when LLM is not available."""
|
116 |
-
import re
|
117 |
-
|
118 |
-
# Initialize the structure
|
119 |
-
structured_data = {
|
120 |
-
"invoice": {
|
121 |
-
"invoice_number": {"value": "", "accuracy": 0.0},
|
122 |
-
"invoice_date": {"value": "", "accuracy": 0.0},
|
123 |
-
"due_date": {"value": "", "accuracy": 0.0},
|
124 |
-
"purchase_order_number": {"value": "", "accuracy": 0.0},
|
125 |
-
"vendor": {
|
126 |
-
"vendor_id": {"value": "", "accuracy": 0.0},
|
127 |
-
"name": {"value": "", "accuracy": 0.0},
|
128 |
-
"address": {
|
129 |
-
"line1": {"value": "", "accuracy": 0.0},
|
130 |
-
"line2": {"value": "", "accuracy": 0.0},
|
131 |
-
"city": {"value": "", "accuracy": 0.0},
|
132 |
-
"state": {"value": "", "accuracy": 0.0},
|
133 |
-
"postal_code": {"value": "", "accuracy": 0.0},
|
134 |
-
"country": {"value": "", "accuracy": 0.0}
|
135 |
-
},
|
136 |
-
"contact": {
|
137 |
-
"email": {"value": "", "accuracy": 0.0},
|
138 |
-
"phone": {"value": "", "accuracy": 0.0}
|
139 |
-
},
|
140 |
-
"tax_id": {"value": "", "accuracy": 0.0}
|
141 |
-
},
|
142 |
-
"buyer": {
|
143 |
-
"buyer_id": {"value": "", "accuracy": 0.0},
|
144 |
-
"name": {"value": "", "accuracy": 0.0},
|
145 |
-
"address": {
|
146 |
-
"line1": {"value": "", "accuracy": 0.0},
|
147 |
-
"line2": {"value": "", "accuracy": 0.0},
|
148 |
-
"city": {"value": "", "accuracy": 0.0},
|
149 |
-
"state": {"value": "", "accuracy": 0.0},
|
150 |
-
"postal_code": {"value": "", "accuracy": 0.0},
|
151 |
-
"country": {"value": "", "accuracy": 0.0}
|
152 |
-
},
|
153 |
-
"contact": {
|
154 |
-
"email": {"value": "", "accuracy": 0.0},
|
155 |
-
"phone": {"value": "", "accuracy": 0.0}
|
156 |
-
},
|
157 |
-
"tax_id": {"value": "", "accuracy": 0.0}
|
158 |
-
},
|
159 |
-
"items": [{
|
160 |
-
"item_id": {"value": "", "accuracy": 0.0},
|
161 |
-
"description": {"value": "", "accuracy": 0.0},
|
162 |
-
"quantity": {"value": 0, "accuracy": 0.0},
|
163 |
-
"unit_of_measure": {"value": "", "accuracy": 0.0},
|
164 |
-
"unit_price": {"value": 0, "accuracy": 0.0},
|
165 |
-
"total_price": {"value": 0, "accuracy": 0.0},
|
166 |
-
"tax_rate": {"value": 0, "accuracy": 0.0},
|
167 |
-
"tax_amount": {"value": 0, "accuracy": 0.0},
|
168 |
-
"discount": {"value": 0, "accuracy": 0.0},
|
169 |
-
"net_amount": {"value": 0, "accuracy": 0.0}
|
170 |
-
}],
|
171 |
-
"sub_total": {"value": 0, "accuracy": 0.0},
|
172 |
-
"tax_total": {"value": 0, "accuracy": 0.0},
|
173 |
-
"discount_total": {"value": 0, "accuracy": 0.0},
|
174 |
-
"total_amount": {"value": 0, "accuracy": 0.0},
|
175 |
-
"currency": {"value": "USD", "accuracy": 0.5}
|
176 |
-
}
|
177 |
-
}
|
178 |
-
|
179 |
-
# Simple pattern matching
|
180 |
-
try:
|
181 |
-
# Invoice number
|
182 |
-
inv_pattern = r'(?:invoice|inv)(?:\s*#|\s*no\.?|\s*number)?\s*:?\s*([A-Z0-9\-/]+)'
|
183 |
-
inv_match = re.search(inv_pattern, raw_text, re.IGNORECASE)
|
184 |
-
if inv_match:
|
185 |
-
structured_data["invoice"]["invoice_number"]["value"] = inv_match.group(1)
|
186 |
-
structured_data["invoice"]["invoice_number"]["accuracy"] = 0.7
|
187 |
-
|
188 |
-
# Date patterns
|
189 |
-
date_pattern = r'(\d{1,2}[/-]\d{1,2}[/-]\d{2,4}|\d{4}[/-]\d{1,2}[/-]\d{1,2})'
|
190 |
-
dates = re.findall(date_pattern, raw_text)
|
191 |
-
if dates:
|
192 |
-
structured_data["invoice"]["invoice_date"]["value"] = dates[0]
|
193 |
-
structured_data["invoice"]["invoice_date"]["accuracy"] = 0.6
|
194 |
-
|
195 |
-
# Total amount
|
196 |
-
amount_pattern = r'(?:total|amount|sum)\s*:?\s*\$?(\d+\.?\d*)'
|
197 |
-
amount_match = re.search(amount_pattern, raw_text, re.IGNORECASE)
|
198 |
-
if amount_match:
|
199 |
-
structured_data["invoice"]["total_amount"]["value"] = float(amount_match.group(1))
|
200 |
-
structured_data["invoice"]["total_amount"]["accuracy"] = 0.6
|
201 |
-
|
202 |
-
# Email
|
203 |
-
email_pattern = r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b'
|
204 |
-
email_match = re.search(email_pattern, raw_text)
|
205 |
-
if email_match:
|
206 |
-
structured_data["invoice"]["vendor"]["contact"]["email"]["value"] = email_match.group()
|
207 |
-
structured_data["invoice"]["vendor"]["contact"]["email"]["accuracy"] = 0.8
|
208 |
-
|
209 |
-
# Phone
|
210 |
-
phone_pattern = r'(?:\+?1[-.\s]?)?\(?([0-9]{3})\)?[-.\s]?([0-9]{3})[-.\s]?([0-9]{4})'
|
211 |
-
phone_match = re.search(phone_pattern, raw_text)
|
212 |
-
if phone_match:
|
213 |
-
structured_data["invoice"]["vendor"]["contact"]["phone"]["value"] = phone_match.group()
|
214 |
-
structured_data["invoice"]["vendor"]["contact"]["phone"]["accuracy"] = 0.7
|
215 |
-
|
216 |
-
except Exception as e:
|
217 |
-
logger.error(f"Rule-based extraction error: {str(e)}")
|
218 |
-
|
219 |
-
return structured_data
|
220 |
-
|
221 |
-
async def process_with_model(filename: str, raw_text: str):
|
222 |
-
"""Process raw text with available model or fallback to rule-based."""
|
223 |
-
start_time = time.time()
|
224 |
-
logger.info(f"Starting text processing for {filename}")
|
225 |
-
|
226 |
-
# Check structured data cache
|
227 |
-
text_hash = get_text_hash(raw_text)
|
228 |
if text_hash in structured_data_cache:
|
229 |
-
logger.info(f"
|
230 |
return structured_data_cache[text_hash]
|
231 |
|
232 |
-
# Truncate text
|
233 |
-
if len(raw_text) > 5000:
|
234 |
-
raw_text = raw_text[:5000]
|
235 |
-
logger.info(f"Truncated raw text for {filename} to 5000 characters")
|
236 |
-
|
237 |
try:
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
except Exception as e:
|
275 |
-
logger.error(f"
|
276 |
-
return
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
"status": "active",
|
284 |
-
"llm_available": llm is not None
|
285 |
-
}
|
286 |
-
|
287 |
-
@app.post("/ocr")
|
288 |
-
async def extract_and_structure(files: List[UploadFile] = File(...)):
|
289 |
-
"""Main endpoint for OCR and data extraction."""
|
290 |
-
output_json = {
|
291 |
-
"success": True,
|
292 |
-
"message": "",
|
293 |
-
"data": []
|
294 |
-
}
|
295 |
-
success_count = 0
|
296 |
-
fail_count = 0
|
297 |
-
|
298 |
-
logger.info(f"Starting processing for {len(files)} files")
|
299 |
-
|
300 |
-
for file in files:
|
301 |
-
total_start_time = time.time()
|
302 |
-
logger.info(f"Processing file: {file.filename}")
|
303 |
-
|
304 |
-
# Validate file format
|
305 |
-
valid_extensions = {'.pdf', '.jpg', '.jpeg', '.png'}
|
306 |
-
file_ext = os.path.splitext(file.filename.lower())[1] if file.filename else '.unknown'
|
307 |
-
if file_ext not in valid_extensions:
|
308 |
-
fail_count += 1
|
309 |
-
output_json["data"].append({
|
310 |
-
"filename": file.filename,
|
311 |
-
"structured_data": {"error": f"Unsupported file format: {file_ext}"},
|
312 |
-
"error": f"Unsupported file format: {file_ext}"
|
313 |
-
})
|
314 |
-
logger.error(f"Unsupported file format for {file.filename}: {file_ext}")
|
315 |
-
continue
|
316 |
-
|
317 |
-
# Read file into memory
|
318 |
-
try:
|
319 |
-
file_start_time = time.time()
|
320 |
-
file_bytes = await file.read()
|
321 |
-
file_stream = io.BytesIO(file_bytes)
|
322 |
-
file_hash = get_file_hash(file_bytes)
|
323 |
-
logger.info(f"Read file {file.filename}, size: {len(file_bytes)/1024:.2f} KB")
|
324 |
-
except Exception as e:
|
325 |
-
fail_count += 1
|
326 |
-
output_json["data"].append({
|
327 |
-
"filename": file.filename,
|
328 |
-
"structured_data": {"error": f"Failed to read file: {str(e)}"},
|
329 |
-
"error": f"Failed to read file: {str(e)}"
|
330 |
-
})
|
331 |
-
logger.error(f"Failed to read file {file.filename}: {str(e)}")
|
332 |
-
continue
|
333 |
-
|
334 |
-
# Check raw text cache
|
335 |
-
raw_text = ""
|
336 |
-
if file_hash in raw_text_cache:
|
337 |
-
raw_text = raw_text_cache[file_hash]
|
338 |
-
logger.info(f"Raw text cache hit for {file.filename}")
|
339 |
-
else:
|
340 |
-
if file_ext == '.pdf':
|
341 |
-
# Try extracting embedded text first
|
342 |
-
try:
|
343 |
-
extract_start_time = time.time()
|
344 |
-
reader = PdfReader(file_stream)
|
345 |
-
for page in reader.pages:
|
346 |
-
text = page.extract_text()
|
347 |
-
if text:
|
348 |
-
raw_text += text + "\n"
|
349 |
-
logger.info(f"Embedded text extraction for {file.filename}, text length: {len(raw_text)}")
|
350 |
-
except Exception as e:
|
351 |
-
logger.warning(f"Embedded text extraction failed for {file.filename}: {str(e)}")
|
352 |
-
|
353 |
-
# If no embedded text, perform OCR
|
354 |
-
if not raw_text.strip():
|
355 |
-
try:
|
356 |
-
convert_start_time = time.time()
|
357 |
-
images = convert_from_bytes(file_bytes, dpi=150, first_page=1, last_page=3) # Limit pages
|
358 |
-
logger.info(f"PDF to images conversion for {file.filename}, {len(images)} pages")
|
359 |
-
|
360 |
-
ocr_start_time = time.time()
|
361 |
-
page_texts = []
|
362 |
-
for i, img in enumerate(images):
|
363 |
-
page_text = await process_pdf_page(img, i)
|
364 |
-
page_texts.append(page_text)
|
365 |
-
raw_text = "".join(page_texts)
|
366 |
-
logger.info(f"Total OCR for {file.filename}, text length: {len(raw_text)}")
|
367 |
-
except Exception as e:
|
368 |
-
fail_count += 1
|
369 |
-
output_json["data"].append({
|
370 |
-
"filename": file.filename,
|
371 |
-
"structured_data": {"error": f"OCR failed: {str(e)}"},
|
372 |
-
"error": f"OCR failed: {str(e)}"
|
373 |
-
})
|
374 |
-
logger.error(f"OCR failed for {file.filename}: {str(e)}")
|
375 |
-
continue
|
376 |
-
else: # JPG/JPEG/PNG
|
377 |
-
try:
|
378 |
-
ocr_start_time = time.time()
|
379 |
-
raw_text = await process_image(file_bytes, file.filename, 0)
|
380 |
-
logger.info(f"Image OCR for {file.filename}, text length: {len(raw_text)}")
|
381 |
-
except Exception as e:
|
382 |
-
fail_count += 1
|
383 |
-
output_json["data"].append({
|
384 |
-
"filename": file.filename,
|
385 |
-
"structured_data": {"error": f"Image OCR failed: {str(e)}"},
|
386 |
-
"error": f"Image OCR failed: {str(e)}"
|
387 |
-
})
|
388 |
-
logger.error(f"Image OCR failed for {file.filename}: {str(e)}")
|
389 |
-
continue
|
390 |
-
|
391 |
-
# Normalize text
|
392 |
-
try:
|
393 |
-
raw_text = unicodedata.normalize('NFKC', raw_text)
|
394 |
-
raw_text = raw_text.encode('utf-8', errors='ignore').decode('utf-8')
|
395 |
-
raw_text_cache[file_hash] = raw_text
|
396 |
-
logger.info(f"Text normalization for {file.filename} completed")
|
397 |
-
except Exception as e:
|
398 |
-
logger.warning(f"Text normalization failed for {file.filename}: {str(e)}")
|
399 |
-
|
400 |
-
# Process with model or rule-based extraction
|
401 |
-
if raw_text.strip():
|
402 |
-
structured_data = await process_with_model(file.filename, raw_text)
|
403 |
-
success_count += 1
|
404 |
-
output_json["data"].append({
|
405 |
-
"filename": file.filename,
|
406 |
-
"structured_data": structured_data,
|
407 |
-
"raw_text": raw_text[:500] + "..." if len(raw_text) > 500 else raw_text, # Include snippet
|
408 |
-
"error": ""
|
409 |
-
})
|
410 |
-
else:
|
411 |
-
fail_count += 1
|
412 |
-
output_json["data"].append({
|
413 |
-
"filename": file.filename,
|
414 |
-
"structured_data": {"error": "No text extracted from file"},
|
415 |
-
"error": "No text extracted from file"
|
416 |
-
})
|
417 |
-
|
418 |
-
logger.info(f"Total processing for {file.filename} completed in {time.time() - total_start_time:.2f} seconds")
|
419 |
-
|
420 |
-
output_json["message"] = f"Processed {len(files)} files. {success_count} succeeded, {fail_count} failed."
|
421 |
-
if fail_count > 0 and success_count == 0:
|
422 |
-
output_json["success"] = False
|
423 |
|
424 |
-
|
425 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
426 |
|
427 |
-
|
428 |
-
|
429 |
-
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import logging
|
3 |
+
import json
|
4 |
+
from pydantic import BaseModel
|
5 |
+
from vllm import LLM
|
6 |
import psutil
|
7 |
import cachetools
|
8 |
import hashlib
|
9 |
|
10 |
+
app = FastAPI()
|
11 |
|
12 |
# Configure logging
|
13 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
14 |
logger = logging.getLogger(__name__)
|
15 |
|
16 |
+
# Set cache directories
|
17 |
+
HF_HOME = "/app/cache"
|
18 |
+
NUMBA_CACHE_DIR = "/app/cache"
|
19 |
|
20 |
+
# Initialize BitNet model for CPU-only
|
|
|
21 |
try:
|
22 |
+
llm = LLM(
|
23 |
+
model="ChintanSatva/bitnet-finetuned-invoice", # Replace with ChintanSatva/bitnet-finetuned-transaction after fine-tuning
|
24 |
+
device="cpu",
|
25 |
+
enforce_eager=True,
|
26 |
+
tensor_parallel_size=1,
|
27 |
+
disable_custom_all_reduce=True,
|
28 |
+
max_model_len=2048,
|
29 |
+
)
|
|
|
|
|
30 |
except Exception as e:
|
31 |
+
logger.error(f"Failed to load BitNet model: {str(e)}")
|
32 |
+
raise HTTPException(status_code=500, detail="BitNet model initialization failed")
|
33 |
|
34 |
+
# In-memory cache (1-hour TTL)
|
|
|
35 |
structured_data_cache = cachetools.TTLCache(maxsize=100, ttl=3600)
|
36 |
|
37 |
def log_memory_usage():
|
38 |
"""Log current memory usage."""
|
39 |
+
process = psutil.Process()
|
40 |
+
mem_info = process.memory_info()
|
41 |
+
return f"Memory usage: {mem_info.rss / 1024 / 1024:.2f} MB"
|
42 |
+
|
43 |
+
def get_text_hash(text: str):
|
44 |
+
"""Generate MD5 hash of text."""
|
45 |
+
return hashlib.md5(text.encode('utf-8')).hexdigest()
|
46 |
+
|
47 |
+
# Allowed categories and subcategories
|
48 |
+
ALLOWED_CATEGORIES = [
|
49 |
+
{"name": "income", "subcategories": ["dividends", "interest earned", "retirement pension", "tax refund", "unemployment", "wages", "other income"]},
|
50 |
+
{"name": "transfer in", "subcategories": ["cash advances and loans", "deposit", "investment and retirement funds", "savings", "account transfer", "other transfer in"]},
|
51 |
+
{"name": "transfer out", "subcategories": ["investment and retirement funds", "savings", "withdrawal", "account transfer", "other transfer out"]},
|
52 |
+
{"name": "loan payments", "subcategories": ["car payment", "credit card payment", "personal loan payment", "mortgage payment", "student loan payment", "other payment"]},
|
53 |
+
{"name": "bank fees", "subcategories": ["atm fees", "foreign transaction fees", "insufficient funds", "interest charge", "overdraft fees", "other bank fees"]},
|
54 |
+
{"name": "entertainment", "subcategories": ["casinos and gambling", "music and audio", "sporting events amusement parks and museums", "tv and movies", "video games", "other entertainment"]},
|
55 |
+
{"name": "food and drink", "subcategories": ["beer wine and liquor", "coffee", "fast food", "groceries", "restaurant", "vending machines", "other food and drink"]},
|
56 |
+
{"name": "general merchandise", "subcategories": ["bookstores and newsstands", "clothing and accessories", "convenience stores", "department stores", "discount stores", "electronics", "gifts and novelties", "office supplies", "online marketplaces", "pet supplies", "sporting goods", "superstores", "tobacco and vape", "other general merchandise"]},
|
57 |
+
{"name": "home improvement", "subcategories": ["furniture", "hardware", "repair and maintenance", "security", "other home improvement"]},
|
58 |
+
{"name": "medical", "subcategories": ["dental care", "eye care", "nursing care", "pharmacies and supplements", "primary care", "veterinary services", "other medical"]},
|
59 |
+
{"name": "personal care", "subcategories": ["gyms and fitness centers", "hair and beauty", "laundry and dry cleaning", "other personal care"]},
|
60 |
+
{"name": "general services", "subcategories": ["accounting and financial planning", "automotive", "childcare", "consulting and legal", "education", "insurance", "postage and shipping", "storage", "other general services"]},
|
61 |
+
{"name": "government and nonprofit", "subcategories": ["donations", "government departments and agencies", "tax payment", "other government and nonprofit"]},
|
62 |
+
{"name": "transportation", "subcategories": ["bikes and scooters", "gas", "parking", "public transit", "taxis and ride shares", "tolls", "other transportation"]},
|
63 |
+
{"name": "travel", "subcategories": ["flights", "lodging", "rental cars", "other travel"]},
|
64 |
+
{"name": "rent and utilities", "subcategories": ["gas and electricity", "internet and cable", "rent", "sewage and waste management", "telephone", "water", "other utilities"]},
|
65 |
+
{"name": "software and technology", "subcategories": ["software subscriptions", "cloud services", "hardware purchases", "online tools", "it support"]}
|
66 |
+
]
|
67 |
+
|
68 |
+
class TransactionRequest(BaseModel):
|
69 |
+
description: str
|
70 |
+
amount: float
|
71 |
+
model: str = "BITNET"
|
72 |
+
apiKey: str = None
|
73 |
+
|
74 |
+
async def categorize_with_bitnet(description: str, amount: float):
|
75 |
+
"""Categorize transaction using BitNet."""
|
76 |
+
logger.info(f"Processing transaction: {description}, amount: {amount}, {log_memory_usage()}")
|
77 |
+
|
78 |
+
# Create cache key
|
79 |
+
text = f"{description}|{amount}"
|
80 |
+
text_hash = get_text_hash(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
if text_hash in structured_data_cache:
|
82 |
+
logger.info(f"Cache hit for transaction: {description}, {log_memory_usage()}")
|
83 |
return structured_data_cache[text_hash]
|
84 |
|
|
|
|
|
|
|
|
|
|
|
85 |
try:
|
86 |
+
prompt = f"""You are an expert financial transaction categorizer using BitNet b1.2-3B. Given a transaction description and amount, categorize it into the specified categories and subcategories. Assign confidence scores (0 to 1). Follow these rules:
|
87 |
+
- Select category and subcategory from this list (case-insensitive, use exact names):
|
88 |
+
{', '.join([f'{c["name"]} ({", ".join(c["subcategories"])})' for c in ALLOWED_CATEGORIES])}
|
89 |
+
- For positive amounts, use 'income' and one of its subcategories.
|
90 |
+
- If unsure, set confidence to 0.7.
|
91 |
+
- If no match, use 'miscellaneous' and 'other'.
|
92 |
+
- Do NOT add markdown or explanations, only output valid JSON.
|
93 |
+
|
94 |
+
Description: {description}
|
95 |
+
Amount: {amount}
|
96 |
+
|
97 |
+
Output JSON:
|
98 |
+
{{
|
99 |
+
"category": "",
|
100 |
+
"subcategory": "",
|
101 |
+
"category_confidence": 0.0,
|
102 |
+
"subcategory_confidence": 0.0
|
103 |
+
}}
|
104 |
+
"""
|
105 |
+
outputs = llm.generate(prompts=[prompt])
|
106 |
+
json_str = outputs[0].outputs[0].text
|
107 |
+
json_start = json_str.find("{")
|
108 |
+
json_end = json_str.rfind("}") + 1
|
109 |
+
result = json.loads(json_str[json_start:json_end])
|
110 |
+
|
111 |
+
# Normalize category and subcategory
|
112 |
+
def normalize(str):
|
113 |
+
return str.strip().lower().replace(" +", " ") if str else ""
|
114 |
+
|
115 |
+
category_name = normalize(result.get("category", ""))
|
116 |
+
subcategory_name = normalize(result.get("subcategory", ""))
|
117 |
+
matched_category = next((cat for cat in ALLOWED_CATEGORIES if normalize(cat["name"]) == category_name), None)
|
118 |
+
if not matched_category:
|
119 |
+
matched_category = next((cat for cat in ALLOWED_CATEGORIES if normalize(cat["name"]) == "miscellaneous"), None)
|
120 |
+
category_name = "miscellaneous"
|
121 |
+
|
122 |
+
matched_subcategory = ""
|
123 |
+
if matched_category:
|
124 |
+
matched_subcategory = next((sub for sub in matched_category["subcategories"] if normalize(sub) == subcategory_name), "")
|
125 |
+
if not matched_subcategory:
|
126 |
+
matched_subcategory = next((sub for sub in matched_category["subcategories"] if "other" in normalize(sub)), matched_category["subcategories"][0])
|
127 |
+
|
128 |
+
# Enforce income for positive amounts
|
129 |
+
if amount > 0:
|
130 |
+
matched_category = next((cat for cat in ALLOWED_CATEGORIES if cat["name"] == "income"), None)
|
131 |
+
category_name = "income"
|
132 |
+
matched_subcategory = next((sub for sub in matched_category["subcategories"] if normalize(sub) == subcategory_name), "other income")
|
133 |
+
|
134 |
+
category_result = {
|
135 |
+
"category": matched_category["name"] if matched_category else "",
|
136 |
+
"subcategory": matched_subcategory,
|
137 |
+
"category_confidence": float(result.get("category_confidence", 0.7)),
|
138 |
+
"subcategory_confidence": float(result.get("subcategory_confidence", 0.7))
|
139 |
+
}
|
140 |
+
structured_data_cache[text_hash] = category_result
|
141 |
+
logger.info(f"BitNet categorization completed for {description}, {log_memory_usage()}")
|
142 |
+
return category_result
|
143 |
except Exception as e:
|
144 |
+
logger.error(f"BitNet categorization failed for {description}: {str(e)}, {log_memory_usage()}")
|
145 |
+
return {
|
146 |
+
"category": "",
|
147 |
+
"subcategory": "",
|
148 |
+
"category_confidence": 0.0,
|
149 |
+
"subcategory_confidence": 0.0,
|
150 |
+
"error": f"BitNet categorization failed: {str(e)}"
|
151 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
|
153 |
+
@app.post("/categorize")
|
154 |
+
async def categorize_transaction(request: TransactionRequest):
|
155 |
+
"""Categorize a financial transaction."""
|
156 |
+
logger.info(f"Received request: description={request.description}, amount={request.amount}, model={request.model}, {log_memory_usage()}")
|
157 |
+
|
158 |
+
if request.model != "BITNET":
|
159 |
+
return {
|
160 |
+
"category": "",
|
161 |
+
"subcategory": "",
|
162 |
+
"category_confidence": 0.0,
|
163 |
+
"subcategory_confidence": 0.0,
|
164 |
+
"error": "Only BITNET model is supported"
|
165 |
+
}
|
166 |
|
167 |
+
result = await categorize_with_bitnet(request.description, request.amount)
|
168 |
+
return result
|
|