Update summarizer/summarizer.py
Browse files- summarizer/summarizer.py +1 -19
summarizer/summarizer.py
CHANGED
@@ -6,36 +6,27 @@ class Summarizer:
|
|
6 |
def __init__(self):
|
7 |
self.api_url = "https://api-inference.huggingface.co/models/Aadityaramrame/carecompanion-summarizer"
|
8 |
self.api_token = os.getenv("API_KEY")
|
9 |
-
|
10 |
if not self.api_token:
|
11 |
raise ValueError("Hugging Face API Key not found in environment variables.")
|
12 |
-
|
13 |
self.headers = {
|
14 |
"Authorization": f"Bearer {self.api_token}",
|
15 |
"Content-Type": "application/json"
|
16 |
}
|
17 |
self.translator = TextTranslator()
|
18 |
-
|
19 |
def clean_text(self, text: str) -> str:
|
20 |
"""Remove unnecessary spaces and line breaks."""
|
21 |
return ' '.join(text.replace('\n', ' ').split())
|
22 |
-
|
23 |
def format_summary(self, summary: str) -> str:
|
24 |
"""Apply simple formatting rules to the model output."""
|
25 |
summary = summary.strip()
|
26 |
-
|
27 |
if summary and not summary[0].isupper():
|
28 |
summary = summary[0].upper() + summary[1:]
|
29 |
-
|
30 |
if "expected to recover within" in summary and not summary.endswith("days."):
|
31 |
summary = summary.rstrip('. ')
|
32 |
summary += " 7–10 days."
|
33 |
-
|
34 |
if "antibiotic" in summary.lower() and "supportive" in summary.lower() and "treatment" not in summary.lower():
|
35 |
summary += " Treatment includes antibiotics and supportive care."
|
36 |
-
|
37 |
return summary
|
38 |
-
|
39 |
def summarize_text(self, text: str, target_lang: str = 'en') -> str:
|
40 |
"""Detect language, translate if needed, summarize, format, and retranslate if needed."""
|
41 |
try:
|
@@ -43,35 +34,26 @@ class Summarizer:
|
|
43 |
detected_lang = self.translator.detect_language(text)
|
44 |
if detected_lang != 'en':
|
45 |
text = self.translator.translate_to_english(text)
|
46 |
-
|
47 |
cleaned_text = self.clean_text(text)
|
48 |
-
|
49 |
payload = {
|
50 |
"inputs": f"summarize the clinical case with diagnosis, comorbidities, and treatment plan: {cleaned_text}"
|
51 |
}
|
52 |
-
|
53 |
response = requests.post(self.api_url, headers=self.headers, json=payload, timeout=30)
|
54 |
response.raise_for_status()
|
55 |
-
|
56 |
response_data = response.json()
|
57 |
-
|
58 |
# Hugging Face returns list of dicts: [{'generated_text': 'summary'}]
|
59 |
if isinstance(response_data, list) and "generated_text" in response_data[0]:
|
60 |
summary = response_data[0]["generated_text"]
|
61 |
else:
|
62 |
return "Unexpected response format from Hugging Face API."
|
63 |
-
|
64 |
formatted_summary = self.format_summary(summary)
|
65 |
-
|
66 |
# Translate back to target language if needed
|
67 |
if target_lang != 'en':
|
68 |
formatted_summary = self.translator.translate_from_english(formatted_summary, target_lang)
|
69 |
-
|
70 |
return formatted_summary
|
71 |
-
|
72 |
except requests.exceptions.Timeout:
|
73 |
return "Summarization request timed out."
|
74 |
except requests.exceptions.RequestException as e:
|
75 |
return f"Summarization request failed: {str(e)}"
|
76 |
except Exception as e:
|
77 |
-
return f"An error occurred during summarization: {str(e)}"
|
|
|
6 |
def __init__(self):
|
7 |
self.api_url = "https://api-inference.huggingface.co/models/Aadityaramrame/carecompanion-summarizer"
|
8 |
self.api_token = os.getenv("API_KEY")
|
|
|
9 |
if not self.api_token:
|
10 |
raise ValueError("Hugging Face API Key not found in environment variables.")
|
|
|
11 |
self.headers = {
|
12 |
"Authorization": f"Bearer {self.api_token}",
|
13 |
"Content-Type": "application/json"
|
14 |
}
|
15 |
self.translator = TextTranslator()
|
|
|
16 |
def clean_text(self, text: str) -> str:
|
17 |
"""Remove unnecessary spaces and line breaks."""
|
18 |
return ' '.join(text.replace('\n', ' ').split())
|
|
|
19 |
def format_summary(self, summary: str) -> str:
|
20 |
"""Apply simple formatting rules to the model output."""
|
21 |
summary = summary.strip()
|
|
|
22 |
if summary and not summary[0].isupper():
|
23 |
summary = summary[0].upper() + summary[1:]
|
|
|
24 |
if "expected to recover within" in summary and not summary.endswith("days."):
|
25 |
summary = summary.rstrip('. ')
|
26 |
summary += " 7–10 days."
|
|
|
27 |
if "antibiotic" in summary.lower() and "supportive" in summary.lower() and "treatment" not in summary.lower():
|
28 |
summary += " Treatment includes antibiotics and supportive care."
|
|
|
29 |
return summary
|
|
|
30 |
def summarize_text(self, text: str, target_lang: str = 'en') -> str:
|
31 |
"""Detect language, translate if needed, summarize, format, and retranslate if needed."""
|
32 |
try:
|
|
|
34 |
detected_lang = self.translator.detect_language(text)
|
35 |
if detected_lang != 'en':
|
36 |
text = self.translator.translate_to_english(text)
|
|
|
37 |
cleaned_text = self.clean_text(text)
|
|
|
38 |
payload = {
|
39 |
"inputs": f"summarize the clinical case with diagnosis, comorbidities, and treatment plan: {cleaned_text}"
|
40 |
}
|
|
|
41 |
response = requests.post(self.api_url, headers=self.headers, json=payload, timeout=30)
|
42 |
response.raise_for_status()
|
|
|
43 |
response_data = response.json()
|
|
|
44 |
# Hugging Face returns list of dicts: [{'generated_text': 'summary'}]
|
45 |
if isinstance(response_data, list) and "generated_text" in response_data[0]:
|
46 |
summary = response_data[0]["generated_text"]
|
47 |
else:
|
48 |
return "Unexpected response format from Hugging Face API."
|
|
|
49 |
formatted_summary = self.format_summary(summary)
|
|
|
50 |
# Translate back to target language if needed
|
51 |
if target_lang != 'en':
|
52 |
formatted_summary = self.translator.translate_from_english(formatted_summary, target_lang)
|
|
|
53 |
return formatted_summary
|
|
|
54 |
except requests.exceptions.Timeout:
|
55 |
return "Summarization request timed out."
|
56 |
except requests.exceptions.RequestException as e:
|
57 |
return f"Summarization request failed: {str(e)}"
|
58 |
except Exception as e:
|
59 |
+
return f"An error occurred during summarization: {str(e)}"
|