Spaces:
Sleeping
Sleeping
| from langchain.chains import LLMChain | |
| import os | |
| import sqlite3 | |
| import praw | |
| import json | |
| from datetime import datetime, timedelta | |
| from sentence_transformers import SentenceTransformer | |
| from dotenv import load_dotenv | |
| from langchain_groq import ChatGroq | |
| from langchain.prompts import ChatPromptTemplate | |
| from langchain.chains import ConversationChain, LLMChain | |
| from langchain.memory import ConversationBufferMemory | |
| load_dotenv() | |
| # Initialize the LLM via LangChain (using Groq) | |
| llm = ChatGroq( | |
| groq_api_key=os.getenv("GROQ_API_KEY"), | |
| model_name="meta-llama/llama-4-maverick-17b-128e-instruct", | |
| temperature=0.2 | |
| ) | |
| # Embedding Model | |
| embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
| # Reddit API Setup | |
| reddit = praw.Reddit( | |
| client_id=os.getenv("REDDIT_CLIENT_ID"), | |
| client_secret=os.getenv("REDDIT_CLIENT_SECRET"), | |
| user_agent=os.getenv("REDDIT_USER_AGENT") | |
| ) | |
| # SQLite DB Connection | |
| def get_db_conn(): | |
| return sqlite3.connect("reddit_data.db", check_same_thread=False) | |
| # Set up the database schema | |
| def setup_db(): | |
| conn = get_db_conn() | |
| cur = conn.cursor() | |
| try: | |
| cur.execute(""" | |
| CREATE TABLE IF NOT EXISTS reddit_posts ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| reddit_id TEXT UNIQUE, | |
| keyword TEXT, | |
| title TEXT, | |
| post_text TEXT, | |
| comments TEXT, | |
| created_at TEXT, | |
| embedding TEXT, | |
| metadata TEXT | |
| ); | |
| """) | |
| conn.commit() | |
| except Exception as e: | |
| print("DB Setup Error:", e) | |
| finally: | |
| cur.close() | |
| conn.close() | |
| # Keyword filter | |
| def keyword_in_post_or_comments(post, keyword): | |
| keyword_lower = keyword.lower() | |
| combined_text = (post.title + " " + post.selftext).lower() | |
| if keyword_lower in combined_text: | |
| return True | |
| post.comments.replace_more(limit=None) | |
| for comment in post.comments.list(): | |
| if keyword_lower in comment.body.lower(): | |
| return True | |
| return False | |
| # Fetch and process Reddit data | |
| def fetch_reddit_data(keyword, days=7, limit=None): | |
| end_time = datetime.utcnow() | |
| start_time = end_time - timedelta(days=days) | |
| subreddit = reddit.subreddit("all") | |
| posts_generator = subreddit.search(keyword, sort="new", time_filter="all", limit=limit) | |
| data = [] | |
| for post in posts_generator: | |
| created = datetime.utcfromtimestamp(post.created_utc) | |
| if created < start_time: | |
| break | |
| if not keyword_in_post_or_comments(post, keyword): | |
| continue | |
| post.comments.replace_more(limit=None) | |
| comments = [comment.body for comment in post.comments.list()] | |
| combined_text = f"{post.title}\n{post.selftext}\n{' '.join(comments)}" | |
| embedding = embedder.encode(combined_text).tolist() | |
| metadata = { | |
| "url": post.url, | |
| "subreddit": post.subreddit.display_name, | |
| "comments_count": len(comments) | |
| } | |
| data.append({ | |
| "reddit_id": post.id, | |
| "keyword": keyword, | |
| "title": post.title, | |
| "post_text": post.selftext, | |
| "comments": comments, | |
| "created_at": created.isoformat(), | |
| "embedding": embedding, | |
| "metadata": metadata | |
| }) | |
| if data: | |
| save_to_db(data) | |
| # Save data into SQLite | |
| def save_to_db(posts): | |
| conn = get_db_conn() | |
| cur = conn.cursor() | |
| for post in posts: | |
| try: | |
| cur.execute(""" | |
| INSERT OR IGNORE INTO reddit_posts | |
| (reddit_id, keyword, title, post_text, comments, created_at, embedding, metadata) | |
| VALUES (?, ?, ?, ?, ?, ?, ?, ?); | |
| """, ( | |
| post["reddit_id"], | |
| post["keyword"], | |
| post["title"], | |
| post["post_text"], | |
| json.dumps(post["comments"]), | |
| post["created_at"], | |
| json.dumps(post["embedding"]), | |
| json.dumps(post["metadata"]) | |
| )) | |
| except Exception as e: | |
| print("Insert Error:", e) | |
| conn.commit() | |
| cur.close() | |
| conn.close() | |
| # Retrieve similar context from DB | |
| def retrieve_context(question, keyword, reddit_id=None, top_k=10): | |
| lower_q = question.lower() | |
| requested_top_k = 50 if any(word in lower_q for word in ["summarize", "overview", "all posts"]) else top_k | |
| conn = get_db_conn() | |
| cur = conn.cursor() | |
| if reddit_id: | |
| cur.execute(""" | |
| SELECT title, post_text, comments FROM reddit_posts | |
| WHERE reddit_id = ?; | |
| """, (reddit_id,)) | |
| else: | |
| cur.execute(""" | |
| SELECT title, post_text, comments FROM reddit_posts | |
| WHERE keyword = ? ORDER BY datetime(created_at) DESC LIMIT ?; | |
| """, (keyword, requested_top_k)) | |
| results = cur.fetchall() | |
| cur.close() | |
| conn.close() | |
| return results | |
| # Summarizer | |
| summarize_prompt = ChatPromptTemplate.from_template(""" | |
| You are a summarizer. Summarize the following context from Reddit posts into a concise summary that preserves the key insights. Do not add extra commentary. | |
| Context: | |
| {context} | |
| Summary: | |
| """) | |
| summarize_chain = LLMChain(llm=llm, prompt=summarize_prompt) | |
| # Chatbot memory and prompt | |
| memory = ConversationBufferMemory(memory_key="chat_history") | |
| chat_prompt = ChatPromptTemplate.from_template(""" | |
| Chat History: | |
| {chat_history} | |
| Context from Reddit and User Question: | |
| {input} | |
| Act as a Professional Assistant as an incremental chat agent. Provide reasoning and answer clearly based on the context and chat history. Your response should be valid, concise, Attractive and relevant. | |
| """) | |
| chat_chain = LLMChain( | |
| llm=llm, | |
| prompt=chat_prompt, | |
| memory=memory, | |
| verbose=True | |
| ) | |
| # Chatbot response | |
| def get_chatbot_response(question, keyword, reddit_id=None): | |
| context_posts = retrieve_context(question, keyword, reddit_id) | |
| context = "\n\n".join([f"{p[0]}:\n{p[1]}" for p in context_posts]) | |
| if len(context) > 3000: | |
| context = summarize_chain.invoke({"context": context}) | |
| combined_input = f"Context:\n{context}\n\nUser Question: {question}" | |
| response = chat_chain.invoke({"input": combined_input}) | |
| return response, context_posts | |