initial commit
Browse files- app.py +290 -0
- requirements.txt +3 -0
app.py
ADDED
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from unsloth import FastModel
|
2 |
+
import torch
|
3 |
+
import gc
|
4 |
+
|
5 |
+
# Set torch parameter to avoid error message, "FailOnRecompileLimitHit: recompile_limit reached with one_graph=True." when doing inference on images
|
6 |
+
torch._dynamo.config.cache_size_limit = 32
|
7 |
+
|
8 |
+
# Initialize model
|
9 |
+
model, tokenizer = FastModel.from_pretrained(
|
10 |
+
model_name = "unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit",
|
11 |
+
# model_name = "unsloth/gemma-3n-E2B-it", # This runs out of memory for the recommend/analyze chats
|
12 |
+
dtype = None, # None for auto detection
|
13 |
+
max_seq_length = 1024, # Choose any for long context!
|
14 |
+
load_in_4bit = True, # 4 bit quantization to reduce memory
|
15 |
+
full_finetuning = False, # [NEW!] We have full finetuning now!
|
16 |
+
# token = "hf_...", # use one if using gated models
|
17 |
+
)
|
18 |
+
|
19 |
+
# Helper function for inference
|
20 |
+
def do_gemma_3n_inference(model, messages, max_new_tokens = 128):
|
21 |
+
inputs = tokenizer.apply_chat_template(
|
22 |
+
messages,
|
23 |
+
add_generation_prompt = True, # Must add for generation
|
24 |
+
tokenize = True,
|
25 |
+
return_dict = True,
|
26 |
+
return_tensors = "pt",
|
27 |
+
).to("cuda")
|
28 |
+
|
29 |
+
with torch.no_grad(): # Disable gradient calculation during inference
|
30 |
+
outputs = model.generate(
|
31 |
+
**inputs,
|
32 |
+
max_new_tokens = max_new_tokens,
|
33 |
+
temperature = 1.0, top_p = 0.95, top_k = 64,
|
34 |
+
return_dict_in_generate=True, # Crucial: Get the full output
|
35 |
+
)
|
36 |
+
|
37 |
+
# Decode generated tokens
|
38 |
+
outputs_excluding_inputs = outputs.sequences[:, inputs.input_ids.shape[1]:] # exclude input tokens
|
39 |
+
generated_text = tokenizer.batch_decode(outputs_excluding_inputs, skip_special_tokens=True)[0]
|
40 |
+
|
41 |
+
# Cleanup to reduce VRAM usage
|
42 |
+
del inputs
|
43 |
+
torch.cuda.empty_cache()
|
44 |
+
gc.collect()
|
45 |
+
|
46 |
+
return generated_text
|
47 |
+
|
48 |
+
import ast
|
49 |
+
|
50 |
+
def query_ai_text_image(text, image_path=None):
|
51 |
+
''' Query AI with a prompt that includes text and an image. '''
|
52 |
+
if image_path is None:
|
53 |
+
return "No image uploaded."
|
54 |
+
messages = [{
|
55 |
+
"role" : "user",
|
56 |
+
"content": [
|
57 |
+
{ "type": "image", "image" : image_path },
|
58 |
+
{ "type": "text", "text" : text }
|
59 |
+
]
|
60 |
+
}]
|
61 |
+
text = do_gemma_3n_inference(model, messages, max_new_tokens = 256)
|
62 |
+
return ast.literal_eval(text)
|
63 |
+
|
64 |
+
|
65 |
+
def query_ai_text(text):
|
66 |
+
''' Query AI with a text prompt. '''
|
67 |
+
messages = [{
|
68 |
+
"role" : "user",
|
69 |
+
"content": [
|
70 |
+
{ "type": "text", "text" : text }
|
71 |
+
]
|
72 |
+
}]
|
73 |
+
text = do_gemma_3n_inference(model, messages, max_new_tokens = 256)
|
74 |
+
return text
|
75 |
+
|
76 |
+
import pandas as pd
|
77 |
+
|
78 |
+
class Inventory:
|
79 |
+
column_names = ['title', 'author', 'year_published', 'isbn', 'description', 'copies_on_shelf', 'total_copies']
|
80 |
+
|
81 |
+
def __init__(self, input_file_path, output_file_path):
|
82 |
+
''' Initialize library inventory with data from an input csv file. Specify the file path for storing updated inventory. '''
|
83 |
+
|
84 |
+
# Load input file, keeping only the relevant columns
|
85 |
+
data = pd.read_csv(input_file_path)
|
86 |
+
data = data[ [col for col in data.columns if col in self.column_names] ]
|
87 |
+
|
88 |
+
# Check if input contains the required fields of "title" and "description"
|
89 |
+
for col in ['title', 'description']:
|
90 |
+
if col not in data.columns:
|
91 |
+
raise Exception(f"Input book info must contain '{col}'.")
|
92 |
+
|
93 |
+
# If the number of copies is not available in the input data, set it to the default value of 1
|
94 |
+
for col in ['copies_on_shelf', 'total_copies']:
|
95 |
+
if col not in data.columns:
|
96 |
+
print(f"Input {col} not found. Setting to default value 1.")
|
97 |
+
data[col] = 1
|
98 |
+
|
99 |
+
# self.data = data
|
100 |
+
# NOTE: Due to runtime memory limitations, we only demonstrate the application on the subset of books that have short descriptions.
|
101 |
+
self.data = data[data.description.str.count(' ') < 50]
|
102 |
+
self.file_path = output_file_path
|
103 |
+
self.save()
|
104 |
+
|
105 |
+
|
106 |
+
def save(self):
|
107 |
+
''' Save inventory data to file. '''
|
108 |
+
self.data.to_csv(self.file_path, index=False)
|
109 |
+
|
110 |
+
def get_index(self, title):
|
111 |
+
''' Return a pandas Index list of book(s) that match a given title. '''
|
112 |
+
idx = self.data[self.data.title.str.lower() == title.lower()].index
|
113 |
+
if idx.size == 0:
|
114 |
+
return None
|
115 |
+
if idx.size > 1:
|
116 |
+
raise Exception(f"Found {idx.size} books with the title '{title}'.") #TODO: Match on author as well.
|
117 |
+
return idx[0]
|
118 |
+
|
119 |
+
def check_out(self, title):
|
120 |
+
i = self.get_index(title)
|
121 |
+
if i is None:
|
122 |
+
return "ERROR: Title not found in library collection." # TODO: Add book to collection
|
123 |
+
if self.data.loc[i, 'copies_on_shelf'] == 0:
|
124 |
+
return "ERROR: Check out unsuccessful. There are 0 copies on shelf."
|
125 |
+
self.data.loc[i, 'copies_on_shelf'] -= 1
|
126 |
+
self.save()
|
127 |
+
return f"Check out successful. {self.data.loc[i, 'copies_on_shelf']} of {self.data.loc[i, 'total_copies']} copies remaining."
|
128 |
+
|
129 |
+
def check_in(self, title):
|
130 |
+
i = self.get_index(title)
|
131 |
+
if i is None:
|
132 |
+
return "ERROR: Title not found in library collection."
|
133 |
+
row = self.data.loc[i]
|
134 |
+
if row.copies_on_shelf == row.total_copies:
|
135 |
+
return f"ERROR: Check in unsuccessful. {row.copies_on_shelf} of {row.total_copies} copies already on shelf."
|
136 |
+
self.data.loc[i, 'copies_on_shelf'] += 1
|
137 |
+
self.save()
|
138 |
+
return f"Check in successful. {self.data.loc[i, 'copies_on_shelf']} of {self.data.loc[i, 'total_copies']} copies on shelf."
|
139 |
+
|
140 |
+
def get_on_shelf_book_info(self):
|
141 |
+
''' Return the title/author/description info of all books with available copies on shelf, in csv format. '''
|
142 |
+
columns = ['title', 'author', 'description']
|
143 |
+
return self.data[self.data.copies_on_shelf > 0][columns].to_csv()
|
144 |
+
|
145 |
+
def get_df(self):
|
146 |
+
''' Return inventory data. '''
|
147 |
+
return self.data
|
148 |
+
|
149 |
+
def get_dtypes(self):
|
150 |
+
''' Get data types for each column. '''
|
151 |
+
return self.data.dtypes
|
152 |
+
|
153 |
+
def set_df(self, data):
|
154 |
+
''' Set inventory as the input DataFrame. '''
|
155 |
+
self.data = data
|
156 |
+
|
157 |
+
# Initialize mobile library Inventory object
|
158 |
+
initial_book_list = '/kaggle/input/caldecott-medal-winners-1938-2019/caldecott_winners.csv'
|
159 |
+
inventory_file_path = '/kaggle/working/inventory.csv'
|
160 |
+
inventory = Inventory(initial_book_list, inventory_file_path)
|
161 |
+
|
162 |
+
import gradio as gr
|
163 |
+
from datetime import datetime
|
164 |
+
|
165 |
+
|
166 |
+
# --- "Scan" tab ---
|
167 |
+
def scan_book(image, action):
|
168 |
+
|
169 |
+
# Query AI to extract the title and author
|
170 |
+
prompt = "Extract the title and author from this book cover image. Format the output as ('[title]', '[author]'). If unsuccessful, output ('Unknown Title', 'Unknown Author')."
|
171 |
+
title, author = query_ai_text_image(prompt, image)
|
172 |
+
|
173 |
+
# AI query success check
|
174 |
+
if title == "Unknown Title" or author == "Unknown Author":
|
175 |
+
return "Could not reliably extract book information from the image. Please try again with a clearer cover."
|
176 |
+
|
177 |
+
# Get the right function (check out or check in)
|
178 |
+
if action == 'out':
|
179 |
+
fn = inventory.check_out
|
180 |
+
elif action == 'in':
|
181 |
+
fn = inventory.check_in
|
182 |
+
else:
|
183 |
+
raise Exception(f'Unknown action {action}. Valid options are "out" or "in".')
|
184 |
+
|
185 |
+
# Perform action and return results
|
186 |
+
return f"Title: {title}\nAuthor: {author}\n" + fn(title)
|
187 |
+
|
188 |
+
|
189 |
+
# --- "Recommend" tab ---
|
190 |
+
recommend_examples = [
|
191 |
+
["Suggest five books for a toddler who loves animals."],
|
192 |
+
["Find 3 books for a preschooler interested in space."],
|
193 |
+
["What are some books about adventures?"]
|
194 |
+
]
|
195 |
+
|
196 |
+
def recommend_chat_response(message, history):
|
197 |
+
prompt = "You are a helpful librarian making book recommendations based on the user's description of the reader's background and interests. Respond with 3-5 books, unless otherwise specified by the user. Respond with a bullet point list formatted '[title] by [author]', followed by a short sentence of less than 20 words about why this book was chosen. You must only choose books from the following csv file: " + inventory.get_on_shelf_book_info()
|
198 |
+
return query_ai_text(f"{prompt} \n User question: {message}")
|
199 |
+
|
200 |
+
|
201 |
+
# --- "Analyze" tab ---
|
202 |
+
analyze_examples = [
|
203 |
+
["What is the newest book we have?"],
|
204 |
+
["Summarize the common themes in our collection."]
|
205 |
+
]
|
206 |
+
|
207 |
+
def analyze_chat_response(message, history):
|
208 |
+
prompt = "You are a helpful librarian answering questions about the library's collection of books, based only on this inventory data: " + inventory.get_df().to_csv(index=False)
|
209 |
+
return query_ai_text(f"{prompt} \n User question: {message}")
|
210 |
+
|
211 |
+
|
212 |
+
# --- "Manage" tab ---
|
213 |
+
def save_inventory(df_input):
|
214 |
+
''' Save the user-edited DataFrame as the inventory DataFrame. ''' # TODO: More robust error checks
|
215 |
+
df = pd.DataFrame(df_input)
|
216 |
+
|
217 |
+
# Explicitly convert columns to desired data types
|
218 |
+
col_type = inventory.get_dtypes().to_list()
|
219 |
+
for i,col in enumerate(df.columns):
|
220 |
+
df[col] = df[col].astype(col_type[i])
|
221 |
+
|
222 |
+
# Save DataFrame
|
223 |
+
inventory.set_df(df)
|
224 |
+
inventory.save()
|
225 |
+
|
226 |
+
|
227 |
+
# --- Main gradio app ---
|
228 |
+
with gr.Blocks() as demo:
|
229 |
+
gr.Markdown("# 🚐 MoLi: Mobile Librarian 📚")
|
230 |
+
gr.Markdown("Scan to check out/in, get book recommendations, and analyze your collection, powered by Google's Gemma 3n AI!")
|
231 |
+
|
232 |
+
with gr.Tabs() as tabs:
|
233 |
+
|
234 |
+
# Scan book to check out or check in
|
235 |
+
actions = ['out', 'in']
|
236 |
+
with gr.Tab(label='Scan'):
|
237 |
+
image_input = gr.Image(type='filepath', label="Upload book cover or take a photo", sources=['upload', 'webcam'], width=300)
|
238 |
+
with gr.Row():
|
239 |
+
button = {a: gr.Button(f'Check {a}') for a in actions}
|
240 |
+
status_text = gr.Textbox(show_label=False)
|
241 |
+
button['out'].click(fn=lambda x: scan_book(x, 'out'), inputs=image_input, outputs=status_text)
|
242 |
+
button['in'].click(fn=lambda x: scan_book(x, 'in'), inputs=image_input, outputs=status_text)
|
243 |
+
# # Somehow the following does not work:
|
244 |
+
# for a, b in button.items():
|
245 |
+
# b.click(fn=lambda x: scan_book(x, a), inputs=image_input, outputs=status_text)
|
246 |
+
|
247 |
+
with gr.Tab(label='Recommend'):
|
248 |
+
recommend_greeting = "Tell me the reader's background and interests, and I'll recommend some books available for check out!"
|
249 |
+
gr.ChatInterface(
|
250 |
+
fn=recommend_chat_response,
|
251 |
+
type='messages',
|
252 |
+
examples=recommend_examples,
|
253 |
+
chatbot=gr.Chatbot(type='messages', placeholder=recommend_greeting),
|
254 |
+
)
|
255 |
+
|
256 |
+
with gr.Tab(label='Analyze'):
|
257 |
+
analyze_greeting = "Ask me anything about the library collection!"
|
258 |
+
gr.ChatInterface(
|
259 |
+
fn=analyze_chat_response,
|
260 |
+
type='messages',
|
261 |
+
examples=analyze_examples,
|
262 |
+
chatbot=gr.Chatbot(type='messages', placeholder=analyze_greeting),
|
263 |
+
)
|
264 |
+
|
265 |
+
with gr.Tab(label='Manage'):
|
266 |
+
|
267 |
+
# Buttons
|
268 |
+
with gr.Row():
|
269 |
+
reload_button = gr.Button('Reload')
|
270 |
+
save_button = gr.Button('Save changes')
|
271 |
+
|
272 |
+
# Textbox to display status messages
|
273 |
+
status_message = gr.Textbox(show_label=False, value='Please reload after check out or check in.')
|
274 |
+
|
275 |
+
# Inventory table
|
276 |
+
inventory_table = gr.DataFrame(
|
277 |
+
value=inventory.get_df(),
|
278 |
+
interactive=True, # Allow editing
|
279 |
+
label="Current Library Inventory",
|
280 |
+
wrap=True
|
281 |
+
# column_widths=["1fr"]*len(inventory.get_dtypes())
|
282 |
+
)
|
283 |
+
|
284 |
+
# Attach functions to buttons
|
285 |
+
reload_button.click(fn=inventory.get_df, outputs=inventory_table).then(fn=lambda:f"Reloaded on {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", outputs=[status_message])
|
286 |
+
save_button.click(fn=save_inventory, inputs=inventory_table).then(fn=lambda:f"Saved on {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", outputs=[status_message])
|
287 |
+
|
288 |
+
|
289 |
+
if __name__ == '__main__':
|
290 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
unsloth
|
2 |
+
transformers==4.54.1 # the latest version, 4.55.0.dev0, results in error when running image processing
|
3 |
+
timm
|