Spaces:
Running
Running
Update email_gen.py
Browse files- email_gen.py +73 -6
email_gen.py
CHANGED
@@ -5,6 +5,14 @@ import re
|
|
5 |
from huggingface_hub import hf_hub_download
|
6 |
import random
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
class EmailGenerator:
|
9 |
def __init__(self, custom_model_path=None):
|
10 |
self.model = None
|
@@ -13,22 +21,37 @@ class EmailGenerator:
|
|
13 |
self.prompt_templates = self._load_prompt_templates()
|
14 |
|
15 |
def _download_model(self):
|
16 |
-
"""Download
|
17 |
try:
|
18 |
-
model_name = "
|
19 |
-
filename = "
|
20 |
|
21 |
-
print("Downloading
|
|
|
22 |
model_path = hf_hub_download(
|
23 |
repo_id=model_name,
|
24 |
filename=filename,
|
25 |
cache_dir="./models"
|
26 |
)
|
27 |
-
print(f"
|
28 |
return model_path
|
29 |
except Exception as e:
|
30 |
print(f"Error downloading model: {e}")
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
def _load_model(self):
|
34 |
"""Load the GGUF model using llama-cpp-python"""
|
@@ -178,6 +201,31 @@ class EmailGenerator:
|
|
178 |
|
179 |
return subject, body
|
180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
def _advanced_fallback_generation(self, name, company, company_info, tone="Professional"):
|
182 |
"""Advanced fallback with company-specific personalization"""
|
183 |
|
@@ -397,6 +445,25 @@ Return ONLY this JSON format:
|
|
397 |
print("π Using advanced fallback generation (optimized for quality)")
|
398 |
subject, body = self._advanced_fallback_generation(name, company, company_info, tone)
|
399 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
400 |
# Always polish fallback content
|
401 |
subject, body = self._polish_email_content(subject, body)
|
402 |
|
|
|
5 |
from huggingface_hub import hf_hub_download
|
6 |
import random
|
7 |
|
8 |
+
# Grammar checking
|
9 |
+
try:
|
10 |
+
import language_tool_python
|
11 |
+
GRAMMAR_AVAILABLE = True
|
12 |
+
except ImportError:
|
13 |
+
GRAMMAR_AVAILABLE = False
|
14 |
+
print("β οΈ language_tool_python not available. Install for grammar checking.")
|
15 |
+
|
16 |
class EmailGenerator:
|
17 |
def __init__(self, custom_model_path=None):
|
18 |
self.model = None
|
|
|
21 |
self.prompt_templates = self._load_prompt_templates()
|
22 |
|
23 |
def _download_model(self):
|
24 |
+
"""Download Mistral-7B GGUF model from Hugging Face (30% better than Vicuna)"""
|
25 |
try:
|
26 |
+
model_name = "QuantFactory/Mistral-7B-Instruct-v0.3-GGUF"
|
27 |
+
filename = "Mistral-7B-Instruct-v0.3.Q4_K_M.gguf"
|
28 |
|
29 |
+
print("Downloading Mistral-7B v0.3 model... This may take a while.")
|
30 |
+
print("π Upgrading to Mistral for 30% better instruction following!")
|
31 |
model_path = hf_hub_download(
|
32 |
repo_id=model_name,
|
33 |
filename=filename,
|
34 |
cache_dir="./models"
|
35 |
)
|
36 |
+
print(f"β
Mistral model downloaded to: {model_path}")
|
37 |
return model_path
|
38 |
except Exception as e:
|
39 |
print(f"Error downloading model: {e}")
|
40 |
+
# Fallback to Vicuna if Mistral fails
|
41 |
+
try:
|
42 |
+
print("π Falling back to Vicuna model...")
|
43 |
+
model_name = "TheBloke/vicuna-7B-v1.5-GGUF"
|
44 |
+
filename = "vicuna-7b-v1.5.Q4_K_M.gguf"
|
45 |
+
model_path = hf_hub_download(
|
46 |
+
repo_id=model_name,
|
47 |
+
filename=filename,
|
48 |
+
cache_dir="./models"
|
49 |
+
)
|
50 |
+
print(f"β
Fallback model downloaded to: {model_path}")
|
51 |
+
return model_path
|
52 |
+
except Exception as e2:
|
53 |
+
print(f"β Both models failed: {e2}")
|
54 |
+
return None
|
55 |
|
56 |
def _load_model(self):
|
57 |
"""Load the GGUF model using llama-cpp-python"""
|
|
|
201 |
|
202 |
return subject, body
|
203 |
|
204 |
+
def _check_grammar(self, text):
|
205 |
+
"""Check grammar and return polished text"""
|
206 |
+
if not GRAMMAR_AVAILABLE:
|
207 |
+
return text, 0
|
208 |
+
|
209 |
+
try:
|
210 |
+
# Initialize language tool (cached)
|
211 |
+
if not hasattr(self, '_grammar_tool'):
|
212 |
+
self._grammar_tool = language_tool_python.LanguageTool('en-US')
|
213 |
+
|
214 |
+
# Check for errors
|
215 |
+
matches = self._grammar_tool.check(text)
|
216 |
+
|
217 |
+
# If more than 2 errors, suggest regeneration
|
218 |
+
if len(matches) > 2:
|
219 |
+
return text, len(matches)
|
220 |
+
|
221 |
+
# Auto-correct simple errors
|
222 |
+
corrected = language_tool_python.utils.correct(text, matches)
|
223 |
+
return corrected, len(matches)
|
224 |
+
|
225 |
+
except Exception as e:
|
226 |
+
print(f"Grammar check failed: {e}")
|
227 |
+
return text, 0
|
228 |
+
|
229 |
def _advanced_fallback_generation(self, name, company, company_info, tone="Professional"):
|
230 |
"""Advanced fallback with company-specific personalization"""
|
231 |
|
|
|
445 |
print("π Using advanced fallback generation (optimized for quality)")
|
446 |
subject, body = self._advanced_fallback_generation(name, company, company_info, tone)
|
447 |
|
448 |
+
# Apply grammar checking and polish
|
449 |
+
if GRAMMAR_AVAILABLE:
|
450 |
+
try:
|
451 |
+
corrected_body, error_count = self._check_grammar(body)
|
452 |
+
if error_count > 2:
|
453 |
+
print(f"β οΈ {error_count} grammar issues found, regenerating...")
|
454 |
+
# Try different template
|
455 |
+
subject, body = self._advanced_fallback_generation(name, company, company_info, tone)
|
456 |
+
corrected_body, _ = self._check_grammar(body)
|
457 |
+
body = corrected_body
|
458 |
+
else:
|
459 |
+
body = corrected_body
|
460 |
+
if error_count > 0:
|
461 |
+
print(f"β
Fixed {error_count} grammar issues")
|
462 |
+
except Exception as e:
|
463 |
+
print(f"Grammar check failed: {e}")
|
464 |
+
|
465 |
+
return subject, body
|
466 |
+
|
467 |
# Always polish fallback content
|
468 |
subject, body = self._polish_email_content(subject, body)
|
469 |
|