ritaycw commited on
Commit
411fe56
·
verified ·
1 Parent(s): 132f5c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +226 -127
app.py CHANGED
@@ -1,147 +1,246 @@
1
- import io
2
- import random
3
- from typing import List, Tuple
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- import aiohttp
6
- import panel as pn
7
- from PIL import Image
8
- from transformers import CLIPModel, CLIPProcessor
9
 
10
- pn.extension(design="bootstrap", sizing_mode="stretch_width")
 
 
 
 
 
 
 
 
11
 
12
- ICON_URLS = {
13
- "brand-github": "https://github.com/holoviz/panel",
14
- "brand-twitter": "https://twitter.com/Panel_Org",
15
- "brand-linkedin": "https://www.linkedin.com/company/panel-org",
16
- "message-circle": "https://discourse.holoviz.org/",
17
- "brand-discord": "https://discord.gg/AXRHnJU6sP",
18
- }
19
 
 
 
 
20
 
21
- async def random_url(_):
22
- pet = random.choice(["cat", "dog"])
23
- api_url = f"https://api.the{pet}api.com/v1/images/search"
24
- async with aiohttp.ClientSession() as session:
25
- async with session.get(api_url) as resp:
26
- return (await resp.json())[0]["url"]
27
 
 
28
 
29
- @pn.cache
30
- def load_processor_model(
31
- processor_name: str, model_name: str
32
- ) -> Tuple[CLIPProcessor, CLIPModel]:
33
- processor = CLIPProcessor.from_pretrained(processor_name)
34
- model = CLIPModel.from_pretrained(model_name)
35
- return processor, model
36
 
37
 
38
- async def open_image_url(image_url: str) -> Image:
39
- async with aiohttp.ClientSession() as session:
40
- async with session.get(image_url) as resp:
41
- return Image.open(io.BytesIO(await resp.read()))
42
 
 
43
 
44
- def get_similarity_scores(class_items: List[str], image: Image) -> List[float]:
45
- processor, model = load_processor_model(
46
- "openai/clip-vit-base-patch32", "openai/clip-vit-base-patch32"
47
- )
48
- inputs = processor(
49
- text=class_items,
50
- images=[image],
51
- return_tensors="pt", # pytorch tensors
52
- )
53
- outputs = model(**inputs)
54
- logits_per_image = outputs.logits_per_image
55
- class_likelihoods = logits_per_image.softmax(dim=1).detach().numpy()
56
- return class_likelihoods[0]
57
-
58
-
59
- async def process_inputs(class_names: List[str], image_url: str):
60
- """
61
- High level function that takes in the user inputs and returns the
62
- classification results as panel objects.
63
- """
64
- try:
65
- main.disabled = True
66
- if not image_url:
67
- yield "##### ⚠️ Provide an image URL"
68
- return
69
-
70
- yield "##### ⚙ Fetching image and running model..."
71
- try:
72
- pil_img = await open_image_url(image_url)
73
- img = pn.pane.Image(pil_img, height=400, align="center")
74
- except Exception as e:
75
- yield f"##### 😔 Something went wrong, please try a different URL!"
76
- return
77
-
78
- class_items = class_names.split(",")
79
- class_likelihoods = get_similarity_scores(class_items, pil_img)
80
-
81
- # build the results column
82
- results = pn.Column("##### 🎉 Here are the results!", img)
83
-
84
- for class_item, class_likelihood in zip(class_items, class_likelihoods):
85
- row_label = pn.widgets.StaticText(
86
- name=class_item.strip(), value=f"{class_likelihood:.2%}", align="center"
87
- )
88
- row_bar = pn.indicators.Progress(
89
- value=int(class_likelihood * 100),
90
- sizing_mode="stretch_width",
91
- bar_color="secondary",
92
- margin=(0, 10),
93
- design=pn.theme.Material,
94
- )
95
- results.append(pn.Column(row_label, row_bar))
96
- yield results
97
- finally:
98
- main.disabled = False
99
-
100
-
101
- # create widgets
102
- randomize_url = pn.widgets.Button(name="Randomize URL", align="end")
103
-
104
- image_url = pn.widgets.TextInput(
105
- name="Image URL to classify",
106
- value=pn.bind(random_url, randomize_url),
 
 
 
107
  )
108
- class_names = pn.widgets.TextInput(
109
- name="Comma separated class names",
110
- placeholder="Enter possible class names, e.g. cat, dog",
111
- value="cat, dog, parrot",
 
112
  )
113
 
114
- input_widgets = pn.Column(
115
- "##### 😊 Click randomize or paste a URL to start classifying!",
116
- pn.Row(image_url, randomize_url),
117
- class_names,
118
  )
119
 
120
- # add interactivity
121
- interactive_result = pn.panel(
122
- pn.bind(process_inputs, image_url=image_url, class_names=class_names),
123
- height=600,
 
124
  )
125
 
126
- # add footer
127
- footer_row = pn.Row(pn.Spacer(), align="center")
128
- for icon, url in ICON_URLS.items():
129
- href_button = pn.widgets.Button(icon=icon, width=35, height=35)
130
- href_button.js_on_click(code=f"window.open('{url}')")
131
- footer_row.append(href_button)
132
- footer_row.append(pn.Spacer())
133
-
134
- # create dashboard
135
- main = pn.WidgetBox(
136
- input_widgets,
137
- interactive_result,
138
- footer_row,
139
  )
140
 
141
- title = "Panel Demo - Image Classification"
142
- pn.template.BootstrapTemplate(
143
- title=title,
144
- main=main,
145
- main_max_width="min(50%, 698px)",
146
- header_background="#F08080",
147
- ).servable(title=title)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # start with the setup
2
+
3
+ # supress warnings about future deprecations
4
+ import warnings
5
+ warnings.simplefilter(action='ignore', category=FutureWarning)
6
+
7
+ import pandas as pd
8
+ import altair as alt
9
+ import numpy as np
10
+ import pprint
11
+ import datetime as dt
12
+ from vega_datasets import data
13
+ import matplotlib.pyplot as plt
14
+
15
+ # Solve a javascript error by explicitly setting the renderer
16
+ alt.renderers.enable('jupyterlab')
17
+
18
+ #load data
19
+ df1=pd.read_csv("https://raw.githubusercontent.com/dallascard/SI649_public/main/altair_hw3/approval_polllist.csv")
20
+ df2=pd.read_csv("https://raw.githubusercontent.com/dallascard/SI649_public/main/altair_hw3/approval_topline.csv")
21
+
22
+ #change the approval ratings into percentage
23
+ df1['approve_percent']=df1['approve']/100
24
+ df1.head()
25
+
26
+ df2['timestamp']=pd.to_datetime(df2['timestamp'])
27
+ df2=pd.melt(df2, id_vars=['president', 'subgroup', 'timestamp'], value_vars=['approve','disapprove']).rename(columns={'variable':'choice', 'value':'rate'})
28
+ df2.head()
29
+
30
+
31
+ ##TODO: replicate vis 1
32
+
33
+ ##Static Component - Bars
34
+ barchart1_1 = alt.Chart(df1).transform_joinaggregate(
35
+ groupby=['pollster']
36
+ ).mark_bar(height=15).encode(
37
+ alt.X('mean(approve_percent):Q', axis=alt.Axis(labels=True, title=None)),
38
+ alt.Y('pollster:N', axis=alt.Axis(labels=True, title=None)),
39
+ alt.Tooltip('mean(approve_percent):Q', format='.0%')
40
+ ).properties(
41
+ title='Average Approval Ratings for Joe Biden'
42
+ )
43
 
44
+ ##Static Component - Vertical Line
45
+ vline1_1 = alt.Chart(df1).mark_rule(size=3, color="firebrick").encode(
46
+ alt.X('mean(approve_percent):Q')
47
+ )
48
 
49
+ ##Static Component - Text
50
+ text1_1 = vline1_1.mark_text(
51
+ color='firebrick',
52
+ fontSize=12,
53
+ align='left',
54
+ dx=7
55
+ ).encode(
56
+ alt.Text('mean(approve_percent):Q', format='.2%'),
57
+ )
58
 
59
+ ##Put all together
60
+ selection1 = alt.selection_interval(encodings=["y"])
61
+ condition1 = alt.condition(selection1, alt.value(1.0), alt.value(0.6))
 
 
 
 
62
 
63
+ barchart1_2 = barchart1_1.add_params(selection1).encode(
64
+ opacity = condition1
65
+ )
66
 
67
+ vline1_2 = vline1_1.add_params(selection1).transform_filter(selection1)
 
 
 
 
 
68
 
69
+ text1_2 = text1_1.transform_filter(selection1)
70
 
71
+ final_viz1 = barchart1_2 + vline1_2 + text1_2
72
+ final_viz1
 
 
 
 
 
73
 
74
 
 
 
 
 
75
 
76
+ #TODO: replicate vis2
77
 
78
+ # Create selection and condition
79
+ selection2 = alt.selection_interval(encodings=["x"])
80
+ condition2 = alt.condition(selection2, alt.value(1.0), alt.value(0.6))
81
+
82
+ # scatter plot
83
+ scatter2_1 = alt.Chart(df1).mark_point().transform_joinaggregate(
84
+ groupby=['pollster']
85
+ ).encode(
86
+ alt.X('startdate:T', axis=alt.Axis(labels=True, title=None)),
87
+ alt.Y('mean(adjusted_approve):Q', axis=alt.Axis(title='Approval ratings')),
88
+ color='pollster:N',
89
+ )
90
+
91
+ # bar chart
92
+ bar2_1 = alt.Chart(df1).mark_bar().transform_joinaggregate(
93
+ groupby=['pollster:N']
94
+ ).encode(
95
+ alt.X('mean(adjusted_approve):Q', axis=alt.Axis(title = 'Mean of Approval Ratings')),
96
+ alt.Y('pollster:N', axis=alt.Axis(labels=True, title=None)),
97
+ color='pollster:N'
98
+ )
99
+
100
+ # Put them all together
101
+ # scatter2_1 & bar2_1
102
+
103
+ scatter2_2 = scatter2_1.add_params(selection2).encode(
104
+ opacity = condition2
105
+ )
106
+
107
+ bar2_2 = bar2_1.add_params(selection2).transform_filter(selection2)
108
+
109
+ final_viz2 = (scatter2_2 & bar2_2).properties(
110
+ title='Recently Reported Approval Ratings for Joe Biden'
111
+ )
112
+
113
+ final_viz2
114
+
115
+
116
+ #TODO: replicate vis3
117
+ # https://altair-viz.github.io/gallery/multiline_tooltip.html
118
+
119
+ # Create a selection for zooming and panning across the x-axis
120
+ scale = alt.selection_interval(bind='scales', encodings=['x'])
121
+
122
+ # Create a selection and condition for the vertical line, annotation dots, and text annotations
123
+ nearest = alt.selection_point(on='mouseover', encodings=['x'], nearest=True, empty=False)
124
+ opacityCondition = alt.condition(nearest, alt.value(1), alt.value(0))
125
+
126
+ # Create the base chart and filter to All polls
127
+ base3 = alt.Chart(df2).mark_line(size=2.5).transform_filter(
128
+ alt.datum.subgroup =='All polls'
129
+ ).encode(
130
+ alt.X('timestamp:T', axis=alt.Axis(labels=True, title=None)),
131
+ y='rate:Q',
132
+ color='choice:N'
133
+ ).add_params(scale).properties(
134
+ title='Approval Ratings for Joe Biden 2021-2023'
135
+ )
136
+
137
+ # Static line chart
138
+ # Vertical line
139
+ selectors = alt.Chart(df2).mark_point().encode(
140
+ x='timestamp:T',
141
+ opacity=alt.value(0),
142
+ ).add_params(
143
+ nearest
144
  )
145
+
146
+ rules = alt.Chart(df2).mark_rule(size=4, color='lightgray').encode(
147
+ x='timestamp:T'
148
+ ).transform_filter(
149
+ nearest
150
  )
151
 
152
+ #interaction dots
153
+ points = base3.mark_point(size=90).encode(
154
+ opacity= opacityCondition # alt.condition(nearest, alt.value(1), alt.value(0))
 
155
  )
156
 
157
+ #interaction text labels
158
+ text = base3.mark_text(fontSize=14, align='left', dx=7).transform_filter(
159
+ alt.datum.subgroup =='All polls'
160
+ ).encode(
161
+ text=alt.condition(nearest, 'rate:Q', alt.value(' '), format='.2f')
162
  )
163
 
164
+ #Put them all together
165
+ alt.layer(
166
+ base3, selectors, points, rules, text
167
+ ).properties(
168
+ width=400, height=300
 
 
 
 
 
 
 
 
169
  )
170
 
171
+
172
+
173
+ ## Viz 4
174
+ # Import panel and vega datasets
175
+
176
+ import panel as pn
177
+ import vega_datasets
178
+
179
+ # Enable Panel extensions
180
+ pn.extension()
181
+
182
+ # Define a function to create and return a plot
183
+
184
+ def create_plot(subgroup, date_range, moving_av_window):
185
+
186
+ # Apply any required transformations to the data in pandas
187
+ approve_data = df2[df2['choice']=='approve']
188
+ filtered_data = approve_data[approve_data['subgroup'] == subgroup]
189
+ filtered_data = filtered_data[(filtered_data['timestamp'].dt.date >= date_range[0]) & \
190
+ (filtered_data['timestamp'].dt.date <= date_range[1])]
191
+ filtered_data['mov_avg'] = filtered_data['rate'].rolling(window=moving_av_window).mean().shift(-moving_av_window//2)
192
+
193
+ # Line chart
194
+ smoothed_line = alt.Chart(filtered_data).mark_line(color='red', size=2).encode(
195
+ x='timestamp:T',
196
+ y='mov_avg:Q'
197
+ )
198
+
199
+ # Scatter plot with individual polls
200
+ scatter4 = alt.Chart(filtered_data).mark_point(size=2, opacity=0.7, color='grey').encode(
201
+ x='timestamp:T',
202
+ y=alt.Y('rate:Q', title='approve', scale=alt.Scale(domain=[30, 60])),
203
+ ).properties(width=600, height=400)
204
+
205
+ # Put them togetehr
206
+ plot = (scatter4+smoothed_line).encode(
207
+ y=alt.Y(axis=alt.Axis(title='approve, mov_avg'))
208
+ )
209
+
210
+ # Return the combined chart
211
+ return plot
212
+
213
+ # Create the selection widget
214
+ dropdown = pn.widgets.Select(options=['All polls', 'Adults', 'Voters'], name='Select')
215
+
216
+ # Create the slider for the date range
217
+ date_range_slider = pn.widgets.DateRangeSlider(name='Date Range Slider',
218
+ start=df2['timestamp'].dt.date.min(),
219
+ end=df2['timestamp'].dt.date.max(),
220
+ value=(df2['timestamp'].dt.date.min(), df2['timestamp'].dt.date.max()),
221
+ step=1)
222
+
223
+ # Create the slider for the moving average window
224
+ window_size_slider = pn.widgets.IntSlider(name='Moving average window', start=1, end=100, value=1)
225
+
226
+ # Bind the widgets to the create_plot function
227
+ final_test = pn.Row(pn.bind(create_plot, subgroup=dropdown,
228
+ date_range=date_range_slider,
229
+ moving_av_window=window_size_slider))
230
+
231
+ # window_size_slider
232
+
233
+ # Combine everything in a Panel Column to create an app
234
+ maincol = pn.Column()
235
+ maincol.append(final_test)
236
+ maincol.append(dropdown)
237
+ maincol.append(date_range_slider)
238
+ maincol.append(window_size_slider)
239
+
240
+
241
+ # set the app to be servable
242
+ template = pn.template.BootstrapTemplate(
243
+ title='SI649 Altair Assignment 3',
244
+ )
245
+ template.main.append(maincol)
246
+ template.servable(title="SI649 Altair Assignment 3")