File size: 4,111 Bytes
3426410
49bd15e
 
 
 
3426410
49bd15e
3426410
 
20b6f71
2cb976e
 
3426410
 
2cb976e
 
3426410
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20b6f71
3426410
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34b213c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
# Avoid cache write permission errors in Hugging Face Spaces
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
os.environ["HF_DATASETS_CACHE"] = "/tmp/huggingface/datasets"
import re
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline

class CancerClassifier:
    def __init__(self, model_path="user1729/BiomedBERT-cancer-bert-classifier-v1.0"):
        model = AutoModelForSequenceClassification.from_pretrained(model_path)
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.classifier = pipeline(
            "text-classification",
            model=model,
            tokenizer=tokenizer,
            return_all_scores=True,
            device=0 if os.environ.get("USE_GPU", "false").lower() == "true" else -1,
        )

    def predict(self, text: str):
        results = self.classifier(text)
        return {
            "predicted_labels": ["Non-Cancer", "Cancer"],
            "confidence_scores": {
                "Non-Cancer": results[0][0]["score"],
                "Cancer": results[0][1]["score"],
            },
        }

class CancerExtractor:
    def __init__(self, model_path="alvaroalon2/biobert_diseases_ner"):
        self.extractor = pipeline(
            "ner",
            model=model_path,
            aggregation_strategy="simple",
            device=0 if os.environ.get("USE_GPU", "false").lower() == "true" else -1,
        )
        self.cancers = [
            "cancer",
            "astrocytoma",
            "medulloblastoma",
            "meningioma",
            "neoplasm",
            "carcinoma",
            "tumor",
            "melanoma",
            "mesothelioma",
            "leukemia",
            "lymphoma",
            "sarcomas",
        ]

    def predict(self, text: str):
        results = self.extractor(text)
        extractions = self.extract_diseases(results)
        extractions_cleaned = self.clean_diseases(extractions)
        detections = self.detect_cancer(extractions_cleaned)
        return detections

    def extract_diseases(self, entities):
        entities = self.merge_subwords(entities)
        diseases = [
            entity["word"]
            for entity in entities
            if "disease" in entity["entity_group"].lower()
        ]
        return diseases

    def merge_subwords(self, entities):
        merged_entities = []
        current_entity = None
        for entity in entities:
            if current_entity is None:
                current_entity = entity.copy()
            else:
                # Check if this entity is part of the same word as the previous one
                if (
                    entity["start"] == current_entity["end"]
                    and "disease" in entity["entity_group"].lower()
                    and "disease" in current_entity["entity_group"].lower()
                ):
                    # Merge with previous entity
                    current_entity["word"] += entity["word"].replace("##", "")
                    current_entity["end"] = entity["end"]
                    current_entity["score"] = (
                        current_entity["score"] + entity["score"]
                    ) / 2
                else:
                    merged_entities.append(current_entity)
                    current_entity = entity.copy()

        if current_entity is not None:
            merged_entities.append(current_entity)
        return merged_entities

    def clean_diseases(self, text_list):
        text_list = [re.sub(r"[^a-zA-Z]", " ", t) for t in text_list]
        unique_text = set([t.lower() for t in text_list])  # and (t not in stop_words)
        cleaned_text = [
            t for t in unique_text if (3 <= len(t.strip()) <= 50 and ("##" not in t))
        ]
        return cleaned_text

    def detect_cancer(self, text_list):
        detected_cancers = [
            word2.lower()
            for word2 in text_list
            if any(word1.lower() in word2.lower() for word1 in self.cancers)
        ]
        return set(detected_cancers)