NimaKL commited on
Commit
b80398e
Β·
1 Parent(s): ef0c030

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -0
app.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import numpy as np
4
+ from transformers import AutoTokenizer
5
+ from transformers import BertForSequenceClassification
6
+
7
+
8
+ st.set_page_config(layout='wide', initial_sidebar_state='expanded')
9
+ col1, col2= st.columns(2)
10
+
11
+ with col1:
12
+ st.title("FireWatch")
13
+ st.markdown("PREDICT WHETHER HEAT SIGNATURES AROUND THE GLOBE ARE LIKELY TO BE FIRES!")
14
+ st.markdown("Traing Code at:")
15
+ st.markdown("https://colab.research.google.com/drive/1-IfOMJ-X8MKzwm3UjbJbK6RmhT7tk_ye?usp=sharing")
16
+ st.markdown("Try the Model Yourself at:")
17
+ st.markdown("https://colab.research.google.com/drive/1GmweeQrkzs0OXQ_KNZsWd1PQVRLCWDKi?usp=sharing")
18
+
19
+
20
+ table_html = """
21
+ <table style="border-collapse: collapse; width: 100%;">
22
+ <tr style="border: 1px solid orange;">
23
+ <th style="border: 1px solid orange; font-weight: bold;">Category</th>
24
+ <th style="border: 1px solid orange; font-weight: bold;">Latitude, Longitude, Brightness, FRP</th>
25
+ </tr>
26
+ <tr style="border: 1px solid orange;">
27
+ <td style="border: 1px solid orange;">Likely</td>
28
+ <td style="border: 1px solid orange;">-26.76123, 147.15512, 393.02, 203.63</td>
29
+ </tr>
30
+ <tr style="border: 1px solid orange;">
31
+ <td style="border: 1px solid orange;">Likely</td>
32
+ <td style="border: 1px solid orange;">-26.7598, 147.14514, 361.54, 79.4</td>
33
+ </tr>
34
+ <tr style="border: 1px solid orange;">
35
+ <td style="border: 1px solid orange;">Unlikely</td>
36
+ <td style="border: 1px solid orange;">-25.70059, 149.48932, 313.9, 5.15</td>
37
+ </tr>
38
+ <tr style="border: 1px solid orange;">
39
+ <td style="border: 1px solid orange;">Unlikely</td>
40
+ <td style="border: 1px solid orange;">-24.4318, 151.83102, 307.98, 8.79</td>
41
+ </tr>
42
+ <tr style="border: 1px solid orange;">
43
+ <td style="border: 1px solid orange;">Unlikely</td>
44
+ <td style="border: 1px solid orange;">-23.21878, 148.91298, 314.08, 7.4</td>
45
+ </tr>
46
+ <tr style="border: 1px solid orange;">
47
+ <td style="border: 1px solid orange;">Likely</td>
48
+ <td style="border: 1px solid orange;">7.87518, 19.9241, 316.32, 39.63</td>
49
+ </tr>
50
+ <tr style="border: 1px solid orange;">
51
+ <td style="border: 1px solid orange;">Unlikely</td>
52
+ <td style="border: 1px solid orange;">-20.10942, 148.14326, 314.39, 8.8</td>
53
+ </tr>
54
+ <tr style="border: 1px solid orange;">
55
+ <td style="border: 1px solid orange;">Unlikely</td>
56
+ <td style="border: 1px solid orange;">7.87772, 19.9048, 304.14, 13.43</td>
57
+ </tr>
58
+ <tr style="border: 1px solid orange;">
59
+ <td style="border: 1px solid orange;">Likely</td>
60
+ <td style="border: 1px solid yellow;">7.8879, 19.92571, 328.6, 77.78</td>
61
+ </tr>
62
+ </table>
63
+ """
64
+
65
+ st.markdown(table_html, unsafe_allow_html=True)
66
+
67
+
68
+ @st.cache(suppress_st_warning=True, allow_output_mutation=True)
69
+ def load_model(show_spinner=True):
70
+ MODEL_PATH = "NimaKL/FireWatch_5k"
71
+ model = BertForSequenceClassification.from_pretrained(MODEL_PATH)
72
+ return model
73
+
74
+
75
+
76
+ token_id = []
77
+ attention_masks = []
78
+ def preprocessing(input_text, tokenizer):
79
+ '''
80
+ Returns <class transformers.tokenization_utils_base.BatchEncoding> with the following fields:
81
+ - input_ids: list of token ids
82
+ - token_type_ids: list of token type ids
83
+ - attention_mask: list of indices (0,1) specifying which tokens should considered by the model (return_attention_mask = True).
84
+ '''
85
+ return tokenizer.encode_plus(
86
+ input_text,
87
+ add_special_tokens = True,
88
+ max_length = 16,
89
+ pad_to_max_length = True,
90
+ return_attention_mask = True,
91
+ return_tensors = 'pt'
92
+ )
93
+
94
+ def predict(new_sentence):
95
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
96
+ # We need Token IDs and Attention Mask for inference on the new sentence
97
+ test_ids = []
98
+ test_attention_mask = []
99
+ # Apply the tokenizer
100
+ encoding = preprocessing(new_sentence, tokenizer)
101
+ # Extract IDs and Attention Mask
102
+ test_ids.append(encoding['input_ids'])
103
+ test_attention_mask.append(encoding['attention_mask'])
104
+ test_ids = torch.cat(test_ids, dim = 0)
105
+ test_attention_mask = torch.cat(test_attention_mask, dim = 0)
106
+ # Forward pass, calculate logit predictions
107
+ with torch.no_grad():
108
+ output = model(test_ids.to(device), token_type_ids = None, attention_mask = test_attention_mask.to(device))
109
+ prediction = 'Likely' if np.argmax(output.logits.cpu().numpy()).flatten().item() == 1 else 'Unlikely'
110
+ pred = 'Predicted Class: '+ prediction
111
+ return pred
112
+
113
+ model = load_model()
114
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
115
+ with col2:
116
+ text = st.text_input('Enter Prediction Data in Correct Format "Latitude, Longtitude, Brightness, FRP".\nExample: 8.81064, -65.07661, 328.04, 18.76 \nPredition Data: ')
117
+ aButton = st.button('Predict')
118
+
119
+ if text or aButton:
120
+ with st.spinner('Wait for it...'):
121
+ st.success(predict(text))
122
+