File size: 1,816 Bytes
10b392a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from dotenv import load_dotenv
load_dotenv()
import os

from langchain_ollama import OllamaEmbeddings
from langchain_openai import ChatOpenAI
from langchain_chroma import Chroma
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain import hub

# β€”β€”β€” CONFIG β€”β€”β€”
PERSIST_DIR    = "chroma_db/"
OLLAMA_URL     = os.getenv("OLLAMA_SERVER")
EMBED_MODEL    = "nomic-embed-text:latest"
LLM_API_KEY    = os.getenv("LLM_API_KEY")
LLM_API_BASE   = os.getenv("LLM_API_BASE", "https://llm.chutes.ai/v1")
LLM_MODEL      = "chutesai/Llama-4-Scout-17B-16E-Instruct"
PROMPT         = hub.pull("langchain-ai/retrieval-qa-chat")
TOP_K          = 5
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”

def run_query(query: str):
    # 1) rebuild the same embedder
    embedder = OllamaEmbeddings(base_url=OLLAMA_URL, model=EMBED_MODEL)
    
    # 2) load the on-disk DB with embedder in place
    vectordb = Chroma(
        persist_directory=PERSIST_DIR,
        collection_name="my_docs",
        embedding_function=embedder
    )

    # 3) set up retriever + LLM chain
    retriever = vectordb.as_retriever(search_kwargs={"k": TOP_K})
    llm = ChatOpenAI(api_key=LLM_API_KEY, base_url=LLM_API_BASE, model=LLM_MODEL)
    combine = create_stuff_documents_chain(llm=llm, prompt=PROMPT)
    rag_chain = create_retrieval_chain(retriever, combine)

    # 4) run your query
    print(f"πŸ” Query: {query}")
    answer = rag_chain.invoke({"input": query})
    print("\nπŸ“„ Answer:\n", answer)

if __name__ == "__main__":
    exit=False
    while not exit:
        user_input = input("Enter your query (or 'exit' to quit): ")
        if user_input.lower() == 'exit':
            exit = True
        else:
            run_query(user_input)