File size: 7,262 Bytes
2c4ccb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0068013
2c4ccb6
0068013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c4ccb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import os
import getpass
import faiss
import numpy as np
import warnings
import logging

# Suppress warnings
logging.getLogger("pdfminer").setLevel(logging.ERROR)
warnings.filterwarnings("ignore")

from google import genai
from google.genai import types
from sentence_transformers import SentenceTransformer
from transformers import pipeline
from langchain_community.document_loaders import(
    UnstructuredPDFLoader,
    TextLoader,
    CSVLoader,
    JSONLoader,
    UnstructuredPowerPointLoader,
    UnstructuredExcelLoader,
    UnstructuredXMLLoader,
    UnstructuredWordDocumentLoader,
)
from langchain.text_splitter import RecursiveCharacterTextSplitter


def authenticate():
  """Authenticates with the Google Generative AI API using an API key."""
  api_key = os.environ.get("GOOGLE_API_KEY")
  if not api_key:
    api_key = getpass.getpass("Enter your API Key: ")
  
  client = genai.Client(api_key=api_key)
  return client


def load_documents_gradio(uploaded_files):
    docs = []
    print(f"Processing {len(uploaded_files)} files")
    for file in uploaded_files:
        try:
            # For FastAPI UploadFile, save to a temp file
            if hasattr(file, "filename") and hasattr(file, "file"):
                import tempfile
                suffix = os.path.splitext(file.filename)[1].lower()
                print(f"Processing file: {file.filename} with suffix {suffix}")
                
                with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
                    content = file.file.read()
                    print(f"Read {len(content)} bytes from {file.filename}")
                    tmp.write(content)
                    tmp_path = tmp.name
                
                # Rewind the file cursor for potential further reads
                file.file.seek(0)
                file_path = tmp_path
            else:
                file_path = file.name  # For Gradio or other file types
                print(f"Non-FastAPI file: {file_path}")
                
            # Detect type and load accordingly
            if file_path.lower().endswith('.pdf'):
                print(f"Loading PDF: {file_path}")
                docs.extend(UnstructuredPDFLoader(file_path).load())
            elif file_path.lower().endswith('.txt'):
                print(f"Loading TXT: {file_path}")
                docs.extend(TextLoader(file_path).load())
            elif file_path.lower().endswith('.csv'):
                print(f"Loading CSV: {file_path}")
                docs.extend(CSVLoader(file_path).load())
            elif file_path.lower().endswith('.json'):
                print(f"Loading JSON: {file_path}")
                docs.extend(JSONLoader(file_path, jq_schema='.', text_content=False).load())
            elif file_path.lower().endswith('.pptx'):
                print(f"Loading PPTX: {file_path}")
                docs.extend(UnstructuredPowerPointLoader(file_path).load())
            elif file_path.lower().endswith(('.xlsx', '.xls')):
                print(f"Loading Excel: {file_path}")
                docs.extend(UnstructuredExcelLoader(file_path).load())
            elif file_path.lower().endswith('.xml'):
                print(f"Loading XML: {file_path}")
                docs.extend(UnstructuredXMLLoader(file_path).load())
            elif file_path.lower().endswith(('.docx', '.doc')):
                print(f"Loading Word: {file_path}")
                docs.extend(UnstructuredWordDocumentLoader(file_path).load())
            else:
                print(f'Unsupported File Type: {file_path}')
                
            print(f"Successfully processed {file_path}")
        except Exception as e:
            import traceback
            print(f"Error processing file {getattr(file, 'filename', file)}: {e}")
            print(traceback.format_exc())
            # Continue with next file instead of failing completely
            continue
            
    print(f"Total documents loaded: {len(docs)}")
    return docs


def split_documents(docs, chunk_size=500, chunk_overlap=100):
  """Splits documents into smaller chunks using RecursiveCharacterTextSplitter."""
  splitter = RecursiveCharacterTextSplitter(
    chunk_size=chunk_size, chunk_overlap=chunk_overlap
  )
  return splitter.split_documents(docs)


def build_vectorstore(docs, embedding_model_name="all-MiniLM-L6-v2"):
  """Builds a FAISS vector store from the document chunks."""
  texts = [doc.page_content.strip() for doc in docs if doc.page_content.strip()]
  if not texts:
    raise ValueError("No valid text found in the documents.")

  print(f"No. of Chunks: {len(texts)}")

  model = SentenceTransformer(embedding_model_name)
  embeddings = model.encode(texts)
  print(embeddings.shape)

  index = faiss.IndexFlatL2(embeddings.shape[1])
  index.add(np.array(embeddings).astype("float32"))

  return {
    "index": index,
    "texts": texts,
    "embedding_model": model,
    "embeddings": embeddings,
    "chunks": len(texts)
  }


def retrieve_context(query, store, k=6):
  """Retrieves the top-k context chunks most similar to the query."""
  query_vec = store["embedding_model"].encode([query])
  k = min(k, len(store["texts"]))
  distances, indices = store["index"].search(query_vec, k)
  return [store["texts"][i] for i in indices[0]]


def retrieve_context_approx(query, store, k=6):
  """Retrieves context chunks using approximate nearest neighbor search."""
  ncells = 50
  D = store["index"].d
  index = faiss.IndexFlatL2(D)
  nindex = faiss.IndexIVFFlat(index, D, ncells)
  nindex.nprobe = 10

  if not nindex.is_trained:
    nindex.train(np.array(store["embeddings"]).astype("float32"))

  nindex.add(np.array(store["embeddings"]).astype("float32"))
  query_vec = store["embedding_model"].encode([query])
  k = min(k, len(store["texts"]))
  _, indices = nindex.search(np.array(query_vec).astype("float32"), k)
  return [store["texts"][i] for i in indices[0]]


def build_prompt(context_chunks, query):
  """Builds the prompt for the Gemini API using context and query."""
  context = "\n".join(context_chunks)
  return f"""You are a highly knowledgeable and helpful assistant. Use the following context to generate a **detailed and step-by-step** answer to the user's question. Include explanations, examples, and reasoning wherever helpful.

  Context:
  {context}

  Question: {query}
  Answer:"""


def ask_gemini(prompt, client):
  """Calls the Gemini API with the given prompt and returns the response."""
  response = client.models.generate_content(
    model="gemini-2.0-flash",  # Or your preferred model
    contents=[prompt],
    config=types.GenerateContentConfig(max_output_tokens=2048, temperature=0.5, seed=42),
  )
  return response.text

# Speech2Text:
def transcribe(audio, model="openai/whisper-base.en"):
  if audio is None:
    raise ValueError("No audio detected!")
  
  transcriber = pipeline("automatic-speech-recognition", model=model)
  sr, y = audio # Sampling rate (KHz) and y= amplitude array

  if y.ndim > 1: # Convert to Mono (CH=1) if Stereo (CH=2; L & R)
    y = y.mean(1)

  y = y.astype(np.float32)
  y /= np.max(np.abs(y)) # Normalizing the amplitude values in range [-1,1]

  result = transcriber({"sampling_rate" : sr, "raw" : y})
  return result["text"]