|
import os |
|
|
|
os.environ["HF_HOME"] = "/tmp/huggingface" |
|
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers" |
|
os.environ["HF_DATASETS_CACHE"] = "/tmp/huggingface/datasets" |
|
import re |
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline |
|
|
|
class CancerClassifier: |
|
def __init__(self, model_path="user1729/BiomedBERT-cancer-bert-classifier-v1.0"): |
|
model = AutoModelForSequenceClassification.from_pretrained(model_path) |
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
self.classifier = pipeline( |
|
"text-classification", |
|
model=model, |
|
tokenizer=tokenizer, |
|
return_all_scores=True, |
|
device=0 if os.environ.get("USE_GPU", "false").lower() == "true" else -1, |
|
) |
|
|
|
def predict(self, text: str): |
|
results = self.classifier(text) |
|
return { |
|
"predicted_labels": ["Non-Cancer", "Cancer"], |
|
"confidence_scores": { |
|
"Non-Cancer": results[0][0]["score"], |
|
"Cancer": results[0][1]["score"], |
|
}, |
|
} |
|
|
|
class CancerExtractor: |
|
def __init__(self, model_path="alvaroalon2/biobert_diseases_ner"): |
|
self.extractor = pipeline( |
|
"ner", |
|
model=model_path, |
|
aggregation_strategy="simple", |
|
device=0 if os.environ.get("USE_GPU", "false").lower() == "true" else -1, |
|
) |
|
self.cancers = [ |
|
"cancer", |
|
"astrocytoma", |
|
"medulloblastoma", |
|
"meningioma", |
|
"neoplasm", |
|
"carcinoma", |
|
"tumor", |
|
"melanoma", |
|
"mesothelioma", |
|
"leukemia", |
|
"lymphoma", |
|
"sarcomas", |
|
] |
|
|
|
def predict(self, text: str): |
|
results = self.extractor(text) |
|
extractions = self.extract_diseases(results) |
|
extractions_cleaned = self.clean_diseases(extractions) |
|
detections = self.detect_cancer(extractions_cleaned) |
|
return detections |
|
|
|
def extract_diseases(self, entities): |
|
entities = self.merge_subwords(entities) |
|
diseases = [ |
|
entity["word"] |
|
for entity in entities |
|
if "disease" in entity["entity_group"].lower() |
|
] |
|
return diseases |
|
|
|
def merge_subwords(self, entities): |
|
merged_entities = [] |
|
current_entity = None |
|
for entity in entities: |
|
if current_entity is None: |
|
current_entity = entity.copy() |
|
else: |
|
|
|
if ( |
|
entity["start"] == current_entity["end"] |
|
and "disease" in entity["entity_group"].lower() |
|
and "disease" in current_entity["entity_group"].lower() |
|
): |
|
|
|
current_entity["word"] += entity["word"].replace("##", "") |
|
current_entity["end"] = entity["end"] |
|
current_entity["score"] = ( |
|
current_entity["score"] + entity["score"] |
|
) / 2 |
|
else: |
|
merged_entities.append(current_entity) |
|
current_entity = entity.copy() |
|
|
|
if current_entity is not None: |
|
merged_entities.append(current_entity) |
|
return merged_entities |
|
|
|
def clean_diseases(self, text_list): |
|
text_list = [re.sub(r"[^a-zA-Z]", " ", t) for t in text_list] |
|
unique_text = set([t.lower() for t in text_list]) |
|
cleaned_text = [ |
|
t for t in unique_text if (3 <= len(t.strip()) <= 50 and ("##" not in t)) |
|
] |
|
return cleaned_text |
|
|
|
def detect_cancer(self, text_list): |
|
detected_cancers = [ |
|
word2.lower() |
|
for word2 in text_list |
|
if any(word1.lower() in word2.lower() for word1 in self.cancers) |
|
] |
|
return set(detected_cancers) |
|
|