File size: 9,054 Bytes
1a27bf2
fd8097f
1a27bf2
7866c2d
1a27bf2
175aecd
c5541a4
fd8097f
 
 
 
c5541a4
 
7866c2d
1a27bf2
fd8097f
 
5203be9
fd8097f
 
c5541a4
fd8097f
c5541a4
458b124
c5541a4
 
3666246
c5541a4
0dd89f7
458b124
3666246
1a27bf2
fd8097f
1a27bf2
7866c2d
fd8097f
1a27bf2
5203be9
fd8097f
 
 
1a27bf2
 
 
 
 
 
 
 
3666246
1a27bf2
 
5bed1d5
1a27bf2
 
 
5bed1d5
 
 
 
1a27bf2
5bed1d5
1a27bf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd8097f
5bed1d5
fd8097f
 
 
175aecd
 
1a27bf2
 
 
 
 
 
 
 
 
175aecd
 
 
 
3666246
175aecd
 
 
 
 
 
 
1a27bf2
 
7866c2d
 
1a27bf2
 
 
 
 
7866c2d
1a27bf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7866c2d
1a27bf2
 
 
 
 
 
 
fd8097f
1a27bf2
 
7866c2d
 
1a27bf2
 
 
 
fd8097f
1a27bf2
 
 
 
 
 
 
7866c2d
 
1a27bf2
 
 
 
cc3cef4
1a27bf2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
from fastapi import FastAPI, HTTPException
import logging
import json
import os
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import psutil
import cachetools
import hashlib

# Set environment variable for cache
os.environ["HF_HOME"] = "/app/cache"

app = FastAPI()

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Initialize BitNet model and tokenizer
try:
    model_name = "1bitLLM/bitnet_b1_58-3B"
    tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer", cache_dir="/app/cache")
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,  # Optimized for CPU
        device_map="cpu",
        low_cpu_mem_usage=True,
        cache_dir="/app/cache",
        trust_remote_code=True
    )
except Exception as e:
    logger.error(f"Failed to load BitNet model: {str(e)}")
    raise HTTPException(status_code=500, detail=f"BitNet model initialization failed: {str(e)}")

# In-memory cache (1-hour TTL)
structured_data_cache = cachetools.TTLCache(maxsize=100, ttl=3600)

def log_memory_usage():
    """Log current memory usage."""
    process = psutil.Process()
    mem_info = process.memory_info()
    return f"Memory usage: {mem_info.rss / 1024 / 1024:.2f} MB"

def get_text_hash(text: str):
    """Generate MD5 hash of text."""
    return hashlib.md5(text.encode('utf-8')).hexdigest()

# Simplified categories (reference only, not in prompt)
ALLOWED_CATEGORIES = [
    {"name": "income", "subcategories": ["dividends", "interest earned", "retirement pension", "tax refund", "unemployment", "wages", "other income"]},
    {"name": "transfer in", "subcategories": ["cash advances and loans", "deposit", "investment and retirement funds", "savings", "account transfer", "other transfer in"]},
    {"name": "transfer out", "subcategories": ["investment and retirement funds", "savings", "withdrawal", "account transfer", "other transfer out"]},
    {"name": "loan payments", "subcategories": ["car payment", "credit card payment", "personal loan payment", "mortgage payment", "student loan payment", "other payment"]},
    {"name": "bank fees", "subcategories": ["atm fees", "foreign transaction fees", "insufficient funds", "interest charge", "overdraft fees", "other bank fees"]},
    {"name": "entertainment", "subcategories": ["casinos and gambling", "music and audio", "sporting events amusement parks and museums", "tv and movies", "video games", "other entertainment"]},
    {"name": "food and drink", "subcategories": ["beer wine and liquor", "coffee", "fast food", "groceries", "restaurant", "vending machines", "other food and drink"]},
    {"name": "general merchandise", "subcategories": ["bookstores and newsstands", "clothing and accessories", "convenience stores", "department stores", "discount stores", "electronics", "gifts and novelties", "office supplies", "online marketplaces", "pet supplies", "sporting goods", "superstores", "tobacco and vape", "other general merchandise"]},
    {"name": "home improvement", "subcategories": ["furniture", "hardware", "repair and maintenance", "security", "other home improvement"]},
    {"name": "medical", "subcategories": ["dental care", "eye care", "nursing care", "pharmacies and supplements", "primary care", "veterinary services", "other medical"]},
    {"name": "personal care", "subcategories": ["gyms and fitness centers", "hair and beauty", "laundry and dry cleaning", "other personal care"]},
    {"name": "general services", "subcategories": ["accounting and financial planning", "automotive", "childcare", "consulting and legal", "education", "insurance", "postage and shipping", "storage", "other general services"]},
    {"name": "government and nonprofit", "subcategories": ["donations", "government departments and agencies", "tax payment", "other government and nonprofit"]},
    {"name": "transportation", "subcategories": ["bikes and scooters", "gas", "parking", "public transit", "taxis and ride shares", "tolls", "other transportation"]},
    {"name": "travel", "subcategories": ["flights", "lodging", "rental cars", "other travel"]},
    {"name": "rent and utilities", "subcategories": ["gas and electricity", "internet and cable", "rent", "sewage and waste management", "telephone", "water", "other utilities"]},
    {"name": "software and technology", "subcategories": ["software subscriptions", "cloud services", "hardware purchases", "online tools", "it support"]}
]

class TransactionRequest(BaseModel):
    description: str
    amount: float
    model: str = "BITNET"
    apiKey: str = None

async def categorize_with_bitnet(description: str, amount: float):
    """Categorize transaction using BitNet."""
    logger.info(f"Processing transaction: {description}, amount: {amount}, {log_memory_usage()}")

    # Create cache key
    text = f"{description}|{amount}"
    text_hash = get_text_hash(text)
    if text_hash in structured_data_cache:
        logger.info(f"Cache hit for transaction: {description}, {log_memory_usage()}")
        return structured_data_cache[text_hash]

    try:
        # Simplified prompt
        prompt = f"""Categorize this transaction into a category and subcategory with confidence scores (0 to 1). Use 'income' for positive amounts. If unsure, use confidence 0.7 and 'miscellaneous'/'other' if no match. Output only JSON.

Description: {description}
Amount: {amount}

{{
  "category": "",
  "subcategory": "",
  "category_confidence": 0.0,
  "subcategory_confidence": 0.0
}}"""
        inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
        outputs = model.generate(
            **inputs,
            max_new_tokens=50,  # Further reduced for speed
            do_sample=False,
            num_beams=1
        )
        json_str = tokenizer.decode(outputs[0], skip_special_tokens=True)
        json_start = json_str.rfind("{")
        json_end = json_str.rfind("}") + 1
        result = json.loads(json_str[json_start:json_end])

        # Normalize category and subcategory
        def normalize(s):
            return s.strip().lower().replace(" +", " ") if s else ""

        category_name = normalize(result.get("category", ""))
        subcategory_name = normalize(result.get("subcategory", ""))
        matched_category = next((cat for cat in ALLOWED_CATEGORIES if normalize(cat["name"]) == category_name), None)
        if not matched_category:
            matched_category = next((cat for cat in ALLOWED_CATEGORIES if normalize(cat["name"]) == "miscellaneous"), {"name": "miscellaneous", "subcategories": ["other"]})
            category_name = "miscellaneous"

        matched_subcategory = ""
        if matched_category:
            matched_subcategory = next((sub for sub in matched_category["subcategories"] if normalize(sub) == subcategory_name), "")
            if not matched_subcategory:
                matched_subcategory = next((sub for sub in matched_category["subcategories"] if "other" in normalize(sub)), matched_category["subcategories"][0])

        # Enforce income for positive amounts
        if amount > 0:
            matched_category = next((cat for cat in ALLOWED_CATEGORIES if cat["name"] == "income"), None)
            category_name = "income"
            matched_subcategory = next((sub for sub in matched_category["subcategories"] if normalize(sub) == subcategory_name), "other income")

        category_result = {
            "category": matched_category["name"] if matched_category else "miscellaneous",
            "subcategory": matched_subcategory,
            "category_confidence": float(result.get("category_confidence", 0.7)),
            "subcategory_confidence": float(result.get("subcategory_confidence", 0.7))
        }
        structured_data_cache[text_hash] = category_result
        logger.info(f"BitNet categorization completed for {description}, {log_memory_usage()}")
        return category_result
    except Exception as e:
        logger.error(f"BitNet categorization failed for {description}: {str(e)}, {log_memory_usage()}")
        return {
            "category": "miscellaneous",
            "subcategory": "other",
            "category_confidence": 0.0,
            "subcategory_confidence": 0.0,
            "error": f"BitNet categorization failed: {str(e)}"
        }

@app.post("/categorize")
async def categorize_transaction(request: TransactionRequest):
    """Categorize a financial transaction."""
    logger.info(f"Received request: description={request.description}, amount={request.amount}, model={request.model}, {log_memory_usage()}")

    if request.model != "BITNET":
        return {
            "category": "miscellaneous",
            "subcategory": "other",
            "category_confidence": 0.0,
            "subcategory_confidence": 0.0,
            "error": "Only BITNET model is supported"
        }

    result = await categorize_with_bitnet(request.description, request.amount)
    return result