|
import torch
|
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
|
|
|
|
|
|
class TextAnalyzer:
|
|
def __init__(self):
|
|
|
|
try:
|
|
model_name = "medicalai/ClinicalBERT"
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
self.model.to(self.device)
|
|
|
|
|
|
self.ner_pipeline = pipeline(
|
|
"ner", model="samrawal/bert-base-uncased_medical-ner"
|
|
)
|
|
print(f"Text model loaded on {self.device}")
|
|
except Exception as e:
|
|
print(f"Error loading text model: {e}")
|
|
|
|
self.model = None
|
|
self.tokenizer = None
|
|
self.ner_pipeline = None
|
|
|
|
def analyze(self, text):
|
|
"""Analyze medical report text and extract key insights"""
|
|
if text.strip() == "":
|
|
return {"Insights": "No text provided"}
|
|
|
|
if self.model is None or self.tokenizer is None:
|
|
|
|
return {
|
|
"Entities": ["fever", "cough"],
|
|
"Sentiment": "Concerning",
|
|
"Key findings": "Patient shows symptoms of respiratory illness",
|
|
}
|
|
|
|
try:
|
|
|
|
if self.ner_pipeline:
|
|
entities = self.ner_pipeline(text)
|
|
unique_entities = list(set([entity["word"] for entity in entities]))
|
|
else:
|
|
unique_entities = []
|
|
|
|
|
|
inputs = self.tokenizer(
|
|
text, return_tensors="pt", padding=True, truncation=True
|
|
).to(self.device)
|
|
with torch.no_grad():
|
|
outputs = self.model(**inputs)
|
|
|
|
|
|
sentiment = (
|
|
"Concerning" if torch.sigmoid(outputs.logits).item() > 0.5 else "Normal"
|
|
)
|
|
|
|
|
|
key_findings = f"Report indicates {'abnormal' if sentiment == 'Concerning' else 'normal'} findings"
|
|
if unique_entities:
|
|
key_findings += f" with mentions of {', '.join(unique_entities[:5])}"
|
|
|
|
return {
|
|
"Entities": unique_entities[:10],
|
|
"Sentiment": sentiment,
|
|
"Key findings": key_findings,
|
|
}
|
|
except Exception as e:
|
|
print(f"Error during text analysis: {e}")
|
|
return {"Error": "Could not analyze text"}
|
|
|