Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -37,7 +37,7 @@ spaBERT_model.load_state_dict(pre_trained_model, strict=False)
|
|
| 37 |
spaBERT_model.to(device)
|
| 38 |
spaBERT_model.eval()
|
| 39 |
|
| 40 |
-
#
|
| 41 |
spatialDataset = PbfMapDataset(data_file_path = data_file_path,
|
| 42 |
tokenizer = bert_tokenizer,
|
| 43 |
max_token_len = 256, #Originally 300
|
|
@@ -51,6 +51,48 @@ spatialDataset = PbfMapDataset(data_file_path = data_file_path,
|
|
| 51 |
|
| 52 |
data_loader = DataLoader(spatialDataset, batch_size=1, num_workers=0, shuffle=False, pin_memory=False, drop_last=False) #issue needs to be fixed with num_workers not stopping after finished
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
#Get BERT Embedding for review
|
| 55 |
def get_bert_embedding(review_text):
|
| 56 |
#tokenize review
|
|
|
|
| 37 |
spaBERT_model.to(device)
|
| 38 |
spaBERT_model.eval()
|
| 39 |
|
| 40 |
+
#Load data using SpatialDataset
|
| 41 |
spatialDataset = PbfMapDataset(data_file_path = data_file_path,
|
| 42 |
tokenizer = bert_tokenizer,
|
| 43 |
max_token_len = 256, #Originally 300
|
|
|
|
| 51 |
|
| 52 |
data_loader = DataLoader(spatialDataset, batch_size=1, num_workers=0, shuffle=False, pin_memory=False, drop_last=False) #issue needs to be fixed with num_workers not stopping after finished
|
| 53 |
|
| 54 |
+
#Pre-aquire the SpaBERT embeddings for all geo-entities within our dataset
|
| 55 |
+
def process_entity(batch, model, device):
|
| 56 |
+
input_ids = batch['masked_input'].to(device)
|
| 57 |
+
attention_mask = batch['attention_mask'].to(device)
|
| 58 |
+
position_list_x = batch['norm_lng_list'].to(device)
|
| 59 |
+
position_list_y = batch['norm_lat_list'].to(device)
|
| 60 |
+
sent_position_ids = batch['sent_position_ids'].to(device)
|
| 61 |
+
pseudo_sentence = batch['pseudo_sentence'].to(device)
|
| 62 |
+
|
| 63 |
+
# Convert tensor to list of token IDs, and decode them into a readable sentence
|
| 64 |
+
pseudo_sentence_decoded = tokenizer.decode(pseudo_sentence[0].tolist(), skip_special_tokens=False)
|
| 65 |
+
|
| 66 |
+
with torch.no_grad():
|
| 67 |
+
outputs = spaBERT_model(#input_ids=input_ids,
|
| 68 |
+
input_ids=pseudo_sentence,
|
| 69 |
+
attention_mask=attention_mask,
|
| 70 |
+
sent_position_ids=sent_position_ids,
|
| 71 |
+
position_list_x=position_list_x,
|
| 72 |
+
position_list_y=position_list_y)
|
| 73 |
+
#NOTE: we are ommitting the pseudo_sentence here. Verify that this is correct
|
| 74 |
+
|
| 75 |
+
embeddings = outputs.hidden_states[-1].to(device)
|
| 76 |
+
|
| 77 |
+
# Extract the [CLS] token embedding (first token)
|
| 78 |
+
embedding = embeddings[:, 0, :].detach() # [batch_size, hidden_size]
|
| 79 |
+
|
| 80 |
+
#pivot_token_len = batch['pivot_token_len'].item()
|
| 81 |
+
#pivot_embeddings = embeddings[:, :pivot_token_len, :]
|
| 82 |
+
|
| 83 |
+
#return pivot_embeddings.cpu().numpy(), input_ids.cpu().numpy()
|
| 84 |
+
return embedding.cpu().numpy(), input_ids.cpu().numpy()
|
| 85 |
+
|
| 86 |
+
all_embeddings = []
|
| 87 |
+
for batch in (data_loader):
|
| 88 |
+
embeddings, input_ids = process_entity(batch, model, device)
|
| 89 |
+
all_embeddings.append(embeddings)
|
| 90 |
+
|
| 91 |
+
st.write("SpaBERT Embedding shape:", all_embeddings[0].shape)
|
| 92 |
+
st.write("SpaBERT Embedding:", all_embeddings[0])
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
|
| 96 |
#Get BERT Embedding for review
|
| 97 |
def get_bert_embedding(review_text):
|
| 98 |
#tokenize review
|