Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -23,7 +23,7 @@ bert_model.to(device)
|
|
| 23 |
bert_model.eval()
|
| 24 |
|
| 25 |
#SpaBERT Initialization Section
|
| 26 |
-
data_file_path = 'models/spabert/datasets/
|
| 27 |
pretrained_model_path = 'models/spabert/datasets/fine-spabert-base-uncased-finetuned-osm-mn.pth'
|
| 28 |
|
| 29 |
config = SpatialBertConfig()
|
|
@@ -49,7 +49,7 @@ spatialDataset = PbfMapDataset(data_file_path = data_file_path,
|
|
| 49 |
label_encoder = None,
|
| 50 |
mode = None) #If set to None it will use the full dataset for mlm
|
| 51 |
|
| 52 |
-
data_loader = DataLoader(spatialDataset, batch_size=1, num_workers=0, shuffle=False, pin_memory=False, drop_last=False)
|
| 53 |
|
| 54 |
# Create a dictionary to map entity names to indices
|
| 55 |
entity_index_dict = {entity['pivot_name']: i for i, entity in enumerate(spatialDataset)}
|
|
@@ -87,9 +87,7 @@ def process_entity(batch, model, device):
|
|
| 87 |
return spaBERT_embedding, input_ids
|
| 88 |
|
| 89 |
spaBERT_embeddings = []
|
| 90 |
-
for
|
| 91 |
-
if i >= 2: # Stop after processing 3 batches
|
| 92 |
-
break
|
| 93 |
spaBERT_embedding, input_ids = process_entity(batch, spaBERT_model, device)
|
| 94 |
spaBERT_embeddings.append(spaBERT_embedding)
|
| 95 |
|
|
|
|
| 23 |
bert_model.eval()
|
| 24 |
|
| 25 |
#SpaBERT Initialization Section
|
| 26 |
+
data_file_path = 'models/spabert/datasets/SpaBERTPivots.json' #Sample file otherwise this model will take too long on CPU.
|
| 27 |
pretrained_model_path = 'models/spabert/datasets/fine-spabert-base-uncased-finetuned-osm-mn.pth'
|
| 28 |
|
| 29 |
config = SpatialBertConfig()
|
|
|
|
| 49 |
label_encoder = None,
|
| 50 |
mode = None) #If set to None it will use the full dataset for mlm
|
| 51 |
|
| 52 |
+
data_loader = DataLoader(spatialDataset, batch_size=1, num_workers=0, shuffle=False, pin_memory=False, drop_last=False)
|
| 53 |
|
| 54 |
# Create a dictionary to map entity names to indices
|
| 55 |
entity_index_dict = {entity['pivot_name']: i for i, entity in enumerate(spatialDataset)}
|
|
|
|
| 87 |
return spaBERT_embedding, input_ids
|
| 88 |
|
| 89 |
spaBERT_embeddings = []
|
| 90 |
+
for batch in (data_loader):
|
|
|
|
|
|
|
| 91 |
spaBERT_embedding, input_ids = process_entity(batch, spaBERT_model, device)
|
| 92 |
spaBERT_embeddings.append(spaBERT_embedding)
|
| 93 |
|