Spaces:
Sleeping
Sleeping
# /// script | |
# dependencies = [ | |
# "PyYAML", | |
# "chromadb", | |
# "sentence-transformers", | |
# "smolagents", | |
# "gradio", | |
# "einops", | |
# "smolagents[litellm]", | |
# ] | |
# /// | |
import yaml | |
with open("prompts.yaml", 'r') as stream: | |
prompt_templates = yaml.safe_load(stream) | |
# # OpenTelemetry | |
# from opentelemetry import trace | |
# from opentelemetry.sdk.trace import TracerProvider | |
# from opentelemetry.sdk.trace.export import BatchSpanProcessor | |
# from openinference.instrumentation.smolagents import SmolagentsInstrumentor | |
# from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter | |
# from opentelemetry.sdk.trace.export import ConsoleSpanExporter, SimpleSpanProcessor | |
# # Endpoint | |
# endpoint = "http://0.0.0.0:6006/v1/traces" | |
# trace_provider = TracerProvider() | |
# trace_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter(endpoint))) | |
# SmolagentsInstrumentor().instrument(tracer_provider=trace_provider) | |
import chromadb | |
from sentence_transformers import SentenceTransformer | |
db_name = "vector_db" | |
EMBEDDING_MODEL_NAME = "nomic-ai/nomic-embed-text-v1" | |
model_embeding = SentenceTransformer(EMBEDDING_MODEL_NAME, trust_remote_code=True) | |
client = chromadb.PersistentClient(path=db_name) | |
from smolagents import Tool | |
class RetrieverTool(Tool): | |
name = "retriever" | |
description = "Provide information of our network using semantic search. " | |
inputs = { | |
"query": { | |
"type": "string", | |
"description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.", | |
} | |
} | |
output_type = "string" | |
# def __init__(self, **kwargs): | |
# super().__init__(**kwargs) | |
# self.collection = client.get_or_create_collection('fabric') | |
def forward(self, query: str) -> str: | |
assert isinstance(query, str), "Your search query must be a string" | |
client = chromadb.PersistentClient(path=db_name) | |
collection = client.get_or_create_collection('fabric') | |
result1 = collection.get(include=['embeddings', 'metadatas'], limit=5000) | |
print("Number of results:", len(result1['embeddings'])) | |
query_vector = model_embeding.encode(query) | |
results = collection.query( | |
query_embeddings=[query_vector], | |
n_results=10, | |
include=["metadatas", "documents"] | |
) | |
response = "" | |
for i in range(len(results['documents'][0])): | |
device = self.device(results['metadatas'][0][i]['source']) | |
if device == "global": | |
response += f"Global: {results['metadatas'][0][i]['source']}\n" | |
else: | |
response += f"Device: {device}\n" | |
response += f"Result: {results['documents'][0][i]}\n" | |
print("Results:", results) | |
return response | |
def device(self, value): | |
""" | |
This method return the name of the device if the data belongs to a device if not is global. | |
Args: | |
value: Source of the metadata. | |
Returns: | |
str: The name of the device. | |
""" | |
if not value: | |
return "global" | |
if "/devices/" not in value: | |
return "global" | |
parts = value.split("/devices/") | |
if len(parts) != 2: | |
return "global" | |
device_name = parts[1].replace(".md", "") | |
return device_name | |
import yaml | |
with open("prompts.yaml", 'r') as stream: | |
prompt_templates = yaml.safe_load(stream) | |
retriever_tool = RetrieverTool() | |
from smolagents import CodeAgent, HfApiModel, LiteLLMModel | |
model = LiteLLMModel("gemini/gemini-2.0-flash") | |
agent = CodeAgent( | |
model=model, | |
tools=[retriever_tool], | |
max_steps=10, | |
verbosity_level=2, | |
grammar=None, | |
planning_interval=None, | |
name="network_information_agent", | |
description="Have access to the network information of our fabric.", | |
add_base_tools=False) | |
# # Example usage | |
# response = agent.run( | |
# "What is the loopback Pool address used by the fabric, how many ip addresses are in use?" | |
# ) | |
# print(response) | |
from smolagents import GradioUI | |
GradioUI(agent).launch() |