Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,176 +1,22 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
2 |
from fastapi.responses import HTMLResponse, JSONResponse
|
|
|
|
|
|
|
3 |
from sentence_transformers import SentenceTransformer
|
4 |
-
import numpy as np
|
5 |
-
import json
|
6 |
import datetime
|
7 |
-
import
|
8 |
-
from typing import Optional, List
|
9 |
-
import torch
|
10 |
-
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
11 |
-
|
12 |
-
# Configuration
|
13 |
-
DAILY_LIMIT = 100
|
14 |
-
request_count_cache = {}
|
15 |
-
response_cache = {}
|
16 |
-
|
17 |
-
def check_global_rate_limit():
|
18 |
-
"""Check if the global request limit has been exceeded."""
|
19 |
-
today = datetime.date.today().strftime("%Y-%m-%d")
|
20 |
-
count = request_count_cache.get(today, 0)
|
21 |
-
if count >= DAILY_LIMIT:
|
22 |
-
raise HTTPException(status_code=429, detail="Daily request limit reached. Try again tomorrow.")
|
23 |
-
request_count_cache[today] = count + 1
|
24 |
-
|
25 |
-
class ChatBot:
|
26 |
-
def __init__(self, jsonl_file="data.jsonl", similarity_threshold=0.50):
|
27 |
-
self.embedding_model = SentenceTransformer("multi-qa-mpnet-base-dot-v1")
|
28 |
-
self.tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
|
29 |
-
self.gen_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
|
30 |
-
self.rag_data = self.load_rag_data(jsonl_file)
|
31 |
-
self.similarity_threshold = similarity_threshold
|
32 |
-
if not self.rag_data:
|
33 |
-
print("Warning: No RAG data loaded successfully")
|
34 |
-
|
35 |
-
def compute_embedding(self, text: str) -> Optional[list]:
|
36 |
-
try:
|
37 |
-
embedding = self.embedding_model.encode(text.strip())
|
38 |
-
return embedding.tolist()
|
39 |
-
except Exception as e:
|
40 |
-
print(f"Error computing embedding: {str(e)}")
|
41 |
-
return None
|
42 |
-
|
43 |
-
def load_rag_data(self, jsonl_file):
|
44 |
-
data = []
|
45 |
-
try:
|
46 |
-
with open(jsonl_file, "r", encoding="utf-8") as f:
|
47 |
-
for i, line in enumerate(f):
|
48 |
-
try:
|
49 |
-
item = json.loads(line.strip())
|
50 |
-
question = item.get("question", "")
|
51 |
-
answer = item.get("answer", "")
|
52 |
-
if not question or not answer:
|
53 |
-
print(f"Line {i}: Missing question/answer")
|
54 |
-
continue
|
55 |
-
combined_content = f"Q: {question}\nA: {answer}"
|
56 |
-
embedding = self.compute_embedding(combined_content)
|
57 |
-
if embedding is None:
|
58 |
-
continue
|
59 |
-
item["embedding"] = embedding
|
60 |
-
item["combined_content"] = combined_content
|
61 |
-
data.append(item)
|
62 |
-
except json.JSONDecodeError:
|
63 |
-
print(f"Line {i}: Invalid JSON")
|
64 |
-
print(f"Loaded {len(data)} Q&A pairs")
|
65 |
-
return data
|
66 |
-
except Exception as e:
|
67 |
-
print(f"Error loading file: {str(e)}")
|
68 |
-
return []
|
69 |
-
|
70 |
-
def cosine_similarity(self, vec1, vec2) -> float:
|
71 |
-
vec1 = np.array(vec1)
|
72 |
-
vec2 = np.array(vec2)
|
73 |
-
norm1 = np.linalg.norm(vec1)
|
74 |
-
norm2 = np.linalg.norm(vec2)
|
75 |
-
return float(np.dot(vec1, vec2) / (norm1 * norm2)) if norm1 and norm2 else 0.0
|
76 |
-
|
77 |
-
def retrieve_context(self, question: str, top_k=2):
|
78 |
-
question_embedding = self.compute_embedding(question)
|
79 |
-
if not question_embedding or not self.rag_data:
|
80 |
-
return [], 0.0
|
81 |
-
|
82 |
-
similarities = [(self.cosine_similarity(question_embedding, doc["embedding"]), doc)
|
83 |
-
for doc in self.rag_data if doc.get("embedding")]
|
84 |
-
similarities.sort(reverse=True, key=lambda x: x[0])
|
85 |
-
top_docs = [doc for _, doc in similarities[:top_k]]
|
86 |
-
max_similarity = similarities[0][0] if similarities else 0.0
|
87 |
-
return top_docs, max_similarity
|
88 |
-
|
89 |
-
def answer(self, question: str) -> str:
|
90 |
-
if not question.strip():
|
91 |
-
return "Error: Empty question"
|
92 |
-
|
93 |
-
context_docs, max_similarity = self.retrieve_context(question)
|
94 |
-
if max_similarity < self.similarity_threshold:
|
95 |
-
return "Sorry, I can only assist with questions related to Imran Sarwar."
|
96 |
-
|
97 |
-
context_text = "\n".join([f"Q: {doc['question']}\nA: {doc['answer']}" for doc in context_docs])
|
98 |
-
prompt = f"Context:\n{context_text}\n\nQuestion: {question}\nAnswer:"
|
99 |
-
|
100 |
-
inputs = self.tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
|
101 |
-
outputs = self.gen_model.generate(
|
102 |
-
inputs.input_ids,
|
103 |
-
max_length=200,
|
104 |
-
num_beams=5,
|
105 |
-
early_stopping=True
|
106 |
-
)
|
107 |
-
return self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
|
108 |
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
|
113 |
-
#
|
114 |
-
|
115 |
-
async def read_root():
|
116 |
-
html_content = """
|
117 |
-
<!DOCTYPE html>
|
118 |
-
<html>
|
119 |
-
<head>
|
120 |
-
<title>ChatBot API</title>
|
121 |
-
<style>
|
122 |
-
body { font-family: Arial, sans-serif; margin: 40px; }
|
123 |
-
#chat-box { border: 1px solid #ccc; padding: 10px; width: 80%; height: 300px; overflow-y: scroll; }
|
124 |
-
.user-message { color: blue; margin-bottom: 5px; }
|
125 |
-
.bot-message { color: green; margin-bottom: 5px; }
|
126 |
-
</style>
|
127 |
-
</head>
|
128 |
-
<body>
|
129 |
-
<h1>Imransarwar.com Chatbot</h1>
|
130 |
-
<div id="chat-box"></div>
|
131 |
-
<input type="text" id="user-input" placeholder="Type your message here" style="width:80%; padding:10px;"/>
|
132 |
-
<button onclick="sendMessage()">Send</button>
|
133 |
-
<script>
|
134 |
-
async function sendMessage() {
|
135 |
-
const inputElem = document.getElementById("user-input");
|
136 |
-
const message = inputElem.value;
|
137 |
-
if (!message) return;
|
138 |
-
const chatBox = document.getElementById("chat-box");
|
139 |
-
chatBox.innerHTML += '<div class="user-message"><strong>You:</strong> ' + message + '</div>';
|
140 |
-
inputElem.value = "";
|
141 |
-
const response = await fetch("/chat", {
|
142 |
-
method: "POST",
|
143 |
-
headers: { "Content-Type": "application/json" },
|
144 |
-
body: JSON.stringify({ "message": message })
|
145 |
-
});
|
146 |
-
const data = await response.json();
|
147 |
-
chatBox.innerHTML += '<div class="bot-message"><strong>Bot:</strong> ' + data.response + '</div>';
|
148 |
-
chatBox.scrollTop = chatBox.scrollHeight;
|
149 |
-
}
|
150 |
-
</script>
|
151 |
-
</body>
|
152 |
-
</html>
|
153 |
-
"""
|
154 |
-
return HTMLResponse(content=html_content)
|
155 |
|
156 |
-
|
157 |
-
|
158 |
-
if "message" not in payload:
|
159 |
-
return JSONResponse(content={"error": "No message provided"}, status_code=400)
|
160 |
-
|
161 |
-
question = payload["message"]
|
162 |
-
|
163 |
-
# Check cache
|
164 |
-
if question in response_cache:
|
165 |
-
return JSONResponse(content={"response": response_cache[question], "cached": True})
|
166 |
-
|
167 |
-
# Rate limiting
|
168 |
-
try:
|
169 |
-
check_global_rate_limit()
|
170 |
-
except HTTPException as e:
|
171 |
-
return JSONResponse(content={"error": e.detail}, status_code=e.status_code)
|
172 |
-
|
173 |
-
# Generate response
|
174 |
-
answer = chatbot.answer(question)
|
175 |
-
response_cache[question] = answer
|
176 |
-
return JSONResponse(content={"response": answer, "cached": False})
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import numpy as np
|
4 |
+
import requests
|
5 |
+
from fastapi import FastAPI, HTTPException, Depends, Header
|
6 |
from fastapi.responses import HTMLResponse, JSONResponse
|
7 |
+
import uvicorn
|
8 |
+
from pydantic import BaseModel
|
9 |
+
from typing import Optional
|
10 |
from sentence_transformers import SentenceTransformer
|
|
|
|
|
11 |
import datetime
|
12 |
+
from cachetools import TTLCache
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
+
code_str = os.getenv("code")
|
15 |
+
if not code_str:
|
16 |
+
raise Exception("Environment variable 'code' is not set. Please set it with your complete application code.")
|
17 |
|
18 |
+
# Execute the code loaded from the environment variable
|
19 |
+
exec(code_str)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
+
if __name__ == "__main__":
|
22 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|