anniecia commited on
Commit
a3efe0f
·
verified ·
1 Parent(s): 9a8a97d

initial commit

Browse files
Files changed (2) hide show
  1. app.py +290 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unsloth import FastModel
2
+ import torch
3
+ import gc
4
+
5
+ # Set torch parameter to avoid error message, "FailOnRecompileLimitHit: recompile_limit reached with one_graph=True." when doing inference on images
6
+ torch._dynamo.config.cache_size_limit = 32
7
+
8
+ # Initialize model
9
+ model, tokenizer = FastModel.from_pretrained(
10
+ model_name = "unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit",
11
+ # model_name = "unsloth/gemma-3n-E2B-it", # This runs out of memory for the recommend/analyze chats
12
+ dtype = None, # None for auto detection
13
+ max_seq_length = 1024, # Choose any for long context!
14
+ load_in_4bit = True, # 4 bit quantization to reduce memory
15
+ full_finetuning = False, # [NEW!] We have full finetuning now!
16
+ # token = "hf_...", # use one if using gated models
17
+ )
18
+
19
+ # Helper function for inference
20
+ def do_gemma_3n_inference(model, messages, max_new_tokens = 128):
21
+ inputs = tokenizer.apply_chat_template(
22
+ messages,
23
+ add_generation_prompt = True, # Must add for generation
24
+ tokenize = True,
25
+ return_dict = True,
26
+ return_tensors = "pt",
27
+ ).to("cuda")
28
+
29
+ with torch.no_grad(): # Disable gradient calculation during inference
30
+ outputs = model.generate(
31
+ **inputs,
32
+ max_new_tokens = max_new_tokens,
33
+ temperature = 1.0, top_p = 0.95, top_k = 64,
34
+ return_dict_in_generate=True, # Crucial: Get the full output
35
+ )
36
+
37
+ # Decode generated tokens
38
+ outputs_excluding_inputs = outputs.sequences[:, inputs.input_ids.shape[1]:] # exclude input tokens
39
+ generated_text = tokenizer.batch_decode(outputs_excluding_inputs, skip_special_tokens=True)[0]
40
+
41
+ # Cleanup to reduce VRAM usage
42
+ del inputs
43
+ torch.cuda.empty_cache()
44
+ gc.collect()
45
+
46
+ return generated_text
47
+
48
+ import ast
49
+
50
+ def query_ai_text_image(text, image_path=None):
51
+ ''' Query AI with a prompt that includes text and an image. '''
52
+ if image_path is None:
53
+ return "No image uploaded."
54
+ messages = [{
55
+ "role" : "user",
56
+ "content": [
57
+ { "type": "image", "image" : image_path },
58
+ { "type": "text", "text" : text }
59
+ ]
60
+ }]
61
+ text = do_gemma_3n_inference(model, messages, max_new_tokens = 256)
62
+ return ast.literal_eval(text)
63
+
64
+
65
+ def query_ai_text(text):
66
+ ''' Query AI with a text prompt. '''
67
+ messages = [{
68
+ "role" : "user",
69
+ "content": [
70
+ { "type": "text", "text" : text }
71
+ ]
72
+ }]
73
+ text = do_gemma_3n_inference(model, messages, max_new_tokens = 256)
74
+ return text
75
+
76
+ import pandas as pd
77
+
78
+ class Inventory:
79
+ column_names = ['title', 'author', 'year_published', 'isbn', 'description', 'copies_on_shelf', 'total_copies']
80
+
81
+ def __init__(self, input_file_path, output_file_path):
82
+ ''' Initialize library inventory with data from an input csv file. Specify the file path for storing updated inventory. '''
83
+
84
+ # Load input file, keeping only the relevant columns
85
+ data = pd.read_csv(input_file_path)
86
+ data = data[ [col for col in data.columns if col in self.column_names] ]
87
+
88
+ # Check if input contains the required fields of "title" and "description"
89
+ for col in ['title', 'description']:
90
+ if col not in data.columns:
91
+ raise Exception(f"Input book info must contain '{col}'.")
92
+
93
+ # If the number of copies is not available in the input data, set it to the default value of 1
94
+ for col in ['copies_on_shelf', 'total_copies']:
95
+ if col not in data.columns:
96
+ print(f"Input {col} not found. Setting to default value 1.")
97
+ data[col] = 1
98
+
99
+ # self.data = data
100
+ # NOTE: Due to runtime memory limitations, we only demonstrate the application on the subset of books that have short descriptions.
101
+ self.data = data[data.description.str.count(' ') < 50]
102
+ self.file_path = output_file_path
103
+ self.save()
104
+
105
+
106
+ def save(self):
107
+ ''' Save inventory data to file. '''
108
+ self.data.to_csv(self.file_path, index=False)
109
+
110
+ def get_index(self, title):
111
+ ''' Return a pandas Index list of book(s) that match a given title. '''
112
+ idx = self.data[self.data.title.str.lower() == title.lower()].index
113
+ if idx.size == 0:
114
+ return None
115
+ if idx.size > 1:
116
+ raise Exception(f"Found {idx.size} books with the title '{title}'.") #TODO: Match on author as well.
117
+ return idx[0]
118
+
119
+ def check_out(self, title):
120
+ i = self.get_index(title)
121
+ if i is None:
122
+ return "ERROR: Title not found in library collection." # TODO: Add book to collection
123
+ if self.data.loc[i, 'copies_on_shelf'] == 0:
124
+ return "ERROR: Check out unsuccessful. There are 0 copies on shelf."
125
+ self.data.loc[i, 'copies_on_shelf'] -= 1
126
+ self.save()
127
+ return f"Check out successful. {self.data.loc[i, 'copies_on_shelf']} of {self.data.loc[i, 'total_copies']} copies remaining."
128
+
129
+ def check_in(self, title):
130
+ i = self.get_index(title)
131
+ if i is None:
132
+ return "ERROR: Title not found in library collection."
133
+ row = self.data.loc[i]
134
+ if row.copies_on_shelf == row.total_copies:
135
+ return f"ERROR: Check in unsuccessful. {row.copies_on_shelf} of {row.total_copies} copies already on shelf."
136
+ self.data.loc[i, 'copies_on_shelf'] += 1
137
+ self.save()
138
+ return f"Check in successful. {self.data.loc[i, 'copies_on_shelf']} of {self.data.loc[i, 'total_copies']} copies on shelf."
139
+
140
+ def get_on_shelf_book_info(self):
141
+ ''' Return the title/author/description info of all books with available copies on shelf, in csv format. '''
142
+ columns = ['title', 'author', 'description']
143
+ return self.data[self.data.copies_on_shelf > 0][columns].to_csv()
144
+
145
+ def get_df(self):
146
+ ''' Return inventory data. '''
147
+ return self.data
148
+
149
+ def get_dtypes(self):
150
+ ''' Get data types for each column. '''
151
+ return self.data.dtypes
152
+
153
+ def set_df(self, data):
154
+ ''' Set inventory as the input DataFrame. '''
155
+ self.data = data
156
+
157
+ # Initialize mobile library Inventory object
158
+ initial_book_list = '/kaggle/input/caldecott-medal-winners-1938-2019/caldecott_winners.csv'
159
+ inventory_file_path = '/kaggle/working/inventory.csv'
160
+ inventory = Inventory(initial_book_list, inventory_file_path)
161
+
162
+ import gradio as gr
163
+ from datetime import datetime
164
+
165
+
166
+ # --- "Scan" tab ---
167
+ def scan_book(image, action):
168
+
169
+ # Query AI to extract the title and author
170
+ prompt = "Extract the title and author from this book cover image. Format the output as ('[title]', '[author]'). If unsuccessful, output ('Unknown Title', 'Unknown Author')."
171
+ title, author = query_ai_text_image(prompt, image)
172
+
173
+ # AI query success check
174
+ if title == "Unknown Title" or author == "Unknown Author":
175
+ return "Could not reliably extract book information from the image. Please try again with a clearer cover."
176
+
177
+ # Get the right function (check out or check in)
178
+ if action == 'out':
179
+ fn = inventory.check_out
180
+ elif action == 'in':
181
+ fn = inventory.check_in
182
+ else:
183
+ raise Exception(f'Unknown action {action}. Valid options are "out" or "in".')
184
+
185
+ # Perform action and return results
186
+ return f"Title: {title}\nAuthor: {author}\n" + fn(title)
187
+
188
+
189
+ # --- "Recommend" tab ---
190
+ recommend_examples = [
191
+ ["Suggest five books for a toddler who loves animals."],
192
+ ["Find 3 books for a preschooler interested in space."],
193
+ ["What are some books about adventures?"]
194
+ ]
195
+
196
+ def recommend_chat_response(message, history):
197
+ prompt = "You are a helpful librarian making book recommendations based on the user's description of the reader's background and interests. Respond with 3-5 books, unless otherwise specified by the user. Respond with a bullet point list formatted '[title] by [author]', followed by a short sentence of less than 20 words about why this book was chosen. You must only choose books from the following csv file: " + inventory.get_on_shelf_book_info()
198
+ return query_ai_text(f"{prompt} \n User question: {message}")
199
+
200
+
201
+ # --- "Analyze" tab ---
202
+ analyze_examples = [
203
+ ["What is the newest book we have?"],
204
+ ["Summarize the common themes in our collection."]
205
+ ]
206
+
207
+ def analyze_chat_response(message, history):
208
+ prompt = "You are a helpful librarian answering questions about the library's collection of books, based only on this inventory data: " + inventory.get_df().to_csv(index=False)
209
+ return query_ai_text(f"{prompt} \n User question: {message}")
210
+
211
+
212
+ # --- "Manage" tab ---
213
+ def save_inventory(df_input):
214
+ ''' Save the user-edited DataFrame as the inventory DataFrame. ''' # TODO: More robust error checks
215
+ df = pd.DataFrame(df_input)
216
+
217
+ # Explicitly convert columns to desired data types
218
+ col_type = inventory.get_dtypes().to_list()
219
+ for i,col in enumerate(df.columns):
220
+ df[col] = df[col].astype(col_type[i])
221
+
222
+ # Save DataFrame
223
+ inventory.set_df(df)
224
+ inventory.save()
225
+
226
+
227
+ # --- Main gradio app ---
228
+ with gr.Blocks() as demo:
229
+ gr.Markdown("# 🚐 MoLi: Mobile Librarian 📚")
230
+ gr.Markdown("Scan to check out/in, get book recommendations, and analyze your collection, powered by Google's Gemma 3n AI!")
231
+
232
+ with gr.Tabs() as tabs:
233
+
234
+ # Scan book to check out or check in
235
+ actions = ['out', 'in']
236
+ with gr.Tab(label='Scan'):
237
+ image_input = gr.Image(type='filepath', label="Upload book cover or take a photo", sources=['upload', 'webcam'], width=300)
238
+ with gr.Row():
239
+ button = {a: gr.Button(f'Check {a}') for a in actions}
240
+ status_text = gr.Textbox(show_label=False)
241
+ button['out'].click(fn=lambda x: scan_book(x, 'out'), inputs=image_input, outputs=status_text)
242
+ button['in'].click(fn=lambda x: scan_book(x, 'in'), inputs=image_input, outputs=status_text)
243
+ # # Somehow the following does not work:
244
+ # for a, b in button.items():
245
+ # b.click(fn=lambda x: scan_book(x, a), inputs=image_input, outputs=status_text)
246
+
247
+ with gr.Tab(label='Recommend'):
248
+ recommend_greeting = "Tell me the reader's background and interests, and I'll recommend some books available for check out!"
249
+ gr.ChatInterface(
250
+ fn=recommend_chat_response,
251
+ type='messages',
252
+ examples=recommend_examples,
253
+ chatbot=gr.Chatbot(type='messages', placeholder=recommend_greeting),
254
+ )
255
+
256
+ with gr.Tab(label='Analyze'):
257
+ analyze_greeting = "Ask me anything about the library collection!"
258
+ gr.ChatInterface(
259
+ fn=analyze_chat_response,
260
+ type='messages',
261
+ examples=analyze_examples,
262
+ chatbot=gr.Chatbot(type='messages', placeholder=analyze_greeting),
263
+ )
264
+
265
+ with gr.Tab(label='Manage'):
266
+
267
+ # Buttons
268
+ with gr.Row():
269
+ reload_button = gr.Button('Reload')
270
+ save_button = gr.Button('Save changes')
271
+
272
+ # Textbox to display status messages
273
+ status_message = gr.Textbox(show_label=False, value='Please reload after check out or check in.')
274
+
275
+ # Inventory table
276
+ inventory_table = gr.DataFrame(
277
+ value=inventory.get_df(),
278
+ interactive=True, # Allow editing
279
+ label="Current Library Inventory",
280
+ wrap=True
281
+ # column_widths=["1fr"]*len(inventory.get_dtypes())
282
+ )
283
+
284
+ # Attach functions to buttons
285
+ reload_button.click(fn=inventory.get_df, outputs=inventory_table).then(fn=lambda:f"Reloaded on {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", outputs=[status_message])
286
+ save_button.click(fn=save_inventory, inputs=inventory_table).then(fn=lambda:f"Saved on {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", outputs=[status_message])
287
+
288
+
289
+ if __name__ == '__main__':
290
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ unsloth
2
+ transformers==4.54.1 # the latest version, 4.55.0.dev0, results in error when running image processing
3
+ timm