amiguel commited on
Commit
d765c07
·
verified ·
1 Parent(s): 837d9dd

Upload 5 files

Browse files
src/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+
2
+
3
+
src/file_loader.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import pdfplumber
3
+
4
+ def load_file(uploaded_file):
5
+ ext = uploaded_file.name.split(".")[-1].lower()
6
+ if ext == "pdf":
7
+ with pdfplumber.open(uploaded_file) as pdf:
8
+ return [page.extract_text() for page in pdf.pages if page.extract_text()]
9
+ elif ext == "csv":
10
+ df = pd.read_csv(uploaded_file)
11
+ return df.astype(str).apply(" ".join, axis=1).tolist()
12
+ elif ext == "xlsx":
13
+ df = pd.read_excel(uploaded_file)
14
+ return df.astype(str).apply(" ".join, axis=1).tolist()
15
+ else:
16
+ raise ValueError("Unsupported file type")
src/model_utils.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
2
+
3
+ def load_hf_model(model_name, device="cpu"):
4
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
5
+ model = AutoModelForCausalLM.from_pretrained(model_name)
6
+ return pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if device=="cuda" else -1)
7
+
8
+ def generate_answer(text_gen, question, context):
9
+ prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:"
10
+ result = text_gen(prompt, max_new_tokens=256, do_sample=True, temperature=0.7)
11
+ return result[0]["generated_text"].split("Answer:")[-1].strip()
src/rag_pipeline.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
2
+ from langchain.embeddings import HuggingFaceEmbeddings
3
+ from langchain.vectorstores import FAISS
4
+
5
+ def build_rag_pipeline(docs, embedding_model):
6
+ splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
7
+ chunks = splitter.split_documents(docs)
8
+ embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
9
+ db = FAISS.from_documents(chunks, embeddings)
10
+ return db.as_retriever()
11
+
12
+ def get_relevant_docs(retriever, query, k=4):
13
+ return retriever.get_relevant_documents(query)[:k]
src/utils.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ def get_font_css():
2
+ return """
3
+ <style>
4
+ @import url('https://fonts.googleapis.com/css2?family=Tw+Cen+MT:wght@400;700&display=swap');
5
+ html, body, [class*='css'] {
6
+ font-family: 'Tw Cen MT', sans-serif !important;
7
+ }
8
+ </style>
9
+ """