Update app.py
Browse files
app.py
CHANGED
@@ -3,13 +3,14 @@ import logging
|
|
3 |
import json
|
4 |
import os
|
5 |
from pydantic import BaseModel
|
6 |
-
from
|
|
|
7 |
import psutil
|
8 |
import cachetools
|
9 |
import hashlib
|
10 |
|
11 |
-
# Set environment variable for
|
12 |
-
os.environ["
|
13 |
|
14 |
app = FastAPI()
|
15 |
|
@@ -17,21 +18,17 @@ app = FastAPI()
|
|
17 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
18 |
logger = logging.getLogger(__name__)
|
19 |
|
20 |
-
#
|
21 |
-
HF_HOME = "/app/cache"
|
22 |
-
NUMBA_CACHE_DIR = "/app/cache"
|
23 |
-
|
24 |
-
# Initialize BitNet model for CPU-only
|
25 |
try:
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
dtype="float32",
|
34 |
)
|
|
|
35 |
except Exception as e:
|
36 |
logger.error(f"Failed to load BitNet model: {str(e)}")
|
37 |
raise HTTPException(status_code=500, detail=f"BitNet model initialization failed: {str(e)}")
|
@@ -107,11 +104,10 @@ Output JSON:
|
|
107 |
"subcategory_confidence": 0.0
|
108 |
}}
|
109 |
"""
|
110 |
-
outputs =
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
result = json.loads(json_str[json_start:json_end])
|
115 |
|
116 |
# Normalize category and subcategory
|
117 |
def normalize(s):
|
|
|
3 |
import json
|
4 |
import os
|
5 |
from pydantic import BaseModel
|
6 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
7 |
+
import torch
|
8 |
import psutil
|
9 |
import cachetools
|
10 |
import hashlib
|
11 |
|
12 |
+
# Set environment variable for cache
|
13 |
+
os.environ["HF_HOME"] = "/app/cache"
|
14 |
|
15 |
app = FastAPI()
|
16 |
|
|
|
18 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
19 |
logger = logging.getLogger(__name__)
|
20 |
|
21 |
+
# Initialize BitNet model and tokenizer
|
|
|
|
|
|
|
|
|
22 |
try:
|
23 |
+
model_name = "1bitLLM/bitnet_b1_58-3B"
|
24 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="/app/cache")
|
25 |
+
model = AutoModelForCausalLM.from_pretrained(
|
26 |
+
model_name,
|
27 |
+
torch_dtype=torch.float32,
|
28 |
+
device_map="cpu",
|
29 |
+
cache_dir="/app/cache"
|
|
|
30 |
)
|
31 |
+
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512)
|
32 |
except Exception as e:
|
33 |
logger.error(f"Failed to load BitNet model: {str(e)}")
|
34 |
raise HTTPException(status_code=500, detail=f"BitNet model initialization failed: {str(e)}")
|
|
|
104 |
"subcategory_confidence": 0.0
|
105 |
}}
|
106 |
"""
|
107 |
+
outputs = pipe(prompt)[0]["generated_text"]
|
108 |
+
json_start = outputs.rfind("{")
|
109 |
+
json_end = outputs.rfind("}") + 1
|
110 |
+
result = json.loads(outputs[json_start:json_end])
|
|
|
111 |
|
112 |
# Normalize category and subcategory
|
113 |
def normalize(s):
|