Spaces:
Running
Running
from langchain.chains import LLMChain | |
import os | |
import sqlite3 | |
import praw | |
import json | |
from datetime import datetime, timedelta | |
from sentence_transformers import SentenceTransformer | |
from dotenv import load_dotenv | |
from langchain_groq import ChatGroq | |
from langchain.prompts import ChatPromptTemplate | |
from langchain.chains import ConversationChain, LLMChain | |
from langchain.memory import ConversationBufferMemory | |
load_dotenv() | |
# Initialize the LLM via LangChain (using Groq) | |
llm = ChatGroq( | |
groq_api_key=os.getenv("GROQ_API_KEY"), | |
model_name="meta-llama/llama-4-maverick-17b-128e-instruct", | |
temperature=0.2 | |
) | |
# Embedding Model | |
embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
# Reddit API Setup | |
reddit = praw.Reddit( | |
client_id=os.getenv("REDDIT_CLIENT_ID"), | |
client_secret=os.getenv("REDDIT_CLIENT_SECRET"), | |
user_agent=os.getenv("REDDIT_USER_AGENT") | |
) | |
# SQLite DB Connection | |
def get_db_conn(): | |
return sqlite3.connect("reddit_data.db", check_same_thread=False) | |
# Set up the database schema | |
def setup_db(): | |
conn = get_db_conn() | |
cur = conn.cursor() | |
try: | |
cur.execute(""" | |
CREATE TABLE IF NOT EXISTS reddit_posts ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
reddit_id TEXT UNIQUE, | |
keyword TEXT, | |
title TEXT, | |
post_text TEXT, | |
comments TEXT, | |
created_at TEXT, | |
embedding TEXT, | |
metadata TEXT | |
); | |
""") | |
conn.commit() | |
except Exception as e: | |
print("DB Setup Error:", e) | |
finally: | |
cur.close() | |
conn.close() | |
# Keyword filter | |
def keyword_in_post_or_comments(post, keyword): | |
keyword_lower = keyword.lower() | |
combined_text = (post.title + " " + post.selftext).lower() | |
if keyword_lower in combined_text: | |
return True | |
post.comments.replace_more(limit=None) | |
for comment in post.comments.list(): | |
if keyword_lower in comment.body.lower(): | |
return True | |
return False | |
# Fetch and process Reddit data | |
def fetch_reddit_data(keyword, days=7, limit=None): | |
end_time = datetime.utcnow() | |
start_time = end_time - timedelta(days=days) | |
subreddit = reddit.subreddit("all") | |
posts_generator = subreddit.search(keyword, sort="new", time_filter="all", limit=limit) | |
data = [] | |
for post in posts_generator: | |
created = datetime.utcfromtimestamp(post.created_utc) | |
if created < start_time: | |
break | |
if not keyword_in_post_or_comments(post, keyword): | |
continue | |
post.comments.replace_more(limit=None) | |
comments = [comment.body for comment in post.comments.list()] | |
combined_text = f"{post.title}\n{post.selftext}\n{' '.join(comments)}" | |
embedding = embedder.encode(combined_text).tolist() | |
metadata = { | |
"url": post.url, | |
"subreddit": post.subreddit.display_name, | |
"comments_count": len(comments) | |
} | |
data.append({ | |
"reddit_id": post.id, | |
"keyword": keyword, | |
"title": post.title, | |
"post_text": post.selftext, | |
"comments": comments, | |
"created_at": created.isoformat(), | |
"embedding": embedding, | |
"metadata": metadata | |
}) | |
if data: | |
save_to_db(data) | |
# Save data into SQLite | |
def save_to_db(posts): | |
conn = get_db_conn() | |
cur = conn.cursor() | |
for post in posts: | |
try: | |
cur.execute(""" | |
INSERT OR IGNORE INTO reddit_posts | |
(reddit_id, keyword, title, post_text, comments, created_at, embedding, metadata) | |
VALUES (?, ?, ?, ?, ?, ?, ?, ?); | |
""", ( | |
post["reddit_id"], | |
post["keyword"], | |
post["title"], | |
post["post_text"], | |
json.dumps(post["comments"]), | |
post["created_at"], | |
json.dumps(post["embedding"]), | |
json.dumps(post["metadata"]) | |
)) | |
except Exception as e: | |
print("Insert Error:", e) | |
conn.commit() | |
cur.close() | |
conn.close() | |
# Retrieve similar context from DB | |
def retrieve_context(question, keyword, reddit_id=None, top_k=10): | |
lower_q = question.lower() | |
requested_top_k = 50 if any(word in lower_q for word in ["summarize", "overview", "all posts"]) else top_k | |
conn = get_db_conn() | |
cur = conn.cursor() | |
if reddit_id: | |
cur.execute(""" | |
SELECT title, post_text, comments FROM reddit_posts | |
WHERE reddit_id = ?; | |
""", (reddit_id,)) | |
else: | |
cur.execute(""" | |
SELECT title, post_text, comments FROM reddit_posts | |
WHERE keyword = ? ORDER BY datetime(created_at) DESC LIMIT ?; | |
""", (keyword, requested_top_k)) | |
results = cur.fetchall() | |
cur.close() | |
conn.close() | |
return results | |
# Summarizer | |
summarize_prompt = ChatPromptTemplate.from_template(""" | |
You are a summarizer. Summarize the following context from Reddit posts into a concise summary that preserves the key insights. Do not add extra commentary. | |
Context: | |
{context} | |
Summary: | |
""") | |
summarize_chain = LLMChain(llm=llm, prompt=summarize_prompt) | |
# Chatbot memory and prompt | |
memory = ConversationBufferMemory(memory_key="chat_history") | |
chat_prompt = ChatPromptTemplate.from_template(""" | |
Chat History: | |
{chat_history} | |
Context from Reddit and User Question: | |
{input} | |
Act as a Professional Assistant as an incremental chat agent. Provide reasoning and answer clearly based on the context and chat history. Your response should be valid, concise, Attractive and relevant. | |
""") | |
chat_chain = LLMChain( | |
llm=llm, | |
prompt=chat_prompt, | |
memory=memory, | |
verbose=True | |
) | |
# Chatbot response | |
def get_chatbot_response(question, keyword, reddit_id=None): | |
context_posts = retrieve_context(question, keyword, reddit_id) | |
context = "\n\n".join([f"{p[0]}:\n{p[1]}" for p in context_posts]) | |
if len(context) > 3000: | |
context = summarize_chain.invoke({"context": context}) | |
combined_input = f"Context:\n{context}\n\nUser Question: {question}" | |
response = chat_chain.invoke({"input": combined_input}) | |
return response, context_posts | |