File size: 4,182 Bytes
d243e59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# /// script
# dependencies = [
#     "PyYAML",
#     "chromadb",
#     "sentence-transformers",
#     "smolagents",
#     "gradio",
#     "einops",
#     "smolagents[litellm]",
# ]
# ///

import yaml
with open("prompts.yaml", 'r') as stream:
    prompt_templates = yaml.safe_load(stream)

# # OpenTelemetry
# from opentelemetry import trace
# from opentelemetry.sdk.trace import TracerProvider
# from opentelemetry.sdk.trace.export import BatchSpanProcessor
# from openinference.instrumentation.smolagents import SmolagentsInstrumentor
# from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
# from opentelemetry.sdk.trace.export import ConsoleSpanExporter, SimpleSpanProcessor
# # Endpoint
# endpoint = "http://0.0.0.0:6006/v1/traces"

# trace_provider = TracerProvider()
# trace_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter(endpoint)))
# SmolagentsInstrumentor().instrument(tracer_provider=trace_provider)

import chromadb
from sentence_transformers import SentenceTransformer

db_name = "vector_db"
EMBEDDING_MODEL_NAME = "nomic-ai/nomic-embed-text-v1"
model_embeding = SentenceTransformer(EMBEDDING_MODEL_NAME, trust_remote_code=True)
client = chromadb.PersistentClient(path=db_name)

from smolagents import Tool

class RetrieverTool(Tool):
    name = "retriever"
    description = "Provide information of our network using semantic search. " 
    inputs = {
        "query": {
            "type": "string",
            "description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
        }
    }
    output_type = "string"

    # def __init__(self, **kwargs):
    #     super().__init__(**kwargs)
    #     self.collection = client.get_or_create_collection('fabric')

    def forward(self, query: str) -> str:
        assert isinstance(query, str), "Your search query must be a string"
        client = chromadb.PersistentClient(path=db_name)
        collection = client.get_or_create_collection('fabric')
        result1 = collection.get(include=['embeddings', 'metadatas'], limit=5000)
        print("Number of results:", len(result1['embeddings']))
        query_vector = model_embeding.encode(query)
        results = collection.query(
            query_embeddings=[query_vector],
            n_results=10,
            include=["metadatas", "documents"]
        )
        response = ""
        for i in range(len(results['documents'][0])):
            device = self.device(results['metadatas'][0][i]['source'])
            if device == "global":
                response += f"Global: {results['metadatas'][0][i]['source']}\n"
            else:
                response += f"Device: {device}\n"
            response += f"Result: {results['documents'][0][i]}\n"
        print("Results:", results)
        return response
    
    def device(self, value):
        """
        This method return the name of the device if the data belongs to a device if not is global.
        Args:
            value: Source of the metadata.
        Returns:
            str: The name of the device.
        """
        if not value:
            return "global"
        if "/devices/" not in value:
            return "global"
        parts = value.split("/devices/")
        if len(parts) != 2:
            return "global"
        device_name = parts[1].replace(".md", "")
        return device_name
    
import yaml

with open("prompts.yaml", 'r') as stream:
    prompt_templates = yaml.safe_load(stream)

retriever_tool = RetrieverTool()
from smolagents import CodeAgent, HfApiModel, LiteLLMModel

model = LiteLLMModel("gemini/gemini-2.0-flash")
agent = CodeAgent(
    model=model,
    tools=[retriever_tool],
    max_steps=10,
    verbosity_level=2,
    grammar=None,
    planning_interval=None,
    name="network_information_agent",
    description="Have access to the network information of our fabric.",
    add_base_tools=False)

# # Example usage
# response = agent.run(
#     "What is the loopback Pool address used by the fabric, how many ip addresses are in use?"
# )
# print(response)

from smolagents import GradioUI
GradioUI(agent).launch()