ChintanSatva commited on
Commit
c5541a4
·
verified ·
1 Parent(s): cde6b17

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -21
app.py CHANGED
@@ -3,13 +3,14 @@ import logging
3
  import json
4
  import os
5
  from pydantic import BaseModel
6
- from vllm import LLM
 
7
  import psutil
8
  import cachetools
9
  import hashlib
10
 
11
- # Set environment variable for transformers cache
12
- os.environ["TRANSFORMERS_CACHE"] = "/app/cache"
13
 
14
  app = FastAPI()
15
 
@@ -17,21 +18,17 @@ app = FastAPI()
17
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
18
  logger = logging.getLogger(__name__)
19
 
20
- # Set cache directories
21
- HF_HOME = "/app/cache"
22
- NUMBA_CACHE_DIR = "/app/cache"
23
-
24
- # Initialize BitNet model for CPU-only
25
  try:
26
- llm = LLM(
27
- model="1bitLLM/bitnet_b1_58-3B",
28
- device="cpu",
29
- enforce_eager=True,
30
- tensor_parallel_size=1,
31
- disable_custom_all_reduce=True,
32
- max_model_len=2048,
33
- dtype="float32",
34
  )
 
35
  except Exception as e:
36
  logger.error(f"Failed to load BitNet model: {str(e)}")
37
  raise HTTPException(status_code=500, detail=f"BitNet model initialization failed: {str(e)}")
@@ -107,11 +104,10 @@ Output JSON:
107
  "subcategory_confidence": 0.0
108
  }}
109
  """
110
- outputs = llm.generate(prompts=[prompt])
111
- json_str = outputs[0].outputs[0].text
112
- json_start = json_str.find("{")
113
- json_end = json_str.rfind("}") + 1
114
- result = json.loads(json_str[json_start:json_end])
115
 
116
  # Normalize category and subcategory
117
  def normalize(s):
 
3
  import json
4
  import os
5
  from pydantic import BaseModel
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
7
+ import torch
8
  import psutil
9
  import cachetools
10
  import hashlib
11
 
12
+ # Set environment variable for cache
13
+ os.environ["HF_HOME"] = "/app/cache"
14
 
15
  app = FastAPI()
16
 
 
18
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
19
  logger = logging.getLogger(__name__)
20
 
21
+ # Initialize BitNet model and tokenizer
 
 
 
 
22
  try:
23
+ model_name = "1bitLLM/bitnet_b1_58-3B"
24
+ tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="/app/cache")
25
+ model = AutoModelForCausalLM.from_pretrained(
26
+ model_name,
27
+ torch_dtype=torch.float32,
28
+ device_map="cpu",
29
+ cache_dir="/app/cache"
 
30
  )
31
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512)
32
  except Exception as e:
33
  logger.error(f"Failed to load BitNet model: {str(e)}")
34
  raise HTTPException(status_code=500, detail=f"BitNet model initialization failed: {str(e)}")
 
104
  "subcategory_confidence": 0.0
105
  }}
106
  """
107
+ outputs = pipe(prompt)[0]["generated_text"]
108
+ json_start = outputs.rfind("{")
109
+ json_end = outputs.rfind("}") + 1
110
+ result = json.loads(outputs[json_start:json_end])
 
111
 
112
  # Normalize category and subcategory
113
  def normalize(s):