user1729 commited on
Commit
2cb976e
·
1 Parent(s): 1a15d27

Modified model loading

Browse files
Files changed (2) hide show
  1. app/main.py +1 -1
  2. app/model.py +6 -3
app/main.py CHANGED
@@ -30,7 +30,7 @@ class BatchResponse(BaseModel):
30
  # Initialize models
31
  try:
32
  logger.info("Loading classification model...")
33
- classification_pipeline = CancerClassifier("user1729/BiomedBERT-cancer-bert-classifier-v1.0")
34
 
35
  logger.info("Loading extraction model...")
36
  extraction_pipeline = CancerExtractor()
 
30
  # Initialize models
31
  try:
32
  logger.info("Loading classification model...")
33
+ classification_pipeline = CancerClassifier()
34
 
35
  logger.info("Loading extraction model...")
36
  extraction_pipeline = CancerExtractor()
app/model.py CHANGED
@@ -1,12 +1,15 @@
1
- from transformers import pipeline
2
  import os
3
  import re
4
 
5
  class CancerClassifier:
6
- def __init__(self, model_path: str):
 
 
7
  self.classifier = pipeline(
8
  "text-classification",
9
- model=model_path,
 
10
  return_all_scores=True,
11
  device=0 if os.environ.get("USE_GPU", "false").lower() == "true" else -1,
12
  )
 
1
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
2
  import os
3
  import re
4
 
5
  class CancerClassifier:
6
+ def __init__(self, model_path: "user1729/BiomedBERT-cancer-bert-classifier-v1.0"):
7
+ model = AutoModelForSequenceClassification.from_pretrained(model_path)
8
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
9
  self.classifier = pipeline(
10
  "text-classification",
11
+ model=model,
12
+ tokenizer=tokenizer,
13
  return_all_scores=True,
14
  device=0 if os.environ.get("USE_GPU", "false").lower() == "true" else -1,
15
  )