Spaces:
Running
Running
Update posts_categorization.py
Browse files- posts_categorization.py +120 -45
posts_categorization.py
CHANGED
@@ -4,34 +4,59 @@ import instructor
|
|
4 |
from pydantic import BaseModel
|
5 |
import os
|
6 |
|
|
|
7 |
api_key = os.getenv('GROQ_API_KEY')
|
8 |
|
|
|
|
|
|
|
9 |
# Create single patched Groq client with instructor for structured output
|
|
|
10 |
client = instructor.from_groq(Groq(api_key=api_key), mode=instructor.Mode.JSON)
|
11 |
|
|
|
12 |
class SummaryOutput(BaseModel):
|
13 |
summary: str
|
14 |
|
15 |
-
#
|
16 |
class ClassificationOutput(BaseModel):
|
17 |
category: str
|
18 |
|
|
|
19 |
PRIMARY_SUMMARIZER_MODEL = "deepseek-r1-distill-llama-70b"
|
20 |
FALLBACK_SUMMARIZER_MODEL = "llama-3.3-70b-versatile"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
25 |
return None
|
26 |
|
27 |
-
|
|
|
28 |
|
29 |
prompt = f"""
|
30 |
Summarize the following LinkedIn post in 5 to 10 words.
|
31 |
Only return the summary inside a JSON field called 'summary'.
|
32 |
|
33 |
Post Text:
|
34 |
-
\"\"\"{
|
35 |
"""
|
36 |
|
37 |
try:
|
@@ -43,7 +68,7 @@ def summarize_post(text):
|
|
43 |
messages=[
|
44 |
{"role": "system", "content": "You are a precise summarizer. Only return a JSON object with a 'summary' string."},
|
45 |
{"role": "user", "content": prompt}
|
46 |
-
|
47 |
temperature=0.3
|
48 |
)
|
49 |
return response.summary
|
@@ -57,8 +82,8 @@ def summarize_post(text):
|
|
57 |
messages=[
|
58 |
{"role": "system", "content": "You are a precise summarizer. Only return a JSON object with a 'summary' string."},
|
59 |
{"role": "user", "content": prompt}
|
60 |
-
|
61 |
-
temperature=0.3
|
62 |
)
|
63 |
print(f"Summarization successful with fallback model: {FALLBACK_SUMMARIZER_MODEL}")
|
64 |
return response.summary
|
@@ -70,63 +95,113 @@ def summarize_post(text):
|
|
70 |
return None
|
71 |
except Exception as e_primary:
|
72 |
print(f"Error during summarization with primary model ({PRIMARY_SUMMARIZER_MODEL}): {e_primary}")
|
73 |
-
#
|
74 |
return None
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
-
|
78 |
-
|
79 |
-
def classify_post(summary, labels):
|
80 |
-
if pd.isna(summary) or summary is None:
|
81 |
-
return None
|
82 |
|
83 |
prompt = f"""
|
84 |
Post Summary: "{summary}"
|
85 |
|
86 |
Available Categories:
|
87 |
-
{'
|
88 |
|
89 |
-
Task: Choose the single most relevant category from the list above that applies to this summary.
|
90 |
-
|
|
|
|
|
91 |
"""
|
92 |
try:
|
|
|
|
|
|
|
|
|
|
|
93 |
result = client.chat.completions.create(
|
94 |
-
model=
|
95 |
response_model=ClassificationOutput,
|
96 |
messages=[
|
97 |
-
{"role": "system", "content":
|
98 |
{"role": "user", "content": prompt}
|
99 |
],
|
100 |
-
temperature=0
|
101 |
)
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
except Exception as e:
|
104 |
-
print(f"Classification error: {e}")
|
105 |
-
return None
|
106 |
|
107 |
-
def summarize_and_classify_post(text, labels):
|
108 |
-
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
return {
|
111 |
-
"summary": summary,
|
112 |
-
"category": category
|
113 |
}
|
114 |
|
115 |
-
def batch_summarize_and_classify(
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
"
|
121 |
-
"Diversity, Equity, and Inclusion (DEI)",
|
122 |
-
"Professional Development and Growth Opportunities",
|
123 |
-
"Mission, Vision, and Social Responsibility",
|
124 |
-
"None"
|
125 |
-
]
|
126 |
|
127 |
results = []
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
from pydantic import BaseModel
|
5 |
import os
|
6 |
|
7 |
+
# Ensure GROQ_API_KEY is set in your environment variables
|
8 |
api_key = os.getenv('GROQ_API_KEY')
|
9 |
|
10 |
+
if not api_key:
|
11 |
+
raise ValueError("GROQ_API_KEY environment variable not set.")
|
12 |
+
|
13 |
# Create single patched Groq client with instructor for structured output
|
14 |
+
# Using Mode.JSON for structured output based on Pydantic models
|
15 |
client = instructor.from_groq(Groq(api_key=api_key), mode=instructor.Mode.JSON)
|
16 |
|
17 |
+
# Pydantic model for summarization output
|
18 |
class SummaryOutput(BaseModel):
|
19 |
summary: str
|
20 |
|
21 |
+
# Pydantic model for classification output
|
22 |
class ClassificationOutput(BaseModel):
|
23 |
category: str
|
24 |
|
25 |
+
# Define model names (as per your original code)
|
26 |
PRIMARY_SUMMARIZER_MODEL = "deepseek-r1-distill-llama-70b"
|
27 |
FALLBACK_SUMMARIZER_MODEL = "llama-3.3-70b-versatile"
|
28 |
+
CLASSIFICATION_MODEL = "meta-llama/llama-4-maverick-17b-128e-instruct" # Or your preferred classification model
|
29 |
+
|
30 |
+
# Define the standard list of categories, including "None"
|
31 |
+
CLASSIFICATION_LABELS = [
|
32 |
+
"Company Culture and Values",
|
33 |
+
"Employee Stories and Spotlights",
|
34 |
+
"Work-Life Balance, Flexibility, and Well-being",
|
35 |
+
"Diversity, Equity, and Inclusion (DEI)",
|
36 |
+
"Professional Development and Growth Opportunities",
|
37 |
+
"Mission, Vision, and Social Responsibility",
|
38 |
+
"None" # Represents no applicable category or cases where classification isn't possible
|
39 |
+
]
|
40 |
|
41 |
+
def summarize_post(text: str) -> str | None:
|
42 |
+
"""
|
43 |
+
Summarizes the given post text using a primary model with a fallback.
|
44 |
+
Returns the summary string or None if summarization fails or input is invalid.
|
45 |
+
"""
|
46 |
+
# Check for NaN, None, or empty/whitespace-only string
|
47 |
+
if pd.isna(text) or text is None or not str(text).strip():
|
48 |
+
print("Summarizer: Input text is empty or None. Returning None.")
|
49 |
return None
|
50 |
|
51 |
+
# Truncate text to a reasonable length to avoid token overflow and reduce costs
|
52 |
+
processed_text = str(text)[:500]
|
53 |
|
54 |
prompt = f"""
|
55 |
Summarize the following LinkedIn post in 5 to 10 words.
|
56 |
Only return the summary inside a JSON field called 'summary'.
|
57 |
|
58 |
Post Text:
|
59 |
+
\"\"\"{processed_text}\"\"\"
|
60 |
"""
|
61 |
|
62 |
try:
|
|
|
68 |
messages=[
|
69 |
{"role": "system", "content": "You are a precise summarizer. Only return a JSON object with a 'summary' string."},
|
70 |
{"role": "user", "content": prompt}
|
71 |
+
],
|
72 |
temperature=0.3
|
73 |
)
|
74 |
return response.summary
|
|
|
82 |
messages=[
|
83 |
{"role": "system", "content": "You are a precise summarizer. Only return a JSON object with a 'summary' string."},
|
84 |
{"role": "user", "content": prompt}
|
85 |
+
],
|
86 |
+
temperature=0.3
|
87 |
)
|
88 |
print(f"Summarization successful with fallback model: {FALLBACK_SUMMARIZER_MODEL}")
|
89 |
return response.summary
|
|
|
95 |
return None
|
96 |
except Exception as e_primary:
|
97 |
print(f"Error during summarization with primary model ({PRIMARY_SUMMARIZER_MODEL}): {e_primary}")
|
98 |
+
# Consider if fallback should be attempted for other errors too, or just return None
|
99 |
return None
|
100 |
|
101 |
+
def classify_post(summary: str | None, labels: list[str]) -> str:
|
102 |
+
"""
|
103 |
+
Classifies the post summary into one of the provided labels.
|
104 |
+
Ensures the returned category is one of the labels, defaulting to "None".
|
105 |
+
"""
|
106 |
+
# If the summary is None (e.g., from a failed summarization or empty input),
|
107 |
+
# or if the summary is an empty string after stripping, classify as "None".
|
108 |
+
if pd.isna(summary) or summary is None or not str(summary).strip():
|
109 |
+
print("Classifier: Input summary is empty or None. Returning 'None' category.")
|
110 |
+
return "None" # Return the string "None" to match the label
|
111 |
|
112 |
+
# Join labels for the prompt to ensure the LLM knows the exact expected strings
|
113 |
+
labels_string = "', '".join(labels)
|
|
|
|
|
|
|
114 |
|
115 |
prompt = f"""
|
116 |
Post Summary: "{summary}"
|
117 |
|
118 |
Available Categories:
|
119 |
+
'{labels_string}'
|
120 |
|
121 |
+
Task: Choose the single most relevant category from the list above that applies to this summary.
|
122 |
+
Return ONLY ONE category string in a structured JSON format under the field 'category'.
|
123 |
+
The category MUST be one of the following: '{labels_string}'.
|
124 |
+
If no specific category applies, or if you are unsure, return "None".
|
125 |
"""
|
126 |
try:
|
127 |
+
system_message = (
|
128 |
+
f"You are a very strict classifier. Your ONLY job is to return a JSON object "
|
129 |
+
f"with a 'category' field. The value of 'category' MUST be one of these "
|
130 |
+
f"exact strings: '{labels_string}'."
|
131 |
+
)
|
132 |
result = client.chat.completions.create(
|
133 |
+
model=CLASSIFICATION_MODEL,
|
134 |
response_model=ClassificationOutput,
|
135 |
messages=[
|
136 |
+
{"role": "system", "content": system_message},
|
137 |
{"role": "user", "content": prompt}
|
138 |
],
|
139 |
+
temperature=0 # Temperature 0 for deterministic classification
|
140 |
)
|
141 |
+
|
142 |
+
returned_category = result.category
|
143 |
+
|
144 |
+
# Validate the output against the provided labels
|
145 |
+
if returned_category not in labels:
|
146 |
+
print(f"Warning: Classifier returned '{returned_category}', which is not in the predefined labels. Forcing to 'None'. Summary: '{summary}'")
|
147 |
+
return "None" # Force to "None" if the LLM returns an unexpected category
|
148 |
+
return returned_category
|
149 |
except Exception as e:
|
150 |
+
print(f"Classification error: {e}. Summary: '{summary}'. Defaulting to 'None' category.")
|
151 |
+
return "None" # Default to "None" on any exception during classification
|
152 |
|
153 |
+
def summarize_and_classify_post(text: str | None, labels: list[str]) -> dict:
|
154 |
+
"""
|
155 |
+
Summarizes and then classifies a single post text.
|
156 |
+
Handles cases where text is None or summarization fails.
|
157 |
+
"""
|
158 |
+
summary = summarize_post(text) # This can return None
|
159 |
+
|
160 |
+
# If summarization didn't produce a result (e.g. empty input, error),
|
161 |
+
# or if the summary itself is effectively empty, the category is "None".
|
162 |
+
if summary is None or not summary.strip():
|
163 |
+
category = "None"
|
164 |
+
else:
|
165 |
+
# If we have a valid summary, try to classify it.
|
166 |
+
# classify_post is designed to return one of the labels or "None".
|
167 |
+
category = classify_post(summary, labels)
|
168 |
+
|
169 |
return {
|
170 |
+
"summary": summary, # This can be None
|
171 |
+
"category": category # This will be one of the labels or "None"
|
172 |
}
|
173 |
|
174 |
+
def batch_summarize_and_classify(posts_data: list[dict]) -> list[dict]:
|
175 |
+
"""
|
176 |
+
Processes a batch of posts, performing summarization and classification for each.
|
177 |
+
Expects posts_data to be a list of dictionaries, each with at least 'id' and 'text' keys.
|
178 |
+
Returns a list of dictionaries, each with 'id', 'summary', and 'category'.
|
179 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
results = []
|
182 |
+
if not posts_data:
|
183 |
+
print("Input 'posts_data' is empty. Returning empty results.")
|
184 |
+
return results
|
185 |
+
|
186 |
+
for i, post_item in enumerate(posts_data):
|
187 |
+
if not isinstance(post_item, dict):
|
188 |
+
print(f"Warning: Item at index {i} is not a dictionary. Skipping.")
|
189 |
+
continue
|
190 |
+
|
191 |
+
post_id = post_item.get("id")
|
192 |
+
text_to_process = post_item.get("text") # This text is passed to summarize_and_classify_post
|
193 |
+
|
194 |
+
print(f"\nProcessing Post ID: {post_id if post_id else 'N/A (ID missing)'}, Text: '{str(text_to_process)[:50]}...'")
|
195 |
+
|
196 |
+
# summarize_and_classify_post will handle None/empty text internally
|
197 |
+
# and ensure category is "None" in such cases.
|
198 |
+
summary_and_category_result = summarize_and_classify_post(text_to_process, CLASSIFICATION_LABELS)
|
199 |
+
|
200 |
+
results.append({
|
201 |
+
"id": post_id, # Include the ID for mapping back to original data
|
202 |
+
"summary": summary_and_category_result["summary"],
|
203 |
+
"category": summary_and_category_result["category"] # This is now validated
|
204 |
+
})
|
205 |
+
print(f"Result for Post ID {post_id}: Summary='{summary_and_category_result['summary']}', Category='{summary_and_category_result['category']}'")
|
206 |
+
|
207 |
+
return results
|