Parishri07 commited on
Commit
f1543d7
·
verified ·
1 Parent(s): 7aaee26

Upload flan_suggestor.py

Browse files
Files changed (1) hide show
  1. smart_suggestion/flan_suggestor.py +71 -0
smart_suggestion/flan_suggestor.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import pandas as pd
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
+
6
+ # Load model
7
+ model_name = "google/flan-t5-small"
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
10
+
11
+ # Load all product CSVs from nested directories
12
+ def load_all_product_data(base_path="data"):
13
+ all_data = []
14
+ for root, dirs, files in os.walk(base_path):
15
+ for file in files:
16
+ if file.endswith(".csv"):
17
+ full_path = os.path.join(root, file)
18
+ df = pd.read_csv(full_path)
19
+ df["Brand"] = os.path.splitext(file)[0]
20
+ df["Store"] = root.split(os.sep)[-4] # e.g., "Store C"
21
+ df["Category"] = root.split(os.sep)[-2]
22
+ all_data.append(df)
23
+ return pd.concat(all_data, ignore_index=True)
24
+
25
+ df = load_all_product_data("data")
26
+
27
+ # Generate smart suggestion
28
+ def generate_product_description(prompt):
29
+ prompt = prompt.lower()
30
+
31
+ # Basic price filter
32
+ price_limit = 99999
33
+ if "under" in prompt:
34
+ try:
35
+ price_limit = int(prompt.split("under")[-1].split()[0])
36
+ except:
37
+ pass
38
+
39
+ filtered_df = df[df["Price"] <= price_limit]
40
+ filtered_df = filtered_df[df["In Stock"].str.lower() == "yes"]
41
+
42
+ if "dry hair" in prompt:
43
+ filtered_df = filtered_df[filtered_df["Hair Type"].str.lower().str.contains("dry", na=False)]
44
+ elif "oily hair" in prompt:
45
+ filtered_df = filtered_df[filtered_df["Hair Type"].str.lower().str.contains("oily", na=False)]
46
+ elif "normal hair" in prompt:
47
+ filtered_df = filtered_df[filtered_df["Hair Type"].str.lower().str.contains("normal", na=False)]
48
+
49
+ if "gift" in prompt:
50
+ filtered_df = filtered_df[filtered_df["Tags"].str.contains("gift", case=False, na=False)]
51
+ if "budget" in prompt:
52
+ filtered_df = filtered_df[filtered_df["Tags"].str.contains("budget", case=False, na=False)]
53
+
54
+ if filtered_df.empty:
55
+ return "🤷 Sorry, no matching suggestions found."
56
+
57
+ rows = []
58
+ for _, row in filtered_df.iterrows():
59
+ text = f"{row['Brand']} {row['Quantity']} – ₹{row['Price']} (Floor {row['Floor']}, Aisle {row['Aisle']})"
60
+ if pd.notna(row.get("Offer")) and str(row["Offer"]).strip():
61
+ text += f" | 🎉 {row['Offer']}"
62
+ rows.append(text)
63
+
64
+ product_text = "\n".join(rows)
65
+ model_prompt = f"Suggest top products:\n{product_text}"
66
+
67
+ input_ids = tokenizer(model_prompt, return_tensors="pt").input_ids
68
+ with torch.no_grad():
69
+ output_ids = model.generate(input_ids, max_new_tokens=100)
70
+
71
+ return tokenizer.decode(output_ids[0], skip_special_tokens=True)