|
from setuptools import setup, find_packages |
|
import subprocess |
|
import sys |
|
import platform |
|
|
|
with open("README.md", "r", encoding="utf-8") as fh: |
|
long_description = fh.read() |
|
|
|
with open("requirements.txt", "r", encoding="utf-8") as fh: |
|
requirements = [line.strip() for line in fh if line.strip() and not line.startswith("#")] |
|
|
|
def setup_spacy_models(models=['en_core_web_sm', 'en_core_web_md']): |
|
""" |
|
Download the specified spaCy model. |
|
|
|
Args: |
|
models(List): List[str] of the names of the spaCy model to download. |
|
""" |
|
try: |
|
for model in models: |
|
print(f"Downloading spaCy model: {model}") |
|
subprocess.check_call([sys.executable, "-m", "spacy", "download", model]) |
|
print(f"Successfully downloaded spaCy model: {model}") |
|
except subprocess.CalledProcessError as e: |
|
print(f"Error downloading spaCy model: {model}") |
|
print(e) |
|
|
|
def setup_gpu_dependencies(): |
|
"""Setup GPU-specific dependencies.""" |
|
cuda_available = False |
|
|
|
|
|
try: |
|
import torch |
|
cuda_available = torch.cuda.is_available() |
|
except ImportError: |
|
pass |
|
|
|
if cuda_available: |
|
try: |
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "faiss-gpu>=1.7.0"]) |
|
print("Successfully installed faiss-gpu") |
|
except subprocess.CalledProcessError: |
|
print("Failed to install faiss-gpu. Falling back to faiss-cpu") |
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "faiss-cpu>=1.7.0"]) |
|
else: |
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "faiss-cpu>=1.7.0"]) |
|
|
|
def setup_models(): |
|
""" |
|
Download other required models. |
|
""" |
|
import tensorflow_hub as hub |
|
from sklearn.feature_extraction.text import TfidfVectorizer |
|
from transformers import ( |
|
AutoTokenizer, |
|
AutoModel, |
|
GPT2TokenizerFast, |
|
MarianTokenizer, |
|
DistilBertTokenizer, |
|
DistilBertModel |
|
) |
|
|
|
|
|
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') |
|
model = DistilBertModel.from_pretrained('distilbert-base-uncased') |
|
|
|
|
|
_ = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4') |
|
|
|
|
|
_ = AutoTokenizer.from_pretrained('humarin/chatgpt_paraphraser_on_T5_base') |
|
|
|
|
|
source_lang, pivot_lang, target_lang = 'en', 'de', 'es' |
|
model_names = [ |
|
f'Helsinki-NLP/opus-mt-{source_lang}-{pivot_lang}', |
|
f'Helsinki-NLP/opus-mt-{pivot_lang}-{target_lang}', |
|
f'Helsinki-NLP/opus-mt-{target_lang}-{source_lang}' |
|
] |
|
for model_name in model_names: |
|
_ = MarianTokenizer.from_pretrained(model_name) |
|
|
|
|
|
_ = GPT2TokenizerFast.from_pretrained('gpt2') |
|
|
|
def setup_nltk(): |
|
""" |
|
Download required NLTK data. |
|
""" |
|
import nltk |
|
required_packages = [ |
|
'wordnet', |
|
'averaged_perceptron_tagger_eng' |
|
] |
|
|
|
for package in required_packages: |
|
try: |
|
print(f"Downloading {package}...") |
|
nltk.download(package) |
|
print(f"Successfully downloaded {package}") |
|
except Exception as e: |
|
print(f"Warning: Could not download {package}: {str(e)}") |
|
|
|
def setup_faiss(): |
|
""" |
|
Download required faiss library. |
|
""" |
|
current_os = platform.system() |
|
cuda_available = False |
|
|
|
|
|
def check_cuda(): |
|
try: |
|
import torch |
|
return torch.cuda.is_available() |
|
except: |
|
return False |
|
|
|
if current_os == "Linux" and check_cuda(): |
|
|
|
try: |
|
print("Attempting to install faiss-gpu...") |
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "faiss-gpu>=1.7.0"]) |
|
print("Successfully installed faiss-gpu") |
|
return |
|
except subprocess.CalledProcessError: |
|
print("Failed to install faiss-gpu. Falling back to faiss-cpu.") |
|
|
|
|
|
try: |
|
print("Installing faiss-cpu...") |
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "faiss-cpu>=1.7.0"]) |
|
print("Successfully installed faiss-cpu") |
|
except subprocess.CalledProcessError as e: |
|
print("Error installing faiss-cpu") |
|
print(e) |
|
|
|
setup( |
|
name="retrieval-chatbot", |
|
version="0.2.0", |
|
author="Joe Armani", |
|
author_email="[email protected]", |
|
description="A retrieval-based chatbot with enhanced validation", |
|
long_description=long_description, |
|
long_description_content_type="text/markdown", |
|
packages=find_packages(), |
|
classifiers=[ |
|
"Development Status :: 3 - Alpha", |
|
"Intended Audience :: Science/Research", |
|
"License :: OSI Approved :: MIT License", |
|
"Operating System :: OS Independent", |
|
"Programming Language :: Python :: 3", |
|
"Programming Language :: Python :: 3.8", |
|
"Programming Language :: Python :: 3.9", |
|
"Programming Language :: Python :: 3.10", |
|
"Topic :: Scientific/Engineering :: Artificial Intelligence", |
|
"Topic :: Text Processing :: Linguistic", |
|
], |
|
python_requires=">=3.8", |
|
install_requires=requirements, |
|
extras_require={ |
|
'dev': [ |
|
'pytest>=7.0.0', |
|
'black>=22.0.0', |
|
'isort>=5.10.0', |
|
'mypy>=1.0.0', |
|
], |
|
'gpu': [ |
|
'faiss-gpu>=1.7.0', |
|
], |
|
}, |
|
entry_points={ |
|
"console_scripts": [ |
|
"dialogue-augment=dialogue_augmenter.main:main", |
|
"run-chatbot=chatbot.main:main", |
|
], |
|
}, |
|
include_package_data=True, |
|
package_data={ |
|
"chatbot": ["config/*.yaml"], |
|
"dialogue_augmenter": ["data/*.json", "config/*.yaml"], |
|
}, |
|
) |
|
|
|
if __name__ == '__main__': |
|
setup_spacy_models() |
|
setup_gpu_dependencies() |
|
setup_models() |
|
setup_nltk() |
|
setup_faiss() |