Wisal_QA / RAG.py
afouda's picture
Update RAG.py
8d41430 verified
import os
import asyncio
from dotenv import load_dotenv
# Initialize DeepInfra-compatible OpenAI client
from openai import OpenAI
openai = OpenAI(
api_key=DEEPINFRA_API_KEY,
base_url="https://api.deepinfra.com/v1/openai",
)
# Weaviate imports
import weaviate
from weaviate.classes.init import Auth
from contextlib import contextmanager
@contextmanager
def weaviate_client():
"""
Context manager that yields a Weaviate client and
guarantees client.close() on exit.
"""
client = weaviate.connect_to_weaviate_cloud(
cluster_url=WEAVIATE_URL,
auth_credentials=Auth.api_key(WEAVIATE_API_KEY),
)
try:
yield client
finally:
client.close()
def embed_texts(texts: list[str], batch_size: int = 50) -> list[list[float]]:
"""Embed texts in batches to avoid API limits."""
all_embeddings: list[list[float]] = []
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
try:
resp = openai.embeddings.create(
model="Qwen/Qwen3-Embedding-8B",
input=batch,
encoding_format="float"
)
batch_embs = [item.embedding for item in resp.data]
all_embeddings.extend(batch_embs)
except Exception as e:
print(f"Embedding batch error (items {i}{i+len(batch)-1}): {e}")
all_embeddings.extend([[] for _ in batch])
return all_embeddings
def encode_query(query: str) -> list[float] | None:
"""Generate a single embedding vector for a query string."""
embs = embed_texts([query], batch_size=1)
if embs and embs[0]:
print("Query embedding (first 5 dims):", embs[0][:5])
return embs[0]
print("Failed to generate query embedding.")
return None
async def rag_autism(query: str, top_k: int = 3) -> dict:
"""
Run a RAG retrieval on the 'UserSpecificDocument' collection in Weaviate.
Returns up to `top_k` matching text chunks.
"""
qe = encode_query(query)
if not qe:
return {"answer": []}
try:
with weaviate_client() as client:
coll = client.collections.get("UserSpecificDocument")
res = coll.query.near_vector(
near_vector=qe,
limit=top_k,
return_properties=["text"]
)
if not getattr(res, "objects", None):
return {"answer": []}
return {
"answer": [
obj.properties.get("text", "[No Text]")
for obj in res.objects
]
}
except Exception as e:
print("RAG Error:", e)
return {"answer": []}
# Example test harness
# if __name__ == "__main__":
# test_queries = [
# "What are the common early signs of autism in young children?",
# "What diagnostic criteria are used for autism spectrum disorder?",
# "What support strategies help improve communication skills in autistic individuals?"
# ]
# for q in test_queries:
# print(f"\nQuery: {q}")
# out = asyncio.run(rag_autism(q, top_k=3))
# print("Retrieved contexts:")
# for idx, ctx in enumerate(out["answer"], 1):
# print(f"{idx}. {ctx}")