GuglielmoTor commited on
Commit
75a5661
·
verified ·
1 Parent(s): fff82df

Update posts_categorization.py

Browse files
Files changed (1) hide show
  1. posts_categorization.py +120 -45
posts_categorization.py CHANGED
@@ -4,34 +4,59 @@ import instructor
4
  from pydantic import BaseModel
5
  import os
6
 
 
7
  api_key = os.getenv('GROQ_API_KEY')
8
 
 
 
 
9
  # Create single patched Groq client with instructor for structured output
 
10
  client = instructor.from_groq(Groq(api_key=api_key), mode=instructor.Mode.JSON)
11
 
 
12
  class SummaryOutput(BaseModel):
13
  summary: str
14
 
15
- # Define pydantic schema for classification output
16
  class ClassificationOutput(BaseModel):
17
  category: str
18
 
 
19
  PRIMARY_SUMMARIZER_MODEL = "deepseek-r1-distill-llama-70b"
20
  FALLBACK_SUMMARIZER_MODEL = "llama-3.3-70b-versatile"
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- # Summarize post text
23
- def summarize_post(text):
24
- if pd.isna(text) or text is None:
 
 
 
 
 
25
  return None
26
 
27
- text = str(text)[:500] # truncate to avoid token overflow
 
28
 
29
  prompt = f"""
30
  Summarize the following LinkedIn post in 5 to 10 words.
31
  Only return the summary inside a JSON field called 'summary'.
32
 
33
  Post Text:
34
- \"\"\"{text}\"\"\"
35
  """
36
 
37
  try:
@@ -43,7 +68,7 @@ def summarize_post(text):
43
  messages=[
44
  {"role": "system", "content": "You are a precise summarizer. Only return a JSON object with a 'summary' string."},
45
  {"role": "user", "content": prompt}
46
- ],
47
  temperature=0.3
48
  )
49
  return response.summary
@@ -57,8 +82,8 @@ def summarize_post(text):
57
  messages=[
58
  {"role": "system", "content": "You are a precise summarizer. Only return a JSON object with a 'summary' string."},
59
  {"role": "user", "content": prompt}
60
- ],
61
- temperature=0.3 # Keep temperature consistent or adjust as needed for fallback
62
  )
63
  print(f"Summarization successful with fallback model: {FALLBACK_SUMMARIZER_MODEL}")
64
  return response.summary
@@ -70,63 +95,113 @@ def summarize_post(text):
70
  return None
71
  except Exception as e_primary:
72
  print(f"Error during summarization with primary model ({PRIMARY_SUMMARIZER_MODEL}): {e_primary}")
73
- # You could also try fallback here for non-rate-limit errors if desired
74
  return None
75
 
 
 
 
 
 
 
 
 
 
 
76
 
77
-
78
- # Classify post summary into structured categories
79
- def classify_post(summary, labels):
80
- if pd.isna(summary) or summary is None:
81
- return None
82
 
83
  prompt = f"""
84
  Post Summary: "{summary}"
85
 
86
  Available Categories:
87
- {', '.join(labels)}
88
 
89
- Task: Choose the single most relevant category from the list above that applies to this summary. Return only one category in a structured JSON format under the field 'category'.
90
- If no category applies, return 'None'.
 
 
91
  """
92
  try:
 
 
 
 
 
93
  result = client.chat.completions.create(
94
- model="meta-llama/llama-4-maverick-17b-128e-instruct",
95
  response_model=ClassificationOutput,
96
  messages=[
97
- {"role": "system", "content": "You are a strict classifier. Return only one matching category name under the field 'category'."},
98
  {"role": "user", "content": prompt}
99
  ],
100
- temperature=0
101
  )
102
- return result.category
 
 
 
 
 
 
 
103
  except Exception as e:
104
- print(f"Classification error: {e}")
105
- return None
106
 
107
- def summarize_and_classify_post(text, labels):
108
- summary = summarize_post(text)
109
- category = classify_post(summary, labels) if summary else None
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  return {
111
- "summary": summary,
112
- "category": category
113
  }
114
 
115
- def batch_summarize_and_classify(posts):
116
-
117
- labels = [
118
- "Company Culture and Values",
119
- "Employee Stories and Spotlights",
120
- "Work-Life Balance, Flexibility, and Well-being",
121
- "Diversity, Equity, and Inclusion (DEI)",
122
- "Professional Development and Growth Opportunities",
123
- "Mission, Vision, and Social Responsibility",
124
- "None"
125
- ]
126
 
127
  results = []
128
- for post in posts:
129
- text = post.get("text")
130
- result = summarize_and_classify_post(text, labels)
131
- results.append(result)
132
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from pydantic import BaseModel
5
  import os
6
 
7
+ # Ensure GROQ_API_KEY is set in your environment variables
8
  api_key = os.getenv('GROQ_API_KEY')
9
 
10
+ if not api_key:
11
+ raise ValueError("GROQ_API_KEY environment variable not set.")
12
+
13
  # Create single patched Groq client with instructor for structured output
14
+ # Using Mode.JSON for structured output based on Pydantic models
15
  client = instructor.from_groq(Groq(api_key=api_key), mode=instructor.Mode.JSON)
16
 
17
+ # Pydantic model for summarization output
18
  class SummaryOutput(BaseModel):
19
  summary: str
20
 
21
+ # Pydantic model for classification output
22
  class ClassificationOutput(BaseModel):
23
  category: str
24
 
25
+ # Define model names (as per your original code)
26
  PRIMARY_SUMMARIZER_MODEL = "deepseek-r1-distill-llama-70b"
27
  FALLBACK_SUMMARIZER_MODEL = "llama-3.3-70b-versatile"
28
+ CLASSIFICATION_MODEL = "meta-llama/llama-4-maverick-17b-128e-instruct" # Or your preferred classification model
29
+
30
+ # Define the standard list of categories, including "None"
31
+ CLASSIFICATION_LABELS = [
32
+ "Company Culture and Values",
33
+ "Employee Stories and Spotlights",
34
+ "Work-Life Balance, Flexibility, and Well-being",
35
+ "Diversity, Equity, and Inclusion (DEI)",
36
+ "Professional Development and Growth Opportunities",
37
+ "Mission, Vision, and Social Responsibility",
38
+ "None" # Represents no applicable category or cases where classification isn't possible
39
+ ]
40
 
41
+ def summarize_post(text: str) -> str | None:
42
+ """
43
+ Summarizes the given post text using a primary model with a fallback.
44
+ Returns the summary string or None if summarization fails or input is invalid.
45
+ """
46
+ # Check for NaN, None, or empty/whitespace-only string
47
+ if pd.isna(text) or text is None or not str(text).strip():
48
+ print("Summarizer: Input text is empty or None. Returning None.")
49
  return None
50
 
51
+ # Truncate text to a reasonable length to avoid token overflow and reduce costs
52
+ processed_text = str(text)[:500]
53
 
54
  prompt = f"""
55
  Summarize the following LinkedIn post in 5 to 10 words.
56
  Only return the summary inside a JSON field called 'summary'.
57
 
58
  Post Text:
59
+ \"\"\"{processed_text}\"\"\"
60
  """
61
 
62
  try:
 
68
  messages=[
69
  {"role": "system", "content": "You are a precise summarizer. Only return a JSON object with a 'summary' string."},
70
  {"role": "user", "content": prompt}
71
+ ],
72
  temperature=0.3
73
  )
74
  return response.summary
 
82
  messages=[
83
  {"role": "system", "content": "You are a precise summarizer. Only return a JSON object with a 'summary' string."},
84
  {"role": "user", "content": prompt}
85
+ ],
86
+ temperature=0.3
87
  )
88
  print(f"Summarization successful with fallback model: {FALLBACK_SUMMARIZER_MODEL}")
89
  return response.summary
 
95
  return None
96
  except Exception as e_primary:
97
  print(f"Error during summarization with primary model ({PRIMARY_SUMMARIZER_MODEL}): {e_primary}")
98
+ # Consider if fallback should be attempted for other errors too, or just return None
99
  return None
100
 
101
+ def classify_post(summary: str | None, labels: list[str]) -> str:
102
+ """
103
+ Classifies the post summary into one of the provided labels.
104
+ Ensures the returned category is one of the labels, defaulting to "None".
105
+ """
106
+ # If the summary is None (e.g., from a failed summarization or empty input),
107
+ # or if the summary is an empty string after stripping, classify as "None".
108
+ if pd.isna(summary) or summary is None or not str(summary).strip():
109
+ print("Classifier: Input summary is empty or None. Returning 'None' category.")
110
+ return "None" # Return the string "None" to match the label
111
 
112
+ # Join labels for the prompt to ensure the LLM knows the exact expected strings
113
+ labels_string = "', '".join(labels)
 
 
 
114
 
115
  prompt = f"""
116
  Post Summary: "{summary}"
117
 
118
  Available Categories:
119
+ '{labels_string}'
120
 
121
+ Task: Choose the single most relevant category from the list above that applies to this summary.
122
+ Return ONLY ONE category string in a structured JSON format under the field 'category'.
123
+ The category MUST be one of the following: '{labels_string}'.
124
+ If no specific category applies, or if you are unsure, return "None".
125
  """
126
  try:
127
+ system_message = (
128
+ f"You are a very strict classifier. Your ONLY job is to return a JSON object "
129
+ f"with a 'category' field. The value of 'category' MUST be one of these "
130
+ f"exact strings: '{labels_string}'."
131
+ )
132
  result = client.chat.completions.create(
133
+ model=CLASSIFICATION_MODEL,
134
  response_model=ClassificationOutput,
135
  messages=[
136
+ {"role": "system", "content": system_message},
137
  {"role": "user", "content": prompt}
138
  ],
139
+ temperature=0 # Temperature 0 for deterministic classification
140
  )
141
+
142
+ returned_category = result.category
143
+
144
+ # Validate the output against the provided labels
145
+ if returned_category not in labels:
146
+ print(f"Warning: Classifier returned '{returned_category}', which is not in the predefined labels. Forcing to 'None'. Summary: '{summary}'")
147
+ return "None" # Force to "None" if the LLM returns an unexpected category
148
+ return returned_category
149
  except Exception as e:
150
+ print(f"Classification error: {e}. Summary: '{summary}'. Defaulting to 'None' category.")
151
+ return "None" # Default to "None" on any exception during classification
152
 
153
+ def summarize_and_classify_post(text: str | None, labels: list[str]) -> dict:
154
+ """
155
+ Summarizes and then classifies a single post text.
156
+ Handles cases where text is None or summarization fails.
157
+ """
158
+ summary = summarize_post(text) # This can return None
159
+
160
+ # If summarization didn't produce a result (e.g. empty input, error),
161
+ # or if the summary itself is effectively empty, the category is "None".
162
+ if summary is None or not summary.strip():
163
+ category = "None"
164
+ else:
165
+ # If we have a valid summary, try to classify it.
166
+ # classify_post is designed to return one of the labels or "None".
167
+ category = classify_post(summary, labels)
168
+
169
  return {
170
+ "summary": summary, # This can be None
171
+ "category": category # This will be one of the labels or "None"
172
  }
173
 
174
+ def batch_summarize_and_classify(posts_data: list[dict]) -> list[dict]:
175
+ """
176
+ Processes a batch of posts, performing summarization and classification for each.
177
+ Expects posts_data to be a list of dictionaries, each with at least 'id' and 'text' keys.
178
+ Returns a list of dictionaries, each with 'id', 'summary', and 'category'.
179
+ """
 
 
 
 
 
180
 
181
  results = []
182
+ if not posts_data:
183
+ print("Input 'posts_data' is empty. Returning empty results.")
184
+ return results
185
+
186
+ for i, post_item in enumerate(posts_data):
187
+ if not isinstance(post_item, dict):
188
+ print(f"Warning: Item at index {i} is not a dictionary. Skipping.")
189
+ continue
190
+
191
+ post_id = post_item.get("id")
192
+ text_to_process = post_item.get("text") # This text is passed to summarize_and_classify_post
193
+
194
+ print(f"\nProcessing Post ID: {post_id if post_id else 'N/A (ID missing)'}, Text: '{str(text_to_process)[:50]}...'")
195
+
196
+ # summarize_and_classify_post will handle None/empty text internally
197
+ # and ensure category is "None" in such cases.
198
+ summary_and_category_result = summarize_and_classify_post(text_to_process, CLASSIFICATION_LABELS)
199
+
200
+ results.append({
201
+ "id": post_id, # Include the ID for mapping back to original data
202
+ "summary": summary_and_category_result["summary"],
203
+ "category": summary_and_category_result["category"] # This is now validated
204
+ })
205
+ print(f"Result for Post ID {post_id}: Summary='{summary_and_category_result['summary']}', Category='{summary_and_category_result['category']}'")
206
+
207
+ return results