alexander-hm commited on
Commit
d66ac69
·
1 Parent(s): 1dd90bb

Add application file

Browse files
Files changed (2) hide show
  1. app.py +101 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app_pure_rag.py
2
+ import numpy as np
3
+ import faiss
4
+ import gradio as gr
5
+ from langchain.text_splitter import CharacterTextSplitter
6
+ from sentence_transformers import SentenceTransformer
7
+
8
+ # --- Load and Prepare Data ---
9
+ with open("gen_agents.txt", "r", encoding="utf-8") as f:
10
+ full_text = f.read()
11
+
12
+ # Split text into passages
13
+ text_splitter = CharacterTextSplitter(separator="\n\n", chunk_size=512, chunk_overlap=20)
14
+ docs = text_splitter.create_documents([full_text])
15
+ passages = [doc.page_content for doc in docs]
16
+
17
+ # Initialize embedder and build FAISS index
18
+ embedder = SentenceTransformer('all-MiniLM-L6-v2')
19
+ passage_embeddings = embedder.encode(passages, convert_to_tensor=False, show_progress_bar=True)
20
+ passage_embeddings = np.array(passage_embeddings).astype("float32")
21
+ d = passage_embeddings.shape[1]
22
+ index = faiss.IndexFlatL2(d)
23
+ index.add(passage_embeddings)
24
+
25
+ # --- Provided Functions ---
26
+ def retrieve_passages(query, embedder, index, passages, top_k=3):
27
+ """
28
+ Retrieve the top-k most relevant passages based on the query.
29
+ """
30
+ query_embedding = embedder.encode([query], convert_to_tensor=False)
31
+ query_embedding = np.array(query_embedding).astype('float32')
32
+ distances, indices = index.search(query_embedding, top_k)
33
+ retrieved = [passages[i] for i in indices[0]]
34
+ return retrieved
35
+
36
+ # --- Gradio App Function ---
37
+ def get_pure_rag_output(query):
38
+ retrieved = retrieve_passages(query, embedder, index, passages, top_k=3)
39
+ rag_text = "\n".join([f"Passage {i+1}: {p}" for i, p in enumerate(retrieved)])
40
+ # Wrap text in a styled div
41
+ return f"<div style='white-space: pre-wrap;'>{rag_text}</div>"
42
+
43
+ def clear_output():
44
+ return ""
45
+
46
+ # --- Custom CSS for a ChatGPT-like Dark Theme ---
47
+ custom_css = """
48
+ body {
49
+ background-color: #343541 !important;
50
+ color: #ECECEC !important;
51
+ margin: 0;
52
+ padding: 0;
53
+ font-family: 'Inter', sans-serif;
54
+ }
55
+ #container {
56
+ max-width: 900px;
57
+ margin: 0 auto;
58
+ padding: 20px;
59
+ }
60
+ label {
61
+ color: #ECECEC;
62
+ font-weight: 600;
63
+ }
64
+ textarea, input {
65
+ background-color: #40414F;
66
+ color: #ECECEC;
67
+ border: 1px solid #565869;
68
+ }
69
+ button {
70
+ background-color: #565869;
71
+ color: #ECECEC;
72
+ border: none;
73
+ font-weight: 600;
74
+ transition: background-color 0.2s ease;
75
+ }
76
+ button:hover {
77
+ background-color: #6e7283;
78
+ }
79
+ .output-box {
80
+ border: 1px solid #565869;
81
+ border-radius: 4px;
82
+ padding: 10px;
83
+ margin-top: 8px;
84
+ background-color: #40414F;
85
+ }
86
+ """
87
+
88
+ # --- Build Gradio Interface ---
89
+ with gr.Blocks(css=custom_css) as demo:
90
+ with gr.Column(elem_id="container"):
91
+ gr.Markdown("## Pure RAG Output\nDisplays the retrieved passages from the corpus.")
92
+ query_input = gr.Textbox(label="Query", placeholder="Enter your query here...", lines=1)
93
+ with gr.Column():
94
+ submit_button = gr.Button("Submit")
95
+ clear_button = gr.Button("Clear")
96
+ output_box = gr.HTML(label="Retrieved Passages", elem_classes="output-box")
97
+
98
+ submit_button.click(fn=get_pure_rag_output, inputs=query_input, outputs=output_box)
99
+ clear_button.click(fn=clear_output, inputs=[], outputs=output_box)
100
+
101
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ openai
3
+ faiss-cpu
4
+ sentence-transformers
5
+ langchain
6
+ numpy