masadonline commited on
Commit
12fd03c
Β·
verified Β·
1 Parent(s): 750403e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -22
app.py CHANGED
@@ -9,39 +9,39 @@ import pandas as pd
9
  from sentence_transformers import SentenceTransformer
10
  from openai import OpenAI
11
  from dotenv import load_dotenv
 
12
 
13
- # Load GROQ API key from .env
14
  load_dotenv()
15
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
16
 
17
  # Setup GROQ LLM client
18
  client = OpenAI(api_key=GROQ_API_KEY, base_url="https://api.groq.com/openai/v1")
19
 
20
- # Constants
21
- EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
 
 
 
 
22
  LLM_MODEL = "llama3-8b-8192"
23
- embedder = SentenceTransformer(EMBEDDING_MODEL)
24
 
 
25
  st.set_page_config(page_title="🧸 ToyShop Assistant", layout="wide")
26
  st.title("🧸 ToyShop RAG-Based Assistant")
27
 
28
- # --- Load and process uploaded files ---
29
-
30
  def extract_pdf_text(file):
31
  text = ""
32
  with pdfplumber.open(file) as pdf:
33
  for page in pdf.pages:
34
- text += page.extract_text() + "\n"
 
 
35
  return text.strip()
36
 
37
  def load_json_orders(json_file):
38
  data = json.load(json_file)
39
- if isinstance(data, list):
40
- return data
41
- elif isinstance(data, dict):
42
- return list(data.values())
43
- else:
44
- return []
45
 
46
  def build_index(text_chunks):
47
  vectors = embedder.encode(text_chunks)
@@ -57,8 +57,7 @@ def ask_llm(context, query):
57
  )
58
  return response.choices[0].message.content.strip()
59
 
60
- # --- File upload UI ---
61
-
62
  st.subheader("πŸ“ Upload Customer Orders (JSON)")
63
  orders_file = st.file_uploader("Upload JSON file", type="json")
64
 
@@ -67,30 +66,28 @@ pdf_files = st.file_uploader("Upload one or more PDFs", type="pdf", accept_multi
67
 
68
  order_chunks, pdf_chunks = [], []
69
 
70
- # --- Process files ---
71
-
72
  if orders_file:
73
  try:
74
  orders = load_json_orders(orders_file)
75
  order_chunks = [json.dumps(order, ensure_ascii=False) for order in orders]
76
- df = pd.DataFrame(orders)
77
  st.success(f"βœ… Loaded {len(order_chunks)} customer order records.")
78
- st.dataframe(df, use_container_width=True)
79
  except Exception as e:
80
  st.error(f"❌ Error loading JSON: {e}")
81
 
 
82
  if pdf_files:
83
  for pdf_file in pdf_files:
84
  try:
85
  text = extract_pdf_text(pdf_file)
86
- pdf_chunks.extend(text.split("\n\n")) # chunk by paragraph
87
  except Exception as e:
88
  st.error(f"❌ Failed to read {pdf_file.name}: {e}")
89
 
 
90
  combined_chunks = order_chunks + pdf_chunks
91
 
92
- # --- Question Answering ---
93
-
94
  if combined_chunks:
95
  index, sources = build_index(combined_chunks)
96
 
 
9
  from sentence_transformers import SentenceTransformer
10
  from openai import OpenAI
11
  from dotenv import load_dotenv
12
+ import torch
13
 
14
+ # Load environment variables
15
  load_dotenv()
16
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
17
 
18
  # Setup GROQ LLM client
19
  client = OpenAI(api_key=GROQ_API_KEY, base_url="https://api.groq.com/openai/v1")
20
 
21
+ # Load embedding model with device specification
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ embedder = SentenceTransformer("all-MiniLM-L6-v2", trust_remote_code=True)
24
+ embedder.to(device)
25
+
26
+ # LLM model name
27
  LLM_MODEL = "llama3-8b-8192"
 
28
 
29
+ # Streamlit setup
30
  st.set_page_config(page_title="🧸 ToyShop Assistant", layout="wide")
31
  st.title("🧸 ToyShop RAG-Based Assistant")
32
 
 
 
33
  def extract_pdf_text(file):
34
  text = ""
35
  with pdfplumber.open(file) as pdf:
36
  for page in pdf.pages:
37
+ page_text = page.extract_text()
38
+ if page_text:
39
+ text += page_text + "\n"
40
  return text.strip()
41
 
42
  def load_json_orders(json_file):
43
  data = json.load(json_file)
44
+ return data if isinstance(data, list) else list(data.values())
 
 
 
 
 
45
 
46
  def build_index(text_chunks):
47
  vectors = embedder.encode(text_chunks)
 
57
  )
58
  return response.choices[0].message.content.strip()
59
 
60
+ # File upload
 
61
  st.subheader("πŸ“ Upload Customer Orders (JSON)")
62
  orders_file = st.file_uploader("Upload JSON file", type="json")
63
 
 
66
 
67
  order_chunks, pdf_chunks = [], []
68
 
69
+ # Handle JSON
 
70
  if orders_file:
71
  try:
72
  orders = load_json_orders(orders_file)
73
  order_chunks = [json.dumps(order, ensure_ascii=False) for order in orders]
 
74
  st.success(f"βœ… Loaded {len(order_chunks)} customer order records.")
75
+ st.dataframe(pd.DataFrame(orders), use_container_width=True)
76
  except Exception as e:
77
  st.error(f"❌ Error loading JSON: {e}")
78
 
79
+ # Handle PDFs
80
  if pdf_files:
81
  for pdf_file in pdf_files:
82
  try:
83
  text = extract_pdf_text(pdf_file)
84
+ pdf_chunks.extend(text.split("\n\n")) # simple paragraph chunking
85
  except Exception as e:
86
  st.error(f"❌ Failed to read {pdf_file.name}: {e}")
87
 
88
+ # Build index if we have content
89
  combined_chunks = order_chunks + pdf_chunks
90
 
 
 
91
  if combined_chunks:
92
  index, sources = build_index(combined_chunks)
93