Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- .gitignore +2 -1
- README.md +4 -0
- app.py +37 -50
- faiss_index/index.faiss +3 -0
- faiss_index/index.pkl +3 -0
- loader.py +123 -0
- requirements.txt +5 -1
- test_loader.py +287 -0
.gitattributes
CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
vector_db/chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
vector_db/chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
|
37 |
+
faiss_index/index.faiss filter=lfs diff=lfs merge=lfs -text
|
.gitignore
CHANGED
@@ -117,4 +117,5 @@ dmypy.json
|
|
117 |
|
118 |
# Pyre type checker
|
119 |
.pyre/
|
120 |
-
.venv
|
|
|
|
117 |
|
118 |
# Pyre type checker
|
119 |
.pyre/
|
120 |
+
.venv
|
121 |
+
documentation
|
README.md
CHANGED
@@ -138,3 +138,7 @@ Contributions are welcome! Please feel free to submit a Pull Request.
|
|
138 |
3. Commit your changes (`git commit -m 'Add some amazing feature'`)
|
139 |
4. Push to the branch (`git push origin feature/amazing-feature`)
|
140 |
5. Open a Pull Request
|
|
|
|
|
|
|
|
|
|
138 |
3. Commit your changes (`git commit -m 'Add some amazing feature'`)
|
139 |
4. Push to the branch (`git push origin feature/amazing-feature`)
|
140 |
5. Open a Pull Request
|
141 |
+
|
142 |
+
## Test Questions
|
143 |
+
|
144 |
+
1. What is the name server in this network? 8.8.8.8
|
app.py
CHANGED
@@ -1,12 +1,15 @@
|
|
1 |
# /// script
|
2 |
# dependencies = [
|
3 |
# "PyYAML",
|
4 |
-
# "
|
5 |
-
# "
|
|
|
|
|
6 |
# "smolagents",
|
7 |
# "gradio",
|
8 |
# "einops",
|
9 |
# "smolagents[litellm]",
|
|
|
10 |
# ]
|
11 |
# ///
|
12 |
|
@@ -28,13 +31,11 @@ with open("prompts.yaml", 'r') as stream:
|
|
28 |
# trace_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter(endpoint)))
|
29 |
# SmolagentsInstrumentor().instrument(tracer_provider=trace_provider)
|
30 |
|
31 |
-
import
|
32 |
-
from
|
33 |
|
34 |
-
|
35 |
-
EMBEDDING_MODEL_NAME = "
|
36 |
-
model_embeding = SentenceTransformer(EMBEDDING_MODEL_NAME, trust_remote_code=True)
|
37 |
-
client = chromadb.PersistentClient(path=db_name)
|
38 |
|
39 |
from smolagents import Tool
|
40 |
|
@@ -49,55 +50,41 @@ class RetrieverTool(Tool):
|
|
49 |
}
|
50 |
output_type = "string"
|
51 |
|
52 |
-
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
def forward(self, query: str) -> str:
|
57 |
assert isinstance(query, str), "Your search query must be a string"
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
print("Number of results:", len(result1['embeddings']))
|
62 |
-
query_vector = model_embeding.encode(query)
|
63 |
-
results = collection.query(
|
64 |
-
query_embeddings=[query_vector],
|
65 |
-
n_results=10,
|
66 |
-
include=["metadatas", "documents"]
|
67 |
-
)
|
68 |
response = ""
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
73 |
else:
|
74 |
-
|
75 |
-
|
76 |
-
|
|
|
|
|
|
|
77 |
return response
|
78 |
|
79 |
-
|
80 |
-
"""
|
81 |
-
This method return the name of the device if the data belongs to a device if not is global.
|
82 |
-
Args:
|
83 |
-
value: Source of the metadata.
|
84 |
-
Returns:
|
85 |
-
str: The name of the device.
|
86 |
-
"""
|
87 |
-
if not value:
|
88 |
-
return "global"
|
89 |
-
if "/devices/" not in value:
|
90 |
-
return "global"
|
91 |
-
parts = value.split("/devices/")
|
92 |
-
if len(parts) != 2:
|
93 |
-
return "global"
|
94 |
-
device_name = parts[1].replace(".md", "")
|
95 |
-
return device_name
|
96 |
-
|
97 |
-
import yaml
|
98 |
-
|
99 |
-
with open("prompts.yaml", 'r') as stream:
|
100 |
-
prompt_templates = yaml.safe_load(stream)
|
101 |
|
102 |
retriever_tool = RetrieverTool()
|
103 |
from smolagents import CodeAgent, HfApiModel, LiteLLMModel
|
|
|
1 |
# /// script
|
2 |
# dependencies = [
|
3 |
# "PyYAML",
|
4 |
+
# "langchain-community", # For FAISS, HuggingFaceEmbeddings
|
5 |
+
# "langchain", # Core Langchain
|
6 |
+
# "faiss-cpu", # FAISS vector store
|
7 |
+
# "sentence-transformers", # For HuggingFaceEmbeddings
|
8 |
# "smolagents",
|
9 |
# "gradio",
|
10 |
# "einops",
|
11 |
# "smolagents[litellm]",
|
12 |
+
# # "unstructured" # Required by loader.py, not directly by app.py but good for environment consistency
|
13 |
# ]
|
14 |
# ///
|
15 |
|
|
|
31 |
# trace_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter(endpoint)))
|
32 |
# SmolagentsInstrumentor().instrument(tracer_provider=trace_provider)
|
33 |
|
34 |
+
from langchain_community.vectorstores import FAISS
|
35 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
36 |
|
37 |
+
FAISS_INDEX_PATH = "faiss_index"
|
38 |
+
EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" # Must match loader.py
|
|
|
|
|
39 |
|
40 |
from smolagents import Tool
|
41 |
|
|
|
50 |
}
|
51 |
output_type = "string"
|
52 |
|
53 |
+
def __init__(self, **kwargs):
|
54 |
+
super().__init__(**kwargs)
|
55 |
+
self.embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
|
56 |
+
# allow_dangerous_deserialization is recommended for FAISS indexes saved by Langchain
|
57 |
+
self.db = FAISS.load_local(
|
58 |
+
FAISS_INDEX_PATH,
|
59 |
+
self.embeddings,
|
60 |
+
allow_dangerous_deserialization=True
|
61 |
+
)
|
62 |
|
63 |
def forward(self, query: str) -> str:
|
64 |
assert isinstance(query, str), "Your search query must be a string"
|
65 |
+
|
66 |
+
results_with_scores = self.db.similarity_search_with_score(query, k=10)
|
67 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
response = ""
|
69 |
+
if not results_with_scores:
|
70 |
+
return "No relevant information found in the documentation for your query."
|
71 |
+
|
72 |
+
for doc, score in results_with_scores:
|
73 |
+
device_name = doc.metadata.get('device_name')
|
74 |
+
source = doc.metadata.get('source', 'Unknown source')
|
75 |
+
|
76 |
+
if device_name:
|
77 |
+
response += f"Device: {device_name} (Source: {source}, Score: {score:.4f})\n"
|
78 |
else:
|
79 |
+
# If not device_name, assume it's global/fabric information
|
80 |
+
response += f"Global/Fabric Info (Source: {source}, Score: {score:.4f})\n"
|
81 |
+
response += f"Result: {doc.page_content}\n\n"
|
82 |
+
|
83 |
+
print(f"Retrieved {len(results_with_scores)} results for query: '{query}'")
|
84 |
+
# print("Full response:\n", response) # For debugging if needed
|
85 |
return response
|
86 |
|
87 |
+
# The 'device' method is removed as 'device_name' is now directly in metadata.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
retriever_tool = RetrieverTool()
|
90 |
from smolagents import CodeAgent, HfApiModel, LiteLLMModel
|
faiss_index/index.faiss
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:828be1f0d0f7a1249982a3858640ea1164e27a55a68ec7cece4a39ea502c375d
|
3 |
+
size 347181
|
faiss_index/index.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f95e3a8e6e86b6c1df0cc92b569424dbeda0061a2081e21f107156921c10898b
|
3 |
+
size 183933
|
loader.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Improved loader script for creating FAISS vector database from Markdown documentation.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
from langchain_community.document_loaders import UnstructuredMarkdownLoader
|
8 |
+
from langchain.text_splitter import MarkdownHeaderTextSplitter, RecursiveCharacterTextSplitter
|
9 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
10 |
+
from langchain_community.vectorstores import FAISS
|
11 |
+
|
12 |
+
# Define the paths to your documentation folders
|
13 |
+
DOCS_DIR = "documentation"
|
14 |
+
DEVICE_DOCS_PATH = os.path.join(DOCS_DIR, "devices")
|
15 |
+
FABRIC_DOCS_PATH = os.path.join(DOCS_DIR, "fabric")
|
16 |
+
FAISS_INDEX_PATH = "faiss_index"
|
17 |
+
|
18 |
+
def load_markdown_documents(file_paths):
|
19 |
+
"""
|
20 |
+
Loads markdown documents from a list of file paths.
|
21 |
+
The filename is stored in the metadata of each document under the 'source' key.
|
22 |
+
Device name is stored in metadata if applicable.
|
23 |
+
"""
|
24 |
+
docs = []
|
25 |
+
for file_path in file_paths:
|
26 |
+
loader = UnstructuredMarkdownLoader(file_path)
|
27 |
+
loaded_docs = loader.load()
|
28 |
+
for doc in loaded_docs:
|
29 |
+
# Ensure metadata is initialized
|
30 |
+
if doc.metadata is None:
|
31 |
+
doc.metadata = {}
|
32 |
+
# Add filename to metadata
|
33 |
+
doc.metadata['source'] = os.path.basename(file_path)
|
34 |
+
# Add device_name to metadata if it's a device file
|
35 |
+
if 'DCX-' in os.path.basename(file_path):
|
36 |
+
doc.metadata['device_name'] = os.path.basename(file_path).replace('.md', '')
|
37 |
+
# Removed device name prepending from here
|
38 |
+
docs.extend(loaded_docs)
|
39 |
+
return docs
|
40 |
+
|
41 |
+
def create_vector_db():
|
42 |
+
"""
|
43 |
+
Scans documentation folders, loads MD files, creates embeddings,
|
44 |
+
and saves a FAISS vector database.
|
45 |
+
"""
|
46 |
+
markdown_files = []
|
47 |
+
for root, _, files in os.walk(DEVICE_DOCS_PATH):
|
48 |
+
for file in files:
|
49 |
+
if file.endswith(".md"):
|
50 |
+
markdown_files.append(os.path.join(root, file))
|
51 |
+
|
52 |
+
for root, _, files in os.walk(FABRIC_DOCS_PATH):
|
53 |
+
for file in files:
|
54 |
+
if file.endswith(".md"):
|
55 |
+
markdown_files.append(os.path.join(root, file))
|
56 |
+
|
57 |
+
if not markdown_files:
|
58 |
+
print("No markdown files found in the specified directories.")
|
59 |
+
return
|
60 |
+
|
61 |
+
print(f"Found {len(markdown_files)} markdown files to process.")
|
62 |
+
|
63 |
+
# Load documents
|
64 |
+
documents = load_markdown_documents(markdown_files)
|
65 |
+
print(f"Loaded {len(documents)} documents.")
|
66 |
+
|
67 |
+
# Define headers to split on
|
68 |
+
headers_to_split_on = [
|
69 |
+
("#", "header1"),
|
70 |
+
("##", "header2"),
|
71 |
+
("###", "header3"),
|
72 |
+
]
|
73 |
+
|
74 |
+
# First split by headers to maintain context
|
75 |
+
header_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on)
|
76 |
+
|
77 |
+
# Split documents by headers first
|
78 |
+
header_split_docs = []
|
79 |
+
for doc in documents:
|
80 |
+
try:
|
81 |
+
header_split = header_splitter.split_text(doc.page_content)
|
82 |
+
for split_doc in header_split:
|
83 |
+
# Copy metadata from original document
|
84 |
+
split_doc.metadata.update(doc.metadata)
|
85 |
+
header_split_docs.extend(header_split)
|
86 |
+
except Exception as e:
|
87 |
+
print(f"Warning: Could not split by headers: {e}")
|
88 |
+
# If header splitting fails, keep the original document
|
89 |
+
header_split_docs.append(doc)
|
90 |
+
|
91 |
+
# Then do recursive character splitting with smaller chunks and larger overlap
|
92 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=200)
|
93 |
+
texts = text_splitter.split_documents(header_split_docs)
|
94 |
+
print(f"Split documents into {len(texts)} chunks.")
|
95 |
+
|
96 |
+
# Add device context to each chunk's page_content if it's from a device file
|
97 |
+
for text_chunk in texts:
|
98 |
+
if 'device_name' in text_chunk.metadata:
|
99 |
+
device_name = text_chunk.metadata['device_name']
|
100 |
+
# Prepend device name to the content of the chunk
|
101 |
+
# Ensure it's not already prepended (e.g. if a header itself was the device name)
|
102 |
+
if not text_chunk.page_content.strip().startswith(f"Device: {device_name}"):
|
103 |
+
text_chunk.page_content = f"Device: {device_name}\\n\\n{text_chunk.page_content}"
|
104 |
+
|
105 |
+
print("Creating FAISS vector database...")
|
106 |
+
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
107 |
+
print("Embeddings model loaded.")
|
108 |
+
|
109 |
+
# Create FAISS vector store
|
110 |
+
if not texts:
|
111 |
+
print("No text chunks to process for FAISS index.")
|
112 |
+
return
|
113 |
+
|
114 |
+
print("Creating FAISS index...")
|
115 |
+
vector_db = FAISS.from_documents(texts, embeddings)
|
116 |
+
print("FAISS index created.")
|
117 |
+
|
118 |
+
# Save FAISS index
|
119 |
+
vector_db.save_local(FAISS_INDEX_PATH)
|
120 |
+
print(f"FAISS index saved to {FAISS_INDEX_PATH}")
|
121 |
+
|
122 |
+
if __name__ == "__main__":
|
123 |
+
create_vector_db()
|
requirements.txt
CHANGED
@@ -4,4 +4,8 @@ sentence-transformers
|
|
4 |
smolagents
|
5 |
gradio
|
6 |
smolagents[litellm]
|
7 |
-
einops
|
|
|
|
|
|
|
|
|
|
4 |
smolagents
|
5 |
gradio
|
6 |
smolagents[litellm]
|
7 |
+
einops
|
8 |
+
langchain-community
|
9 |
+
langchain
|
10 |
+
faiss-cpu
|
11 |
+
unstructured
|
test_loader.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Test script for the FAISS vector database created by loader.py.
|
4 |
+
Allows interactive querying of the documentation and searching for specific strings in results.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
9 |
+
from langchain_community.vectorstores import FAISS
|
10 |
+
|
11 |
+
# Configuration
|
12 |
+
FAISS_INDEX_PATH = "faiss_index"
|
13 |
+
EMBEDDINGS_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
14 |
+
|
15 |
+
def load_vector_db():
|
16 |
+
"""
|
17 |
+
Load the FAISS vector database from disk.
|
18 |
+
"""
|
19 |
+
if not os.path.exists(FAISS_INDEX_PATH):
|
20 |
+
print(f"Error: FAISS index not found at {FAISS_INDEX_PATH}")
|
21 |
+
print("Please run loader.py first to create the vector database.")
|
22 |
+
return None
|
23 |
+
|
24 |
+
try:
|
25 |
+
# Initialize embeddings (must use same model as used for creating the index)
|
26 |
+
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)
|
27 |
+
|
28 |
+
# Load FAISS index
|
29 |
+
vector_db = FAISS.load_local(FAISS_INDEX_PATH, embeddings, allow_dangerous_deserialization=True)
|
30 |
+
print(f"Successfully loaded FAISS index from {FAISS_INDEX_PATH}")
|
31 |
+
return vector_db
|
32 |
+
|
33 |
+
except Exception as e:
|
34 |
+
print(f"Error loading FAISS index: {e}")
|
35 |
+
return None
|
36 |
+
|
37 |
+
def search_documents(vector_db, query, k=3):
|
38 |
+
"""
|
39 |
+
Search the vector database for documents similar to the query.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
vector_db: The loaded FAISS vector store
|
43 |
+
query: The search query string
|
44 |
+
k: Number of top results to return
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
List of documents with similarity scores
|
48 |
+
"""
|
49 |
+
try:
|
50 |
+
# Perform similarity search with scores
|
51 |
+
docs_with_scores = vector_db.similarity_search_with_score(query, k=k)
|
52 |
+
return docs_with_scores
|
53 |
+
except Exception as e:
|
54 |
+
print(f"Error during search: {e}")
|
55 |
+
return []
|
56 |
+
|
57 |
+
def find_string_in_results(docs_with_scores, search_string):
|
58 |
+
"""
|
59 |
+
Find specific strings in the search results.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
docs_with_scores: List of (document, score) tuples from similarity search
|
63 |
+
search_string: String to search for in the documents
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
List of matches with context
|
67 |
+
"""
|
68 |
+
matches = []
|
69 |
+
|
70 |
+
for i, (doc, score) in enumerate(docs_with_scores):
|
71 |
+
content = doc.page_content.lower()
|
72 |
+
search_lower = search_string.lower()
|
73 |
+
|
74 |
+
if search_lower in content:
|
75 |
+
# Find all occurrences
|
76 |
+
start = 0
|
77 |
+
while True:
|
78 |
+
pos = content.find(search_lower, start)
|
79 |
+
if pos == -1:
|
80 |
+
break
|
81 |
+
|
82 |
+
# Extract context around the match (100 chars before and after)
|
83 |
+
context_start = max(0, pos - 100)
|
84 |
+
context_end = min(len(doc.page_content), pos + len(search_string) + 100)
|
85 |
+
context = doc.page_content[context_start:context_end]
|
86 |
+
|
87 |
+
matches.append({
|
88 |
+
'result_index': i + 1,
|
89 |
+
'source': doc.metadata.get('source', 'Unknown'),
|
90 |
+
'similarity_score': score,
|
91 |
+
'context': context,
|
92 |
+
'position': pos
|
93 |
+
})
|
94 |
+
|
95 |
+
start = pos + 1
|
96 |
+
|
97 |
+
return matches
|
98 |
+
|
99 |
+
def print_search_results(docs_with_scores):
|
100 |
+
"""
|
101 |
+
Print search results in a formatted way.
|
102 |
+
"""
|
103 |
+
print(f"\n{'='*60}")
|
104 |
+
print(f"SEARCH RESULTS ({len(docs_with_scores)} results)")
|
105 |
+
print(f"{'='*60}")
|
106 |
+
|
107 |
+
for i, (doc, score) in enumerate(docs_with_scores, 1):
|
108 |
+
print(f"\n--- Result {i} (Similarity Score: {score:.4f}) ---")
|
109 |
+
print(f"Source: {doc.metadata.get('source', 'Unknown')}")
|
110 |
+
print(f"Content Preview: {doc.page_content[:200]}...")
|
111 |
+
print("-" * 50)
|
112 |
+
|
113 |
+
def print_string_matches(matches, search_string):
|
114 |
+
"""
|
115 |
+
Print string search matches in a formatted way.
|
116 |
+
"""
|
117 |
+
if not matches:
|
118 |
+
print(f"\nβ No matches found for '{search_string}' in the search results.")
|
119 |
+
return
|
120 |
+
|
121 |
+
print(f"\n{'='*60}")
|
122 |
+
print(f"STRING SEARCH RESULTS for '{search_string}' ({len(matches)} matches)")
|
123 |
+
print(f"{'='*60}")
|
124 |
+
|
125 |
+
for match in matches:
|
126 |
+
print(f"\nβ
Match found in Result #{match['result_index']}")
|
127 |
+
print(f"Source: {match['source']}")
|
128 |
+
print(f"Similarity Score: {match['similarity_score']:.4f}")
|
129 |
+
print(f"Context: ...{match['context']}...")
|
130 |
+
print("-" * 50)
|
131 |
+
|
132 |
+
# Test cases configuration
|
133 |
+
TEST_CASES = [
|
134 |
+
{
|
135 |
+
"question": "What is the management IP address of DCX-L2LEAF1A?",
|
136 |
+
"expected_string": "172.20.20.57"
|
137 |
+
},
|
138 |
+
{
|
139 |
+
"question": "What VLANs are on DCX-L2LEAF1A?",
|
140 |
+
"expected_string": "VRF10_VLAN11"
|
141 |
+
},
|
142 |
+
{
|
143 |
+
"question": "What spanning tree mode is configured?",
|
144 |
+
"expected_string": "mstp"
|
145 |
+
},
|
146 |
+
{
|
147 |
+
"question": "What is the NTP server configured?",
|
148 |
+
"expected_string": "0.pool.ntp.org"
|
149 |
+
},
|
150 |
+
{
|
151 |
+
"question": "What VRF is used for management?",
|
152 |
+
"expected_string": "MGMT"
|
153 |
+
},
|
154 |
+
{
|
155 |
+
"question": "What is the default gateway for management?",
|
156 |
+
"expected_string": "172.20.20.1"
|
157 |
+
},
|
158 |
+
{
|
159 |
+
"question": "What ethernet interfaces are on DCX-L2LEAF1A?",
|
160 |
+
"expected_string": "Ethernet1"
|
161 |
+
},
|
162 |
+
{
|
163 |
+
"question": "What port-channel interfaces exist?",
|
164 |
+
"expected_string": "Port-Channel1"
|
165 |
+
},
|
166 |
+
{
|
167 |
+
"question": "What is the TerminAttr daemon configuration?",
|
168 |
+
"expected_string": "apiserver.arista.io"
|
169 |
+
},
|
170 |
+
{
|
171 |
+
"question": "What local users are configured?",
|
172 |
+
"expected_string": "admin"
|
173 |
+
},
|
174 |
+
{
|
175 |
+
"question": "What's the description of Ethernet5 on DCX-L2LEAF1A?",
|
176 |
+
"expected_string": "DCX-leaf1-server1_iLO"
|
177 |
+
},
|
178 |
+
{
|
179 |
+
"question": "What channel group is configured on DCX-L2LEAF1A Ethernet1?",
|
180 |
+
"expected_string": "channel-group 1"
|
181 |
+
},
|
182 |
+
{
|
183 |
+
"question": "What VLAN access mode is on DCX-L2LEAF1A Ethernet5?",
|
184 |
+
"expected_string": "access vlan 11"
|
185 |
+
},
|
186 |
+
{
|
187 |
+
"question": "What is the DNS server configured?",
|
188 |
+
"expected_string": "8.8.8.8"
|
189 |
+
},
|
190 |
+
{
|
191 |
+
"question": "What protocol is used for management API on DCX-L2LEAF1A?",
|
192 |
+
"expected_string": "protocol https"
|
193 |
+
}
|
194 |
+
]
|
195 |
+
|
196 |
+
def run_automated_tests(vector_db):
|
197 |
+
"""
|
198 |
+
Run automated tests using predefined test cases.
|
199 |
+
"""
|
200 |
+
print("\nπ§ͺ Running Automated FAISS Database Tests")
|
201 |
+
print("=" * 60)
|
202 |
+
|
203 |
+
total_tests = len(TEST_CASES)
|
204 |
+
passed_tests = 0
|
205 |
+
failed_tests = 0
|
206 |
+
|
207 |
+
for i, test_case in enumerate(TEST_CASES, 1):
|
208 |
+
question = test_case["question"]
|
209 |
+
expected_string = test_case["expected_string"]
|
210 |
+
|
211 |
+
print(f"\nπ Test {i}/{total_tests}: {question}")
|
212 |
+
print(f"Expected to find: '{expected_string}'")
|
213 |
+
print("-" * 50)
|
214 |
+
|
215 |
+
try:
|
216 |
+
# Perform semantic search (increase k to get more results)
|
217 |
+
docs_with_scores = search_documents(vector_db, question, k=10)
|
218 |
+
|
219 |
+
if not docs_with_scores:
|
220 |
+
print("β FAIL: No search results found")
|
221 |
+
failed_tests += 1
|
222 |
+
continue
|
223 |
+
|
224 |
+
# Search for the expected string in results
|
225 |
+
matches = find_string_in_results(docs_with_scores, expected_string)
|
226 |
+
|
227 |
+
if matches:
|
228 |
+
print(f"β
PASS: Found '{expected_string}' in search results")
|
229 |
+
print(f" Found in: {matches[0]['source']}")
|
230 |
+
print(f" Similarity Score: {matches[0]['similarity_score']:.4f}")
|
231 |
+
print(f" Context: ...{matches[0]['context'][:100]}...")
|
232 |
+
passed_tests += 1
|
233 |
+
else:
|
234 |
+
print(f"β FAIL: '{expected_string}' not found in search results")
|
235 |
+
print(" Search results sources (top 5):")
|
236 |
+
for j, (doc, score) in enumerate(docs_with_scores[:5]):
|
237 |
+
print(f" - {doc.metadata.get('source', 'Unknown')} (score: {score:.4f})")
|
238 |
+
|
239 |
+
# Debug: show content preview of top result
|
240 |
+
if docs_with_scores:
|
241 |
+
top_doc = docs_with_scores[0][0]
|
242 |
+
print(f" Top result content preview: {top_doc.page_content[:200]}...")
|
243 |
+
failed_tests += 1
|
244 |
+
|
245 |
+
except Exception as e:
|
246 |
+
print(f"β ERROR: {e}")
|
247 |
+
failed_tests += 1
|
248 |
+
|
249 |
+
# Print summary
|
250 |
+
print("\n" + "=" * 60)
|
251 |
+
print("π TEST SUMMARY")
|
252 |
+
print("=" * 60)
|
253 |
+
print(f"Total Tests: {total_tests}")
|
254 |
+
print(f"β
Passed: {passed_tests}")
|
255 |
+
print(f"β Failed: {failed_tests}")
|
256 |
+
print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%")
|
257 |
+
|
258 |
+
if failed_tests > 0:
|
259 |
+
print(f"\nβ οΈ {failed_tests} test(s) failed. Check the results above.")
|
260 |
+
return False
|
261 |
+
else:
|
262 |
+
print(f"\nπ All tests passed!")
|
263 |
+
return True
|
264 |
+
|
265 |
+
def main():
|
266 |
+
"""
|
267 |
+
Main function to run the automated test script.
|
268 |
+
"""
|
269 |
+
print("π Loading FAISS Vector Database...")
|
270 |
+
|
271 |
+
# Load the vector database
|
272 |
+
vector_db = load_vector_db()
|
273 |
+
if vector_db is None:
|
274 |
+
return
|
275 |
+
|
276 |
+
# Run automated tests
|
277 |
+
success = run_automated_tests(vector_db)
|
278 |
+
|
279 |
+
# Exit with appropriate code
|
280 |
+
if not success:
|
281 |
+
exit(1)
|
282 |
+
else:
|
283 |
+
print("\nβ
All tests completed successfully!")
|
284 |
+
exit(0)
|
285 |
+
|
286 |
+
if __name__ == "__main__":
|
287 |
+
main()
|