Movie_Buff_QA / app.py
iisadia's picture
Update app.py
10a9825 verified
raw
history blame
4.98 kB
import streamlit as st
import pandas as pd
import faiss
import numpy as np
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from groq import Groq
import os
# --------------------------
# Configuration & Styling
# --------------------------
st.set_page_config(
page_title="CineMaster AI - Movie Expert",
page_icon="🎬",
layout="wide",
initial_sidebar_state="expanded"
)
st.markdown("""
<style>
:root {
--primary: #7017ff;
--secondary: #ff2d55;
}
.header {
background: linear-gradient(135deg, var(--primary), var(--secondary));
color: white;
padding: 2rem;
border-radius: 15px;
text-align: center;
box-shadow: 0 4px 6px rgba(0,0,0,0.1);
}
.response-box {
background: rgba(255,255,255,0.1);
border-radius: 10px;
padding: 1.5rem;
margin: 1rem 0;
border: 1px solid rgba(255,255,255,0.2);
}
.stButton>button {
background: linear-gradient(45deg, var(--primary), var(--secondary)) !important;
color: white !important;
border-radius: 25px;
padding: 0.8rem 2rem;
font-weight: 600;
transition: transform 0.2s;
}
.stButton>button:hover {
transform: scale(1.05);
}
</style>
""", unsafe_allow_html=True)
# --------------------------
# Movie Dataset & Embeddings
# --------------------------
# Replace load_movie_data() with:
@st.cache_resource
def load_movie_data():
dataset = load_dataset("wiki_movies", split="train")
df = pd.DataFrame(dataset)
df['context'] = df.apply(lambda x: f"Title: {x['title']}\nPlot: {x['plot']}\nCast: {x['cast']}", axis=1)
return df
@st.cache_resource
def setup_retrieval(df):
embedder = SentenceTransformer('all-MiniLM-L6-v2')
embeddings = embedder.encode(df['context'].tolist())
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings)
return embedder, index
# --------------------------
# Groq API Setup
# --------------------------
def get_groq_client():
return Groq(
api_key=os.getenv("GROQ_API_KEY", "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")
)
def movie_expert(query, context):
prompt = f"""You are a film expert. Answer using this context:
{context}
Question: {query}
Format response with:
1. πŸŽ₯ Direct Answer
2. πŸ“– Detailed Explanation
3. πŸ† Key Cast Members
4. 🌟 Trivia (if available)
"""
response = client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model="llama3-70b-8192",
temperature=0.3
)
return response.choices[0].message.content
# --------------------------
# Main Application
# --------------------------
def main():
df = load_movie_data()
embedder, index = setup_retrieval(df)
# Header Section
st.markdown("""
<div class="header">
<h1>🎞️ CineMaster AI</h1>
<h3>Your Personal Movie Encyclopedia</h3>
</div>
""", unsafe_allow_html=True)
# Sidebar
with st.sidebar:
st.image("https://cdn-icons-png.flaticon.com/512/2598/2598702.png", width=120)
st.subheader("Sample Questions")
examples = [
"Who played the Joker in The Dark Knight?",
"What's the plot of Inception?",
"List Christopher Nolan's movies",
"Who directed The Dark Knight?",
"What year was Inception released?"
]
for ex in examples:
st.code(ex, language="bash")
# Main Interface
query = st.text_input("🎯 Ask any movie question:",
placeholder="e.g., 'Who played the villain in The Dark Knight?'")
if st.button("πŸš€ Get Answer"):
if query:
with st.spinner("πŸ” Searching through 10,000+ movie records..."):
query_embed = embedder.encode([query])
_, indices = index.search(query_embed, 2)
contexts = [df.iloc[i]['context'] for i in indices[0]]
combined_context = "\n\n".join(contexts)
with st.spinner("πŸŽ₯ Generating cinematic insights..."):
answer = movie_expert(query, combined_context)
st.markdown("---")
with st.container():
st.markdown("## 🎬 Expert Analysis")
st.markdown(f'<div class="response-box">{answer}</div>', unsafe_allow_html=True)
st.markdown("## πŸ“š Source Materials")
cols = st.columns(2)
for i, ctx in enumerate(contexts):
with cols[i]:
with st.expander(f"Source {i+1}", expanded=True):
st.write(ctx)
else:
st.warning("Please enter a movie-related question")
if __name__ == "__main__":
main()