hbertrand commited on
Commit
1b88635
·
unverified ·
1 Parent(s): fde3910

PR: mock chatbot test data (#63)

Browse files

* mock chatbot test data

* source key

* mock and real test

Files changed (1) hide show
  1. tests/test_chatbot.py +67 -1
tests/test_chatbot.py CHANGED
@@ -1,13 +1,45 @@
1
  import os
2
  from pathlib import Path
3
 
 
 
 
4
  from buster.chatbot import Chatbot, ChatbotConfig
 
5
 
6
  TEST_DATA_DIR = Path(__file__).resolve().parent / "data"
7
  DOCUMENTS_FILE = os.path.join(str(TEST_DATA_DIR), "document_embeddings_huggingface_subset.tar.gz")
8
 
9
 
10
- def test_chatbot_simple():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  hf_transformers_cfg = ChatbotConfig(
12
  documents_file=DOCUMENTS_FILE,
13
  unknown_prompt="This doesn't seem to be related to the huggingface library. I am not sure how to answer.",
@@ -31,3 +63,37 @@ def test_chatbot_simple():
31
  chatbot = Chatbot(hf_transformers_cfg)
32
  answer = chatbot.process_input("What is a transformer?")
33
  assert isinstance(answer, str)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  from pathlib import Path
3
 
4
+ import numpy as np
5
+ import pandas as pd
6
+
7
  from buster.chatbot import Chatbot, ChatbotConfig
8
+ from buster.documents import DocumentsManager
9
 
10
  TEST_DATA_DIR = Path(__file__).resolve().parent / "data"
11
  DOCUMENTS_FILE = os.path.join(str(TEST_DATA_DIR), "document_embeddings_huggingface_subset.tar.gz")
12
 
13
 
14
+ def get_fake_embedding(length=1536):
15
+ rng = np.random.default_rng()
16
+ return list(rng.random(length, dtype=np.float32))
17
+
18
+
19
+ class DocumentsMock(DocumentsManager):
20
+ def __init__(self, filepath):
21
+ self.filepath = filepath
22
+
23
+ n_samples = 100
24
+ self.documents = pd.DataFrame.from_dict(
25
+ {
26
+ "title": ["test"] * n_samples,
27
+ "url": ["http://url.com"] * n_samples,
28
+ "content": ["cool text"] * n_samples,
29
+ "embedding": [get_fake_embedding()] * n_samples,
30
+ "n_tokens": [10] * n_samples,
31
+ "source": ["fake source"] * n_samples,
32
+ }
33
+ )
34
+
35
+ def add(self, documents):
36
+ pass
37
+
38
+ def get_documents(self, source):
39
+ return self.documents
40
+
41
+
42
+ def test_chatbot_real_data():
43
  hf_transformers_cfg = ChatbotConfig(
44
  documents_file=DOCUMENTS_FILE,
45
  unknown_prompt="This doesn't seem to be related to the huggingface library. I am not sure how to answer.",
 
63
  chatbot = Chatbot(hf_transformers_cfg)
64
  answer = chatbot.process_input("What is a transformer?")
65
  assert isinstance(answer, str)
66
+
67
+
68
+ def test_chatbot_mock_data(tmp_path, monkeypatch):
69
+ gpt_expected_answer = "this is GPT answer"
70
+ monkeypatch.setattr("buster.chatbot.get_documents_manager_from_extension", lambda filepath: DocumentsMock)
71
+ monkeypatch.setattr("buster.chatbot.get_embedding", lambda x, engine: get_fake_embedding())
72
+ monkeypatch.setattr(
73
+ "buster.chatbot.openai.Completion.create", lambda **kwargs: {"choices": [{"text": gpt_expected_answer}]}
74
+ )
75
+
76
+ hf_transformers_cfg = ChatbotConfig(
77
+ documents_file=tmp_path / "not_a_real_file.tar.gz",
78
+ unknown_prompt="This doesn't seem to be related to the huggingface library. I am not sure how to answer.",
79
+ embedding_model="text-embedding-ada-002",
80
+ top_k=3,
81
+ thresh=0.7,
82
+ max_words=3000,
83
+ completion_kwargs={
84
+ "temperature": 0,
85
+ "engine": "text-davinci-003",
86
+ "max_tokens": 100,
87
+ },
88
+ response_format="slack",
89
+ text_before_prompt=(
90
+ """You are a slack chatbot assistant answering technical questions about huggingface transformers, a library to train transformers in python.\n"""
91
+ """Make sure to format your answers in Markdown format, including code block and snippets.\n"""
92
+ """Do not include any links to urls or hyperlinks in your answers.\n\n"""
93
+ """Now answer the following question:\n"""
94
+ ),
95
+ )
96
+ chatbot = Chatbot(hf_transformers_cfg)
97
+ answer = chatbot.process_input("What is a transformer?")
98
+ assert isinstance(answer, str)
99
+ assert answer.startswith(gpt_expected_answer)