rogerscuall's picture
Upload folder using huggingface_hub
d243e59 verified
raw
history blame
4.18 kB
# /// script
# dependencies = [
# "PyYAML",
# "chromadb",
# "sentence-transformers",
# "smolagents",
# "gradio",
# "einops",
# "smolagents[litellm]",
# ]
# ///
import yaml
with open("prompts.yaml", 'r') as stream:
prompt_templates = yaml.safe_load(stream)
# # OpenTelemetry
# from opentelemetry import trace
# from opentelemetry.sdk.trace import TracerProvider
# from opentelemetry.sdk.trace.export import BatchSpanProcessor
# from openinference.instrumentation.smolagents import SmolagentsInstrumentor
# from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
# from opentelemetry.sdk.trace.export import ConsoleSpanExporter, SimpleSpanProcessor
# # Endpoint
# endpoint = "http://0.0.0.0:6006/v1/traces"
# trace_provider = TracerProvider()
# trace_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter(endpoint)))
# SmolagentsInstrumentor().instrument(tracer_provider=trace_provider)
import chromadb
from sentence_transformers import SentenceTransformer
db_name = "vector_db"
EMBEDDING_MODEL_NAME = "nomic-ai/nomic-embed-text-v1"
model_embeding = SentenceTransformer(EMBEDDING_MODEL_NAME, trust_remote_code=True)
client = chromadb.PersistentClient(path=db_name)
from smolagents import Tool
class RetrieverTool(Tool):
name = "retriever"
description = "Provide information of our network using semantic search. "
inputs = {
"query": {
"type": "string",
"description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
}
}
output_type = "string"
# def __init__(self, **kwargs):
# super().__init__(**kwargs)
# self.collection = client.get_or_create_collection('fabric')
def forward(self, query: str) -> str:
assert isinstance(query, str), "Your search query must be a string"
client = chromadb.PersistentClient(path=db_name)
collection = client.get_or_create_collection('fabric')
result1 = collection.get(include=['embeddings', 'metadatas'], limit=5000)
print("Number of results:", len(result1['embeddings']))
query_vector = model_embeding.encode(query)
results = collection.query(
query_embeddings=[query_vector],
n_results=10,
include=["metadatas", "documents"]
)
response = ""
for i in range(len(results['documents'][0])):
device = self.device(results['metadatas'][0][i]['source'])
if device == "global":
response += f"Global: {results['metadatas'][0][i]['source']}\n"
else:
response += f"Device: {device}\n"
response += f"Result: {results['documents'][0][i]}\n"
print("Results:", results)
return response
def device(self, value):
"""
This method return the name of the device if the data belongs to a device if not is global.
Args:
value: Source of the metadata.
Returns:
str: The name of the device.
"""
if not value:
return "global"
if "/devices/" not in value:
return "global"
parts = value.split("/devices/")
if len(parts) != 2:
return "global"
device_name = parts[1].replace(".md", "")
return device_name
import yaml
with open("prompts.yaml", 'r') as stream:
prompt_templates = yaml.safe_load(stream)
retriever_tool = RetrieverTool()
from smolagents import CodeAgent, HfApiModel, LiteLLMModel
model = LiteLLMModel("gemini/gemini-2.0-flash")
agent = CodeAgent(
model=model,
tools=[retriever_tool],
max_steps=10,
verbosity_level=2,
grammar=None,
planning_interval=None,
name="network_information_agent",
description="Have access to the network information of our fabric.",
add_base_tools=False)
# # Example usage
# response = agent.run(
# "What is the loopback Pool address used by the fabric, how many ip addresses are in use?"
# )
# print(response)
from smolagents import GradioUI
GradioUI(agent).launch()