upload phasenet
Browse files- .gitattributes +6 -0
- model/.DS_Store +0 -0
- model/190703-214543/checkpoint +3 -0
- model/190703-214543/config.log +3 -0
- model/190703-214543/loss.log +3 -0
- model/190703-214543/model_95.ckpt.data-00000-of-00001 +3 -0
- model/190703-214543/model_95.ckpt.index +3 -0
- model/190703-214543/model_95.ckpt.meta +3 -0
- phasenet/.DS_Store +0 -0
- phasenet/__init__.py +1 -0
- phasenet/__pycache__/__init__.cpython-39.pyc +0 -0
- phasenet/__pycache__/detect_peaks.cpython-39.pyc +0 -0
- phasenet/__pycache__/model.cpython-39.pyc +0 -0
- phasenet/__pycache__/postprocess.cpython-39.pyc +0 -0
- phasenet/app.py +331 -0
- phasenet/data_reader.py +964 -0
- phasenet/detect_peaks.py +207 -0
- phasenet/model.py +489 -0
- phasenet/postprocess.py +383 -0
- phasenet/predict.py +266 -0
- phasenet/slide_window.py +88 -0
- phasenet/test_app.py +47 -0
- phasenet/train.py +246 -0
- phasenet/util.py +238 -0
- phasenet/visulization.py +481 -0
- pipeline.py +40 -2
- requirements.txt +1 -1
.gitattributes
CHANGED
|
@@ -32,3 +32,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
model/190703-214543/checkpoint filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
model/190703-214543/config.log filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
model/190703-214543/loss.log filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
model/190703-214543/model_95.ckpt.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
model/190703-214543/model_95.ckpt.index filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
model/190703-214543/model_95.ckpt.meta filter=lfs diff=lfs merge=lfs -text
|
model/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
model/190703-214543/checkpoint
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1606ccb25e1533fa0398c5dbce7f3a45ac77f90b78b99f81a044294ba38a2c0c
|
| 3 |
+
size 83
|
model/190703-214543/config.log
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ed9dfa705053a5025facc9952c7da6abef19ec5f672d9e50386bf3f2d80294f2
|
| 3 |
+
size 345
|
model/190703-214543/loss.log
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ccb6f19117497571e19bec5da6012ac7af91f1bd29e931ffd0b23c6b657bb401
|
| 3 |
+
size 8101
|
model/190703-214543/model_95.ckpt.data-00000-of-00001
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9ee2c15dd78fb15de45a55ad64a446f1a0ced152ba4ac5c506d82b9194da85b4
|
| 3 |
+
size 3226256
|
model/190703-214543/model_95.ckpt.index
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f96b553b76be4ebae9a455eaf8d83cfa8c0e110f06cfba958de2568e5b6b2780
|
| 3 |
+
size 7223
|
model/190703-214543/model_95.ckpt.meta
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7ebd154a5ba0721ba8bbb627ba61b556ee60660eb34bbcd1b1f50396b07cc4ed
|
| 3 |
+
size 2172055
|
phasenet/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
phasenet/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__version__ = "0.1.0"
|
phasenet/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (178 Bytes). View file
|
|
|
phasenet/__pycache__/detect_peaks.cpython-39.pyc
ADDED
|
Binary file (6.71 kB). View file
|
|
|
phasenet/__pycache__/model.cpython-39.pyc
ADDED
|
Binary file (12 kB). View file
|
|
|
phasenet/__pycache__/postprocess.cpython-39.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
phasenet/app.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from collections import defaultdict, namedtuple
|
| 3 |
+
from datetime import datetime, timedelta
|
| 4 |
+
from json import dumps
|
| 5 |
+
from typing import Any, AnyStr, Dict, List, NamedTuple, Union, Optional
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import requests
|
| 9 |
+
import tensorflow as tf
|
| 10 |
+
from fastapi import FastAPI
|
| 11 |
+
from kafka import KafkaProducer
|
| 12 |
+
from pydantic import BaseModel
|
| 13 |
+
from scipy.interpolate import interp1d
|
| 14 |
+
|
| 15 |
+
from model import ModelConfig, UNet
|
| 16 |
+
from postprocess import extract_picks
|
| 17 |
+
|
| 18 |
+
tf.compat.v1.disable_eager_execution()
|
| 19 |
+
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
| 20 |
+
PROJECT_ROOT = os.path.realpath(os.path.join(os.path.dirname(__file__), ".."))
|
| 21 |
+
JSONObject = Dict[AnyStr, Any]
|
| 22 |
+
JSONArray = List[Any]
|
| 23 |
+
JSONStructure = Union[JSONArray, JSONObject]
|
| 24 |
+
|
| 25 |
+
app = FastAPI()
|
| 26 |
+
X_SHAPE = [3000, 1, 3]
|
| 27 |
+
SAMPLING_RATE = 100
|
| 28 |
+
|
| 29 |
+
# load model
|
| 30 |
+
model = UNet(mode="pred")
|
| 31 |
+
sess_config = tf.compat.v1.ConfigProto()
|
| 32 |
+
sess_config.gpu_options.allow_growth = True
|
| 33 |
+
|
| 34 |
+
sess = tf.compat.v1.Session(config=sess_config)
|
| 35 |
+
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
|
| 36 |
+
init = tf.compat.v1.global_variables_initializer()
|
| 37 |
+
sess.run(init)
|
| 38 |
+
latest_check_point = tf.train.latest_checkpoint(f"{PROJECT_ROOT}/model/190703-214543")
|
| 39 |
+
print(f"restoring model {latest_check_point}")
|
| 40 |
+
saver.restore(sess, latest_check_point)
|
| 41 |
+
|
| 42 |
+
# GAMMA API Endpoint
|
| 43 |
+
GAMMA_API_URL = "http://gamma-api:8001"
|
| 44 |
+
# GAMMA_API_URL = 'http://localhost:8001'
|
| 45 |
+
# GAMMA_API_URL = "http://gamma.quakeflow.com"
|
| 46 |
+
# GAMMA_API_URL = "http://127.0.0.1:8001"
|
| 47 |
+
|
| 48 |
+
# Kafak producer
|
| 49 |
+
use_kafka = False
|
| 50 |
+
|
| 51 |
+
try:
|
| 52 |
+
print("Connecting to k8s kafka")
|
| 53 |
+
BROKER_URL = "quakeflow-kafka-headless:9092"
|
| 54 |
+
# BROKER_URL = "34.83.137.139:9094"
|
| 55 |
+
producer = KafkaProducer(
|
| 56 |
+
bootstrap_servers=[BROKER_URL],
|
| 57 |
+
key_serializer=lambda x: dumps(x).encode("utf-8"),
|
| 58 |
+
value_serializer=lambda x: dumps(x).encode("utf-8"),
|
| 59 |
+
)
|
| 60 |
+
use_kafka = True
|
| 61 |
+
print("k8s kafka connection success!")
|
| 62 |
+
except BaseException:
|
| 63 |
+
print("k8s Kafka connection error")
|
| 64 |
+
try:
|
| 65 |
+
print("Connecting to local kafka")
|
| 66 |
+
producer = KafkaProducer(
|
| 67 |
+
bootstrap_servers=["localhost:9092"],
|
| 68 |
+
key_serializer=lambda x: dumps(x).encode("utf-8"),
|
| 69 |
+
value_serializer=lambda x: dumps(x).encode("utf-8"),
|
| 70 |
+
)
|
| 71 |
+
use_kafka = True
|
| 72 |
+
print("local kafka connection success!")
|
| 73 |
+
except BaseException:
|
| 74 |
+
print("local Kafka connection error")
|
| 75 |
+
print(f"Kafka status: {use_kafka}")
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def normalize_batch(data, window=3000):
|
| 79 |
+
"""
|
| 80 |
+
data: nsta, nt, nch
|
| 81 |
+
"""
|
| 82 |
+
shift = window // 2
|
| 83 |
+
nsta, nt, nch = data.shape
|
| 84 |
+
|
| 85 |
+
# std in slide windows
|
| 86 |
+
data_pad = np.pad(data, ((0, 0), (window // 2, window // 2), (0, 0)), mode="reflect")
|
| 87 |
+
t = np.arange(0, nt, shift, dtype="int")
|
| 88 |
+
std = np.zeros([nsta, len(t) + 1, nch])
|
| 89 |
+
mean = np.zeros([nsta, len(t) + 1, nch])
|
| 90 |
+
for i in range(1, len(t)):
|
| 91 |
+
std[:, i, :] = np.std(data_pad[:, i * shift : i * shift + window, :], axis=1)
|
| 92 |
+
mean[:, i, :] = np.mean(data_pad[:, i * shift : i * shift + window, :], axis=1)
|
| 93 |
+
|
| 94 |
+
t = np.append(t, nt)
|
| 95 |
+
# std[:, -1, :] = np.std(data_pad[:, -window:, :], axis=1)
|
| 96 |
+
# mean[:, -1, :] = np.mean(data_pad[:, -window:, :], axis=1)
|
| 97 |
+
std[:, -1, :], mean[:, -1, :] = std[:, -2, :], mean[:, -2, :]
|
| 98 |
+
std[:, 0, :], mean[:, 0, :] = std[:, 1, :], mean[:, 1, :]
|
| 99 |
+
std[std == 0] = 1
|
| 100 |
+
|
| 101 |
+
# ## normalize data with interplated std
|
| 102 |
+
t_interp = np.arange(nt, dtype="int")
|
| 103 |
+
std_interp = interp1d(t, std, axis=1, kind="slinear")(t_interp)
|
| 104 |
+
mean_interp = interp1d(t, mean, axis=1, kind="slinear")(t_interp)
|
| 105 |
+
data = (data - mean_interp) / std_interp
|
| 106 |
+
|
| 107 |
+
return data
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def preprocess(data):
|
| 111 |
+
raw = data.copy()
|
| 112 |
+
data = normalize_batch(data)
|
| 113 |
+
if len(data.shape) == 3:
|
| 114 |
+
data = data[:, :, np.newaxis, :]
|
| 115 |
+
raw = raw[:, :, np.newaxis, :]
|
| 116 |
+
return data, raw
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def calc_timestamp(timestamp, sec):
|
| 120 |
+
timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f") + timedelta(seconds=sec)
|
| 121 |
+
return timestamp.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def format_picks(picks, dt, amplitudes):
|
| 125 |
+
picks_ = []
|
| 126 |
+
for pick, amplitude in zip(picks, amplitudes):
|
| 127 |
+
for idxs, probs, amps in zip(pick.p_idx, pick.p_prob, amplitude.p_amp):
|
| 128 |
+
for idx, prob, amp in zip(idxs, probs, amps):
|
| 129 |
+
picks_.append(
|
| 130 |
+
{
|
| 131 |
+
"id": pick.fname,
|
| 132 |
+
"timestamp": calc_timestamp(pick.t0, float(idx) * dt),
|
| 133 |
+
"prob": prob,
|
| 134 |
+
"amp": amp,
|
| 135 |
+
"type": "p",
|
| 136 |
+
}
|
| 137 |
+
)
|
| 138 |
+
for idxs, probs, amps in zip(pick.s_idx, pick.s_prob, amplitude.s_amp):
|
| 139 |
+
for idx, prob, amp in zip(idxs, probs, amps):
|
| 140 |
+
picks_.append(
|
| 141 |
+
{
|
| 142 |
+
"id": pick.fname,
|
| 143 |
+
"timestamp": calc_timestamp(pick.t0, float(idx) * dt),
|
| 144 |
+
"prob": prob,
|
| 145 |
+
"amp": amp,
|
| 146 |
+
"type": "s",
|
| 147 |
+
}
|
| 148 |
+
)
|
| 149 |
+
return picks_
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def format_data(data):
|
| 153 |
+
|
| 154 |
+
# chn2idx = {"ENZ": {"E":0, "N":1, "Z":2},
|
| 155 |
+
# "123": {"3":0, "2":1, "1":2},
|
| 156 |
+
# "12Z": {"1":0, "2":1, "Z":2}}
|
| 157 |
+
chn2idx = {"E": 0, "N": 1, "Z": 2, "3": 0, "2": 1, "1": 2}
|
| 158 |
+
Data = NamedTuple("data", [("id", list), ("timestamp", list), ("vec", list), ("dt", float)])
|
| 159 |
+
|
| 160 |
+
# Group by station
|
| 161 |
+
chn_ = defaultdict(list)
|
| 162 |
+
t0_ = defaultdict(list)
|
| 163 |
+
vv_ = defaultdict(list)
|
| 164 |
+
for i in range(len(data.id)):
|
| 165 |
+
key = data.id[i][:-1]
|
| 166 |
+
chn_[key].append(data.id[i][-1])
|
| 167 |
+
t0_[key].append(datetime.strptime(data.timestamp[i], "%Y-%m-%dT%H:%M:%S.%f").timestamp() * SAMPLING_RATE)
|
| 168 |
+
vv_[key].append(np.array(data.vec[i]))
|
| 169 |
+
|
| 170 |
+
# Merge to Data tuple
|
| 171 |
+
id_ = []
|
| 172 |
+
timestamp_ = []
|
| 173 |
+
vec_ = []
|
| 174 |
+
for k in chn_:
|
| 175 |
+
id_.append(k)
|
| 176 |
+
min_t0 = min(t0_[k])
|
| 177 |
+
timestamp_.append(datetime.fromtimestamp(min_t0 / SAMPLING_RATE).strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3])
|
| 178 |
+
vec = np.zeros([X_SHAPE[0], X_SHAPE[-1]])
|
| 179 |
+
for i in range(len(chn_[k])):
|
| 180 |
+
# vec[int(t0_[k][i]-min_t0):len(vv_[k][i]), chn2idx[chn_[k][i]]] = vv_[k][i][int(t0_[k][i]-min_t0):X_SHAPE[0]] - np.mean(vv_[k][i])
|
| 181 |
+
shift = int(t0_[k][i] - min_t0)
|
| 182 |
+
vec[shift : len(vv_[k][i]) + shift, chn2idx[chn_[k][i]]] = vv_[k][i][: X_SHAPE[0] - shift] - np.mean(
|
| 183 |
+
vv_[k][i][: X_SHAPE[0] - shift]
|
| 184 |
+
)
|
| 185 |
+
vec_.append(vec.tolist())
|
| 186 |
+
|
| 187 |
+
return Data(id=id_, timestamp=timestamp_, vec=vec_, dt=1 / SAMPLING_RATE)
|
| 188 |
+
# return {"id": id_, "timestamp": timestamp_, "vec": vec_, "dt":1 / SAMPLING_RATE}
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def get_prediction(data, return_preds=False):
|
| 192 |
+
|
| 193 |
+
vec = np.array(data.vec)
|
| 194 |
+
vec, vec_raw = preprocess(vec)
|
| 195 |
+
|
| 196 |
+
feed = {model.X: vec, model.drop_rate: 0, model.is_training: False}
|
| 197 |
+
preds = sess.run(model.preds, feed_dict=feed)
|
| 198 |
+
|
| 199 |
+
picks = extract_picks(preds, station_ids=data.id, begin_times=data.timestamp, waveforms=vec_raw)
|
| 200 |
+
|
| 201 |
+
picks = [{k: v for k, v in pick.items() if k in ["station_id", "phase_time", "phase_score", "phase_type", "dt"]} for pick in picks]
|
| 202 |
+
|
| 203 |
+
if return_preds:
|
| 204 |
+
return picks, preds
|
| 205 |
+
|
| 206 |
+
return picks
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class Data(BaseModel):
|
| 210 |
+
# id: Union[List[str], str]
|
| 211 |
+
# timestamp: Union[List[str], str]
|
| 212 |
+
# vec: Union[List[List[List[float]]], List[List[float]]]
|
| 213 |
+
id: List[str]
|
| 214 |
+
timestamp: List[str]
|
| 215 |
+
vec: Union[List[List[List[float]]], List[List[float]]]
|
| 216 |
+
dt: Optional[float] = 0.01
|
| 217 |
+
## gamma
|
| 218 |
+
stations: Optional[List[Dict[str, Union[float, str]]]] = None
|
| 219 |
+
config: Optional[Dict[str, Union[List[float], List[int], List[str], float, int, str]]] = None
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
# @app.on_event("startup")
|
| 223 |
+
# def set_default_executor():
|
| 224 |
+
# from concurrent.futures import ThreadPoolExecutor
|
| 225 |
+
# import asyncio
|
| 226 |
+
#
|
| 227 |
+
# loop = asyncio.get_running_loop()
|
| 228 |
+
# loop.set_default_executor(
|
| 229 |
+
# ThreadPoolExecutor(max_workers=2)
|
| 230 |
+
# )
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
@app.post("/predict")
|
| 234 |
+
def predict(data: Data):
|
| 235 |
+
|
| 236 |
+
picks = get_prediction(data)
|
| 237 |
+
|
| 238 |
+
return picks
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
@app.post("/predict_prob")
|
| 242 |
+
def predict(data: Data):
|
| 243 |
+
|
| 244 |
+
picks, preds = get_prediction(data, True)
|
| 245 |
+
|
| 246 |
+
return picks, preds.tolist()
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
@app.post("/predict_phasenet2gamma")
|
| 250 |
+
def predict(data: Data):
|
| 251 |
+
|
| 252 |
+
picks = get_prediction(data)
|
| 253 |
+
|
| 254 |
+
# if use_kafka:
|
| 255 |
+
# print("Push picks to kafka...")
|
| 256 |
+
# for pick in picks:
|
| 257 |
+
# producer.send("phasenet_picks", key=pick["id"], value=pick)
|
| 258 |
+
try:
|
| 259 |
+
catalog = requests.post(f"{GAMMA_API_URL}/predict", json={"picks": picks,
|
| 260 |
+
"stations": data.stations,
|
| 261 |
+
"config": data.config})
|
| 262 |
+
print(catalog.json()["catalog"])
|
| 263 |
+
return catalog.json()
|
| 264 |
+
except Exception as error:
|
| 265 |
+
print(error)
|
| 266 |
+
|
| 267 |
+
return {}
|
| 268 |
+
|
| 269 |
+
@app.post("/predict_phasenet2gamma2ui")
|
| 270 |
+
def predict(data: Data):
|
| 271 |
+
|
| 272 |
+
picks = get_prediction(data)
|
| 273 |
+
|
| 274 |
+
try:
|
| 275 |
+
catalog = requests.post(f"{GAMMA_API_URL}/predict", json={"picks": picks,
|
| 276 |
+
"stations": data.stations,
|
| 277 |
+
"config": data.config})
|
| 278 |
+
print(catalog.json()["catalog"])
|
| 279 |
+
return catalog.json()
|
| 280 |
+
except Exception as error:
|
| 281 |
+
print(error)
|
| 282 |
+
|
| 283 |
+
if use_kafka:
|
| 284 |
+
print("Push picks to kafka...")
|
| 285 |
+
for pick in picks:
|
| 286 |
+
producer.send("phasenet_picks", key=pick["id"], value=pick)
|
| 287 |
+
print("Push waveform to kafka...")
|
| 288 |
+
for id, timestamp, vec in zip(data.id, data.timestamp, data.vec):
|
| 289 |
+
producer.send("waveform_phasenet", key=id, value={"timestamp": timestamp, "vec": vec, "dt": data.dt})
|
| 290 |
+
|
| 291 |
+
return {}
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
@app.post("/predict_stream_phasenet2gamma")
|
| 295 |
+
def predict(data: Data):
|
| 296 |
+
|
| 297 |
+
data = format_data(data)
|
| 298 |
+
# for i in range(len(data.id)):
|
| 299 |
+
# plt.clf()
|
| 300 |
+
# plt.subplot(311)
|
| 301 |
+
# plt.plot(np.array(data.vec)[i, :, 0])
|
| 302 |
+
# plt.subplot(312)
|
| 303 |
+
# plt.plot(np.array(data.vec)[i, :, 1])
|
| 304 |
+
# plt.subplot(313)
|
| 305 |
+
# plt.plot(np.array(data.vec)[i, :, 2])
|
| 306 |
+
# plt.savefig(f"{data.id[i]}.png")
|
| 307 |
+
|
| 308 |
+
picks = get_prediction(data)
|
| 309 |
+
|
| 310 |
+
return_value = {}
|
| 311 |
+
try:
|
| 312 |
+
catalog = requests.post(f"{GAMMA_API_URL}/predict_stream", json={"picks": picks})
|
| 313 |
+
print("GMMA:", catalog.json()["catalog"])
|
| 314 |
+
return_value = catalog.json()
|
| 315 |
+
except Exception as error:
|
| 316 |
+
print(error)
|
| 317 |
+
|
| 318 |
+
if use_kafka:
|
| 319 |
+
print("Push picks to kafka...")
|
| 320 |
+
for pick in picks:
|
| 321 |
+
producer.send("phasenet_picks", key=pick["id"], value=pick)
|
| 322 |
+
print("Push waveform to kafka...")
|
| 323 |
+
for id, timestamp, vec in zip(data.id, data.timestamp, data.vec):
|
| 324 |
+
producer.send("waveform_phasenet", key=id, value={"timestamp": timestamp, "vec": vec, "dt": data.dt})
|
| 325 |
+
|
| 326 |
+
return return_value
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
@app.get("/healthz")
|
| 330 |
+
def healthz():
|
| 331 |
+
return {"status": "ok"}
|
phasenet/data_reader.py
ADDED
|
@@ -0,0 +1,964 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
|
| 3 |
+
tf.compat.v1.disable_eager_execution()
|
| 4 |
+
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
|
| 11 |
+
pd.options.mode.chained_assignment = None
|
| 12 |
+
import json
|
| 13 |
+
|
| 14 |
+
# import s3fs
|
| 15 |
+
import h5py
|
| 16 |
+
import obspy
|
| 17 |
+
from scipy.interpolate import interp1d
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def py_func_decorator(output_types=None, output_shapes=None, name=None):
|
| 22 |
+
def decorator(func):
|
| 23 |
+
def call(*args, **kwargs):
|
| 24 |
+
nonlocal output_shapes
|
| 25 |
+
# flat_output_types = nest.flatten(output_types)
|
| 26 |
+
flat_output_types = tf.nest.flatten(output_types)
|
| 27 |
+
# flat_values = tf.py_func(
|
| 28 |
+
flat_values = tf.numpy_function(func, inp=args, Tout=flat_output_types, name=name)
|
| 29 |
+
if output_shapes is not None:
|
| 30 |
+
for v, s in zip(flat_values, output_shapes):
|
| 31 |
+
v.set_shape(s)
|
| 32 |
+
# return nest.pack_sequence_as(output_types, flat_values)
|
| 33 |
+
return tf.nest.pack_sequence_as(output_types, flat_values)
|
| 34 |
+
|
| 35 |
+
return call
|
| 36 |
+
|
| 37 |
+
return decorator
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def dataset_map(iterator, output_types, output_shapes=None, num_parallel_calls=None, name=None, shuffle=False):
|
| 41 |
+
dataset = tf.data.Dataset.range(len(iterator))
|
| 42 |
+
if shuffle:
|
| 43 |
+
dataset = dataset.shuffle(len(iterator), reshuffle_each_iteration=True)
|
| 44 |
+
|
| 45 |
+
@py_func_decorator(output_types, output_shapes, name=name)
|
| 46 |
+
def index_to_entry(idx):
|
| 47 |
+
return iterator[idx]
|
| 48 |
+
|
| 49 |
+
return dataset.map(index_to_entry, num_parallel_calls=num_parallel_calls)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def normalize(data, axis=(0,)):
|
| 53 |
+
"""data shape: (nt, nsta, nch)"""
|
| 54 |
+
data -= np.mean(data, axis=axis, keepdims=True)
|
| 55 |
+
std_data = np.std(data, axis=axis, keepdims=True)
|
| 56 |
+
std_data[std_data == 0] = 1
|
| 57 |
+
data /= std_data
|
| 58 |
+
# data /= (std_data + 1e-12)
|
| 59 |
+
return data
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def normalize_long(data, axis=(0,), window=3000):
|
| 63 |
+
"""
|
| 64 |
+
data: nt, nch
|
| 65 |
+
"""
|
| 66 |
+
nt, nar, nch = data.shape
|
| 67 |
+
if window is None:
|
| 68 |
+
window = nt
|
| 69 |
+
shift = window // 2
|
| 70 |
+
|
| 71 |
+
## std in slide windows
|
| 72 |
+
data_pad = np.pad(data, ((window // 2, window // 2), (0, 0), (0, 0)), mode="reflect")
|
| 73 |
+
t = np.arange(0, nt, shift, dtype="int")
|
| 74 |
+
std = np.zeros([len(t) + 1, nar, nch])
|
| 75 |
+
mean = np.zeros([len(t) + 1, nar, nch])
|
| 76 |
+
for i in range(1, len(std)):
|
| 77 |
+
std[i, :] = np.std(data_pad[i * shift : i * shift + window, :, :], axis=axis)
|
| 78 |
+
mean[i, :] = np.mean(data_pad[i * shift : i * shift + window, :, :], axis=axis)
|
| 79 |
+
|
| 80 |
+
t = np.append(t, nt)
|
| 81 |
+
# std[-1, :] = np.std(data_pad[-window:, :], axis=0)
|
| 82 |
+
# mean[-1, :] = np.mean(data_pad[-window:, :], axis=0)
|
| 83 |
+
std[-1, ...], mean[-1, ...] = std[-2, ...], mean[-2, ...]
|
| 84 |
+
std[0, ...], mean[0, ...] = std[1, ...], mean[1, ...]
|
| 85 |
+
# std[std == 0] = 1.0
|
| 86 |
+
|
| 87 |
+
## normalize data with interplated std
|
| 88 |
+
t_interp = np.arange(nt, dtype="int")
|
| 89 |
+
std_interp = interp1d(t, std, axis=0, kind="slinear")(t_interp)
|
| 90 |
+
# std_interp = np.exp(interp1d(t, np.log(std), axis=0, kind="slinear")(t_interp))
|
| 91 |
+
mean_interp = interp1d(t, mean, axis=0, kind="slinear")(t_interp)
|
| 92 |
+
tmp = np.sum(std_interp, axis=(0, 1))
|
| 93 |
+
std_interp[std_interp == 0] = 1.0
|
| 94 |
+
data = (data - mean_interp) / std_interp
|
| 95 |
+
# data = (data - mean_interp)/(std_interp + 1e-12)
|
| 96 |
+
|
| 97 |
+
### dropout effect of < 3 channel
|
| 98 |
+
nonzero = np.count_nonzero(tmp)
|
| 99 |
+
if (nonzero < 3) and (nonzero > 0):
|
| 100 |
+
data *= 3.0 / nonzero
|
| 101 |
+
|
| 102 |
+
return data
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def normalize_batch(data, window=3000):
|
| 106 |
+
"""
|
| 107 |
+
data: nsta, nt, nch
|
| 108 |
+
"""
|
| 109 |
+
nsta, nt, nar, nch = data.shape
|
| 110 |
+
if window is None:
|
| 111 |
+
window = nt
|
| 112 |
+
shift = window // 2
|
| 113 |
+
|
| 114 |
+
## std in slide windows
|
| 115 |
+
data_pad = np.pad(data, ((0, 0), (window // 2, window // 2), (0, 0), (0, 0)), mode="reflect")
|
| 116 |
+
t = np.arange(0, nt, shift, dtype="int")
|
| 117 |
+
std = np.zeros([nsta, len(t) + 1, nar, nch])
|
| 118 |
+
mean = np.zeros([nsta, len(t) + 1, nar, nch])
|
| 119 |
+
for i in range(1, len(t)):
|
| 120 |
+
std[:, i, :, :] = np.std(data_pad[:, i * shift : i * shift + window, :, :], axis=1)
|
| 121 |
+
mean[:, i, :, :] = np.mean(data_pad[:, i * shift : i * shift + window, :, :], axis=1)
|
| 122 |
+
|
| 123 |
+
t = np.append(t, nt)
|
| 124 |
+
# std[:, -1, :] = np.std(data_pad[:, -window:, :], axis=1)
|
| 125 |
+
# mean[:, -1, :] = np.mean(data_pad[:, -window:, :], axis=1)
|
| 126 |
+
std[:, -1, :, :], mean[:, -1, :, :] = std[:, -2, :, :], mean[:, -2, :, :]
|
| 127 |
+
std[:, 0, :, :], mean[:, 0, :, :] = std[:, 1, :, :], mean[:, 1, :, :]
|
| 128 |
+
# std[std == 0] = 1
|
| 129 |
+
|
| 130 |
+
# ## normalize data with interplated std
|
| 131 |
+
t_interp = np.arange(nt, dtype="int")
|
| 132 |
+
std_interp = interp1d(t, std, axis=1, kind="slinear")(t_interp)
|
| 133 |
+
# std_interp = np.exp(interp1d(t, np.log(std), axis=1, kind="slinear")(t_interp))
|
| 134 |
+
mean_interp = interp1d(t, mean, axis=1, kind="slinear")(t_interp)
|
| 135 |
+
tmp = np.sum(std_interp, axis=(1, 2))
|
| 136 |
+
std_interp[std_interp == 0] = 1.0
|
| 137 |
+
data = (data - mean_interp) / std_interp
|
| 138 |
+
# data = (data - mean_interp)/(std_interp + 1e-12)
|
| 139 |
+
|
| 140 |
+
### dropout effect of < 3 channel
|
| 141 |
+
nonzero = np.count_nonzero(tmp, axis=-1)
|
| 142 |
+
data[nonzero > 0, ...] *= 3.0 / nonzero[nonzero > 0][:, np.newaxis, np.newaxis, np.newaxis]
|
| 143 |
+
|
| 144 |
+
return data
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class DataConfig:
|
| 148 |
+
|
| 149 |
+
seed = 123
|
| 150 |
+
use_seed = True
|
| 151 |
+
n_channel = 3
|
| 152 |
+
n_class = 3
|
| 153 |
+
sampling_rate = 100
|
| 154 |
+
dt = 1.0 / sampling_rate
|
| 155 |
+
X_shape = [3000, 1, n_channel]
|
| 156 |
+
Y_shape = [3000, 1, n_class]
|
| 157 |
+
min_event_gap = 3 * sampling_rate
|
| 158 |
+
label_shape = "gaussian"
|
| 159 |
+
label_width = 30
|
| 160 |
+
dtype = "float32"
|
| 161 |
+
|
| 162 |
+
def __init__(self, **kwargs):
|
| 163 |
+
for k, v in kwargs.items():
|
| 164 |
+
setattr(self, k, v)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class DataReader:
|
| 168 |
+
def __init__(self, format="numpy", config=DataConfig(), **kwargs):
|
| 169 |
+
self.buffer = {}
|
| 170 |
+
self.n_channel = config.n_channel
|
| 171 |
+
self.n_class = config.n_class
|
| 172 |
+
self.X_shape = config.X_shape
|
| 173 |
+
self.Y_shape = config.Y_shape
|
| 174 |
+
self.dt = config.dt
|
| 175 |
+
self.dtype = config.dtype
|
| 176 |
+
self.label_shape = config.label_shape
|
| 177 |
+
self.label_width = config.label_width
|
| 178 |
+
self.config = config
|
| 179 |
+
self.format = format
|
| 180 |
+
if "highpass_filter" in kwargs:
|
| 181 |
+
self.highpass_filter = kwargs["highpass_filter"]
|
| 182 |
+
if format in ["numpy", "mseed", "sac"]:
|
| 183 |
+
self.data_dir = kwargs["data_dir"]
|
| 184 |
+
try:
|
| 185 |
+
csv = pd.read_csv(kwargs["data_list"], header=0, sep="[,|\s+]", engine="python")
|
| 186 |
+
except:
|
| 187 |
+
csv = pd.read_csv(kwargs["data_list"], header=0, sep="\t")
|
| 188 |
+
self.data_list = csv["fname"]
|
| 189 |
+
self.num_data = len(self.data_list)
|
| 190 |
+
elif format == "hdf5":
|
| 191 |
+
self.h5 = h5py.File(kwargs["hdf5_file"], "r", libver="latest", swmr=True)
|
| 192 |
+
self.h5_data = self.h5[kwargs["hdf5_group"]]
|
| 193 |
+
self.data_list = list(self.h5_data.keys())
|
| 194 |
+
self.num_data = len(self.data_list)
|
| 195 |
+
elif format == "s3":
|
| 196 |
+
self.s3fs = s3fs.S3FileSystem(
|
| 197 |
+
anon=kwargs["anon"],
|
| 198 |
+
key=kwargs["key"],
|
| 199 |
+
secret=kwargs["secret"],
|
| 200 |
+
client_kwargs={"endpoint_url": kwargs["s3_url"]},
|
| 201 |
+
use_ssl=kwargs["use_ssl"],
|
| 202 |
+
)
|
| 203 |
+
self.num_data = 0
|
| 204 |
+
else:
|
| 205 |
+
raise (f"{format} not support!")
|
| 206 |
+
|
| 207 |
+
def __len__(self):
|
| 208 |
+
return self.num_data
|
| 209 |
+
|
| 210 |
+
def read_numpy(self, fname):
|
| 211 |
+
# try:
|
| 212 |
+
if fname not in self.buffer:
|
| 213 |
+
npz = np.load(fname)
|
| 214 |
+
meta = {}
|
| 215 |
+
if len(npz["data"].shape) == 2:
|
| 216 |
+
meta["data"] = npz["data"][:, np.newaxis, :]
|
| 217 |
+
else:
|
| 218 |
+
meta["data"] = npz["data"]
|
| 219 |
+
if "p_idx" in npz.files:
|
| 220 |
+
if len(npz["p_idx"].shape) == 0:
|
| 221 |
+
meta["itp"] = [[npz["p_idx"]]]
|
| 222 |
+
else:
|
| 223 |
+
meta["itp"] = npz["p_idx"]
|
| 224 |
+
if "s_idx" in npz.files:
|
| 225 |
+
if len(npz["s_idx"].shape) == 0:
|
| 226 |
+
meta["its"] = [[npz["s_idx"]]]
|
| 227 |
+
else:
|
| 228 |
+
meta["its"] = npz["s_idx"]
|
| 229 |
+
if "itp" in npz.files:
|
| 230 |
+
if len(npz["itp"].shape) == 0:
|
| 231 |
+
meta["itp"] = [[npz["itp"]]]
|
| 232 |
+
else:
|
| 233 |
+
meta["itp"] = npz["itp"]
|
| 234 |
+
if "its" in npz.files:
|
| 235 |
+
if len(npz["its"].shape) == 0:
|
| 236 |
+
meta["its"] = [[npz["its"]]]
|
| 237 |
+
else:
|
| 238 |
+
meta["its"] = npz["its"]
|
| 239 |
+
if "station_id" in npz.files:
|
| 240 |
+
meta["station_id"] = npz["station_id"]
|
| 241 |
+
if "sta_id" in npz.files:
|
| 242 |
+
meta["station_id"] = npz["sta_id"]
|
| 243 |
+
if "t0" in npz.files:
|
| 244 |
+
meta["t0"] = npz["t0"]
|
| 245 |
+
self.buffer[fname] = meta
|
| 246 |
+
else:
|
| 247 |
+
meta = self.buffer[fname]
|
| 248 |
+
return meta
|
| 249 |
+
# except:
|
| 250 |
+
# logging.error("Failed reading {}".format(fname))
|
| 251 |
+
# return None
|
| 252 |
+
|
| 253 |
+
def read_hdf5(self, fname):
|
| 254 |
+
data = self.h5_data[fname][()]
|
| 255 |
+
attrs = self.h5_data[fname].attrs
|
| 256 |
+
meta = {}
|
| 257 |
+
if len(data.shape) == 2:
|
| 258 |
+
meta["data"] = data[:, np.newaxis, :]
|
| 259 |
+
else:
|
| 260 |
+
meta["data"] = data
|
| 261 |
+
if "p_idx" in attrs:
|
| 262 |
+
if len(attrs["p_idx"].shape) == 0:
|
| 263 |
+
meta["itp"] = [[attrs["p_idx"]]]
|
| 264 |
+
else:
|
| 265 |
+
meta["itp"] = attrs["p_idx"]
|
| 266 |
+
if "s_idx" in attrs:
|
| 267 |
+
if len(attrs["s_idx"].shape) == 0:
|
| 268 |
+
meta["its"] = [[attrs["s_idx"]]]
|
| 269 |
+
else:
|
| 270 |
+
meta["its"] = attrs["s_idx"]
|
| 271 |
+
if "itp" in attrs:
|
| 272 |
+
if len(attrs["itp"].shape) == 0:
|
| 273 |
+
meta["itp"] = [[attrs["itp"]]]
|
| 274 |
+
else:
|
| 275 |
+
meta["itp"] = attrs["itp"]
|
| 276 |
+
if "its" in attrs:
|
| 277 |
+
if len(attrs["its"].shape) == 0:
|
| 278 |
+
meta["its"] = [[attrs["its"]]]
|
| 279 |
+
else:
|
| 280 |
+
meta["its"] = attrs["its"]
|
| 281 |
+
if "t0" in attrs:
|
| 282 |
+
meta["t0"] = attrs["t0"]
|
| 283 |
+
return meta
|
| 284 |
+
|
| 285 |
+
def read_s3(self, format, fname, bucket, key, secret, s3_url, use_ssl):
|
| 286 |
+
with self.s3fs.open(bucket + "/" + fname, "rb") as fp:
|
| 287 |
+
if format == "numpy":
|
| 288 |
+
meta = self.read_numpy(fp)
|
| 289 |
+
elif format == "mseed":
|
| 290 |
+
meta = self.read_mseed(fp)
|
| 291 |
+
else:
|
| 292 |
+
raise (f"Format {format} not supported")
|
| 293 |
+
return meta
|
| 294 |
+
|
| 295 |
+
def read_mseed(self, fname):
|
| 296 |
+
|
| 297 |
+
mseed = obspy.read(fname)
|
| 298 |
+
mseed = mseed.detrend("spline", order=2, dspline=5 * mseed[0].stats.sampling_rate)
|
| 299 |
+
mseed = mseed.merge(fill_value=0)
|
| 300 |
+
if self.highpass_filter > 0:
|
| 301 |
+
mseed = mseed.filter("highpass", freq=self.highpass_filter)
|
| 302 |
+
starttime = min([st.stats.starttime for st in mseed])
|
| 303 |
+
endtime = max([st.stats.endtime for st in mseed])
|
| 304 |
+
mseed = mseed.trim(starttime, endtime, pad=True, fill_value=0)
|
| 305 |
+
if abs(mseed[0].stats.sampling_rate - self.config.sampling_rate) > 1:
|
| 306 |
+
logging.warning(
|
| 307 |
+
f"Sampling rate mismatch in {fname.split('/')[-1]}: {mseed[0].stats.sampling_rate}Hz != {self.config.sampling_rate}Hz "
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
order = ["3", "2", "1", "E", "N", "Z"]
|
| 311 |
+
order = {key: i for i, key in enumerate(order)}
|
| 312 |
+
comp2idx = {"3": 0, "2": 1, "1": 2, "E": 0, "N": 1, "Z": 2}
|
| 313 |
+
|
| 314 |
+
t0 = starttime.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
|
| 315 |
+
nt = len(mseed[0].data)
|
| 316 |
+
data = np.zeros([nt, self.config.n_channel], dtype=self.dtype)
|
| 317 |
+
ids = [x.get_id() for x in mseed]
|
| 318 |
+
|
| 319 |
+
for j, id in enumerate(sorted(ids, key=lambda x: order[x[-1]])):
|
| 320 |
+
if len(ids) != 3:
|
| 321 |
+
if len(ids) > 3:
|
| 322 |
+
logging.warning(f"More than 3 channels {ids}!")
|
| 323 |
+
j = comp2idx[id[-1]]
|
| 324 |
+
data[:, j] = mseed.select(id=id)[0].data.astype(self.dtype)
|
| 325 |
+
|
| 326 |
+
data = data[:, np.newaxis, :]
|
| 327 |
+
meta = {"data": data, "t0": t0}
|
| 328 |
+
return meta
|
| 329 |
+
|
| 330 |
+
def read_sac(self, fname):
|
| 331 |
+
|
| 332 |
+
mseed = obspy.read(fname)
|
| 333 |
+
mseed = mseed.detrend("spline", order=2, dspline=5 * mseed[0].stats.sampling_rate)
|
| 334 |
+
mseed = mseed.merge(fill_value=0)
|
| 335 |
+
if self.highpass_filter > 0:
|
| 336 |
+
mseed = mseed.filter("highpass", freq=self.highpass_filter)
|
| 337 |
+
starttime = min([st.stats.starttime for st in mseed])
|
| 338 |
+
endtime = max([st.stats.endtime for st in mseed])
|
| 339 |
+
mseed = mseed.trim(starttime, endtime, pad=True, fill_value=0)
|
| 340 |
+
if abs(mseed[0].stats.sampling_rate - self.config.sampling_rate) > 1:
|
| 341 |
+
logging.warning(
|
| 342 |
+
f"Sampling rate mismatch in {fname.split('/')[-1]}: {mseed[0].stats.sampling_rate}Hz != {self.config.sampling_rate}Hz "
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
order = ["3", "2", "1", "E", "N", "Z"]
|
| 346 |
+
order = {key: i for i, key in enumerate(order)}
|
| 347 |
+
comp2idx = {"3": 0, "2": 1, "1": 2, "E": 0, "N": 1, "Z": 2}
|
| 348 |
+
|
| 349 |
+
t0 = starttime.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
|
| 350 |
+
nt = len(mseed[0].data)
|
| 351 |
+
data = np.zeros([nt, self.config.n_channel], dtype=self.dtype)
|
| 352 |
+
ids = [x.get_id() for x in mseed]
|
| 353 |
+
for j, id in enumerate(sorted(ids, key=lambda x: order[x[-1]])):
|
| 354 |
+
if len(ids) != 3:
|
| 355 |
+
if len(ids) > 3:
|
| 356 |
+
logging.warning(f"More than 3 channels {ids}!")
|
| 357 |
+
j = comp2idx[id[-1]]
|
| 358 |
+
data[:, j] = mseed.select(id=id)[0].data.astype(self.dtype)
|
| 359 |
+
|
| 360 |
+
data = data[:, np.newaxis, :]
|
| 361 |
+
meta = {"data": data, "t0": t0}
|
| 362 |
+
return meta
|
| 363 |
+
|
| 364 |
+
def read_mseed_array(self, fname, stations, amplitude=False, remove_resp=True):
|
| 365 |
+
|
| 366 |
+
data = []
|
| 367 |
+
station_id = []
|
| 368 |
+
t0 = []
|
| 369 |
+
raw_amp = []
|
| 370 |
+
|
| 371 |
+
try:
|
| 372 |
+
mseed = obspy.read(fname)
|
| 373 |
+
read_success = True
|
| 374 |
+
except Exception as e:
|
| 375 |
+
read_success = False
|
| 376 |
+
print(e)
|
| 377 |
+
|
| 378 |
+
if read_success:
|
| 379 |
+
try:
|
| 380 |
+
mseed = mseed.merge(fill_value=0)
|
| 381 |
+
except Exception as e:
|
| 382 |
+
print(e)
|
| 383 |
+
|
| 384 |
+
for i in range(len(mseed)):
|
| 385 |
+
if mseed[i].stats.sampling_rate != self.config.sampling_rate:
|
| 386 |
+
logging.warning(
|
| 387 |
+
f"Resampling {mseed[i].id} from {mseed[i].stats.sampling_rate} to {self.config.sampling_rate} Hz"
|
| 388 |
+
)
|
| 389 |
+
try:
|
| 390 |
+
mseed[i] = mseed[i].interpolate(self.config.sampling_rate, method="linear")
|
| 391 |
+
except Exception as e:
|
| 392 |
+
print(e)
|
| 393 |
+
mseed[i].data = mseed[i].data.astype(float) * 0.0 ## set to zero if resampling fails
|
| 394 |
+
|
| 395 |
+
if self.highpass_filter == 0:
|
| 396 |
+
try:
|
| 397 |
+
mseed = mseed.detrend("spline", order=2, dspline=5 * mseed[0].stats.sampling_rate)
|
| 398 |
+
except:
|
| 399 |
+
logging.error(f"Error: spline detrend failed at file {fname}")
|
| 400 |
+
mseed = mseed.detrend("demean")
|
| 401 |
+
else:
|
| 402 |
+
mseed = mseed.filter("highpass", freq=self.highpass_filter)
|
| 403 |
+
|
| 404 |
+
starttime = min([st.stats.starttime for st in mseed])
|
| 405 |
+
endtime = max([st.stats.endtime for st in mseed])
|
| 406 |
+
mseed = mseed.trim(starttime, endtime, pad=True, fill_value=0)
|
| 407 |
+
|
| 408 |
+
order = ["3", "2", "1", "E", "N", "Z"]
|
| 409 |
+
order = {key: i for i, key in enumerate(order)}
|
| 410 |
+
comp2idx = {"3": 0, "2": 1, "1": 2, "E": 0, "N": 1, "Z": 2}
|
| 411 |
+
|
| 412 |
+
nsta = len(stations)
|
| 413 |
+
nt = len(mseed[0].data)
|
| 414 |
+
# for i in range(nsta):
|
| 415 |
+
for sta in stations:
|
| 416 |
+
trace_data = np.zeros([nt, self.config.n_channel], dtype=self.dtype)
|
| 417 |
+
if amplitude:
|
| 418 |
+
trace_amp = np.zeros([nt, self.config.n_channel], dtype=self.dtype)
|
| 419 |
+
empty_station = True
|
| 420 |
+
# sta = stations.iloc[i]["station"]
|
| 421 |
+
# comp = stations.iloc[i]["component"].split(",")
|
| 422 |
+
comp = stations[sta]["component"]
|
| 423 |
+
if amplitude:
|
| 424 |
+
# resp = stations.iloc[i]["response"].split(",")
|
| 425 |
+
resp = stations[sta]["response"]
|
| 426 |
+
|
| 427 |
+
for j, c in enumerate(sorted(comp, key=lambda x: order[x[-1]])):
|
| 428 |
+
|
| 429 |
+
resp_j = resp[j]
|
| 430 |
+
if len(comp) != 3: ## less than 3 component
|
| 431 |
+
j = comp2idx[c]
|
| 432 |
+
|
| 433 |
+
if len(mseed.select(id=sta + c)) == 0:
|
| 434 |
+
print(f"Empty trace: {sta+c} {starttime}")
|
| 435 |
+
continue
|
| 436 |
+
else:
|
| 437 |
+
empty_station = False
|
| 438 |
+
|
| 439 |
+
tmp = mseed.select(id=sta + c)[0].data.astype(self.dtype)
|
| 440 |
+
trace_data[: len(tmp), j] = tmp[:nt]
|
| 441 |
+
if amplitude:
|
| 442 |
+
# if stations.iloc[i]["unit"] == "m/s**2":
|
| 443 |
+
if stations[sta]["unit"] == "m/s**2":
|
| 444 |
+
tmp = mseed.select(id=sta + c)[0]
|
| 445 |
+
tmp = tmp.integrate()
|
| 446 |
+
tmp = tmp.filter("highpass", freq=1.0)
|
| 447 |
+
tmp = tmp.data.astype(self.dtype)
|
| 448 |
+
trace_amp[: len(tmp), j] = tmp[:nt]
|
| 449 |
+
# elif stations.iloc[i]["unit"] == "m/s":
|
| 450 |
+
elif stations[sta]["unit"] == "m/s":
|
| 451 |
+
tmp = mseed.select(id=sta + c)[0].data.astype(self.dtype)
|
| 452 |
+
trace_amp[: len(tmp), j] = tmp[:nt]
|
| 453 |
+
else:
|
| 454 |
+
print(
|
| 455 |
+
f"Error in {stations.iloc[i]['station']}\n{stations.iloc[i]['unit']} should be m/s**2 or m/s!"
|
| 456 |
+
)
|
| 457 |
+
if amplitude and remove_resp:
|
| 458 |
+
# trace_amp[:, j] /= float(resp[j])
|
| 459 |
+
trace_amp[:, j] /= float(resp_j)
|
| 460 |
+
|
| 461 |
+
if not empty_station:
|
| 462 |
+
data.append(trace_data)
|
| 463 |
+
if amplitude:
|
| 464 |
+
raw_amp.append(trace_amp)
|
| 465 |
+
station_id.append(sta)
|
| 466 |
+
t0.append(starttime.datetime.isoformat(timespec="milliseconds"))
|
| 467 |
+
|
| 468 |
+
if len(data) > 0:
|
| 469 |
+
data = np.stack(data)
|
| 470 |
+
if len(data.shape) == 3:
|
| 471 |
+
data = data[:, :, np.newaxis, :]
|
| 472 |
+
if amplitude:
|
| 473 |
+
raw_amp = np.stack(raw_amp)
|
| 474 |
+
if len(raw_amp.shape) == 3:
|
| 475 |
+
raw_amp = raw_amp[:, :, np.newaxis, :]
|
| 476 |
+
else:
|
| 477 |
+
nt = 60 * 60 * self.config.sampling_rate # assume 1 hour data
|
| 478 |
+
data = np.zeros([1, nt, 1, self.config.n_channel], dtype=self.dtype)
|
| 479 |
+
if amplitude:
|
| 480 |
+
raw_amp = np.zeros([1, nt, 1, self.config.n_channel], dtype=self.dtype)
|
| 481 |
+
t0 = ["1970-01-01T00:00:00.000"]
|
| 482 |
+
station_id = ["None"]
|
| 483 |
+
|
| 484 |
+
if amplitude:
|
| 485 |
+
meta = {"data": data, "t0": t0, "station_id": station_id, "fname": fname.split("/")[-1], "raw_amp": raw_amp}
|
| 486 |
+
else:
|
| 487 |
+
meta = {"data": data, "t0": t0, "station_id": station_id, "fname": fname.split("/")[-1]}
|
| 488 |
+
return meta
|
| 489 |
+
|
| 490 |
+
def generate_label(self, data, phase_list, mask=None):
|
| 491 |
+
# target = np.zeros(self.Y_shape, dtype=self.dtype)
|
| 492 |
+
target = np.zeros_like(data)
|
| 493 |
+
|
| 494 |
+
if self.label_shape == "gaussian":
|
| 495 |
+
label_window = np.exp(
|
| 496 |
+
-((np.arange(-self.label_width // 2, self.label_width // 2 + 1)) ** 2)
|
| 497 |
+
/ (2 * (self.label_width / 5) ** 2)
|
| 498 |
+
)
|
| 499 |
+
elif self.label_shape == "triangle":
|
| 500 |
+
label_window = 1 - np.abs(
|
| 501 |
+
2 / self.label_width * (np.arange(-self.label_width // 2, self.label_width // 2 + 1))
|
| 502 |
+
)
|
| 503 |
+
else:
|
| 504 |
+
print(f"Label shape {self.label_shape} should be guassian or triangle")
|
| 505 |
+
raise
|
| 506 |
+
|
| 507 |
+
for i, phases in enumerate(phase_list):
|
| 508 |
+
for j, idx_list in enumerate(phases):
|
| 509 |
+
for idx in idx_list:
|
| 510 |
+
if np.isnan(idx):
|
| 511 |
+
continue
|
| 512 |
+
idx = int(idx)
|
| 513 |
+
if (idx - self.label_width // 2 >= 0) and (idx + self.label_width // 2 + 1 <= target.shape[0]):
|
| 514 |
+
target[idx - self.label_width // 2 : idx + self.label_width // 2 + 1, j, i + 1] = label_window
|
| 515 |
+
|
| 516 |
+
target[..., 0] = 1 - np.sum(target[..., 1:], axis=-1)
|
| 517 |
+
if mask is not None:
|
| 518 |
+
target[:, mask == 0, :] = 0
|
| 519 |
+
|
| 520 |
+
return target
|
| 521 |
+
|
| 522 |
+
def random_shift(self, sample, itp, its, itp_old=None, its_old=None, shift_range=None):
|
| 523 |
+
# anchor = np.round(1/2 * (min(itp[~np.isnan(itp.astype(float))]) + min(its[~np.isnan(its.astype(float))]))).astype(int)
|
| 524 |
+
flattern = lambda x: np.array([i for trace in x for i in trace], dtype=float)
|
| 525 |
+
shift_pick = lambda x, shift: [[i - shift for i in trace] for trace in x]
|
| 526 |
+
itp_flat = flattern(itp)
|
| 527 |
+
its_flat = flattern(its)
|
| 528 |
+
if (itp_old is None) and (its_old is None):
|
| 529 |
+
hi = np.round(np.median(itp_flat[~np.isnan(itp_flat)])).astype(int)
|
| 530 |
+
lo = -(sample.shape[0] - np.round(np.median(its_flat[~np.isnan(its_flat)])).astype(int))
|
| 531 |
+
if shift_range is None:
|
| 532 |
+
shift = np.random.randint(low=lo, high=hi + 1)
|
| 533 |
+
else:
|
| 534 |
+
shift = np.random.randint(low=max(lo, shift_range[0]), high=min(hi + 1, shift_range[1]))
|
| 535 |
+
else:
|
| 536 |
+
itp_old_flat = flattern(itp_old)
|
| 537 |
+
its_old_flat = flattern(its_old)
|
| 538 |
+
itp_ref = np.round(np.min(itp_flat[~np.isnan(itp_flat)])).astype(int)
|
| 539 |
+
its_ref = np.round(np.max(its_flat[~np.isnan(its_flat)])).astype(int)
|
| 540 |
+
itp_old_ref = np.round(np.min(itp_old_flat[~np.isnan(itp_old_flat)])).astype(int)
|
| 541 |
+
its_old_ref = np.round(np.max(its_old_flat[~np.isnan(its_old_flat)])).astype(int)
|
| 542 |
+
# min_event_gap = np.round(self.min_event_gap*(its_ref-itp_ref)).astype(int)
|
| 543 |
+
# min_event_gap_old = np.round(self.min_event_gap*(its_old_ref-itp_old_ref)).astype(int)
|
| 544 |
+
if shift_range is None:
|
| 545 |
+
hi = list(range(max(its_ref - itp_old_ref + self.min_event_gap, 0), itp_ref))
|
| 546 |
+
lo = list(range(-(sample.shape[0] - its_ref), -(max(its_old_ref - itp_ref + self.min_event_gap, 0))))
|
| 547 |
+
else:
|
| 548 |
+
lo_ = max(-(sample.shape[0] - its_ref), shift_range[0])
|
| 549 |
+
hi_ = min(itp_ref, shift_range[1])
|
| 550 |
+
hi = list(range(max(its_ref - itp_old_ref + self.min_event_gap, 0), hi_))
|
| 551 |
+
lo = list(range(lo_, -(max(its_old_ref - itp_ref + self.min_event_gap, 0))))
|
| 552 |
+
if len(hi + lo) > 0:
|
| 553 |
+
shift = np.random.choice(hi + lo)
|
| 554 |
+
else:
|
| 555 |
+
shift = 0
|
| 556 |
+
|
| 557 |
+
shifted_sample = np.zeros_like(sample)
|
| 558 |
+
if shift > 0:
|
| 559 |
+
shifted_sample[:-shift, ...] = sample[shift:, ...]
|
| 560 |
+
elif shift < 0:
|
| 561 |
+
shifted_sample[-shift:, ...] = sample[:shift, ...]
|
| 562 |
+
else:
|
| 563 |
+
shifted_sample[...] = sample[...]
|
| 564 |
+
|
| 565 |
+
return shifted_sample, shift_pick(itp, shift), shift_pick(its, shift), shift
|
| 566 |
+
|
| 567 |
+
def stack_events(self, sample_old, itp_old, its_old, shift_range=None, mask_old=None):
|
| 568 |
+
|
| 569 |
+
i = np.random.randint(self.num_data)
|
| 570 |
+
base_name = self.data_list[i]
|
| 571 |
+
if self.format == "numpy":
|
| 572 |
+
meta = self.read_numpy(os.path.join(self.data_dir, base_name))
|
| 573 |
+
elif self.format == "hdf5":
|
| 574 |
+
meta = self.read_hdf5(base_name)
|
| 575 |
+
if meta == -1:
|
| 576 |
+
return sample_old, itp_old, its_old
|
| 577 |
+
|
| 578 |
+
sample = np.copy(meta["data"])
|
| 579 |
+
itp = meta["itp"]
|
| 580 |
+
its = meta["its"]
|
| 581 |
+
if mask_old is not None:
|
| 582 |
+
mask = np.copy(meta["mask"])
|
| 583 |
+
sample = normalize(sample)
|
| 584 |
+
sample, itp, its, shift = self.random_shift(sample, itp, its, itp_old, its_old, shift_range)
|
| 585 |
+
|
| 586 |
+
if shift != 0:
|
| 587 |
+
sample_old += sample
|
| 588 |
+
# itp_old = [np.hstack([i, j]) for i,j in zip(itp_old, itp)]
|
| 589 |
+
# its_old = [np.hstack([i, j]) for i,j in zip(its_old, its)]
|
| 590 |
+
itp_old = [i + j for i, j in zip(itp_old, itp)]
|
| 591 |
+
its_old = [i + j for i, j in zip(its_old, its)]
|
| 592 |
+
if mask_old is not None:
|
| 593 |
+
mask_old = mask_old * mask
|
| 594 |
+
|
| 595 |
+
return sample_old, itp_old, its_old, mask_old
|
| 596 |
+
|
| 597 |
+
def cut_window(self, sample, target, itp, its, select_range):
|
| 598 |
+
shift_pick = lambda x, shift: [[i - shift for i in trace] for trace in x]
|
| 599 |
+
sample = sample[select_range[0] : select_range[1]]
|
| 600 |
+
target = target[select_range[0] : select_range[1]]
|
| 601 |
+
return (sample, target, shift_pick(itp, select_range[0]), shift_pick(its, select_range[0]))
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
class DataReader_train(DataReader):
|
| 605 |
+
def __init__(self, format="numpy", config=DataConfig(), **kwargs):
|
| 606 |
+
|
| 607 |
+
super().__init__(format=format, config=config, **kwargs)
|
| 608 |
+
|
| 609 |
+
self.min_event_gap = config.min_event_gap
|
| 610 |
+
self.buffer_channels = {}
|
| 611 |
+
self.shift_range = [-2000 + self.label_width * 2, 1000 - self.label_width * 2]
|
| 612 |
+
self.select_range = [5000, 8000]
|
| 613 |
+
|
| 614 |
+
def __getitem__(self, i):
|
| 615 |
+
|
| 616 |
+
base_name = self.data_list[i]
|
| 617 |
+
if self.format == "numpy":
|
| 618 |
+
meta = self.read_numpy(os.path.join(self.data_dir, base_name))
|
| 619 |
+
elif self.format == "hdf5":
|
| 620 |
+
meta = self.read_hdf5(base_name)
|
| 621 |
+
if meta == None:
|
| 622 |
+
return (np.zeros(self.X_shape, dtype=self.dtype), np.zeros(self.Y_shape, dtype=self.dtype), base_name)
|
| 623 |
+
|
| 624 |
+
sample = np.copy(meta["data"])
|
| 625 |
+
itp_list = meta["itp"]
|
| 626 |
+
its_list = meta["its"]
|
| 627 |
+
|
| 628 |
+
sample = normalize(sample)
|
| 629 |
+
if np.random.random() < 0.95:
|
| 630 |
+
sample, itp_list, its_list, _ = self.random_shift(sample, itp_list, its_list, shift_range=self.shift_range)
|
| 631 |
+
sample, itp_list, its_list, _ = self.stack_events(sample, itp_list, its_list, shift_range=self.shift_range)
|
| 632 |
+
target = self.generate_label(sample, [itp_list, its_list])
|
| 633 |
+
sample, target, itp_list, its_list = self.cut_window(sample, target, itp_list, its_list, self.select_range)
|
| 634 |
+
else:
|
| 635 |
+
## noise
|
| 636 |
+
assert self.X_shape[0] <= min(min(itp_list))
|
| 637 |
+
sample = sample[: self.X_shape[0], ...]
|
| 638 |
+
target = np.zeros(self.Y_shape).astype(self.dtype)
|
| 639 |
+
itp_list = [[]]
|
| 640 |
+
its_list = [[]]
|
| 641 |
+
|
| 642 |
+
sample = normalize(sample)
|
| 643 |
+
return (sample.astype(self.dtype), target.astype(self.dtype), base_name)
|
| 644 |
+
|
| 645 |
+
def dataset(self, batch_size, num_parallel_calls=2, shuffle=True, drop_remainder=True):
|
| 646 |
+
dataset = dataset_map(
|
| 647 |
+
self,
|
| 648 |
+
output_types=(self.dtype, self.dtype, "string"),
|
| 649 |
+
output_shapes=(self.X_shape, self.Y_shape, None),
|
| 650 |
+
num_parallel_calls=num_parallel_calls,
|
| 651 |
+
shuffle=shuffle,
|
| 652 |
+
)
|
| 653 |
+
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder).prefetch(batch_size * 2)
|
| 654 |
+
return dataset
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
class DataReader_test(DataReader):
|
| 658 |
+
def __init__(self, format="numpy", config=DataConfig(), **kwargs):
|
| 659 |
+
|
| 660 |
+
super().__init__(format=format, config=config, **kwargs)
|
| 661 |
+
|
| 662 |
+
self.select_range = [5000, 8000]
|
| 663 |
+
|
| 664 |
+
def __getitem__(self, i):
|
| 665 |
+
|
| 666 |
+
base_name = self.data_list[i]
|
| 667 |
+
if self.format == "numpy":
|
| 668 |
+
meta = self.read_numpy(os.path.join(self.data_dir, base_name))
|
| 669 |
+
elif self.format == "hdf5":
|
| 670 |
+
meta = self.read_hdf5(base_name)
|
| 671 |
+
if meta == -1:
|
| 672 |
+
return (np.zeros(self.Y_shape, dtype=self.dtype), np.zeros(self.X_shape, dtype=self.dtype), base_name)
|
| 673 |
+
|
| 674 |
+
sample = np.copy(meta["data"])
|
| 675 |
+
itp_list = meta["itp"]
|
| 676 |
+
its_list = meta["its"]
|
| 677 |
+
|
| 678 |
+
# sample, itp_list, its_list, _ = self.random_shift(sample, itp_list, its_list, shift_range=self.shift_range)
|
| 679 |
+
target = self.generate_label(sample, [itp_list, its_list])
|
| 680 |
+
sample, target, itp_list, its_list = self.cut_window(sample, target, itp_list, its_list, self.select_range)
|
| 681 |
+
|
| 682 |
+
sample = normalize(sample)
|
| 683 |
+
return (sample, target, base_name, itp_list, its_list)
|
| 684 |
+
|
| 685 |
+
def dataset(self, batch_size, num_parallel_calls=2, shuffle=False, drop_remainder=False):
|
| 686 |
+
dataset = dataset_map(
|
| 687 |
+
self,
|
| 688 |
+
output_types=(self.dtype, self.dtype, "string", "int64", "int64"),
|
| 689 |
+
output_shapes=(self.X_shape, self.Y_shape, None, None, None),
|
| 690 |
+
num_parallel_calls=num_parallel_calls,
|
| 691 |
+
shuffle=shuffle,
|
| 692 |
+
)
|
| 693 |
+
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder).prefetch(batch_size * 2)
|
| 694 |
+
return dataset
|
| 695 |
+
|
| 696 |
+
|
| 697 |
+
class DataReader_pred(DataReader):
|
| 698 |
+
def __init__(self, format="numpy", amplitude=True, config=DataConfig(), **kwargs):
|
| 699 |
+
|
| 700 |
+
super().__init__(format=format, config=config, **kwargs)
|
| 701 |
+
|
| 702 |
+
self.amplitude = amplitude
|
| 703 |
+
self.X_shape = self.get_data_shape()
|
| 704 |
+
|
| 705 |
+
def get_data_shape(self):
|
| 706 |
+
base_name = self.data_list[0]
|
| 707 |
+
if self.format == "numpy":
|
| 708 |
+
meta = self.read_numpy(os.path.join(self.data_dir, base_name))
|
| 709 |
+
elif self.format == "mseed":
|
| 710 |
+
meta = self.read_mseed(os.path.join(self.data_dir, base_name))
|
| 711 |
+
elif self.format == "sac":
|
| 712 |
+
meta = self.read_sac(os.path.join(self.data_dir, base_name))
|
| 713 |
+
elif self.format == "hdf5":
|
| 714 |
+
meta = self.read_hdf5(base_name)
|
| 715 |
+
return meta["data"].shape
|
| 716 |
+
|
| 717 |
+
def adjust_missingchannels(self, data):
|
| 718 |
+
tmp = np.max(np.abs(data), axis=0, keepdims=True)
|
| 719 |
+
assert tmp.shape[-1] == data.shape[-1]
|
| 720 |
+
if np.count_nonzero(tmp) > 0:
|
| 721 |
+
data *= data.shape[-1] / np.count_nonzero(tmp)
|
| 722 |
+
return data
|
| 723 |
+
|
| 724 |
+
def __getitem__(self, i):
|
| 725 |
+
|
| 726 |
+
base_name = self.data_list[i]
|
| 727 |
+
|
| 728 |
+
if self.format == "numpy":
|
| 729 |
+
meta = self.read_numpy(os.path.join(self.data_dir, base_name))
|
| 730 |
+
elif self.format == "mseed":
|
| 731 |
+
meta = self.read_mseed(os.path.join(self.data_dir, base_name))
|
| 732 |
+
elif self.format == "sac":
|
| 733 |
+
meta = self.read_sac(os.path.join(self.data_dir, base_name))
|
| 734 |
+
elif self.format == "hdf5":
|
| 735 |
+
meta = self.read_hdf5(base_name)
|
| 736 |
+
else:
|
| 737 |
+
raise (f"{self.format} does not support!")
|
| 738 |
+
if meta == -1:
|
| 739 |
+
return (np.zeros(self.X_shape, dtype=self.dtype), base_name)
|
| 740 |
+
|
| 741 |
+
raw_amp = np.zeros(self.X_shape, dtype=self.dtype)
|
| 742 |
+
raw_amp[: meta["data"].shape[0], ...] = meta["data"][: self.X_shape[0], ...]
|
| 743 |
+
sample = np.zeros(self.X_shape, dtype=self.dtype)
|
| 744 |
+
sample[: meta["data"].shape[0], ...] = normalize_long(meta["data"])[: self.X_shape[0], ...]
|
| 745 |
+
if abs(meta["data"].shape[0] - self.X_shape[0]) > 1:
|
| 746 |
+
logging.warning(f"Data length mismatch in {base_name}: {meta['data'].shape[0]} != {self.X_shape[0]}")
|
| 747 |
+
|
| 748 |
+
if "t0" in meta:
|
| 749 |
+
t0 = meta["t0"]
|
| 750 |
+
else:
|
| 751 |
+
t0 = "1970-01-01T00:00:00.000"
|
| 752 |
+
|
| 753 |
+
if "station_id" in meta:
|
| 754 |
+
station_id = meta["station_id"].split("/")[-1].rstrip("*")
|
| 755 |
+
else:
|
| 756 |
+
# station_id = base_name.split("/")[-1].rstrip("*")
|
| 757 |
+
station_id = os.path.basename(base_name).rstrip("*")
|
| 758 |
+
|
| 759 |
+
if np.isnan(sample).any() or np.isinf(sample).any():
|
| 760 |
+
logging.warning(f"Data error: Nan or Inf found in {base_name}")
|
| 761 |
+
sample[np.isnan(sample)] = 0
|
| 762 |
+
sample[np.isinf(sample)] = 0
|
| 763 |
+
|
| 764 |
+
# sample = self.adjust_missingchannels(sample)
|
| 765 |
+
if self.amplitude:
|
| 766 |
+
return (sample[: self.X_shape[0], ...], raw_amp[: self.X_shape[0], ...], base_name, t0, station_id)
|
| 767 |
+
else:
|
| 768 |
+
return (sample[: self.X_shape[0], ...], base_name, t0, station_id)
|
| 769 |
+
|
| 770 |
+
def dataset(self, batch_size, num_parallel_calls=2, shuffle=False, drop_remainder=False):
|
| 771 |
+
if self.amplitude:
|
| 772 |
+
dataset = dataset_map(
|
| 773 |
+
self,
|
| 774 |
+
output_types=(self.dtype, self.dtype, "string", "string", "string"),
|
| 775 |
+
output_shapes=(self.X_shape, self.X_shape, None, None, None),
|
| 776 |
+
num_parallel_calls=num_parallel_calls,
|
| 777 |
+
shuffle=shuffle,
|
| 778 |
+
)
|
| 779 |
+
else:
|
| 780 |
+
dataset = dataset_map(
|
| 781 |
+
self,
|
| 782 |
+
output_types=(self.dtype, "string", "string", "string"),
|
| 783 |
+
output_shapes=(self.X_shape, None, None, None),
|
| 784 |
+
num_parallel_calls=num_parallel_calls,
|
| 785 |
+
shuffle=shuffle,
|
| 786 |
+
)
|
| 787 |
+
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder).prefetch(batch_size * 2)
|
| 788 |
+
return dataset
|
| 789 |
+
|
| 790 |
+
|
| 791 |
+
class DataReader_mseed_array(DataReader):
|
| 792 |
+
def __init__(self, stations, amplitude=True, remove_resp=True, config=DataConfig(), **kwargs):
|
| 793 |
+
|
| 794 |
+
super().__init__(format="mseed", config=config, **kwargs)
|
| 795 |
+
|
| 796 |
+
# self.stations = pd.read_json(stations)
|
| 797 |
+
with open(stations, "r") as f:
|
| 798 |
+
self.stations = json.load(f)
|
| 799 |
+
print(pd.DataFrame.from_dict(self.stations, orient="index").to_string())
|
| 800 |
+
|
| 801 |
+
self.amplitude = amplitude
|
| 802 |
+
self.remove_resp = remove_resp
|
| 803 |
+
self.X_shape = self.get_data_shape()
|
| 804 |
+
|
| 805 |
+
def get_data_shape(self):
|
| 806 |
+
fname = os.path.join(self.data_dir, self.data_list[0])
|
| 807 |
+
meta = self.read_mseed_array(fname, self.stations, self.amplitude, self.remove_resp)
|
| 808 |
+
return meta["data"].shape
|
| 809 |
+
|
| 810 |
+
def __getitem__(self, i):
|
| 811 |
+
|
| 812 |
+
fp = os.path.join(self.data_dir, self.data_list[i])
|
| 813 |
+
# try:
|
| 814 |
+
meta = self.read_mseed_array(fp, self.stations, self.amplitude, self.remove_resp)
|
| 815 |
+
# except Exception as e:
|
| 816 |
+
# logging.error(f"Failed reading {fp}: {e}")
|
| 817 |
+
# if self.amplitude:
|
| 818 |
+
# return (np.zeros(self.X_shape).astype(self.dtype), np.zeros(self.X_shape).astype(self.dtype),
|
| 819 |
+
# [self.stations.iloc[i]["station"] for i in range(len(self.stations))], ["0" for i in range(len(self.stations))])
|
| 820 |
+
# else:
|
| 821 |
+
# return (np.zeros(self.X_shape).astype(self.dtype), ["" for i in range(len(self.stations))],
|
| 822 |
+
# [self.stations.iloc[i]["station"] for i in range(len(self.stations))])
|
| 823 |
+
|
| 824 |
+
sample = np.zeros([len(meta["data"]), *self.X_shape[1:]], dtype=self.dtype)
|
| 825 |
+
sample[:, : meta["data"].shape[1], :, :] = normalize_batch(meta["data"])[:, : self.X_shape[1], :, :]
|
| 826 |
+
if np.isnan(sample).any() or np.isinf(sample).any():
|
| 827 |
+
logging.warning(f"Data error: Nan or Inf found in {fp}")
|
| 828 |
+
sample[np.isnan(sample)] = 0
|
| 829 |
+
sample[np.isinf(sample)] = 0
|
| 830 |
+
t0 = meta["t0"]
|
| 831 |
+
base_name = meta["fname"]
|
| 832 |
+
station_id = meta["station_id"]
|
| 833 |
+
# base_name = [self.stations.iloc[i]["station"]+"."+t0[i] for i in range(len(self.stations))]
|
| 834 |
+
# base_name = [self.stations.iloc[i]["station"] for i in range(len(self.stations))]
|
| 835 |
+
|
| 836 |
+
if self.amplitude:
|
| 837 |
+
raw_amp = np.zeros([len(meta["raw_amp"]), *self.X_shape[1:]], dtype=self.dtype)
|
| 838 |
+
raw_amp[:, : meta["raw_amp"].shape[1], :, :] = meta["raw_amp"][:, : self.X_shape[1], :, :]
|
| 839 |
+
if np.isnan(raw_amp).any() or np.isinf(raw_amp).any():
|
| 840 |
+
logging.warning(f"Data error: Nan or Inf found in {fp}")
|
| 841 |
+
raw_amp[np.isnan(raw_amp)] = 0
|
| 842 |
+
raw_amp[np.isinf(raw_amp)] = 0
|
| 843 |
+
return (sample, raw_amp, base_name, t0, station_id)
|
| 844 |
+
else:
|
| 845 |
+
return (sample, base_name, t0, station_id)
|
| 846 |
+
|
| 847 |
+
def dataset(self, num_parallel_calls=1, shuffle=False):
|
| 848 |
+
if self.amplitude:
|
| 849 |
+
dataset = dataset_map(
|
| 850 |
+
self,
|
| 851 |
+
output_types=(self.dtype, self.dtype, "string", "string", "string"),
|
| 852 |
+
output_shapes=([None, *self.X_shape[1:]], [None, *self.X_shape[1:]], None, None, None),
|
| 853 |
+
num_parallel_calls=num_parallel_calls,
|
| 854 |
+
)
|
| 855 |
+
else:
|
| 856 |
+
dataset = dataset_map(
|
| 857 |
+
self,
|
| 858 |
+
output_types=(self.dtype, "string", "string", "string"),
|
| 859 |
+
output_shapes=([None, *self.X_shape[1:]], None, None, None),
|
| 860 |
+
num_parallel_calls=num_parallel_calls,
|
| 861 |
+
)
|
| 862 |
+
dataset = dataset.prefetch(1)
|
| 863 |
+
# dataset = dataset.prefetch(len(self.stations)*2)
|
| 864 |
+
return dataset
|
| 865 |
+
|
| 866 |
+
|
| 867 |
+
###### test ########
|
| 868 |
+
|
| 869 |
+
|
| 870 |
+
def test_DataReader():
|
| 871 |
+
import os
|
| 872 |
+
import timeit
|
| 873 |
+
|
| 874 |
+
import matplotlib.pyplot as plt
|
| 875 |
+
|
| 876 |
+
if not os.path.exists("test_figures"):
|
| 877 |
+
os.mkdir("test_figures")
|
| 878 |
+
|
| 879 |
+
def plot_sample(sample, fname, label=None):
|
| 880 |
+
plt.clf()
|
| 881 |
+
plt.subplot(211)
|
| 882 |
+
plt.plot(sample[:, 0, -1])
|
| 883 |
+
if label is not None:
|
| 884 |
+
plt.subplot(212)
|
| 885 |
+
plt.plot(label[:, 0, 0])
|
| 886 |
+
plt.plot(label[:, 0, 1])
|
| 887 |
+
plt.plot(label[:, 0, 2])
|
| 888 |
+
plt.savefig(f"test_figures/{fname.decode()}.png")
|
| 889 |
+
|
| 890 |
+
def read(data_reader, batch=1):
|
| 891 |
+
start_time = timeit.default_timer()
|
| 892 |
+
if batch is None:
|
| 893 |
+
dataset = data_reader.dataset(shuffle=False)
|
| 894 |
+
else:
|
| 895 |
+
dataset = data_reader.dataset(1, shuffle=False)
|
| 896 |
+
sess = tf.compat.v1.Session()
|
| 897 |
+
|
| 898 |
+
print(len(data_reader))
|
| 899 |
+
print("-------", tf.data.Dataset.cardinality(dataset))
|
| 900 |
+
num = 0
|
| 901 |
+
x = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
|
| 902 |
+
while True:
|
| 903 |
+
num += 1
|
| 904 |
+
# print(num)
|
| 905 |
+
try:
|
| 906 |
+
out = sess.run(x)
|
| 907 |
+
if len(out) == 2:
|
| 908 |
+
sample, fname = out[0], out[1]
|
| 909 |
+
for i in range(len(sample)):
|
| 910 |
+
plot_sample(sample[i], fname[i])
|
| 911 |
+
else:
|
| 912 |
+
sample, label, fname = out[0], out[1], out[2]
|
| 913 |
+
for i in range(len(sample)):
|
| 914 |
+
plot_sample(sample[i], fname[i], label[i])
|
| 915 |
+
except tf.errors.OutOfRangeError:
|
| 916 |
+
break
|
| 917 |
+
print("End of dataset")
|
| 918 |
+
print("Tensorflow Dataset:\nexecution time = ", timeit.default_timer() - start_time)
|
| 919 |
+
|
| 920 |
+
data_reader = DataReader_train(data_list="test_data/selected_phases.csv", data_dir="test_data/data/")
|
| 921 |
+
|
| 922 |
+
read(data_reader)
|
| 923 |
+
|
| 924 |
+
data_reader = DataReader_train(format="hdf5", hdf5="test_data/data.h5", group="data")
|
| 925 |
+
|
| 926 |
+
read(data_reader)
|
| 927 |
+
|
| 928 |
+
data_reader = DataReader_test(data_list="test_data/selected_phases.csv", data_dir="test_data/data/")
|
| 929 |
+
|
| 930 |
+
read(data_reader)
|
| 931 |
+
|
| 932 |
+
data_reader = DataReader_test(format="hdf5", hdf5="test_data/data.h5", group="data")
|
| 933 |
+
|
| 934 |
+
read(data_reader)
|
| 935 |
+
|
| 936 |
+
data_reader = DataReader_pred(format="numpy", data_list="test_data/selected_phases.csv", data_dir="test_data/data/")
|
| 937 |
+
|
| 938 |
+
read(data_reader)
|
| 939 |
+
|
| 940 |
+
data_reader = DataReader_pred(
|
| 941 |
+
format="mseed", data_list="test_data/mseed_station.csv", data_dir="test_data/waveforms/"
|
| 942 |
+
)
|
| 943 |
+
|
| 944 |
+
read(data_reader)
|
| 945 |
+
|
| 946 |
+
data_reader = DataReader_pred(
|
| 947 |
+
format="mseed", amplitude=True, data_list="test_data/mseed_station.csv", data_dir="test_data/waveforms/"
|
| 948 |
+
)
|
| 949 |
+
|
| 950 |
+
read(data_reader)
|
| 951 |
+
|
| 952 |
+
data_reader = DataReader_mseed_array(
|
| 953 |
+
data_list="test_data/mseed.csv",
|
| 954 |
+
data_dir="test_data/waveforms/",
|
| 955 |
+
stations="test_data/stations.csv",
|
| 956 |
+
remove_resp=False,
|
| 957 |
+
)
|
| 958 |
+
|
| 959 |
+
read(data_reader, batch=None)
|
| 960 |
+
|
| 961 |
+
|
| 962 |
+
if __name__ == "__main__":
|
| 963 |
+
|
| 964 |
+
test_DataReader()
|
phasenet/detect_peaks.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Detect peaks in data based on their amplitude and other features."""
|
| 2 |
+
|
| 3 |
+
from __future__ import division, print_function
|
| 4 |
+
import warnings
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
__author__ = "Marcos Duarte, https://github.com/demotu"
|
| 8 |
+
__version__ = "1.0.6"
|
| 9 |
+
__license__ = "MIT"
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def detect_peaks(x, mph=None, mpd=1, threshold=0, edge='rising',
|
| 14 |
+
kpsh=False, valley=False, show=False, ax=None, title=True):
|
| 15 |
+
|
| 16 |
+
"""Detect peaks in data based on their amplitude and other features.
|
| 17 |
+
|
| 18 |
+
Parameters
|
| 19 |
+
----------
|
| 20 |
+
x : 1D array_like
|
| 21 |
+
data.
|
| 22 |
+
mph : {None, number}, optional (default = None)
|
| 23 |
+
detect peaks that are greater than minimum peak height (if parameter
|
| 24 |
+
`valley` is False) or peaks that are smaller than maximum peak height
|
| 25 |
+
(if parameter `valley` is True).
|
| 26 |
+
mpd : positive integer, optional (default = 1)
|
| 27 |
+
detect peaks that are at least separated by minimum peak distance (in
|
| 28 |
+
number of data).
|
| 29 |
+
threshold : positive number, optional (default = 0)
|
| 30 |
+
detect peaks (valleys) that are greater (smaller) than `threshold`
|
| 31 |
+
in relation to their immediate neighbors.
|
| 32 |
+
edge : {None, 'rising', 'falling', 'both'}, optional (default = 'rising')
|
| 33 |
+
for a flat peak, keep only the rising edge ('rising'), only the
|
| 34 |
+
falling edge ('falling'), both edges ('both'), or don't detect a
|
| 35 |
+
flat peak (None).
|
| 36 |
+
kpsh : bool, optional (default = False)
|
| 37 |
+
keep peaks with same height even if they are closer than `mpd`.
|
| 38 |
+
valley : bool, optional (default = False)
|
| 39 |
+
if True (1), detect valleys (local minima) instead of peaks.
|
| 40 |
+
show : bool, optional (default = False)
|
| 41 |
+
if True (1), plot data in matplotlib figure.
|
| 42 |
+
ax : a matplotlib.axes.Axes instance, optional (default = None).
|
| 43 |
+
title : bool or string, optional (default = True)
|
| 44 |
+
if True, show standard title. If False or empty string, doesn't show
|
| 45 |
+
any title. If string, shows string as title.
|
| 46 |
+
|
| 47 |
+
Returns
|
| 48 |
+
-------
|
| 49 |
+
ind : 1D array_like
|
| 50 |
+
indeces of the peaks in `x`.
|
| 51 |
+
|
| 52 |
+
Notes
|
| 53 |
+
-----
|
| 54 |
+
The detection of valleys instead of peaks is performed internally by simply
|
| 55 |
+
negating the data: `ind_valleys = detect_peaks(-x)`
|
| 56 |
+
|
| 57 |
+
The function can handle NaN's
|
| 58 |
+
|
| 59 |
+
See this IPython Notebook [1]_.
|
| 60 |
+
|
| 61 |
+
References
|
| 62 |
+
----------
|
| 63 |
+
.. [1] http://nbviewer.ipython.org/github/demotu/BMC/blob/master/notebooks/DetectPeaks.ipynb
|
| 64 |
+
|
| 65 |
+
Examples
|
| 66 |
+
--------
|
| 67 |
+
>>> from detect_peaks import detect_peaks
|
| 68 |
+
>>> x = np.random.randn(100)
|
| 69 |
+
>>> x[60:81] = np.nan
|
| 70 |
+
>>> # detect all peaks and plot data
|
| 71 |
+
>>> ind = detect_peaks(x, show=True)
|
| 72 |
+
>>> print(ind)
|
| 73 |
+
|
| 74 |
+
>>> x = np.sin(2*np.pi*5*np.linspace(0, 1, 200)) + np.random.randn(200)/5
|
| 75 |
+
>>> # set minimum peak height = 0 and minimum peak distance = 20
|
| 76 |
+
>>> detect_peaks(x, mph=0, mpd=20, show=True)
|
| 77 |
+
|
| 78 |
+
>>> x = [0, 1, 0, 2, 0, 3, 0, 2, 0, 1, 0]
|
| 79 |
+
>>> # set minimum peak distance = 2
|
| 80 |
+
>>> detect_peaks(x, mpd=2, show=True)
|
| 81 |
+
|
| 82 |
+
>>> x = np.sin(2*np.pi*5*np.linspace(0, 1, 200)) + np.random.randn(200)/5
|
| 83 |
+
>>> # detection of valleys instead of peaks
|
| 84 |
+
>>> detect_peaks(x, mph=-1.2, mpd=20, valley=True, show=True)
|
| 85 |
+
|
| 86 |
+
>>> x = [0, 1, 1, 0, 1, 1, 0]
|
| 87 |
+
>>> # detect both edges
|
| 88 |
+
>>> detect_peaks(x, edge='both', show=True)
|
| 89 |
+
|
| 90 |
+
>>> x = [-2, 1, -2, 2, 1, 1, 3, 0]
|
| 91 |
+
>>> # set threshold = 2
|
| 92 |
+
>>> detect_peaks(x, threshold = 2, show=True)
|
| 93 |
+
|
| 94 |
+
>>> x = [-2, 1, -2, 2, 1, 1, 3, 0]
|
| 95 |
+
>>> fig, axs = plt.subplots(ncols=2, nrows=1, figsize=(10, 4))
|
| 96 |
+
>>> detect_peaks(x, show=True, ax=axs[0], threshold=0.5, title=False)
|
| 97 |
+
>>> detect_peaks(x, show=True, ax=axs[1], threshold=1.5, title=False)
|
| 98 |
+
|
| 99 |
+
Version history
|
| 100 |
+
---------------
|
| 101 |
+
'1.0.6':
|
| 102 |
+
Fix issue of when specifying ax object only the first plot was shown
|
| 103 |
+
Add parameter to choose if a title is shown and input a title
|
| 104 |
+
'1.0.5':
|
| 105 |
+
The sign of `mph` is inverted if parameter `valley` is True
|
| 106 |
+
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
x = np.atleast_1d(x).astype('float64')
|
| 110 |
+
if x.size < 3:
|
| 111 |
+
return np.array([], dtype=int)
|
| 112 |
+
if valley:
|
| 113 |
+
x = -x
|
| 114 |
+
if mph is not None:
|
| 115 |
+
mph = -mph
|
| 116 |
+
# find indices of all peaks
|
| 117 |
+
dx = x[1:] - x[:-1]
|
| 118 |
+
# handle NaN's
|
| 119 |
+
indnan = np.where(np.isnan(x))[0]
|
| 120 |
+
if indnan.size:
|
| 121 |
+
x[indnan] = np.inf
|
| 122 |
+
dx[np.where(np.isnan(dx))[0]] = np.inf
|
| 123 |
+
ine, ire, ife = np.array([[], [], []], dtype=int)
|
| 124 |
+
if not edge:
|
| 125 |
+
ine = np.where((np.hstack((dx, 0)) < 0) & (np.hstack((0, dx)) > 0))[0]
|
| 126 |
+
else:
|
| 127 |
+
if edge.lower() in ['rising', 'both']:
|
| 128 |
+
ire = np.where((np.hstack((dx, 0)) <= 0) & (np.hstack((0, dx)) > 0))[0]
|
| 129 |
+
if edge.lower() in ['falling', 'both']:
|
| 130 |
+
ife = np.where((np.hstack((dx, 0)) < 0) & (np.hstack((0, dx)) >= 0))[0]
|
| 131 |
+
ind = np.unique(np.hstack((ine, ire, ife)))
|
| 132 |
+
# handle NaN's
|
| 133 |
+
if ind.size and indnan.size:
|
| 134 |
+
# NaN's and values close to NaN's cannot be peaks
|
| 135 |
+
ind = ind[np.in1d(ind, np.unique(np.hstack((indnan, indnan-1, indnan+1))), invert=True)]
|
| 136 |
+
# first and last values of x cannot be peaks
|
| 137 |
+
if ind.size and ind[0] == 0:
|
| 138 |
+
ind = ind[1:]
|
| 139 |
+
if ind.size and ind[-1] == x.size-1:
|
| 140 |
+
ind = ind[:-1]
|
| 141 |
+
# remove peaks < minimum peak height
|
| 142 |
+
if ind.size and mph is not None:
|
| 143 |
+
ind = ind[x[ind] >= mph]
|
| 144 |
+
# remove peaks - neighbors < threshold
|
| 145 |
+
if ind.size and threshold > 0:
|
| 146 |
+
dx = np.min(np.vstack([x[ind]-x[ind-1], x[ind]-x[ind+1]]), axis=0)
|
| 147 |
+
ind = np.delete(ind, np.where(dx < threshold)[0])
|
| 148 |
+
# detect small peaks closer than minimum peak distance
|
| 149 |
+
if ind.size and mpd > 1:
|
| 150 |
+
ind = ind[np.argsort(x[ind])][::-1] # sort ind by peak height
|
| 151 |
+
idel = np.zeros(ind.size, dtype=bool)
|
| 152 |
+
for i in range(ind.size):
|
| 153 |
+
if not idel[i]:
|
| 154 |
+
# keep peaks with the same height if kpsh is True
|
| 155 |
+
idel = idel | (ind >= ind[i] - mpd) & (ind <= ind[i] + mpd) \
|
| 156 |
+
& (x[ind[i]] > x[ind] if kpsh else True)
|
| 157 |
+
idel[i] = 0 # Keep current peak
|
| 158 |
+
# remove the small peaks and sort back the indices by their occurrence
|
| 159 |
+
ind = np.sort(ind[~idel])
|
| 160 |
+
|
| 161 |
+
if show:
|
| 162 |
+
if indnan.size:
|
| 163 |
+
x[indnan] = np.nan
|
| 164 |
+
if valley:
|
| 165 |
+
x = -x
|
| 166 |
+
if mph is not None:
|
| 167 |
+
mph = -mph
|
| 168 |
+
_plot(x, mph, mpd, threshold, edge, valley, ax, ind, title)
|
| 169 |
+
|
| 170 |
+
return ind, x[ind]
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def _plot(x, mph, mpd, threshold, edge, valley, ax, ind, title):
|
| 174 |
+
"""Plot results of the detect_peaks function, see its help."""
|
| 175 |
+
try:
|
| 176 |
+
import matplotlib.pyplot as plt
|
| 177 |
+
except ImportError:
|
| 178 |
+
print('matplotlib is not available.')
|
| 179 |
+
else:
|
| 180 |
+
if ax is None:
|
| 181 |
+
_, ax = plt.subplots(1, 1, figsize=(8, 4))
|
| 182 |
+
no_ax = True
|
| 183 |
+
else:
|
| 184 |
+
no_ax = False
|
| 185 |
+
|
| 186 |
+
ax.plot(x, 'b', lw=1)
|
| 187 |
+
if ind.size:
|
| 188 |
+
label = 'valley' if valley else 'peak'
|
| 189 |
+
label = label + 's' if ind.size > 1 else label
|
| 190 |
+
ax.plot(ind, x[ind], '+', mfc=None, mec='r', mew=2, ms=8,
|
| 191 |
+
label='%d %s' % (ind.size, label))
|
| 192 |
+
ax.legend(loc='best', framealpha=.5, numpoints=1)
|
| 193 |
+
ax.set_xlim(-.02*x.size, x.size*1.02-1)
|
| 194 |
+
ymin, ymax = x[np.isfinite(x)].min(), x[np.isfinite(x)].max()
|
| 195 |
+
yrange = ymax - ymin if ymax > ymin else 1
|
| 196 |
+
ax.set_ylim(ymin - 0.1*yrange, ymax + 0.1*yrange)
|
| 197 |
+
ax.set_xlabel('Data #', fontsize=14)
|
| 198 |
+
ax.set_ylabel('Amplitude', fontsize=14)
|
| 199 |
+
if title:
|
| 200 |
+
if not isinstance(title, str):
|
| 201 |
+
mode = 'Valley detection' if valley else 'Peak detection'
|
| 202 |
+
title = "%s (mph=%s, mpd=%d, threshold=%s, edge='%s')"% \
|
| 203 |
+
(mode, str(mph), mpd, str(threshold), edge)
|
| 204 |
+
ax.set_title(title)
|
| 205 |
+
# plt.grid()
|
| 206 |
+
if no_ax:
|
| 207 |
+
plt.show()
|
phasenet/model.py
ADDED
|
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
tf.compat.v1.disable_eager_execution()
|
| 3 |
+
import numpy as np
|
| 4 |
+
import logging
|
| 5 |
+
import warnings
|
| 6 |
+
warnings.filterwarnings('ignore', category=UserWarning)
|
| 7 |
+
|
| 8 |
+
class ModelConfig:
|
| 9 |
+
|
| 10 |
+
batch_size = 20
|
| 11 |
+
depths = 5
|
| 12 |
+
filters_root = 8
|
| 13 |
+
kernel_size = [7, 1]
|
| 14 |
+
pool_size = [4, 1]
|
| 15 |
+
dilation_rate = [1, 1]
|
| 16 |
+
class_weights = [1.0, 1.0, 1.0]
|
| 17 |
+
loss_type = "cross_entropy"
|
| 18 |
+
weight_decay = 0.0
|
| 19 |
+
optimizer = "adam"
|
| 20 |
+
momentum = 0.9
|
| 21 |
+
learning_rate = 0.01
|
| 22 |
+
decay_step = 1e9
|
| 23 |
+
decay_rate = 0.9
|
| 24 |
+
drop_rate = 0.0
|
| 25 |
+
summary = True
|
| 26 |
+
|
| 27 |
+
X_shape = [3000, 1, 3]
|
| 28 |
+
n_channel = X_shape[-1]
|
| 29 |
+
Y_shape = [3000, 1, 3]
|
| 30 |
+
n_class = Y_shape[-1]
|
| 31 |
+
|
| 32 |
+
def __init__(self, **kwargs):
|
| 33 |
+
for k,v in kwargs.items():
|
| 34 |
+
setattr(self, k, v)
|
| 35 |
+
|
| 36 |
+
def update_args(self, args):
|
| 37 |
+
for k,v in vars(args).items():
|
| 38 |
+
setattr(self, k, v)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def crop_and_concat(net1, net2):
|
| 42 |
+
"""
|
| 43 |
+
the size(net1) <= size(net2)
|
| 44 |
+
"""
|
| 45 |
+
# net1_shape = net1.get_shape().as_list()
|
| 46 |
+
# net2_shape = net2.get_shape().as_list()
|
| 47 |
+
# # print(net1_shape)
|
| 48 |
+
# # print(net2_shape)
|
| 49 |
+
# # if net2_shape[1] >= net1_shape[1] and net2_shape[2] >= net1_shape[2]:
|
| 50 |
+
# offsets = [0, (net2_shape[1] - net1_shape[1]) // 2, (net2_shape[2] - net1_shape[2]) // 2, 0]
|
| 51 |
+
# size = [-1, net1_shape[1], net1_shape[2], -1]
|
| 52 |
+
# net2_resize = tf.slice(net2, offsets, size)
|
| 53 |
+
# return tf.concat([net1, net2_resize], 3)
|
| 54 |
+
|
| 55 |
+
## dynamic shape
|
| 56 |
+
chn1 = net1.get_shape().as_list()[-1]
|
| 57 |
+
chn2 = net2.get_shape().as_list()[-1]
|
| 58 |
+
net1_shape = tf.shape(net1)
|
| 59 |
+
net2_shape = tf.shape(net2)
|
| 60 |
+
# print(net1_shape)
|
| 61 |
+
# print(net2_shape)
|
| 62 |
+
# if net2_shape[1] >= net1_shape[1] and net2_shape[2] >= net1_shape[2]:
|
| 63 |
+
offsets = [0, (net2_shape[1] - net1_shape[1]) // 2, (net2_shape[2] - net1_shape[2]) // 2, 0]
|
| 64 |
+
size = [-1, net1_shape[1], net1_shape[2], -1]
|
| 65 |
+
net2_resize = tf.slice(net2, offsets, size)
|
| 66 |
+
|
| 67 |
+
out = tf.concat([net1, net2_resize], 3)
|
| 68 |
+
out.set_shape([None, None, None, chn1+chn2])
|
| 69 |
+
|
| 70 |
+
return out
|
| 71 |
+
|
| 72 |
+
# else:
|
| 73 |
+
# offsets = [0, (net1_shape[1] - net2_shape[1]) // 2, (net1_shape[2] - net2_shape[2]) // 2, 0]
|
| 74 |
+
# size = [-1, net2_shape[1], net2_shape[2], -1]
|
| 75 |
+
# net1_resize = tf.slice(net1, offsets, size)
|
| 76 |
+
# return tf.concat([net1_resize, net2], 3)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def crop_only(net1, net2):
|
| 80 |
+
"""
|
| 81 |
+
the size(net1) <= size(net2)
|
| 82 |
+
"""
|
| 83 |
+
net1_shape = net1.get_shape().as_list()
|
| 84 |
+
net2_shape = net2.get_shape().as_list()
|
| 85 |
+
# print(net1_shape)
|
| 86 |
+
# print(net2_shape)
|
| 87 |
+
# if net2_shape[1] >= net1_shape[1] and net2_shape[2] >= net1_shape[2]:
|
| 88 |
+
offsets = [0, (net2_shape[1] - net1_shape[1]) // 2, (net2_shape[2] - net1_shape[2]) // 2, 0]
|
| 89 |
+
size = [-1, net1_shape[1], net1_shape[2], -1]
|
| 90 |
+
net2_resize = tf.slice(net2, offsets, size)
|
| 91 |
+
#return tf.concat([net1, net2_resize], 3)
|
| 92 |
+
return net2_resize
|
| 93 |
+
|
| 94 |
+
class UNet:
|
| 95 |
+
def __init__(self, config=ModelConfig(), input_batch=None, mode='train'):
|
| 96 |
+
self.depths = config.depths
|
| 97 |
+
self.filters_root = config.filters_root
|
| 98 |
+
self.kernel_size = config.kernel_size
|
| 99 |
+
self.dilation_rate = config.dilation_rate
|
| 100 |
+
self.pool_size = config.pool_size
|
| 101 |
+
self.X_shape = config.X_shape
|
| 102 |
+
self.Y_shape = config.Y_shape
|
| 103 |
+
self.n_channel = config.n_channel
|
| 104 |
+
self.n_class = config.n_class
|
| 105 |
+
self.class_weights = config.class_weights
|
| 106 |
+
self.batch_size = config.batch_size
|
| 107 |
+
self.loss_type = config.loss_type
|
| 108 |
+
self.weight_decay = config.weight_decay
|
| 109 |
+
self.optimizer = config.optimizer
|
| 110 |
+
self.learning_rate = config.learning_rate
|
| 111 |
+
self.decay_step = config.decay_step
|
| 112 |
+
self.decay_rate = config.decay_rate
|
| 113 |
+
self.momentum = config.momentum
|
| 114 |
+
self.global_step = tf.compat.v1.get_variable(name="global_step", initializer=0, dtype=tf.int32)
|
| 115 |
+
self.summary_train = []
|
| 116 |
+
self.summary_valid = []
|
| 117 |
+
|
| 118 |
+
self.build(input_batch, mode=mode)
|
| 119 |
+
|
| 120 |
+
def add_placeholders(self, input_batch=None, mode="train"):
|
| 121 |
+
if input_batch is None:
|
| 122 |
+
# self.X = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, self.X_shape[-3], self.X_shape[-2], self.X_shape[-1]], name='X')
|
| 123 |
+
# self.Y = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, self.Y_shape[-3], self.Y_shape[-2], self.n_class], name='y')
|
| 124 |
+
self.X = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, None, None, self.X_shape[-1]], name='X')
|
| 125 |
+
self.Y = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, None, None, self.n_class], name='y')
|
| 126 |
+
else:
|
| 127 |
+
self.X = input_batch[0]
|
| 128 |
+
if mode in ["train", "valid", "test"]:
|
| 129 |
+
self.Y = input_batch[1]
|
| 130 |
+
self.input_batch = input_batch
|
| 131 |
+
|
| 132 |
+
self.is_training = tf.compat.v1.placeholder(dtype=tf.bool, name="is_training")
|
| 133 |
+
# self.keep_prob = tf.compat.v1.placeholder(dtype=tf.float32, name="keep_prob")
|
| 134 |
+
self.drop_rate = tf.compat.v1.placeholder(dtype=tf.float32, name="drop_rate")
|
| 135 |
+
|
| 136 |
+
def add_prediction_op(self):
|
| 137 |
+
logging.info("Model: depths {depths}, filters {filters}, "
|
| 138 |
+
"filter size {kernel_size[0]}x{kernel_size[1]}, "
|
| 139 |
+
"pool size: {pool_size[0]}x{pool_size[1]}, "
|
| 140 |
+
"dilation rate: {dilation_rate[0]}x{dilation_rate[1]}".format(
|
| 141 |
+
depths=self.depths,
|
| 142 |
+
filters=self.filters_root,
|
| 143 |
+
kernel_size=self.kernel_size,
|
| 144 |
+
dilation_rate=self.dilation_rate,
|
| 145 |
+
pool_size=self.pool_size))
|
| 146 |
+
|
| 147 |
+
if self.weight_decay > 0:
|
| 148 |
+
weight_decay = tf.constant(self.weight_decay, dtype=tf.float32, name="weight_constant")
|
| 149 |
+
self.regularizer = tf.keras.regularizers.l2(l=0.5 * (weight_decay))
|
| 150 |
+
else:
|
| 151 |
+
self.regularizer = None
|
| 152 |
+
|
| 153 |
+
self.initializer = tf.compat.v1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform")
|
| 154 |
+
|
| 155 |
+
# down sample layers
|
| 156 |
+
convs = [None] * self.depths # store output of each depth
|
| 157 |
+
|
| 158 |
+
with tf.compat.v1.variable_scope("Input"):
|
| 159 |
+
net = self.X
|
| 160 |
+
net = tf.compat.v1.layers.conv2d(net,
|
| 161 |
+
filters=self.filters_root,
|
| 162 |
+
kernel_size=self.kernel_size,
|
| 163 |
+
activation=None,
|
| 164 |
+
padding='same',
|
| 165 |
+
dilation_rate=self.dilation_rate,
|
| 166 |
+
kernel_initializer=self.initializer,
|
| 167 |
+
kernel_regularizer=self.regularizer,
|
| 168 |
+
name="input_conv")
|
| 169 |
+
net = tf.compat.v1.layers.batch_normalization(net,
|
| 170 |
+
training=self.is_training,
|
| 171 |
+
name="input_bn")
|
| 172 |
+
net = tf.nn.relu(net,
|
| 173 |
+
name="input_relu")
|
| 174 |
+
# net = tf.nn.dropout(net, self.keep_prob)
|
| 175 |
+
net = tf.compat.v1.layers.dropout(net,
|
| 176 |
+
rate=self.drop_rate,
|
| 177 |
+
training=self.is_training,
|
| 178 |
+
name="input_dropout")
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
for depth in range(0, self.depths):
|
| 182 |
+
with tf.compat.v1.variable_scope("DownConv_%d" % depth):
|
| 183 |
+
filters = int(2**(depth) * self.filters_root)
|
| 184 |
+
|
| 185 |
+
net = tf.compat.v1.layers.conv2d(net,
|
| 186 |
+
filters=filters,
|
| 187 |
+
kernel_size=self.kernel_size,
|
| 188 |
+
activation=None,
|
| 189 |
+
use_bias=False,
|
| 190 |
+
padding='same',
|
| 191 |
+
dilation_rate=self.dilation_rate,
|
| 192 |
+
kernel_initializer=self.initializer,
|
| 193 |
+
kernel_regularizer=self.regularizer,
|
| 194 |
+
name="down_conv1_{}".format(depth + 1))
|
| 195 |
+
net = tf.compat.v1.layers.batch_normalization(net,
|
| 196 |
+
training=self.is_training,
|
| 197 |
+
name="down_bn1_{}".format(depth + 1))
|
| 198 |
+
net = tf.nn.relu(net,
|
| 199 |
+
name="down_relu1_{}".format(depth+1))
|
| 200 |
+
net = tf.compat.v1.layers.dropout(net,
|
| 201 |
+
rate=self.drop_rate,
|
| 202 |
+
training=self.is_training,
|
| 203 |
+
name="down_dropout1_{}".format(depth + 1))
|
| 204 |
+
|
| 205 |
+
convs[depth] = net
|
| 206 |
+
|
| 207 |
+
if depth < self.depths - 1:
|
| 208 |
+
net = tf.compat.v1.layers.conv2d(net,
|
| 209 |
+
filters=filters,
|
| 210 |
+
kernel_size=self.kernel_size,
|
| 211 |
+
strides=self.pool_size,
|
| 212 |
+
activation=None,
|
| 213 |
+
use_bias=False,
|
| 214 |
+
padding='same',
|
| 215 |
+
dilation_rate=self.dilation_rate,
|
| 216 |
+
kernel_initializer=self.initializer,
|
| 217 |
+
kernel_regularizer=self.regularizer,
|
| 218 |
+
name="down_conv3_{}".format(depth + 1))
|
| 219 |
+
net = tf.compat.v1.layers.batch_normalization(net,
|
| 220 |
+
training=self.is_training,
|
| 221 |
+
name="down_bn3_{}".format(depth + 1))
|
| 222 |
+
net = tf.nn.relu(net,
|
| 223 |
+
name="down_relu3_{}".format(depth+1))
|
| 224 |
+
net = tf.compat.v1.layers.dropout(net,
|
| 225 |
+
rate=self.drop_rate,
|
| 226 |
+
training=self.is_training,
|
| 227 |
+
name="down_dropout3_{}".format(depth + 1))
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
# up layers
|
| 231 |
+
for depth in range(self.depths - 2, -1, -1):
|
| 232 |
+
with tf.compat.v1.variable_scope("UpConv_%d" % depth):
|
| 233 |
+
filters = int(2**(depth) * self.filters_root)
|
| 234 |
+
net = tf.compat.v1.layers.conv2d_transpose(net,
|
| 235 |
+
filters=filters,
|
| 236 |
+
kernel_size=self.kernel_size,
|
| 237 |
+
strides=self.pool_size,
|
| 238 |
+
activation=None,
|
| 239 |
+
use_bias=False,
|
| 240 |
+
padding="same",
|
| 241 |
+
kernel_initializer=self.initializer,
|
| 242 |
+
kernel_regularizer=self.regularizer,
|
| 243 |
+
name="up_conv0_{}".format(depth+1))
|
| 244 |
+
net = tf.compat.v1.layers.batch_normalization(net,
|
| 245 |
+
training=self.is_training,
|
| 246 |
+
name="up_bn0_{}".format(depth + 1))
|
| 247 |
+
net = tf.nn.relu(net,
|
| 248 |
+
name="up_relu0_{}".format(depth+1))
|
| 249 |
+
net = tf.compat.v1.layers.dropout(net,
|
| 250 |
+
rate=self.drop_rate,
|
| 251 |
+
training=self.is_training,
|
| 252 |
+
name="up_dropout0_{}".format(depth + 1))
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
#skip connection
|
| 256 |
+
net = crop_and_concat(convs[depth], net)
|
| 257 |
+
#net = crop_only(convs[depth], net)
|
| 258 |
+
|
| 259 |
+
net = tf.compat.v1.layers.conv2d(net,
|
| 260 |
+
filters=filters,
|
| 261 |
+
kernel_size=self.kernel_size,
|
| 262 |
+
activation=None,
|
| 263 |
+
use_bias=False,
|
| 264 |
+
padding='same',
|
| 265 |
+
dilation_rate=self.dilation_rate,
|
| 266 |
+
kernel_initializer=self.initializer,
|
| 267 |
+
kernel_regularizer=self.regularizer,
|
| 268 |
+
name="up_conv1_{}".format(depth + 1))
|
| 269 |
+
net = tf.compat.v1.layers.batch_normalization(net,
|
| 270 |
+
training=self.is_training,
|
| 271 |
+
name="up_bn1_{}".format(depth + 1))
|
| 272 |
+
net = tf.nn.relu(net,
|
| 273 |
+
name="up_relu1_{}".format(depth + 1))
|
| 274 |
+
net = tf.compat.v1.layers.dropout(net,
|
| 275 |
+
rate=self.drop_rate,
|
| 276 |
+
training=self.is_training,
|
| 277 |
+
name="up_dropout1_{}".format(depth + 1))
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
# Output Map
|
| 281 |
+
with tf.compat.v1.variable_scope("Output"):
|
| 282 |
+
net = tf.compat.v1.layers.conv2d(net,
|
| 283 |
+
filters=self.n_class,
|
| 284 |
+
kernel_size=(1,1),
|
| 285 |
+
activation=None,
|
| 286 |
+
padding='same',
|
| 287 |
+
#dilation_rate=self.dilation_rate,
|
| 288 |
+
kernel_initializer=self.initializer,
|
| 289 |
+
kernel_regularizer=self.regularizer,
|
| 290 |
+
name="output_conv")
|
| 291 |
+
# net = tf.nn.relu(net,
|
| 292 |
+
# name="output_relu")
|
| 293 |
+
# net = tf.compat.v1.layers.dropout(net,
|
| 294 |
+
# rate=self.drop_rate,
|
| 295 |
+
# training=self.is_training,
|
| 296 |
+
# name="output_dropout")
|
| 297 |
+
# net = tf.compat.v1.layers.batch_normalization(net,
|
| 298 |
+
# training=self.is_training,
|
| 299 |
+
# name="output_bn")
|
| 300 |
+
output = net
|
| 301 |
+
|
| 302 |
+
with tf.compat.v1.variable_scope("representation"):
|
| 303 |
+
self.representation = convs[-1]
|
| 304 |
+
|
| 305 |
+
with tf.compat.v1.variable_scope("logits"):
|
| 306 |
+
self.logits = output
|
| 307 |
+
tmp = tf.compat.v1.summary.histogram("logits", self.logits)
|
| 308 |
+
self.summary_train.append(tmp)
|
| 309 |
+
|
| 310 |
+
with tf.compat.v1.variable_scope("preds"):
|
| 311 |
+
self.preds = tf.nn.softmax(output)
|
| 312 |
+
tmp = tf.compat.v1.summary.histogram("preds", self.preds)
|
| 313 |
+
self.summary_train.append(tmp)
|
| 314 |
+
|
| 315 |
+
def add_loss_op(self):
|
| 316 |
+
if self.loss_type == "cross_entropy":
|
| 317 |
+
with tf.compat.v1.variable_scope("cross_entropy"):
|
| 318 |
+
flat_logits = tf.reshape(self.logits, [-1, self.n_class], name="logits")
|
| 319 |
+
flat_labels = tf.reshape(self.Y, [-1, self.n_class], name="labels")
|
| 320 |
+
if (np.array(self.class_weights) != 1).any():
|
| 321 |
+
class_weights = tf.constant(np.array(self.class_weights, dtype=np.float32), name="class_weights")
|
| 322 |
+
weight_map = tf.multiply(flat_labels, class_weights)
|
| 323 |
+
weight_map = tf.reduce_sum(input_tensor=weight_map, axis=1)
|
| 324 |
+
loss_map = tf.nn.softmax_cross_entropy_with_logits(logits=flat_logits,
|
| 325 |
+
labels=flat_labels)
|
| 326 |
+
|
| 327 |
+
weighted_loss = tf.multiply(loss_map, weight_map)
|
| 328 |
+
loss = tf.reduce_mean(input_tensor=weighted_loss)
|
| 329 |
+
else:
|
| 330 |
+
loss = tf.reduce_mean(input_tensor=tf.nn.softmax_cross_entropy_with_logits(logits=flat_logits,
|
| 331 |
+
labels=flat_labels))
|
| 332 |
+
|
| 333 |
+
elif self.loss_type == "IOU":
|
| 334 |
+
with tf.compat.v1.variable_scope("IOU"):
|
| 335 |
+
eps = 1e-7
|
| 336 |
+
loss = 0
|
| 337 |
+
for i in range(1, self.n_class):
|
| 338 |
+
intersection = eps + tf.reduce_sum(input_tensor=self.preds[:,:,:,i] * self.Y[:,:,:,i], axis=[1,2])
|
| 339 |
+
union = eps + tf.reduce_sum(input_tensor=self.preds[:,:,:,i], axis=[1,2]) + tf.reduce_sum(input_tensor=self.Y[:,:,:,i], axis=[1,2])
|
| 340 |
+
loss += 1 - tf.reduce_mean(input_tensor=intersection / union)
|
| 341 |
+
elif self.loss_type == "mean_squared":
|
| 342 |
+
with tf.compat.v1.variable_scope("mean_squared"):
|
| 343 |
+
flat_logits = tf.reshape(self.logits, [-1, self.n_class], name="logits")
|
| 344 |
+
flat_labels = tf.reshape(self.Y, [-1, self.n_class], name="labels")
|
| 345 |
+
with tf.compat.v1.variable_scope("mean_squared"):
|
| 346 |
+
loss = tf.compat.v1.losses.mean_squared_error(labels=flat_labels, predictions=flat_logits)
|
| 347 |
+
else:
|
| 348 |
+
raise ValueError("Unknown loss function: " % self.loss_type)
|
| 349 |
+
|
| 350 |
+
tmp = tf.compat.v1.summary.scalar("train_loss", loss)
|
| 351 |
+
self.summary_train.append(tmp)
|
| 352 |
+
tmp = tf.compat.v1.summary.scalar("valid_loss", loss)
|
| 353 |
+
self.summary_valid.append(tmp)
|
| 354 |
+
|
| 355 |
+
if self.weight_decay > 0:
|
| 356 |
+
with tf.compat.v1.name_scope('weight_loss'):
|
| 357 |
+
tmp = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
|
| 358 |
+
weight_loss = tf.add_n(tmp, name="weight_loss")
|
| 359 |
+
self.loss = loss + weight_loss
|
| 360 |
+
else:
|
| 361 |
+
self.loss = loss
|
| 362 |
+
|
| 363 |
+
def add_training_op(self):
|
| 364 |
+
if self.optimizer == "momentum":
|
| 365 |
+
self.learning_rate_node = tf.compat.v1.train.exponential_decay(learning_rate=self.learning_rate,
|
| 366 |
+
global_step=self.global_step,
|
| 367 |
+
decay_steps=self.decay_step,
|
| 368 |
+
decay_rate=self.decay_rate,
|
| 369 |
+
staircase=True)
|
| 370 |
+
optimizer = tf.compat.v1.train.MomentumOptimizer(learning_rate=self.learning_rate_node,
|
| 371 |
+
momentum=self.momentum)
|
| 372 |
+
elif self.optimizer == "adam":
|
| 373 |
+
self.learning_rate_node = tf.compat.v1.train.exponential_decay(learning_rate=self.learning_rate,
|
| 374 |
+
global_step=self.global_step,
|
| 375 |
+
decay_steps=self.decay_step,
|
| 376 |
+
decay_rate=self.decay_rate,
|
| 377 |
+
staircase=True)
|
| 378 |
+
|
| 379 |
+
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=self.learning_rate_node)
|
| 380 |
+
update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)
|
| 381 |
+
with tf.control_dependencies(update_ops):
|
| 382 |
+
self.train_op = optimizer.minimize(self.loss, global_step=self.global_step)
|
| 383 |
+
tmp = tf.compat.v1.summary.scalar("learning_rate", self.learning_rate_node)
|
| 384 |
+
self.summary_train.append(tmp)
|
| 385 |
+
|
| 386 |
+
def add_metrics_op(self):
|
| 387 |
+
with tf.compat.v1.variable_scope("metrics"):
|
| 388 |
+
|
| 389 |
+
Y= tf.argmax(input=self.Y, axis=-1)
|
| 390 |
+
confusion_matrix = tf.cast(tf.math.confusion_matrix(
|
| 391 |
+
labels=tf.reshape(Y, [-1]),
|
| 392 |
+
predictions=tf.reshape(self.preds, [-1]),
|
| 393 |
+
num_classes=self.n_class, name='confusion_matrix'),
|
| 394 |
+
dtype=tf.float32)
|
| 395 |
+
|
| 396 |
+
# with tf.variable_scope("P"):
|
| 397 |
+
c = tf.constant(1e-7, dtype=tf.float32)
|
| 398 |
+
precision_P = (confusion_matrix[1,1] + c) / (tf.reduce_sum(input_tensor=confusion_matrix[:,1]) + c)
|
| 399 |
+
recall_P = (confusion_matrix[1,1] + c) / (tf.reduce_sum(input_tensor=confusion_matrix[1,:]) + c)
|
| 400 |
+
f1_P = 2 * precision_P * recall_P / (precision_P + recall_P)
|
| 401 |
+
|
| 402 |
+
tmp1 = tf.compat.v1.summary.scalar("train_precision_p", precision_P)
|
| 403 |
+
tmp2 = tf.compat.v1.summary.scalar("train_recall_p", recall_P)
|
| 404 |
+
tmp3 = tf.compat.v1.summary.scalar("train_f1_p", f1_P)
|
| 405 |
+
self.summary_train.extend([tmp1, tmp2, tmp3])
|
| 406 |
+
|
| 407 |
+
tmp1 = tf.compat.v1.summary.scalar("valid_precision_p", precision_P)
|
| 408 |
+
tmp2 = tf.compat.v1.summary.scalar("valid_recall_p", recall_P)
|
| 409 |
+
tmp3 = tf.compat.v1.summary.scalar("valid_f1_p", f1_P)
|
| 410 |
+
self.summary_valid.extend([tmp1, tmp2, tmp3])
|
| 411 |
+
|
| 412 |
+
# with tf.variable_scope("S"):
|
| 413 |
+
precision_S = (confusion_matrix[2,2] + c) / (tf.reduce_sum(input_tensor=confusion_matrix[:,2]) + c)
|
| 414 |
+
recall_S = (confusion_matrix[2,2] + c) / (tf.reduce_sum(input_tensor=confusion_matrix[2,:]) + c)
|
| 415 |
+
f1_S = 2 * precision_S * recall_S / (precision_S + recall_S)
|
| 416 |
+
|
| 417 |
+
tmp1 = tf.compat.v1.summary.scalar("train_precision_s", precision_S)
|
| 418 |
+
tmp2 = tf.compat.v1.summary.scalar("train_recall_s", recall_S)
|
| 419 |
+
tmp3 = tf.compat.v1.summary.scalar("train_f1_s", f1_S)
|
| 420 |
+
self.summary_train.extend([tmp1, tmp2, tmp3])
|
| 421 |
+
|
| 422 |
+
tmp1 = tf.compat.v1.summary.scalar("valid_precision_s", precision_S)
|
| 423 |
+
tmp2 = tf.compat.v1.summary.scalar("valid_recall_s", recall_S)
|
| 424 |
+
tmp3 = tf.compat.v1.summary.scalar("valid_f1_s", f1_S)
|
| 425 |
+
self.summary_valid.extend([tmp1, tmp2, tmp3])
|
| 426 |
+
|
| 427 |
+
self.precision = [precision_P, precision_S]
|
| 428 |
+
self.recall = [recall_P, recall_S]
|
| 429 |
+
self.f1 = [f1_P, f1_S]
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
def train_on_batch(self, sess, inputs_batch, labels_batch, summary_writer, drop_rate=0.0):
|
| 434 |
+
feed = {self.X: inputs_batch,
|
| 435 |
+
self.Y: labels_batch,
|
| 436 |
+
self.drop_rate: drop_rate,
|
| 437 |
+
self.is_training: True}
|
| 438 |
+
|
| 439 |
+
_, step_summary, step, loss = sess.run([self.train_op,
|
| 440 |
+
self.summary_train,
|
| 441 |
+
self.global_step,
|
| 442 |
+
self.loss],
|
| 443 |
+
feed_dict=feed)
|
| 444 |
+
summary_writer.add_summary(step_summary, step)
|
| 445 |
+
return loss
|
| 446 |
+
|
| 447 |
+
def valid_on_batch(self, sess, inputs_batch, labels_batch, summary_writer):
|
| 448 |
+
feed = {self.X: inputs_batch,
|
| 449 |
+
self.Y: labels_batch,
|
| 450 |
+
self.drop_rate: 0,
|
| 451 |
+
self.is_training: False}
|
| 452 |
+
|
| 453 |
+
step_summary, step, loss, preds = sess.run([self.summary_valid,
|
| 454 |
+
self.global_step,
|
| 455 |
+
self.loss,
|
| 456 |
+
self.preds],
|
| 457 |
+
feed_dict=feed)
|
| 458 |
+
summary_writer.add_summary(step_summary, step)
|
| 459 |
+
return loss, preds
|
| 460 |
+
|
| 461 |
+
def test_on_batch(self, sess, summary_writer):
|
| 462 |
+
feed = {self.drop_rate: 0,
|
| 463 |
+
self.is_training: False}
|
| 464 |
+
step_summary, step, loss, preds, \
|
| 465 |
+
X_batch, Y_batch, fname_batch, \
|
| 466 |
+
itp_batch, its_batch = sess.run([self.summary_valid,
|
| 467 |
+
self.global_step,
|
| 468 |
+
self.loss,
|
| 469 |
+
self.preds,
|
| 470 |
+
self.X,
|
| 471 |
+
self.Y,
|
| 472 |
+
self.input_batch[2],
|
| 473 |
+
self.input_batch[3],
|
| 474 |
+
self.input_batch[4]],
|
| 475 |
+
feed_dict=feed)
|
| 476 |
+
summary_writer.add_summary(step_summary, step)
|
| 477 |
+
return loss, preds, X_batch, Y_batch, fname_batch, itp_batch, its_batch
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def build(self, input_batch=None, mode='train'):
|
| 481 |
+
self.add_placeholders(input_batch, mode)
|
| 482 |
+
self.add_prediction_op()
|
| 483 |
+
if mode in ["train", "valid", "test"]:
|
| 484 |
+
self.add_loss_op()
|
| 485 |
+
self.add_training_op()
|
| 486 |
+
# self.add_metrics_op()
|
| 487 |
+
self.summary_train = tf.compat.v1.summary.merge(self.summary_train)
|
| 488 |
+
self.summary_valid = tf.compat.v1.summary.merge(self.summary_valid)
|
| 489 |
+
return 0
|
phasenet/postprocess.py
ADDED
|
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
from collections import namedtuple
|
| 5 |
+
from datetime import datetime, timedelta
|
| 6 |
+
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
import numpy as np
|
| 9 |
+
from .detect_peaks import detect_peaks
|
| 10 |
+
|
| 11 |
+
# def extract_picks(preds, fnames=None, station_ids=None, t0=None, config=None):
|
| 12 |
+
|
| 13 |
+
# if preds.shape[-1] == 4:
|
| 14 |
+
# record = namedtuple("phase", ["fname", "station_id", "t0", "p_idx", "p_prob", "s_idx", "s_prob", "ps_idx", "ps_prob"])
|
| 15 |
+
# else:
|
| 16 |
+
# record = namedtuple("phase", ["fname", "station_id", "t0", "p_idx", "p_prob", "s_idx", "s_prob"])
|
| 17 |
+
|
| 18 |
+
# picks = []
|
| 19 |
+
# for i, pred in enumerate(preds):
|
| 20 |
+
|
| 21 |
+
# if config is None:
|
| 22 |
+
# mph_p, mph_s, mpd = 0.3, 0.3, 50
|
| 23 |
+
# else:
|
| 24 |
+
# mph_p, mph_s, mpd = config.min_p_prob, config.min_s_prob, config.mpd
|
| 25 |
+
|
| 26 |
+
# if (fnames is None):
|
| 27 |
+
# fname = f"{i:04d}"
|
| 28 |
+
# else:
|
| 29 |
+
# if isinstance(fnames[i], str):
|
| 30 |
+
# fname = fnames[i]
|
| 31 |
+
# else:
|
| 32 |
+
# fname = fnames[i].decode()
|
| 33 |
+
|
| 34 |
+
# if (station_ids is None):
|
| 35 |
+
# station_id = f"{i:04d}"
|
| 36 |
+
# else:
|
| 37 |
+
# if isinstance(station_ids[i], str):
|
| 38 |
+
# station_id = station_ids[i]
|
| 39 |
+
# else:
|
| 40 |
+
# station_id = station_ids[i].decode()
|
| 41 |
+
|
| 42 |
+
# if (t0 is None):
|
| 43 |
+
# start_time = "1970-01-01T00:00:00.000"
|
| 44 |
+
# else:
|
| 45 |
+
# if isinstance(t0[i], str):
|
| 46 |
+
# start_time = t0[i]
|
| 47 |
+
# else:
|
| 48 |
+
# start_time = t0[i].decode()
|
| 49 |
+
|
| 50 |
+
# p_idx, p_prob, s_idx, s_prob = [], [], [], []
|
| 51 |
+
# for j in range(pred.shape[1]):
|
| 52 |
+
# p_idx_, p_prob_ = detect_peaks(pred[:,j,1], mph=mph_p, mpd=mpd, show=False)
|
| 53 |
+
# s_idx_, s_prob_ = detect_peaks(pred[:,j,2], mph=mph_s, mpd=mpd, show=False)
|
| 54 |
+
# p_idx.append(list(p_idx_))
|
| 55 |
+
# p_prob.append(list(p_prob_))
|
| 56 |
+
# s_idx.append(list(s_idx_))
|
| 57 |
+
# s_prob.append(list(s_prob_))
|
| 58 |
+
|
| 59 |
+
# if pred.shape[-1] == 4:
|
| 60 |
+
# ps_idx, ps_prob = detect_peaks(pred[:,0,3], mph=0.3, mpd=mpd, show=False)
|
| 61 |
+
# picks.append(record(fname, station_id, start_time, list(p_idx), list(p_prob), list(s_idx), list(s_prob), list(ps_idx), list(ps_prob)))
|
| 62 |
+
# else:
|
| 63 |
+
# picks.append(record(fname, station_id, start_time, list(p_idx), list(p_prob), list(s_idx), list(s_prob)))
|
| 64 |
+
|
| 65 |
+
# return picks
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def extract_picks(
|
| 69 |
+
preds,
|
| 70 |
+
file_names=None,
|
| 71 |
+
begin_times=None,
|
| 72 |
+
station_ids=None,
|
| 73 |
+
dt=0.01,
|
| 74 |
+
phases=["P", "S"],
|
| 75 |
+
config=None,
|
| 76 |
+
waveforms=None,
|
| 77 |
+
use_amplitude=False,
|
| 78 |
+
upload_waveform=False,
|
| 79 |
+
):
|
| 80 |
+
"""Extract picks from prediction results.
|
| 81 |
+
Args:
|
| 82 |
+
preds ([type]): [Nb, Nt, Ns, Nc] "batch, time, station, channel"
|
| 83 |
+
file_names ([type], optional): [Nb]. Defaults to None.
|
| 84 |
+
station_ids ([type], optional): [Ns]. Defaults to None.
|
| 85 |
+
t0 ([type], optional): [Nb]. Defaults to None.
|
| 86 |
+
config ([type], optional): [description]. Defaults to None.
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
picks [type]: {file_name, station_id, pick_time, pick_prob, pick_type}
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
mph = {}
|
| 93 |
+
if config is None:
|
| 94 |
+
for x in phases:
|
| 95 |
+
mph[x] = 0.3
|
| 96 |
+
mpd = 50
|
| 97 |
+
## upload waveform
|
| 98 |
+
pre_idx = int(1 / dt)
|
| 99 |
+
post_idx = int(4 / dt)
|
| 100 |
+
else:
|
| 101 |
+
mph["P"] = config.min_p_prob
|
| 102 |
+
mph["S"] = config.min_s_prob
|
| 103 |
+
mph["PS"] = 0.3
|
| 104 |
+
mpd = config.mpd
|
| 105 |
+
pre_idx = int(config.pre_sec / dt)
|
| 106 |
+
post_idx = int(config.post_sec / dt)
|
| 107 |
+
|
| 108 |
+
Nb, Nt, Ns, Nc = preds.shape
|
| 109 |
+
|
| 110 |
+
if file_names is None:
|
| 111 |
+
file_names = [f"{i:04d}" for i in range(Nb)]
|
| 112 |
+
elif not (isinstance(file_names, np.ndarray) or isinstance(file_names, list)):
|
| 113 |
+
if isinstance(file_names, bytes):
|
| 114 |
+
file_names = file_names.decode()
|
| 115 |
+
file_names = [file_names] * Nb
|
| 116 |
+
else:
|
| 117 |
+
file_names = [x.decode() if isinstance(x, bytes) else x for x in file_names]
|
| 118 |
+
|
| 119 |
+
if begin_times is None:
|
| 120 |
+
begin_times = ["1970-01-01T00:00:00.000+00:00"] * Nb
|
| 121 |
+
else:
|
| 122 |
+
begin_times = [x.decode() if isinstance(x, bytes) else x for x in begin_times]
|
| 123 |
+
|
| 124 |
+
picks = []
|
| 125 |
+
for i in range(Nb):
|
| 126 |
+
|
| 127 |
+
file_name = file_names[i]
|
| 128 |
+
begin_time = datetime.fromisoformat(begin_times[i])
|
| 129 |
+
|
| 130 |
+
for j in range(Ns):
|
| 131 |
+
if (station_ids is None) or (len(station_ids[i]) == 0):
|
| 132 |
+
station_id = f"{j:04d}"
|
| 133 |
+
else:
|
| 134 |
+
station_id = station_ids[i].decode() if isinstance(station_ids[i], bytes) else station_ids[i]
|
| 135 |
+
|
| 136 |
+
if (waveforms is not None) and use_amplitude:
|
| 137 |
+
amp = np.max(np.abs(waveforms[i, :, j, :]), axis=-1) ## amplitude over three channelspy
|
| 138 |
+
for k in range(Nc - 1): # 0-th channel noise
|
| 139 |
+
idxs, probs = detect_peaks(preds[i, :, j, k + 1], mph=mph[phases[k]], mpd=mpd, show=False)
|
| 140 |
+
for l, (phase_index, phase_prob) in enumerate(zip(idxs, probs)):
|
| 141 |
+
pick_time = begin_time + timedelta(seconds=phase_index * dt)
|
| 142 |
+
pick = {
|
| 143 |
+
"file_name": file_name,
|
| 144 |
+
"station_id": station_id,
|
| 145 |
+
"begin_time": begin_time.isoformat(timespec="milliseconds"),
|
| 146 |
+
"phase_index": int(phase_index),
|
| 147 |
+
"phase_time": pick_time.isoformat(timespec="milliseconds"),
|
| 148 |
+
"phase_score": round(phase_prob, 3),
|
| 149 |
+
"phase_type": phases[k],
|
| 150 |
+
"dt": dt,
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
## process waveform
|
| 154 |
+
if waveforms is not None:
|
| 155 |
+
tmp = np.zeros((pre_idx + post_idx, 3))
|
| 156 |
+
lo = phase_index - pre_idx
|
| 157 |
+
hi = phase_index + post_idx
|
| 158 |
+
insert_idx = 0
|
| 159 |
+
if lo < 0:
|
| 160 |
+
lo = 0
|
| 161 |
+
insert_idx = -lo
|
| 162 |
+
if hi > Nt:
|
| 163 |
+
hi = Nt
|
| 164 |
+
tmp[insert_idx : insert_idx + hi - lo, :] = waveforms[i, lo:hi, j, :]
|
| 165 |
+
if upload_waveform:
|
| 166 |
+
pick["waveform"] = tmp.tolist()
|
| 167 |
+
pick["_id"] = f"{pick['station_id']}_{pick['timestamp']}_{pick['type']}"
|
| 168 |
+
if use_amplitude:
|
| 169 |
+
next_pick = idxs[l + 1] if l < len(idxs) - 1 else (phase_index + post_idx * 3)
|
| 170 |
+
pick["phase_amp"] = np.max(
|
| 171 |
+
amp[phase_index : min(phase_index + post_idx * 3, next_pick)]
|
| 172 |
+
).item() ## peak amplitude
|
| 173 |
+
|
| 174 |
+
picks.append(pick)
|
| 175 |
+
|
| 176 |
+
return picks
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def extract_amplitude(data, picks, window_p=10, window_s=5, config=None):
|
| 180 |
+
record = namedtuple("amplitude", ["p_amp", "s_amp"])
|
| 181 |
+
dt = 0.01 if config is None else config.dt
|
| 182 |
+
window_p = int(window_p / dt)
|
| 183 |
+
window_s = int(window_s / dt)
|
| 184 |
+
amps = []
|
| 185 |
+
for i, (da, pi) in enumerate(zip(data, picks)):
|
| 186 |
+
p_amp, s_amp = [], []
|
| 187 |
+
for j in range(da.shape[1]):
|
| 188 |
+
amp = np.max(np.abs(da[:, j, :]), axis=-1)
|
| 189 |
+
# amp = np.median(np.abs(da[:,j,:]), axis=-1)
|
| 190 |
+
# amp = np.linalg.norm(da[:,j,:], axis=-1)
|
| 191 |
+
tmp = []
|
| 192 |
+
for k in range(len(pi.p_idx[j]) - 1):
|
| 193 |
+
tmp.append(np.max(amp[pi.p_idx[j][k] : min(pi.p_idx[j][k] + window_p, pi.p_idx[j][k + 1])]))
|
| 194 |
+
if len(pi.p_idx[j]) >= 1:
|
| 195 |
+
tmp.append(np.max(amp[pi.p_idx[j][-1] : pi.p_idx[j][-1] + window_p]))
|
| 196 |
+
p_amp.append(tmp)
|
| 197 |
+
tmp = []
|
| 198 |
+
for k in range(len(pi.s_idx[j]) - 1):
|
| 199 |
+
tmp.append(np.max(amp[pi.s_idx[j][k] : min(pi.s_idx[j][k] + window_s, pi.s_idx[j][k + 1])]))
|
| 200 |
+
if len(pi.s_idx[j]) >= 1:
|
| 201 |
+
tmp.append(np.max(amp[pi.s_idx[j][-1] : pi.s_idx[j][-1] + window_s]))
|
| 202 |
+
s_amp.append(tmp)
|
| 203 |
+
amps.append(record(p_amp, s_amp))
|
| 204 |
+
return amps
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def save_picks(picks, output_dir, amps=None, fname=None):
|
| 208 |
+
if fname is None:
|
| 209 |
+
fname = "picks.csv"
|
| 210 |
+
|
| 211 |
+
int2s = lambda x: ",".join(["[" + ",".join(map(str, i)) + "]" for i in x])
|
| 212 |
+
flt2s = lambda x: ",".join(["[" + ",".join(map("{:0.3f}".format, i)) + "]" for i in x])
|
| 213 |
+
sci2s = lambda x: ",".join(["[" + ",".join(map("{:0.3e}".format, i)) + "]" for i in x])
|
| 214 |
+
if amps is None:
|
| 215 |
+
if hasattr(picks[0], "ps_idx"):
|
| 216 |
+
with open(os.path.join(output_dir, fname), "w") as fp:
|
| 217 |
+
fp.write("fname\tt0\tp_idx\tp_prob\ts_idx\ts_prob\tps_idx\tps_prob\n")
|
| 218 |
+
for pick in picks:
|
| 219 |
+
fp.write(
|
| 220 |
+
f"{pick.fname}\t{pick.t0}\t{int2s(pick.p_idx)}\t{flt2s(pick.p_prob)}\t{int2s(pick.s_idx)}\t{flt2s(pick.s_prob)}\t{int2s(pick.ps_idx)}\t{flt2s(pick.ps_prob)}\n"
|
| 221 |
+
)
|
| 222 |
+
fp.close()
|
| 223 |
+
else:
|
| 224 |
+
with open(os.path.join(output_dir, fname), "w") as fp:
|
| 225 |
+
fp.write("fname\tt0\tp_idx\tp_prob\ts_idx\ts_prob\n")
|
| 226 |
+
for pick in picks:
|
| 227 |
+
fp.write(
|
| 228 |
+
f"{pick.fname}\t{pick.t0}\t{int2s(pick.p_idx)}\t{flt2s(pick.p_prob)}\t{int2s(pick.s_idx)}\t{flt2s(pick.s_prob)}\n"
|
| 229 |
+
)
|
| 230 |
+
fp.close()
|
| 231 |
+
else:
|
| 232 |
+
with open(os.path.join(output_dir, fname), "w") as fp:
|
| 233 |
+
fp.write("fname\tt0\tp_idx\tp_prob\ts_idx\ts_prob\tp_amp\ts_amp\n")
|
| 234 |
+
for pick, amp in zip(picks, amps):
|
| 235 |
+
fp.write(
|
| 236 |
+
f"{pick.fname}\t{pick.t0}\t{int2s(pick.p_idx)}\t{flt2s(pick.p_prob)}\t{int2s(pick.s_idx)}\t{flt2s(pick.s_prob)}\t{sci2s(amp.p_amp)}\t{sci2s(amp.s_amp)}\n"
|
| 237 |
+
)
|
| 238 |
+
fp.close()
|
| 239 |
+
|
| 240 |
+
return 0
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def calc_timestamp(timestamp, sec):
|
| 244 |
+
timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f") + timedelta(seconds=sec)
|
| 245 |
+
return timestamp.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def save_picks_json(picks, output_dir, dt=0.01, amps=None, fname=None):
|
| 249 |
+
if fname is None:
|
| 250 |
+
fname = "picks.json"
|
| 251 |
+
|
| 252 |
+
picks_ = []
|
| 253 |
+
if amps is None:
|
| 254 |
+
for pick in picks:
|
| 255 |
+
for idxs, probs in zip(pick.p_idx, pick.p_prob):
|
| 256 |
+
for idx, prob in zip(idxs, probs):
|
| 257 |
+
picks_.append(
|
| 258 |
+
{
|
| 259 |
+
"id": pick.station_id,
|
| 260 |
+
"timestamp": calc_timestamp(pick.t0, float(idx) * dt),
|
| 261 |
+
"prob": prob.astype(float),
|
| 262 |
+
"type": "p",
|
| 263 |
+
}
|
| 264 |
+
)
|
| 265 |
+
for idxs, probs in zip(pick.s_idx, pick.s_prob):
|
| 266 |
+
for idx, prob in zip(idxs, probs):
|
| 267 |
+
picks_.append(
|
| 268 |
+
{
|
| 269 |
+
"id": pick.station_id,
|
| 270 |
+
"timestamp": calc_timestamp(pick.t0, float(idx) * dt),
|
| 271 |
+
"prob": prob.astype(float),
|
| 272 |
+
"type": "s",
|
| 273 |
+
}
|
| 274 |
+
)
|
| 275 |
+
else:
|
| 276 |
+
for pick, amplitude in zip(picks, amps):
|
| 277 |
+
for idxs, probs, amps in zip(pick.p_idx, pick.p_prob, amplitude.p_amp):
|
| 278 |
+
for idx, prob, amp in zip(idxs, probs, amps):
|
| 279 |
+
picks_.append(
|
| 280 |
+
{
|
| 281 |
+
"id": pick.station_id,
|
| 282 |
+
"timestamp": calc_timestamp(pick.t0, float(idx) * dt),
|
| 283 |
+
"prob": prob.astype(float),
|
| 284 |
+
"amp": amp.astype(float),
|
| 285 |
+
"type": "p",
|
| 286 |
+
}
|
| 287 |
+
)
|
| 288 |
+
for idxs, probs, amps in zip(pick.s_idx, pick.s_prob, amplitude.s_amp):
|
| 289 |
+
for idx, prob, amp in zip(idxs, probs, amps):
|
| 290 |
+
picks_.append(
|
| 291 |
+
{
|
| 292 |
+
"id": pick.station_id,
|
| 293 |
+
"timestamp": calc_timestamp(pick.t0, float(idx) * dt),
|
| 294 |
+
"prob": prob.astype(float),
|
| 295 |
+
"amp": amp.astype(float),
|
| 296 |
+
"type": "s",
|
| 297 |
+
}
|
| 298 |
+
)
|
| 299 |
+
with open(os.path.join(output_dir, fname), "w") as fp:
|
| 300 |
+
json.dump(picks_, fp)
|
| 301 |
+
|
| 302 |
+
return 0
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def convert_true_picks(fname, itp, its, itps=None):
|
| 306 |
+
true_picks = []
|
| 307 |
+
if itps is None:
|
| 308 |
+
record = namedtuple("phase", ["fname", "p_idx", "s_idx"])
|
| 309 |
+
for i in range(len(fname)):
|
| 310 |
+
true_picks.append(record(fname[i].decode(), itp[i], its[i]))
|
| 311 |
+
else:
|
| 312 |
+
record = namedtuple("phase", ["fname", "p_idx", "s_idx", "ps_idx"])
|
| 313 |
+
for i in range(len(fname)):
|
| 314 |
+
true_picks.append(record(fname[i].decode(), itp[i], its[i], itps[i]))
|
| 315 |
+
|
| 316 |
+
return true_picks
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def calc_metrics(nTP, nP, nT):
|
| 320 |
+
"""
|
| 321 |
+
nTP: true positive
|
| 322 |
+
nP: number of positive picks
|
| 323 |
+
nT: number of true picks
|
| 324 |
+
"""
|
| 325 |
+
precision = nTP / nP
|
| 326 |
+
recall = nTP / nT
|
| 327 |
+
f1 = 2 * precision * recall / (precision + recall)
|
| 328 |
+
return [precision, recall, f1]
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def calc_performance(picks, true_picks, tol=3.0, dt=1.0):
|
| 332 |
+
assert len(picks) == len(true_picks)
|
| 333 |
+
logging.info("Total records: {}".format(len(picks)))
|
| 334 |
+
|
| 335 |
+
count = lambda picks: sum([len(x) for x in picks])
|
| 336 |
+
metrics = {}
|
| 337 |
+
for phase in true_picks[0]._fields:
|
| 338 |
+
if phase == "fname":
|
| 339 |
+
continue
|
| 340 |
+
true_positive, positive, true = 0, 0, 0
|
| 341 |
+
residual = []
|
| 342 |
+
for i in range(len(true_picks)):
|
| 343 |
+
true += count(getattr(true_picks[i], phase))
|
| 344 |
+
positive += count(getattr(picks[i], phase))
|
| 345 |
+
# print(i, phase, getattr(picks[i], phase), getattr(true_picks[i], phase))
|
| 346 |
+
diff = dt * (
|
| 347 |
+
np.array(getattr(picks[i], phase))[:, np.newaxis, :]
|
| 348 |
+
- np.array(getattr(true_picks[i], phase))[:, :, np.newaxis]
|
| 349 |
+
)
|
| 350 |
+
residual.extend(list(diff[np.abs(diff) <= tol]))
|
| 351 |
+
true_positive += np.sum(np.abs(diff) <= tol)
|
| 352 |
+
metrics[phase] = calc_metrics(true_positive, positive, true)
|
| 353 |
+
|
| 354 |
+
logging.info(f"{phase}-phase:")
|
| 355 |
+
logging.info(f"True={true}, Positive={positive}, True Positive={true_positive}")
|
| 356 |
+
logging.info(f"Precision={metrics[phase][0]:.3f}, Recall={metrics[phase][1]:.3f}, F1={metrics[phase][2]:.3f}")
|
| 357 |
+
logging.info(f"Residual mean={np.mean(residual):.4f}, std={np.std(residual):.4f}")
|
| 358 |
+
|
| 359 |
+
return metrics
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def save_prob_h5(probs, fnames, output_h5):
|
| 363 |
+
if fnames is None:
|
| 364 |
+
fnames = [f"{i:04d}" for i in range(len(probs))]
|
| 365 |
+
elif type(fnames[0]) is bytes:
|
| 366 |
+
fnames = [f.decode().rstrip(".npz") for f in fnames]
|
| 367 |
+
else:
|
| 368 |
+
fnames = [f.rstrip(".npz") for f in fnames]
|
| 369 |
+
for prob, fname in zip(probs, fnames):
|
| 370 |
+
output_h5.create_dataset(fname, data=prob, dtype="float32")
|
| 371 |
+
return 0
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def save_prob(probs, fnames, prob_dir):
|
| 375 |
+
if fnames is None:
|
| 376 |
+
fnames = [f"{i:04d}" for i in range(len(probs))]
|
| 377 |
+
elif type(fnames[0]) is bytes:
|
| 378 |
+
fnames = [f.decode().rstrip(".npz") for f in fnames]
|
| 379 |
+
else:
|
| 380 |
+
fnames = [f.rstrip(".npz") for f in fnames]
|
| 381 |
+
for prob, fname in zip(probs, fnames):
|
| 382 |
+
np.savez(os.path.join(prob_dir, fname + ".npz"), prob=prob)
|
| 383 |
+
return 0
|
phasenet/predict.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import logging
|
| 3 |
+
import multiprocessing
|
| 4 |
+
import os
|
| 5 |
+
import pickle
|
| 6 |
+
import time
|
| 7 |
+
from functools import partial
|
| 8 |
+
|
| 9 |
+
import h5py
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import tensorflow as tf
|
| 13 |
+
from data_reader import DataReader_mseed_array, DataReader_pred
|
| 14 |
+
from model import ModelConfig, UNet
|
| 15 |
+
from postprocess import (
|
| 16 |
+
extract_amplitude,
|
| 17 |
+
extract_picks,
|
| 18 |
+
save_picks,
|
| 19 |
+
save_picks_json,
|
| 20 |
+
save_prob_h5,
|
| 21 |
+
)
|
| 22 |
+
from pymongo import MongoClient
|
| 23 |
+
from tqdm import tqdm
|
| 24 |
+
from visulization import plot_waveform
|
| 25 |
+
|
| 26 |
+
tf.compat.v1.disable_eager_execution()
|
| 27 |
+
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
| 28 |
+
|
| 29 |
+
username = "root"
|
| 30 |
+
password = "quakeflow123"
|
| 31 |
+
# client = MongoClient(f"mongodb://{username}:{password}@127.0.0.1:27017")
|
| 32 |
+
client = MongoClient(f"mongodb://{username}:{password}@quakeflow-mongodb-headless.default.svc.cluster.local:27017")
|
| 33 |
+
|
| 34 |
+
# db = client["quakeflow"]
|
| 35 |
+
# collection = db["waveform"]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def upload_mongodb(picks):
|
| 39 |
+
db = client["quakeflow"]
|
| 40 |
+
collection = db["waveform"]
|
| 41 |
+
try:
|
| 42 |
+
collection.insert_many(picks)
|
| 43 |
+
except Exception as e:
|
| 44 |
+
print("Warning:", e)
|
| 45 |
+
collection.delete_many({"_id": {"$in": [p["_id"] for p in picks]}})
|
| 46 |
+
collection.insert_many(picks)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def read_args():
|
| 50 |
+
|
| 51 |
+
parser = argparse.ArgumentParser()
|
| 52 |
+
parser.add_argument("--batch_size", default=20, type=int, help="batch size")
|
| 53 |
+
parser.add_argument("--model_dir", help="Checkpoint directory (default: None)")
|
| 54 |
+
parser.add_argument("--data_dir", default="", help="Input file directory")
|
| 55 |
+
parser.add_argument("--data_list", default="", help="Input csv file")
|
| 56 |
+
parser.add_argument("--hdf5_file", default="", help="Input hdf5 file")
|
| 57 |
+
parser.add_argument("--hdf5_group", default="data", help="data group name in hdf5 file")
|
| 58 |
+
parser.add_argument("--result_dir", default="results", help="Output directory")
|
| 59 |
+
parser.add_argument("--result_fname", default="picks", help="Output file")
|
| 60 |
+
parser.add_argument("--highpass_filter", default=0.0, type=float, help="Highpass filter")
|
| 61 |
+
parser.add_argument("--min_p_prob", default=0.3, type=float, help="Probability threshold for P pick")
|
| 62 |
+
parser.add_argument("--min_s_prob", default=0.3, type=float, help="Probability threshold for S pick")
|
| 63 |
+
parser.add_argument("--mpd", default=50, type=float, help="Minimum peak distance")
|
| 64 |
+
parser.add_argument("--amplitude", action="store_true", help="if return amplitude value")
|
| 65 |
+
parser.add_argument("--format", default="numpy", help="input format")
|
| 66 |
+
parser.add_argument("--s3_url", default="localhost:9000", help="s3 url")
|
| 67 |
+
parser.add_argument("--stations", default="", help="seismic station info")
|
| 68 |
+
parser.add_argument("--plot_figure", action="store_true", help="If plot figure for test")
|
| 69 |
+
parser.add_argument("--save_prob", action="store_true", help="If save result for test")
|
| 70 |
+
parser.add_argument("--upload_waveform", action="store_true", help="If upload waveform to mongodb")
|
| 71 |
+
parser.add_argument("--pre_sec", default=1, type=float, help="Window length before pick")
|
| 72 |
+
parser.add_argument("--post_sec", default=4, type=float, help="Window length after pick")
|
| 73 |
+
args = parser.parse_args()
|
| 74 |
+
|
| 75 |
+
return args
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def pred_fn(args, data_reader, figure_dir=None, prob_dir=None, log_dir=None):
|
| 79 |
+
current_time = time.strftime("%y%m%d-%H%M%S")
|
| 80 |
+
if log_dir is None:
|
| 81 |
+
log_dir = os.path.join(args.log_dir, "pred", current_time)
|
| 82 |
+
if not os.path.exists(log_dir):
|
| 83 |
+
os.makedirs(log_dir)
|
| 84 |
+
if (args.plot_figure == True) and (figure_dir is None):
|
| 85 |
+
figure_dir = os.path.join(log_dir, "figures")
|
| 86 |
+
if not os.path.exists(figure_dir):
|
| 87 |
+
os.makedirs(figure_dir)
|
| 88 |
+
if (args.save_prob == True) and (prob_dir is None):
|
| 89 |
+
prob_dir = os.path.join(log_dir, "probs")
|
| 90 |
+
if not os.path.exists(prob_dir):
|
| 91 |
+
os.makedirs(prob_dir)
|
| 92 |
+
if args.save_prob:
|
| 93 |
+
h5 = h5py.File(os.path.join(args.result_dir, "result.h5"), "w", libver="latest")
|
| 94 |
+
prob_h5 = h5.create_group("/prob")
|
| 95 |
+
logging.info("Pred log: %s" % log_dir)
|
| 96 |
+
logging.info("Dataset size: {}".format(data_reader.num_data))
|
| 97 |
+
|
| 98 |
+
with tf.compat.v1.name_scope("Input_Batch"):
|
| 99 |
+
if args.format == "mseed_array":
|
| 100 |
+
batch_size = 1
|
| 101 |
+
else:
|
| 102 |
+
batch_size = args.batch_size
|
| 103 |
+
dataset = data_reader.dataset(batch_size)
|
| 104 |
+
batch = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
|
| 105 |
+
|
| 106 |
+
config = ModelConfig(X_shape=data_reader.X_shape)
|
| 107 |
+
with open(os.path.join(log_dir, "config.log"), "w") as fp:
|
| 108 |
+
fp.write("\n".join("%s: %s" % item for item in vars(config).items()))
|
| 109 |
+
|
| 110 |
+
model = UNet(config=config, input_batch=batch, mode="pred")
|
| 111 |
+
# model = UNet(config=config, mode="pred")
|
| 112 |
+
sess_config = tf.compat.v1.ConfigProto()
|
| 113 |
+
sess_config.gpu_options.allow_growth = True
|
| 114 |
+
# sess_config.log_device_placement = False
|
| 115 |
+
|
| 116 |
+
with tf.compat.v1.Session(config=sess_config) as sess:
|
| 117 |
+
|
| 118 |
+
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables(), max_to_keep=5)
|
| 119 |
+
init = tf.compat.v1.global_variables_initializer()
|
| 120 |
+
sess.run(init)
|
| 121 |
+
|
| 122 |
+
latest_check_point = tf.train.latest_checkpoint(args.model_dir)
|
| 123 |
+
logging.info(f"restoring model {latest_check_point}")
|
| 124 |
+
saver.restore(sess, latest_check_point)
|
| 125 |
+
|
| 126 |
+
picks = []
|
| 127 |
+
amps = [] if args.amplitude else None
|
| 128 |
+
if args.plot_figure:
|
| 129 |
+
multiprocessing.set_start_method("spawn")
|
| 130 |
+
pool = multiprocessing.Pool(multiprocessing.cpu_count())
|
| 131 |
+
|
| 132 |
+
for _ in tqdm(range(0, data_reader.num_data, batch_size), desc="Pred"):
|
| 133 |
+
if args.amplitude:
|
| 134 |
+
pred_batch, X_batch, amp_batch, fname_batch, t0_batch, station_batch = sess.run(
|
| 135 |
+
[model.preds, batch[0], batch[1], batch[2], batch[3], batch[4]],
|
| 136 |
+
feed_dict={model.drop_rate: 0, model.is_training: False},
|
| 137 |
+
)
|
| 138 |
+
# X_batch, amp_batch, fname_batch, t0_batch = sess.run([batch[0], batch[1], batch[2], batch[3]])
|
| 139 |
+
else:
|
| 140 |
+
pred_batch, X_batch, fname_batch, t0_batch, station_batch = sess.run(
|
| 141 |
+
[model.preds, batch[0], batch[1], batch[2], batch[3]],
|
| 142 |
+
feed_dict={model.drop_rate: 0, model.is_training: False},
|
| 143 |
+
)
|
| 144 |
+
# X_batch, fname_batch, t0_batch = sess.run([model.preds, batch[0], batch[1], batch[2]])
|
| 145 |
+
# pred_batch = []
|
| 146 |
+
# for i in range(0, len(X_batch), 1):
|
| 147 |
+
# pred_batch.append(sess.run(model.preds, feed_dict={model.X: X_batch[i:i+1], model.drop_rate: 0, model.is_training: False}))
|
| 148 |
+
# pred_batch = np.vstack(pred_batch)
|
| 149 |
+
|
| 150 |
+
waveforms = None
|
| 151 |
+
if args.upload_waveform:
|
| 152 |
+
waveforms = X_batch
|
| 153 |
+
if args.amplitude:
|
| 154 |
+
waveforms = amp_batch
|
| 155 |
+
|
| 156 |
+
picks_ = extract_picks(
|
| 157 |
+
preds=pred_batch,
|
| 158 |
+
file_names=fname_batch,
|
| 159 |
+
station_ids=station_batch,
|
| 160 |
+
begin_times=t0_batch,
|
| 161 |
+
config=args,
|
| 162 |
+
waveforms=waveforms,
|
| 163 |
+
use_amplitude=args.amplitude,
|
| 164 |
+
upload_waveform=args.upload_waveform,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
if args.upload_waveform:
|
| 168 |
+
upload_mongodb(picks_)
|
| 169 |
+
picks.extend(picks_)
|
| 170 |
+
|
| 171 |
+
if args.plot_figure:
|
| 172 |
+
if not (isinstance(fname_batch, np.ndarray) or isinstance(fname_batch, list)):
|
| 173 |
+
fname_batch = [fname_batch.decode().rstrip(".mseed") + "_" + x.decode() for x in station_batch]
|
| 174 |
+
else:
|
| 175 |
+
fname_batch = [x.decode() for x in fname_batch]
|
| 176 |
+
pool.starmap(
|
| 177 |
+
partial(
|
| 178 |
+
plot_waveform,
|
| 179 |
+
figure_dir=figure_dir,
|
| 180 |
+
),
|
| 181 |
+
# zip(X_batch, pred_batch, [x.decode() for x in fname_batch]),
|
| 182 |
+
zip(X_batch, pred_batch, fname_batch),
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
if args.save_prob:
|
| 186 |
+
# save_prob(pred_batch, fname_batch, prob_dir=prob_dir)
|
| 187 |
+
if not (isinstance(fname_batch, np.ndarray) or isinstance(fname_batch, list)):
|
| 188 |
+
fname_batch = [fname_batch.decode().rstrip(".mseed") + "_" + x.decode() for x in station_batch]
|
| 189 |
+
else:
|
| 190 |
+
fname_batch = [x.decode() for x in fname_batch]
|
| 191 |
+
save_prob_h5(pred_batch, fname_batch, prob_h5)
|
| 192 |
+
|
| 193 |
+
if len(picks) > 0:
|
| 194 |
+
# save_picks(picks, args.result_dir, amps=amps, fname=args.result_fname+".csv")
|
| 195 |
+
# save_picks_json(picks, args.result_dir, dt=data_reader.dt, amps=amps, fname=args.result_fname+".json")
|
| 196 |
+
df = pd.DataFrame(picks)
|
| 197 |
+
# df["fname"] = df["file_name"]
|
| 198 |
+
# df["id"] = df["station_id"]
|
| 199 |
+
# df["timestamp"] = df["phase_time"]
|
| 200 |
+
# df["prob"] = df["phase_prob"]
|
| 201 |
+
# df["type"] = df["phase_type"]
|
| 202 |
+
if args.amplitude:
|
| 203 |
+
# df["amp"] = df["phase_amp"]
|
| 204 |
+
df = df[
|
| 205 |
+
[
|
| 206 |
+
"file_name",
|
| 207 |
+
"begin_time",
|
| 208 |
+
"station_id",
|
| 209 |
+
"phase_index",
|
| 210 |
+
"phase_time",
|
| 211 |
+
"phase_score",
|
| 212 |
+
"phase_amp",
|
| 213 |
+
"phase_type",
|
| 214 |
+
]
|
| 215 |
+
]
|
| 216 |
+
else:
|
| 217 |
+
df = df[
|
| 218 |
+
["file_name", "begin_time", "station_id", "phase_index", "phase_time", "phase_score", "phase_type"]
|
| 219 |
+
]
|
| 220 |
+
# if args.amplitude:
|
| 221 |
+
# df = df[["file_name","station_id","phase_index","phase_time","phase_prob","phase_amplitude", "phase_type","dt",]]
|
| 222 |
+
# else:
|
| 223 |
+
# df = df[["file_name","station_id","phase_index","phase_time","phase_prob","phase_type","dt"]]
|
| 224 |
+
df.to_csv(os.path.join(args.result_dir, args.result_fname + ".csv"), index=False)
|
| 225 |
+
|
| 226 |
+
print(
|
| 227 |
+
f"Done with {len(df[df['phase_type'] == 'P'])} P-picks and {len(df[df['phase_type'] == 'S'])} S-picks"
|
| 228 |
+
)
|
| 229 |
+
else:
|
| 230 |
+
print(f"Done with 0 P-picks and 0 S-picks")
|
| 231 |
+
return 0
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def main(args):
|
| 235 |
+
|
| 236 |
+
logging.basicConfig(format="%(asctime)s %(message)s", level=logging.INFO)
|
| 237 |
+
|
| 238 |
+
with tf.compat.v1.name_scope("create_inputs"):
|
| 239 |
+
|
| 240 |
+
if args.format == "mseed_array":
|
| 241 |
+
data_reader = DataReader_mseed_array(
|
| 242 |
+
data_dir=args.data_dir,
|
| 243 |
+
data_list=args.data_list,
|
| 244 |
+
stations=args.stations,
|
| 245 |
+
amplitude=args.amplitude,
|
| 246 |
+
highpass_filter=args.highpass_filter,
|
| 247 |
+
)
|
| 248 |
+
else:
|
| 249 |
+
data_reader = DataReader_pred(
|
| 250 |
+
format=args.format,
|
| 251 |
+
data_dir=args.data_dir,
|
| 252 |
+
data_list=args.data_list,
|
| 253 |
+
hdf5_file=args.hdf5_file,
|
| 254 |
+
hdf5_group=args.hdf5_group,
|
| 255 |
+
amplitude=args.amplitude,
|
| 256 |
+
highpass_filter=args.highpass_filter,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
pred_fn(args, data_reader, log_dir=args.result_dir)
|
| 260 |
+
|
| 261 |
+
return
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
if __name__ == "__main__":
|
| 265 |
+
args = read_args()
|
| 266 |
+
main(args)
|
phasenet/slide_window.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from collections import defaultdict, namedtuple
|
| 3 |
+
from datetime import datetime, timedelta
|
| 4 |
+
from json import dumps
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import tensorflow as tf
|
| 8 |
+
|
| 9 |
+
from model import ModelConfig, UNet
|
| 10 |
+
from postprocess import extract_amplitude, extract_picks
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import obspy
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
tf.compat.v1.disable_eager_execution()
|
| 16 |
+
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
| 17 |
+
PROJECT_ROOT = os.path.realpath(os.path.join(os.path.dirname(__file__), ".."))
|
| 18 |
+
|
| 19 |
+
# load model
|
| 20 |
+
model = UNet(mode="pred")
|
| 21 |
+
sess_config = tf.compat.v1.ConfigProto()
|
| 22 |
+
sess_config.gpu_options.allow_growth = True
|
| 23 |
+
|
| 24 |
+
sess = tf.compat.v1.Session(config=sess_config)
|
| 25 |
+
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
|
| 26 |
+
init = tf.compat.v1.global_variables_initializer()
|
| 27 |
+
sess.run(init)
|
| 28 |
+
latest_check_point = tf.train.latest_checkpoint(f"{PROJECT_ROOT}/model/190703-214543")
|
| 29 |
+
print(f"restoring model {latest_check_point}")
|
| 30 |
+
saver.restore(sess, latest_check_point)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def calc_timestamp(timestamp, sec):
|
| 34 |
+
timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f") + timedelta(seconds=sec)
|
| 35 |
+
return timestamp.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
|
| 36 |
+
|
| 37 |
+
def format_picks(picks, dt):
|
| 38 |
+
picks_ = []
|
| 39 |
+
for pick in picks:
|
| 40 |
+
for idxs, probs in zip(pick.p_idx, pick.p_prob):
|
| 41 |
+
for idx, prob in zip(idxs, probs):
|
| 42 |
+
picks_.append(
|
| 43 |
+
{
|
| 44 |
+
"id": pick.fname,
|
| 45 |
+
"timestamp": calc_timestamp(pick.t0, float(idx) * dt),
|
| 46 |
+
"prob": prob,
|
| 47 |
+
"type": "p",
|
| 48 |
+
}
|
| 49 |
+
)
|
| 50 |
+
for idxs, probs in zip(pick.s_idx, pick.s_prob):
|
| 51 |
+
for idx, prob in zip(idxs, probs):
|
| 52 |
+
picks_.append(
|
| 53 |
+
{
|
| 54 |
+
"id": pick.fname,
|
| 55 |
+
"timestamp": calc_timestamp(pick.t0, float(idx) * dt),
|
| 56 |
+
"prob": prob,
|
| 57 |
+
"type": "s",
|
| 58 |
+
}
|
| 59 |
+
)
|
| 60 |
+
return picks_
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
stream = obspy.read()
|
| 64 |
+
stream = stream.sort() ## Assume it is NPZ sorted
|
| 65 |
+
assert(len(stream) == 3)
|
| 66 |
+
data = []
|
| 67 |
+
for trace in stream:
|
| 68 |
+
data.append(trace.data)
|
| 69 |
+
data = np.array(data).T
|
| 70 |
+
assert(data.shape[-1] == 3)
|
| 71 |
+
|
| 72 |
+
# data_id = stream[0].get_id()[:-1]
|
| 73 |
+
# timestamp = stream[0].stats.starttime.datetime.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
|
| 74 |
+
|
| 75 |
+
data = np.stack([data for i in range(10)]) ## Assume 10 windows
|
| 76 |
+
data = data[:,:,np.newaxis,:] ## batch, nt, dummy_dim, channel
|
| 77 |
+
print(f"{data.shape = }")
|
| 78 |
+
data = (data - data.mean(axis=1, keepdims=True))/data.std(axis=1, keepdims=True)
|
| 79 |
+
|
| 80 |
+
feed = {model.X: data, model.drop_rate: 0, model.is_training: False}
|
| 81 |
+
preds = sess.run(model.preds, feed_dict=feed)
|
| 82 |
+
|
| 83 |
+
picks = extract_picks(preds, fnames=None, station_ids=None, t0=None)
|
| 84 |
+
picks = format_picks(picks, dt=0.01)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
picks = pd.DataFrame(picks)
|
| 88 |
+
print(picks)
|
phasenet/test_app.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
import obspy
|
| 3 |
+
import numpy as np
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
|
| 7 |
+
### Start running the model first:
|
| 8 |
+
### FLASK_ENV=development FLASK_APP=app.py flask run
|
| 9 |
+
|
| 10 |
+
def read_data(mseed):
|
| 11 |
+
data = []
|
| 12 |
+
mseed = mseed.sort()
|
| 13 |
+
for c in ["E", "N", "Z"]:
|
| 14 |
+
data.append(mseed.select(channel="*"+c)[0].data)
|
| 15 |
+
return np.array(data).T
|
| 16 |
+
|
| 17 |
+
timestamp = lambda x: x.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
|
| 18 |
+
|
| 19 |
+
## prepare some test data
|
| 20 |
+
mseed = obspy.read()
|
| 21 |
+
data = []
|
| 22 |
+
for i in range(1):
|
| 23 |
+
data.append(read_data(mseed))
|
| 24 |
+
data = {
|
| 25 |
+
"id": ["test01"],
|
| 26 |
+
"timestamp": [timestamp(datetime.now())],
|
| 27 |
+
"vec": np.array(data).tolist(),
|
| 28 |
+
"dt": 0.01
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
## run prediction
|
| 32 |
+
print(data["id"])
|
| 33 |
+
resp = requests.get("http://localhost:8000/predict", json=data)
|
| 34 |
+
# picks = resp.json()["picks"]
|
| 35 |
+
print(resp.json())
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
## plot figure
|
| 39 |
+
plt.figure()
|
| 40 |
+
plt.plot(np.array(data["data"])[0,:,1])
|
| 41 |
+
ylim = plt.ylim()
|
| 42 |
+
plt.plot([picks[0][0][0], picks[0][0][0]], ylim, label="P-phase")
|
| 43 |
+
plt.text(picks[0][0][0], ylim[1]*0.9, f"{picks[0][1][0]:.2f}")
|
| 44 |
+
plt.plot([picks[0][2][0], picks[0][2][0]], ylim, label="S-phase")
|
| 45 |
+
plt.text(picks[0][2][0], ylim[1]*0.9, f"{picks[0][1][0]:.2f}")
|
| 46 |
+
plt.legend()
|
| 47 |
+
plt.savefig("test.png")
|
phasenet/train.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import tensorflow as tf
|
| 3 |
+
tf.compat.v1.disable_eager_execution()
|
| 4 |
+
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
| 5 |
+
import argparse, os, time, logging
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import multiprocessing
|
| 9 |
+
from functools import partial
|
| 10 |
+
import pickle
|
| 11 |
+
from model import UNet, ModelConfig
|
| 12 |
+
from data_reader import DataReader_train, DataReader_test
|
| 13 |
+
from postprocess import extract_picks, save_picks, save_picks_json, extract_amplitude, convert_true_picks, calc_performance
|
| 14 |
+
from visulization import plot_waveform
|
| 15 |
+
from util import EMA, LMA
|
| 16 |
+
|
| 17 |
+
def read_args():
|
| 18 |
+
|
| 19 |
+
parser = argparse.ArgumentParser()
|
| 20 |
+
parser.add_argument("--mode", default="train", help="train/train_valid/test/debug")
|
| 21 |
+
parser.add_argument("--epochs", default=100, type=int, help="number of epochs (default: 10)")
|
| 22 |
+
parser.add_argument("--batch_size", default=20, type=int, help="batch size")
|
| 23 |
+
parser.add_argument("--learning_rate", default=0.01, type=float, help="learning rate")
|
| 24 |
+
parser.add_argument("--drop_rate", default=0.0, type=float, help="dropout rate")
|
| 25 |
+
parser.add_argument("--decay_step", default=-1, type=int, help="decay step")
|
| 26 |
+
parser.add_argument("--decay_rate", default=0.9, type=float, help="decay rate")
|
| 27 |
+
parser.add_argument("--momentum", default=0.9, type=float, help="momentum")
|
| 28 |
+
parser.add_argument("--optimizer", default="adam", help="optimizer: adam, momentum")
|
| 29 |
+
parser.add_argument("--summary", default=True, type=bool, help="summary")
|
| 30 |
+
parser.add_argument("--class_weights", nargs="+", default=[1, 1, 1], type=float, help="class weights")
|
| 31 |
+
parser.add_argument("--model_dir", default=None, help="Checkpoint directory (default: None)")
|
| 32 |
+
parser.add_argument("--load_model", action="store_true", help="Load checkpoint")
|
| 33 |
+
parser.add_argument("--log_dir", default="log", help="Log directory (default: log)")
|
| 34 |
+
parser.add_argument("--num_plots", default=10, type=int, help="Plotting training results")
|
| 35 |
+
parser.add_argument("--min_p_prob", default=0.3, type=float, help="Probability threshold for P pick")
|
| 36 |
+
parser.add_argument("--min_s_prob", default=0.3, type=float, help="Probability threshold for S pick")
|
| 37 |
+
parser.add_argument("--format", default="numpy", help="Input data format")
|
| 38 |
+
parser.add_argument("--train_dir", default="./dataset/waveform_train/", help="Input file directory")
|
| 39 |
+
parser.add_argument("--train_list", default="./dataset/waveform.csv", help="Input csv file")
|
| 40 |
+
parser.add_argument("--valid_dir", default=None, help="Input file directory")
|
| 41 |
+
parser.add_argument("--valid_list", default=None, help="Input csv file")
|
| 42 |
+
parser.add_argument("--test_dir", default=None, help="Input file directory")
|
| 43 |
+
parser.add_argument("--test_list", default=None, help="Input csv file")
|
| 44 |
+
parser.add_argument("--result_dir", default="results", help="result directory")
|
| 45 |
+
parser.add_argument("--plot_figure", action="store_true", help="If plot figure for test")
|
| 46 |
+
parser.add_argument("--save_prob", action="store_true", help="If save result for test")
|
| 47 |
+
args = parser.parse_args()
|
| 48 |
+
|
| 49 |
+
return args
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def train_fn(args, data_reader, data_reader_valid=None):
|
| 53 |
+
|
| 54 |
+
current_time = time.strftime("%y%m%d-%H%M%S")
|
| 55 |
+
log_dir = os.path.join(args.log_dir, current_time)
|
| 56 |
+
if not os.path.exists(log_dir):
|
| 57 |
+
os.makedirs(log_dir)
|
| 58 |
+
logging.info("Training log: {}".format(log_dir))
|
| 59 |
+
model_dir = os.path.join(log_dir, 'models')
|
| 60 |
+
os.makedirs(model_dir)
|
| 61 |
+
|
| 62 |
+
figure_dir = os.path.join(log_dir, 'figures')
|
| 63 |
+
if not os.path.exists(figure_dir):
|
| 64 |
+
os.makedirs(figure_dir)
|
| 65 |
+
|
| 66 |
+
config = ModelConfig(X_shape=data_reader.X_shape, Y_shape=data_reader.Y_shape)
|
| 67 |
+
if args.decay_step == -1:
|
| 68 |
+
args.decay_step = data_reader.num_data // args.batch_size
|
| 69 |
+
config.update_args(args)
|
| 70 |
+
with open(os.path.join(log_dir, 'config.log'), 'w') as fp:
|
| 71 |
+
fp.write('\n'.join("%s: %s" % item for item in vars(config).items()))
|
| 72 |
+
|
| 73 |
+
with tf.compat.v1.name_scope('Input_Batch'):
|
| 74 |
+
dataset = data_reader.dataset(args.batch_size, shuffle=True).repeat()
|
| 75 |
+
batch = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
|
| 76 |
+
if data_reader_valid is not None:
|
| 77 |
+
dataset_valid = data_reader_valid.dataset(args.batch_size, shuffle=False).repeat()
|
| 78 |
+
valid_batch = tf.compat.v1.data.make_one_shot_iterator(dataset_valid).get_next()
|
| 79 |
+
|
| 80 |
+
model = UNet(config, input_batch=batch)
|
| 81 |
+
sess_config = tf.compat.v1.ConfigProto()
|
| 82 |
+
sess_config.gpu_options.allow_growth = True
|
| 83 |
+
# sess_config.log_device_placement = False
|
| 84 |
+
|
| 85 |
+
with tf.compat.v1.Session(config=sess_config) as sess:
|
| 86 |
+
|
| 87 |
+
summary_writer = tf.compat.v1.summary.FileWriter(log_dir, sess.graph)
|
| 88 |
+
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables(), max_to_keep=5)
|
| 89 |
+
init = tf.compat.v1.global_variables_initializer()
|
| 90 |
+
sess.run(init)
|
| 91 |
+
|
| 92 |
+
if args.model_dir is not None:
|
| 93 |
+
logging.info("restoring models...")
|
| 94 |
+
latest_check_point = tf.train.latest_checkpoint(args.model_dir)
|
| 95 |
+
saver.restore(sess, latest_check_point)
|
| 96 |
+
|
| 97 |
+
if args.plot_figure:
|
| 98 |
+
multiprocessing.set_start_method('spawn')
|
| 99 |
+
pool = multiprocessing.Pool(multiprocessing.cpu_count())
|
| 100 |
+
|
| 101 |
+
flog = open(os.path.join(log_dir, 'loss.log'), 'w')
|
| 102 |
+
train_loss = EMA(0.9)
|
| 103 |
+
best_valid_loss = np.inf
|
| 104 |
+
for epoch in range(args.epochs):
|
| 105 |
+
progressbar = tqdm(range(0, data_reader.num_data, args.batch_size), desc="{}: epoch {}".format(log_dir.split("/")[-1], epoch))
|
| 106 |
+
for _ in progressbar:
|
| 107 |
+
loss_batch, _, _ = sess.run([model.loss, model.train_op, model.global_step],
|
| 108 |
+
feed_dict={model.drop_rate: args.drop_rate, model.is_training: True})
|
| 109 |
+
train_loss(loss_batch)
|
| 110 |
+
progressbar.set_description("{}: epoch {}, loss={:.6f}, mean={:.6f}".format(log_dir.split("/")[-1], epoch, loss_batch, train_loss.value))
|
| 111 |
+
flog.write("epoch: {}, mean loss: {}\n".format(epoch, train_loss.value))
|
| 112 |
+
|
| 113 |
+
if data_reader_valid is not None:
|
| 114 |
+
valid_loss = LMA()
|
| 115 |
+
progressbar = tqdm(range(0, data_reader_valid.num_data, args.batch_size), desc="Valid:")
|
| 116 |
+
for _ in progressbar:
|
| 117 |
+
loss_batch, preds_batch, X_batch, Y_batch, fname_batch = sess.run([model.loss, model.preds, valid_batch[0], valid_batch[1], valid_batch[2]],
|
| 118 |
+
feed_dict={model.drop_rate: 0, model.is_training: False})
|
| 119 |
+
valid_loss(loss_batch)
|
| 120 |
+
progressbar.set_description("valid, loss={:.6f}, mean={:.6f}".format(loss_batch, valid_loss.value))
|
| 121 |
+
if valid_loss.value < best_valid_loss:
|
| 122 |
+
best_valid_loss = valid_loss.value
|
| 123 |
+
saver.save(sess, os.path.join(model_dir, "model_{}.ckpt".format(epoch)))
|
| 124 |
+
flog.write("Valid: mean loss: {}\n".format(valid_loss.value))
|
| 125 |
+
else:
|
| 126 |
+
loss_batch, preds_batch, X_batch, Y_batch, fname_batch = sess.run([model.loss, model.preds, batch[0], batch[1], batch[2]],
|
| 127 |
+
feed_dict={model.drop_rate: 0, model.is_training: False})
|
| 128 |
+
saver.save(sess, os.path.join(model_dir, "model_{}.ckpt".format(epoch)))
|
| 129 |
+
|
| 130 |
+
if args.plot_figure:
|
| 131 |
+
pool.starmap(
|
| 132 |
+
partial(
|
| 133 |
+
plot_waveform,
|
| 134 |
+
figure_dir=figure_dir,
|
| 135 |
+
),
|
| 136 |
+
zip(X_batch, preds_batch, [x.decode() for x in fname_batch], Y_batch),
|
| 137 |
+
)
|
| 138 |
+
# plot_waveform(X_batch, preds_batch, fname_batch, label=Y_batch, figure_dir=figure_dir)
|
| 139 |
+
flog.flush()
|
| 140 |
+
|
| 141 |
+
flog.close()
|
| 142 |
+
|
| 143 |
+
return 0
|
| 144 |
+
|
| 145 |
+
def test_fn(args, data_reader):
|
| 146 |
+
current_time = time.strftime("%y%m%d-%H%M%S")
|
| 147 |
+
logging.info("{} log: {}".format(args.mode, current_time))
|
| 148 |
+
if args.model_dir is None:
|
| 149 |
+
logging.error(f"model_dir = None!")
|
| 150 |
+
return -1
|
| 151 |
+
if not os.path.exists(args.result_dir):
|
| 152 |
+
os.makedirs(args.result_dir)
|
| 153 |
+
figure_dir=os.path.join(args.result_dir, "figures")
|
| 154 |
+
if not os.path.exists(figure_dir):
|
| 155 |
+
os.makedirs(figure_dir)
|
| 156 |
+
|
| 157 |
+
config = ModelConfig(X_shape=data_reader.X_shape, Y_shape=data_reader.Y_shape)
|
| 158 |
+
config.update_args(args)
|
| 159 |
+
with open(os.path.join(args.result_dir, 'config.log'), 'w') as fp:
|
| 160 |
+
fp.write('\n'.join("%s: %s" % item for item in vars(config).items()))
|
| 161 |
+
|
| 162 |
+
with tf.compat.v1.name_scope('Input_Batch'):
|
| 163 |
+
dataset = data_reader.dataset(args.batch_size, shuffle=False)
|
| 164 |
+
batch = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()
|
| 165 |
+
|
| 166 |
+
model = UNet(config, input_batch=batch, mode='test')
|
| 167 |
+
sess_config = tf.compat.v1.ConfigProto()
|
| 168 |
+
sess_config.gpu_options.allow_growth = True
|
| 169 |
+
# sess_config.log_device_placement = False
|
| 170 |
+
|
| 171 |
+
with tf.compat.v1.Session(config=sess_config) as sess:
|
| 172 |
+
|
| 173 |
+
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
|
| 174 |
+
init = tf.compat.v1.global_variables_initializer()
|
| 175 |
+
sess.run(init)
|
| 176 |
+
|
| 177 |
+
logging.info("restoring models...")
|
| 178 |
+
latest_check_point = tf.train.latest_checkpoint(args.model_dir)
|
| 179 |
+
if latest_check_point is None:
|
| 180 |
+
logging.error(f"No models found in model_dir: {args.model_dir}")
|
| 181 |
+
return -1
|
| 182 |
+
saver.restore(sess, latest_check_point)
|
| 183 |
+
|
| 184 |
+
flog = open(os.path.join(args.result_dir, 'loss.log'), 'w')
|
| 185 |
+
test_loss = LMA()
|
| 186 |
+
progressbar = tqdm(range(0, data_reader.num_data, args.batch_size), desc=args.mode)
|
| 187 |
+
picks = []
|
| 188 |
+
true_picks = []
|
| 189 |
+
for _ in progressbar:
|
| 190 |
+
loss_batch, preds_batch, X_batch, Y_batch, fname_batch, itp_batch, its_batch \
|
| 191 |
+
= sess.run([model.loss, model.preds, batch[0], batch[1], batch[2], batch[3], batch[4]],
|
| 192 |
+
feed_dict={model.drop_rate: 0, model.is_training: False})
|
| 193 |
+
|
| 194 |
+
test_loss(loss_batch)
|
| 195 |
+
progressbar.set_description("{}, loss={:.6f}, mean loss={:6f}".format(args.mode, loss_batch, test_loss.value))
|
| 196 |
+
|
| 197 |
+
picks_ = extract_picks(preds_batch, fname_batch)
|
| 198 |
+
picks.extend(picks_)
|
| 199 |
+
true_picks.extend(convert_true_picks(fname_batch, itp_batch, its_batch))
|
| 200 |
+
if args.plot_figure:
|
| 201 |
+
plot_waveform(data_reader.config, X_batch, preds_batch, label=Y_batch, fname=fname_batch,
|
| 202 |
+
itp=itp_batch, its=its_batch, figure_dir=figure_dir)
|
| 203 |
+
|
| 204 |
+
save_picks(picks, args.result_dir)
|
| 205 |
+
metrics = calc_performance(picks, true_picks, tol=3.0, dt=data_reader.config.dt)
|
| 206 |
+
flog.write("mean loss: {}\n".format(test_loss))
|
| 207 |
+
flog.close()
|
| 208 |
+
|
| 209 |
+
return 0
|
| 210 |
+
|
| 211 |
+
def main(args):
|
| 212 |
+
|
| 213 |
+
logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
|
| 214 |
+
coord = tf.train.Coordinator()
|
| 215 |
+
|
| 216 |
+
if (args.mode == "train") or (args.mode == "train_valid"):
|
| 217 |
+
with tf.compat.v1.name_scope('create_inputs'):
|
| 218 |
+
data_reader = DataReader_train(format=args.format,
|
| 219 |
+
data_dir=args.train_dir,
|
| 220 |
+
data_list=args.train_list)
|
| 221 |
+
if args.mode == "train_valid":
|
| 222 |
+
data_reader_valid = DataReader_train(format=args.format,
|
| 223 |
+
data_dir=args.valid_dir,
|
| 224 |
+
data_list=args.valid_list)
|
| 225 |
+
logging.info("Dataset size: train {}, valid {}".format(data_reader.num_data, data_reader_valid.num_data))
|
| 226 |
+
else:
|
| 227 |
+
data_reader_valid = None
|
| 228 |
+
logging.info("Dataset size: train {}".format(data_reader.num_data))
|
| 229 |
+
train_fn(args, data_reader, data_reader_valid)
|
| 230 |
+
|
| 231 |
+
elif args.mode == "test":
|
| 232 |
+
with tf.compat.v1.name_scope('create_inputs'):
|
| 233 |
+
data_reader = DataReader_test(format=args.format,
|
| 234 |
+
data_dir=args.test_dir,
|
| 235 |
+
data_list=args.test_list)
|
| 236 |
+
test_fn(args, data_reader)
|
| 237 |
+
|
| 238 |
+
else:
|
| 239 |
+
print("mode should be: train, train_valid, or test")
|
| 240 |
+
|
| 241 |
+
return
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
if __name__ == '__main__':
|
| 245 |
+
args = read_args()
|
| 246 |
+
main(args)
|
phasenet/util.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import division
|
| 2 |
+
import matplotlib
|
| 3 |
+
matplotlib.use('agg')
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import numpy as np
|
| 6 |
+
import os
|
| 7 |
+
from data_reader import DataConfig
|
| 8 |
+
from detect_peaks import detect_peaks
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
class EMA(object):
|
| 12 |
+
def __init__(self, alpha):
|
| 13 |
+
self.alpha = alpha
|
| 14 |
+
self.x = 0.
|
| 15 |
+
self.count = 0
|
| 16 |
+
|
| 17 |
+
@property
|
| 18 |
+
def value(self):
|
| 19 |
+
return self.x
|
| 20 |
+
|
| 21 |
+
def __call__(self, x):
|
| 22 |
+
if self.count == 0:
|
| 23 |
+
self.x = x
|
| 24 |
+
else:
|
| 25 |
+
self.x = self.alpha * self.x + (1 - self.alpha) * x
|
| 26 |
+
self.count += 1
|
| 27 |
+
return self.x
|
| 28 |
+
|
| 29 |
+
class LMA(object):
|
| 30 |
+
def __init__(self):
|
| 31 |
+
self.x = 0.
|
| 32 |
+
self.count = 0
|
| 33 |
+
|
| 34 |
+
@property
|
| 35 |
+
def value(self):
|
| 36 |
+
return self.x
|
| 37 |
+
|
| 38 |
+
def __call__(self, x):
|
| 39 |
+
if self.count == 0:
|
| 40 |
+
self.x = x
|
| 41 |
+
else:
|
| 42 |
+
self.x += (x - self.x)/(self.count+1)
|
| 43 |
+
self.count += 1
|
| 44 |
+
return self.x
|
| 45 |
+
|
| 46 |
+
def detect_peaks_thread(i, pred, fname=None, result_dir=None, args=None):
|
| 47 |
+
if args is None:
|
| 48 |
+
itp, prob_p = detect_peaks(pred[i,:,0,1], mph=0.5, mpd=0.5/DataConfig().dt, show=False)
|
| 49 |
+
its, prob_s = detect_peaks(pred[i,:,0,2], mph=0.5, mpd=0.5/DataConfig().dt, show=False)
|
| 50 |
+
else:
|
| 51 |
+
itp, prob_p = detect_peaks(pred[i,:,0,1], mph=args.tp_prob, mpd=0.5/DataConfig().dt, show=False)
|
| 52 |
+
its, prob_s = detect_peaks(pred[i,:,0,2], mph=args.ts_prob, mpd=0.5/DataConfig().dt, show=False)
|
| 53 |
+
if (fname is not None) and (result_dir is not None):
|
| 54 |
+
# np.savez(os.path.join(result_dir, fname[i].decode().split('/')[-1]), pred=pred[i], itp=itp, its=its, prob_p=prob_p, prob_s=prob_s)
|
| 55 |
+
try:
|
| 56 |
+
np.savez(os.path.join(result_dir, fname[i].decode()), pred=pred[i], itp=itp, its=its, prob_p=prob_p, prob_s=prob_s)
|
| 57 |
+
except FileNotFoundError:
|
| 58 |
+
#if not os.path.exists(os.path.dirname(os.path.join(result_dir, fname[i].decode()))):
|
| 59 |
+
os.makedirs(os.path.dirname(os.path.join(result_dir, fname[i].decode())), exist_ok=True)
|
| 60 |
+
np.savez(os.path.join(result_dir, fname[i].decode()), pred=pred[i], itp=itp, its=its, prob_p=prob_p, prob_s=prob_s)
|
| 61 |
+
return [(itp, prob_p), (its, prob_s)]
|
| 62 |
+
|
| 63 |
+
def plot_result_thread(i, pred, X, Y=None, itp=None, its=None,
|
| 64 |
+
itp_pred=None, its_pred=None, fname=None, figure_dir=None):
|
| 65 |
+
dt = DataConfig().dt
|
| 66 |
+
t = np.arange(0, pred.shape[1]) * dt
|
| 67 |
+
box = dict(boxstyle='round', facecolor='white', alpha=1)
|
| 68 |
+
text_loc = [0.05, 0.77]
|
| 69 |
+
|
| 70 |
+
plt.figure(i)
|
| 71 |
+
plt.clf()
|
| 72 |
+
# fig_size = plt.gcf().get_size_inches()
|
| 73 |
+
# plt.gcf().set_size_inches(fig_size*[1, 1.2])
|
| 74 |
+
plt.subplot(411)
|
| 75 |
+
plt.plot(t, X[i, :, 0, 0], 'k', label='E', linewidth=0.5)
|
| 76 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
| 77 |
+
tmp_min = np.min(X[i, :, 0, 0])
|
| 78 |
+
tmp_max = np.max(X[i, :, 0, 0])
|
| 79 |
+
if (itp is not None) and (its is not None):
|
| 80 |
+
for j in range(len(itp[i])):
|
| 81 |
+
if j == 0:
|
| 82 |
+
plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'b', label='P', linewidth=0.5)
|
| 83 |
+
else:
|
| 84 |
+
plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'b', linewidth=0.5)
|
| 85 |
+
for j in range(len(its[i])):
|
| 86 |
+
if j == 0:
|
| 87 |
+
plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'r', label='S', linewidth=0.5)
|
| 88 |
+
else:
|
| 89 |
+
plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'r', linewidth=0.5)
|
| 90 |
+
plt.ylabel('Amplitude')
|
| 91 |
+
plt.legend(loc='upper right', fontsize='small')
|
| 92 |
+
plt.gca().set_xticklabels([])
|
| 93 |
+
plt.text(text_loc[0], text_loc[1], '(i)', horizontalalignment='center',
|
| 94 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 95 |
+
plt.subplot(412)
|
| 96 |
+
plt.plot(t, X[i, :, 0, 1], 'k', label='N', linewidth=0.5)
|
| 97 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
| 98 |
+
tmp_min = np.min(X[i, :, 0, 1])
|
| 99 |
+
tmp_max = np.max(X[i, :, 0, 1])
|
| 100 |
+
if (itp is not None) and (its is not None):
|
| 101 |
+
for j in range(len(itp[i])):
|
| 102 |
+
plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'b', linewidth=0.5)
|
| 103 |
+
for j in range(len(its[i])):
|
| 104 |
+
plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'r', linewidth=0.5)
|
| 105 |
+
plt.ylabel('Amplitude')
|
| 106 |
+
plt.legend(loc='upper right', fontsize='small')
|
| 107 |
+
plt.gca().set_xticklabels([])
|
| 108 |
+
plt.text(text_loc[0], text_loc[1], '(ii)', horizontalalignment='center',
|
| 109 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 110 |
+
plt.subplot(413)
|
| 111 |
+
plt.plot(t, X[i, :, 0, 2], 'k', label='Z', linewidth=0.5)
|
| 112 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
| 113 |
+
tmp_min = np.min(X[i, :, 0, 2])
|
| 114 |
+
tmp_max = np.max(X[i, :, 0, 2])
|
| 115 |
+
if (itp is not None) and (its is not None):
|
| 116 |
+
for j in range(len(itp[i])):
|
| 117 |
+
plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'b', linewidth=0.5)
|
| 118 |
+
for j in range(len(its[i])):
|
| 119 |
+
plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'r', linewidth=0.5)
|
| 120 |
+
plt.ylabel('Amplitude')
|
| 121 |
+
plt.legend(loc='upper right', fontsize='small')
|
| 122 |
+
plt.gca().set_xticklabels([])
|
| 123 |
+
plt.text(text_loc[0], text_loc[1], '(iii)', horizontalalignment='center',
|
| 124 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 125 |
+
plt.subplot(414)
|
| 126 |
+
if Y is not None:
|
| 127 |
+
plt.plot(t, Y[i, :, 0, 1], 'b', label='P', linewidth=0.5)
|
| 128 |
+
plt.plot(t, Y[i, :, 0, 2], 'r', label='S', linewidth=0.5)
|
| 129 |
+
plt.plot(t, pred[i, :, 0, 1], '--g', label='$\hat{P}$', linewidth=0.5)
|
| 130 |
+
plt.plot(t, pred[i, :, 0, 2], '-.m', label='$\hat{S}$', linewidth=0.5)
|
| 131 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
| 132 |
+
if (itp_pred is not None) and (its_pred is not None):
|
| 133 |
+
for j in range(len(itp_pred)):
|
| 134 |
+
plt.plot([itp_pred[j]*dt, itp_pred[j]*dt], [-0.1, 1.1], '--g', linewidth=0.5)
|
| 135 |
+
for j in range(len(its_pred)):
|
| 136 |
+
plt.plot([its_pred[j]*dt, its_pred[j]*dt], [-0.1, 1.1], '-.m', linewidth=0.5)
|
| 137 |
+
plt.ylim([-0.05, 1.05])
|
| 138 |
+
plt.text(text_loc[0], text_loc[1], '(iv)', horizontalalignment='center',
|
| 139 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 140 |
+
plt.legend(loc='upper right', fontsize='small')
|
| 141 |
+
plt.xlabel('Time (s)')
|
| 142 |
+
plt.ylabel('Probability')
|
| 143 |
+
|
| 144 |
+
plt.tight_layout()
|
| 145 |
+
plt.gcf().align_labels()
|
| 146 |
+
|
| 147 |
+
try:
|
| 148 |
+
plt.savefig(os.path.join(figure_dir,
|
| 149 |
+
fname[i].decode().rstrip('.npz')+'.png'),
|
| 150 |
+
bbox_inches='tight')
|
| 151 |
+
except FileNotFoundError:
|
| 152 |
+
#if not os.path.exists(os.path.dirname(os.path.join(figure_dir, fname[i].decode()))):
|
| 153 |
+
os.makedirs(os.path.dirname(os.path.join(figure_dir, fname[i].decode())), exist_ok=True)
|
| 154 |
+
plt.savefig(os.path.join(figure_dir,
|
| 155 |
+
fname[i].decode().rstrip('.npz')+'.png'),
|
| 156 |
+
bbox_inches='tight')
|
| 157 |
+
#plt.savefig(os.path.join(figure_dir,
|
| 158 |
+
# fname[i].decode().split('/')[-1].rstrip('.npz')+'.png'),
|
| 159 |
+
# bbox_inches='tight')
|
| 160 |
+
# plt.savefig(os.path.join(figure_dir,
|
| 161 |
+
# fname[i].decode().split('/')[-1].rstrip('.npz')+'.pdf'),
|
| 162 |
+
# bbox_inches='tight')
|
| 163 |
+
plt.close(i)
|
| 164 |
+
return 0
|
| 165 |
+
|
| 166 |
+
def postprocessing_thread(i, pred, X, Y=None, itp=None, its=None, fname=None, result_dir=None, figure_dir=None, args=None):
|
| 167 |
+
(itp_pred, prob_p), (its_pred, prob_s) = detect_peaks_thread(i, pred, fname, result_dir, args)
|
| 168 |
+
if (fname is not None) and (figure_dir is not None):
|
| 169 |
+
plot_result_thread(i, pred, X, Y, itp, its, itp_pred, its_pred, fname, figure_dir)
|
| 170 |
+
return [(itp_pred, prob_p), (its_pred, prob_s)]
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def clean_queue(picks):
|
| 174 |
+
clean = []
|
| 175 |
+
for i in range(len(picks)):
|
| 176 |
+
tmp = []
|
| 177 |
+
for j in picks[i]:
|
| 178 |
+
if j != 0:
|
| 179 |
+
tmp.append(j)
|
| 180 |
+
clean.append(tmp)
|
| 181 |
+
return clean
|
| 182 |
+
|
| 183 |
+
def clean_queue_thread(picks):
|
| 184 |
+
tmp = []
|
| 185 |
+
for j in picks:
|
| 186 |
+
if j != 0:
|
| 187 |
+
tmp.append(j)
|
| 188 |
+
return tmp
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def metrics(TP, nP, nT):
|
| 192 |
+
'''
|
| 193 |
+
TP: true positive
|
| 194 |
+
nP: number of positive picks
|
| 195 |
+
nT: number of true picks
|
| 196 |
+
'''
|
| 197 |
+
precision = TP / nP
|
| 198 |
+
recall = TP / nT
|
| 199 |
+
F1 = 2* precision * recall / (precision + recall)
|
| 200 |
+
return [precision, recall, F1]
|
| 201 |
+
|
| 202 |
+
def correct_picks(picks, true_p, true_s, tol):
|
| 203 |
+
dt = DataConfig().dt
|
| 204 |
+
if len(true_p) != len(true_s):
|
| 205 |
+
print("The length of true P and S pickers are not the same")
|
| 206 |
+
num = len(true_p)
|
| 207 |
+
TP_p = 0; TP_s = 0; nP_p = 0; nP_s = 0; nT_p = 0; nT_s = 0
|
| 208 |
+
diff_p = []; diff_s = []
|
| 209 |
+
for i in range(num):
|
| 210 |
+
nT_p += len(true_p[i])
|
| 211 |
+
nT_s += len(true_s[i])
|
| 212 |
+
nP_p += len(picks[i][0][0])
|
| 213 |
+
nP_s += len(picks[i][1][0])
|
| 214 |
+
|
| 215 |
+
if len(true_p[i]) > 1 or len(true_s[i]) > 1:
|
| 216 |
+
print(i, picks[i], true_p[i], true_s[i])
|
| 217 |
+
tmp_p = np.array(picks[i][0][0]) - np.array(true_p[i])[:,np.newaxis]
|
| 218 |
+
tmp_s = np.array(picks[i][1][0]) - np.array(true_s[i])[:,np.newaxis]
|
| 219 |
+
TP_p += np.sum(np.abs(tmp_p) < tol/dt)
|
| 220 |
+
TP_s += np.sum(np.abs(tmp_s) < tol/dt)
|
| 221 |
+
diff_p.append(tmp_p[np.abs(tmp_p) < 0.5/dt])
|
| 222 |
+
diff_s.append(tmp_s[np.abs(tmp_s) < 0.5/dt])
|
| 223 |
+
|
| 224 |
+
return [TP_p, TP_s, nP_p, nP_s, nT_p, nT_s, diff_p, diff_s]
|
| 225 |
+
|
| 226 |
+
def calculate_metrics(picks, itp, its, tol=0.1):
|
| 227 |
+
TP_p, TP_s, nP_p, nP_s, nT_p, nT_s, diff_p, diff_s = correct_picks(picks, itp, its, tol)
|
| 228 |
+
precision_p, recall_p, f1_p = metrics(TP_p, nP_p, nT_p)
|
| 229 |
+
precision_s, recall_s, f1_s = metrics(TP_s, nP_s, nT_s)
|
| 230 |
+
|
| 231 |
+
logging.info("Total records: {}".format(len(picks)))
|
| 232 |
+
logging.info("P-phase:")
|
| 233 |
+
logging.info("True={}, Predict={}, TruePositive={}".format(nT_p, nP_p, TP_p))
|
| 234 |
+
logging.info("Precision={:.3f}, Recall={:.3f}, F1={:.3f}".format(precision_p, recall_p, f1_p))
|
| 235 |
+
logging.info("S-phase:")
|
| 236 |
+
logging.info("True={}, Predict={}, TruePositive={}".format(nT_s, nP_s, TP_s))
|
| 237 |
+
logging.info("Precision={:.3f}, Recall={:.3f}, F1={:.3f}".format(precision_s, recall_s, f1_s))
|
| 238 |
+
return [precision_p, recall_p, f1_p], [precision_s, recall_s, f1_s]
|
phasenet/visulization.py
ADDED
|
@@ -0,0 +1,481 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib
|
| 2 |
+
matplotlib.use("agg")
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import numpy as np
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def plot_residual(diff_p, diff_s, diff_ps, tol, dt):
|
| 9 |
+
box = dict(boxstyle='round', facecolor='white', alpha=1)
|
| 10 |
+
text_loc = [0.07, 0.95]
|
| 11 |
+
plt.figure(figsize=(8,3))
|
| 12 |
+
plt.subplot(1,3,1)
|
| 13 |
+
plt.hist(diff_p, range=(-tol, tol), bins=int(2*tol/dt)+1, facecolor='b', edgecolor='black', linewidth=1)
|
| 14 |
+
plt.ylabel("Number of picks")
|
| 15 |
+
plt.xlabel("Residual (s)")
|
| 16 |
+
plt.text(text_loc[0], text_loc[1], "(i)", horizontalalignment='left', verticalalignment='top',
|
| 17 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 18 |
+
plt.title("P-phase")
|
| 19 |
+
plt.subplot(1,3,2)
|
| 20 |
+
plt.hist(diff_s, range=(-tol, tol), bins=int(2*tol/dt)+1, facecolor='b', edgecolor='black', linewidth=1)
|
| 21 |
+
plt.xlabel("Residual (s)")
|
| 22 |
+
plt.text(text_loc[0], text_loc[1], "(ii)", horizontalalignment='left', verticalalignment='top',
|
| 23 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 24 |
+
plt.title("S-phase")
|
| 25 |
+
plt.subplot(1,3,3)
|
| 26 |
+
plt.hist(diff_ps, range=(-tol, tol), bins=int(2*tol/dt)+1, facecolor='b', edgecolor='black', linewidth=1)
|
| 27 |
+
plt.xlabel("Residual (s)")
|
| 28 |
+
plt.text(text_loc[0], text_loc[1], "(iii)", horizontalalignment='left', verticalalignment='top',
|
| 29 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 30 |
+
plt.title("PS-phase")
|
| 31 |
+
plt.tight_layout()
|
| 32 |
+
plt.savefig("residuals.png", dpi=300)
|
| 33 |
+
plt.savefig("residuals.pdf")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# def plot_waveform(config, data, pred, label=None,
|
| 37 |
+
# itp=None, its=None, itps=None,
|
| 38 |
+
# itp_pred=None, its_pred=None, itps_pred=None,
|
| 39 |
+
# fname=None, figure_dir="./", epoch=0, max_fig=10):
|
| 40 |
+
|
| 41 |
+
# dt = config.dt if hasattr(config, "dt") else 1.0
|
| 42 |
+
# t = np.arange(0, pred.shape[1]) * dt
|
| 43 |
+
# box = dict(boxstyle='round', facecolor='white', alpha=1)
|
| 44 |
+
# text_loc = [0.05, 0.77]
|
| 45 |
+
# if fname is None:
|
| 46 |
+
# fname = [f"{epoch:03d}_{i:02d}" for i in range(len(data))]
|
| 47 |
+
# else:
|
| 48 |
+
# fname = [fname[i].decode().rstrip(".npz") for i in range(len(fname))]
|
| 49 |
+
|
| 50 |
+
# for i in range(min(len(data), max_fig)):
|
| 51 |
+
# plt.figure(i)
|
| 52 |
+
|
| 53 |
+
# plt.subplot(411)
|
| 54 |
+
# plt.plot(t, data[i, :, 0, 0], 'k', label='E', linewidth=0.5)
|
| 55 |
+
# plt.autoscale(enable=True, axis='x', tight=True)
|
| 56 |
+
# tmp_min = np.min(data[i, :, 0, 0])
|
| 57 |
+
# tmp_max = np.max(data[i, :, 0, 0])
|
| 58 |
+
# if (itp is not None) and (its is not None):
|
| 59 |
+
# for j in range(len(itp[i])):
|
| 60 |
+
# lb = "P" if j==0 else ""
|
| 61 |
+
# plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'C0', label=lb, linewidth=0.5)
|
| 62 |
+
# for j in range(len(its[i])):
|
| 63 |
+
# lb = "S" if j==0 else ""
|
| 64 |
+
# plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'C1', label=lb, linewidth=0.5)
|
| 65 |
+
# if (itps is not None):
|
| 66 |
+
# for j in range(len(itps[i])):
|
| 67 |
+
# lb = "PS" if j==0 else ""
|
| 68 |
+
# plt.plot([itps[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'C2', label=lb, linewidth=0.5)
|
| 69 |
+
# plt.ylabel('Amplitude')
|
| 70 |
+
# plt.legend(loc='upper right', fontsize='small')
|
| 71 |
+
# plt.gca().set_xticklabels([])
|
| 72 |
+
# plt.text(text_loc[0], text_loc[1], '(i)', horizontalalignment='center',
|
| 73 |
+
# transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 74 |
+
|
| 75 |
+
# plt.subplot(412)
|
| 76 |
+
# plt.plot(t, data[i, :, 0, 1], 'k', label='N', linewidth=0.5)
|
| 77 |
+
# plt.autoscale(enable=True, axis='x', tight=True)
|
| 78 |
+
# tmp_min = np.min(data[i, :, 0, 1])
|
| 79 |
+
# tmp_max = np.max(data[i, :, 0, 1])
|
| 80 |
+
# if (itp is not None) and (its is not None):
|
| 81 |
+
# for j in range(len(itp[i])):
|
| 82 |
+
# lb = "P" if j==0 else ""
|
| 83 |
+
# plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'C0', label=lb, linewidth=0.5)
|
| 84 |
+
# for j in range(len(its[i])):
|
| 85 |
+
# lb = "S" if j==0 else ""
|
| 86 |
+
# plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'C1', label=lb, linewidth=0.5)
|
| 87 |
+
# if (itps is not None):
|
| 88 |
+
# for j in range(len(itps[i])):
|
| 89 |
+
# lb = "PS" if j==0 else ""
|
| 90 |
+
# plt.plot([itps[i][j]*dt, itps[i][j]*dt], [tmp_min, tmp_max], 'C2', label=lb, linewidth=0.5)
|
| 91 |
+
# plt.ylabel('Amplitude')
|
| 92 |
+
# plt.legend(loc='upper right', fontsize='small')
|
| 93 |
+
# plt.gca().set_xticklabels([])
|
| 94 |
+
# plt.text(text_loc[0], text_loc[1], '(ii)', horizontalalignment='center',
|
| 95 |
+
# transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 96 |
+
|
| 97 |
+
# plt.subplot(413)
|
| 98 |
+
# plt.plot(t, data[i, :, 0, 2], 'k', label='Z', linewidth=0.5)
|
| 99 |
+
# plt.autoscale(enable=True, axis='x', tight=True)
|
| 100 |
+
# tmp_min = np.min(data[i, :, 0, 2])
|
| 101 |
+
# tmp_max = np.max(data[i, :, 0, 2])
|
| 102 |
+
# if (itp is not None) and (its is not None):
|
| 103 |
+
# for j in range(len(itp[i])):
|
| 104 |
+
# lb = "P" if j==0 else ""
|
| 105 |
+
# plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'C0', label=lb, linewidth=0.5)
|
| 106 |
+
# for j in range(len(its[i])):
|
| 107 |
+
# lb = "S" if j==0 else ""
|
| 108 |
+
# plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'C1', label=lb, linewidth=0.5)
|
| 109 |
+
# if (itps is not None):
|
| 110 |
+
# for j in range(len(itps[i])):
|
| 111 |
+
# lb = "PS" if j==0 else ""
|
| 112 |
+
# plt.plot([itps[i][j]*dt, itps[i][j]*dt], [tmp_min, tmp_max], 'C2', label=lb, linewidth=0.5)
|
| 113 |
+
# plt.ylabel('Amplitude')
|
| 114 |
+
# plt.legend(loc='upper right', fontsize='small')
|
| 115 |
+
# plt.gca().set_xticklabels([])
|
| 116 |
+
# plt.text(text_loc[0], text_loc[1], '(iii)', horizontalalignment='center',
|
| 117 |
+
# transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 118 |
+
|
| 119 |
+
# plt.subplot(414)
|
| 120 |
+
# if label is not None:
|
| 121 |
+
# plt.plot(t, label[i, :, 0, 1], 'C0', label='P', linewidth=1)
|
| 122 |
+
# plt.plot(t, label[i, :, 0, 2], 'C1', label='S', linewidth=1)
|
| 123 |
+
# if label.shape[-1] == 4:
|
| 124 |
+
# plt.plot(t, label[i, :, 0, 3], 'C2', label='PS', linewidth=1)
|
| 125 |
+
# plt.plot(t, pred[i, :, 0, 1], '--C0', label='$\hat{P}$', linewidth=1)
|
| 126 |
+
# plt.plot(t, pred[i, :, 0, 2], '--C1', label='$\hat{S}$', linewidth=1)
|
| 127 |
+
# if pred.shape[-1] == 4:
|
| 128 |
+
# plt.plot(t, pred[i, :, 0, 3], '--C2', label='$\hat{PS}$', linewidth=1)
|
| 129 |
+
# plt.autoscale(enable=True, axis='x', tight=True)
|
| 130 |
+
# if (itp_pred is not None) and (its_pred is not None) :
|
| 131 |
+
# for j in range(len(itp_pred)):
|
| 132 |
+
# plt.plot([itp_pred[j]*dt, itp_pred[j]*dt], [-0.1, 1.1], '--C0', linewidth=1)
|
| 133 |
+
# for j in range(len(its_pred)):
|
| 134 |
+
# plt.plot([its_pred[j]*dt, its_pred[j]*dt], [-0.1, 1.1], '--C1', linewidth=1)
|
| 135 |
+
# if (itps_pred is not None):
|
| 136 |
+
# for j in range(len(itps_pred)):
|
| 137 |
+
# plt.plot([itps_pred[j]*dt, itps_pred[j]*dt], [-0.1, 1.1], '--C2', linewidth=1)
|
| 138 |
+
# plt.ylim([-0.05, 1.05])
|
| 139 |
+
# plt.text(text_loc[0], text_loc[1], '(iv)', horizontalalignment='center',
|
| 140 |
+
# transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 141 |
+
# plt.legend(loc='upper right', fontsize='small', ncol=2)
|
| 142 |
+
# plt.xlabel('Time (s)')
|
| 143 |
+
# plt.ylabel('Probability')
|
| 144 |
+
# plt.tight_layout()
|
| 145 |
+
# plt.gcf().align_labels()
|
| 146 |
+
|
| 147 |
+
# try:
|
| 148 |
+
# plt.savefig(os.path.join(figure_dir, fname[i]+'.png'), bbox_inches='tight')
|
| 149 |
+
# except FileNotFoundError:
|
| 150 |
+
# os.makedirs(os.path.dirname(os.path.join(figure_dir, fname[i])), exist_ok=True)
|
| 151 |
+
# plt.savefig(os.path.join(figure_dir, fname[i]+'.png'), bbox_inches='tight')
|
| 152 |
+
|
| 153 |
+
# plt.close(i)
|
| 154 |
+
# return 0
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def plot_waveform(data, pred, fname, label=None,
|
| 158 |
+
itp=None, its=None, itps=None,
|
| 159 |
+
itp_pred=None, its_pred=None, itps_pred=None,
|
| 160 |
+
figure_dir="./", dt=0.01):
|
| 161 |
+
|
| 162 |
+
t = np.arange(0, pred.shape[0]) * dt
|
| 163 |
+
box = dict(boxstyle='round', facecolor='white', alpha=1)
|
| 164 |
+
text_loc = [0.05, 0.77]
|
| 165 |
+
|
| 166 |
+
plt.figure()
|
| 167 |
+
|
| 168 |
+
plt.subplot(411)
|
| 169 |
+
plt.plot(t, data[:, 0, 0], 'k', label='E', linewidth=0.5)
|
| 170 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
| 171 |
+
tmp_min = np.min(data[:, 0, 0])
|
| 172 |
+
tmp_max = np.max(data[:, 0, 0])
|
| 173 |
+
if (itp is not None) and (its is not None):
|
| 174 |
+
for j in range(len(itp)):
|
| 175 |
+
lb = "P" if j==0 else ""
|
| 176 |
+
plt.plot([itp[j]*dt, itp[j]*dt], [tmp_min, tmp_max], 'C0', label=lb, linewidth=0.5)
|
| 177 |
+
for j in range(len(its[i])):
|
| 178 |
+
lb = "S" if j==0 else ""
|
| 179 |
+
plt.plot([its[j]*dt, its[j]*dt], [tmp_min, tmp_max], 'C1', label=lb, linewidth=0.5)
|
| 180 |
+
if (itps is not None):
|
| 181 |
+
for j in range(len(itps)):
|
| 182 |
+
lb = "PS" if j==0 else ""
|
| 183 |
+
plt.plot([itps[j]*dt, its[j]*dt], [tmp_min, tmp_max], 'C2', label=lb, linewidth=0.5)
|
| 184 |
+
plt.ylabel('Amplitude')
|
| 185 |
+
plt.legend(loc='upper right', fontsize='small')
|
| 186 |
+
plt.gca().set_xticklabels([])
|
| 187 |
+
plt.text(text_loc[0], text_loc[1], '(i)', horizontalalignment='center',
|
| 188 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 189 |
+
|
| 190 |
+
plt.subplot(412)
|
| 191 |
+
plt.plot(t, data[:, 0, 1], 'k', label='N', linewidth=0.5)
|
| 192 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
| 193 |
+
tmp_min = np.min(data[:, 0, 1])
|
| 194 |
+
tmp_max = np.max(data[:, 0, 1])
|
| 195 |
+
if (itp is not None) and (its is not None):
|
| 196 |
+
for j in range(len(itp)):
|
| 197 |
+
lb = "P" if j==0 else ""
|
| 198 |
+
plt.plot([itp[j]*dt, itp[j]*dt], [tmp_min, tmp_max], 'C0', label=lb, linewidth=0.5)
|
| 199 |
+
for j in range(len(its)):
|
| 200 |
+
lb = "S" if j==0 else ""
|
| 201 |
+
plt.plot([its[j]*dt, its[j]*dt], [tmp_min, tmp_max], 'C1', label=lb, linewidth=0.5)
|
| 202 |
+
if (itps is not None):
|
| 203 |
+
for j in range(len(itps)):
|
| 204 |
+
lb = "PS" if j==0 else ""
|
| 205 |
+
plt.plot([itps[j]*dt, itps[j]*dt], [tmp_min, tmp_max], 'C2', label=lb, linewidth=0.5)
|
| 206 |
+
plt.ylabel('Amplitude')
|
| 207 |
+
plt.legend(loc='upper right', fontsize='small')
|
| 208 |
+
plt.gca().set_xticklabels([])
|
| 209 |
+
plt.text(text_loc[0], text_loc[1], '(ii)', horizontalalignment='center',
|
| 210 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 211 |
+
|
| 212 |
+
plt.subplot(413)
|
| 213 |
+
plt.plot(t, data[:, 0, 2], 'k', label='Z', linewidth=0.5)
|
| 214 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
| 215 |
+
tmp_min = np.min(data[:, 0, 2])
|
| 216 |
+
tmp_max = np.max(data[:, 0, 2])
|
| 217 |
+
if (itp is not None) and (its is not None):
|
| 218 |
+
for j in range(len(itp)):
|
| 219 |
+
lb = "P" if j==0 else ""
|
| 220 |
+
plt.plot([itp[j]*dt, itp[j]*dt], [tmp_min, tmp_max], 'C0', label=lb, linewidth=0.5)
|
| 221 |
+
for j in range(len(its)):
|
| 222 |
+
lb = "S" if j==0 else ""
|
| 223 |
+
plt.plot([its[j]*dt, its[j]*dt], [tmp_min, tmp_max], 'C1', label=lb, linewidth=0.5)
|
| 224 |
+
if (itps is not None):
|
| 225 |
+
for j in range(len(itps)):
|
| 226 |
+
lb = "PS" if j==0 else ""
|
| 227 |
+
plt.plot([itps[j]*dt, itps[j]*dt], [tmp_min, tmp_max], 'C2', label=lb, linewidth=0.5)
|
| 228 |
+
plt.ylabel('Amplitude')
|
| 229 |
+
plt.legend(loc='upper right', fontsize='small')
|
| 230 |
+
plt.gca().set_xticklabels([])
|
| 231 |
+
plt.text(text_loc[0], text_loc[1], '(iii)', horizontalalignment='center',
|
| 232 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 233 |
+
|
| 234 |
+
plt.subplot(414)
|
| 235 |
+
if label is not None:
|
| 236 |
+
plt.plot(t, label[:, 0, 1], 'C0', label='P', linewidth=1)
|
| 237 |
+
plt.plot(t, label[:, 0, 2], 'C1', label='S', linewidth=1)
|
| 238 |
+
if label.shape[-1] == 4:
|
| 239 |
+
plt.plot(t, label[:, 0, 3], 'C2', label='PS', linewidth=1)
|
| 240 |
+
plt.plot(t, pred[:, 0, 1], '--C0', label='$\hat{P}$', linewidth=1)
|
| 241 |
+
plt.plot(t, pred[:, 0, 2], '--C1', label='$\hat{S}$', linewidth=1)
|
| 242 |
+
if pred.shape[-1] == 4:
|
| 243 |
+
plt.plot(t, pred[:, 0, 3], '--C2', label='$\hat{PS}$', linewidth=1)
|
| 244 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
| 245 |
+
if (itp_pred is not None) and (its_pred is not None) :
|
| 246 |
+
for j in range(len(itp_pred)):
|
| 247 |
+
plt.plot([itp_pred[j]*dt, itp_pred[j]*dt], [-0.1, 1.1], '--C0', linewidth=1)
|
| 248 |
+
for j in range(len(its_pred)):
|
| 249 |
+
plt.plot([its_pred[j]*dt, its_pred[j]*dt], [-0.1, 1.1], '--C1', linewidth=1)
|
| 250 |
+
if (itps_pred is not None):
|
| 251 |
+
for j in range(len(itps_pred)):
|
| 252 |
+
plt.plot([itps_pred[j]*dt, itps_pred[j]*dt], [-0.1, 1.1], '--C2', linewidth=1)
|
| 253 |
+
plt.ylim([-0.05, 1.05])
|
| 254 |
+
plt.text(text_loc[0], text_loc[1], '(iv)', horizontalalignment='center',
|
| 255 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 256 |
+
plt.legend(loc='upper right', fontsize='small', ncol=2)
|
| 257 |
+
plt.xlabel('Time (s)')
|
| 258 |
+
plt.ylabel('Probability')
|
| 259 |
+
plt.tight_layout()
|
| 260 |
+
plt.gcf().align_labels()
|
| 261 |
+
|
| 262 |
+
try:
|
| 263 |
+
plt.savefig(os.path.join(figure_dir, fname+'.png'), bbox_inches='tight')
|
| 264 |
+
except FileNotFoundError:
|
| 265 |
+
os.makedirs(os.path.dirname(os.path.join(figure_dir, fname)), exist_ok=True)
|
| 266 |
+
plt.savefig(os.path.join(figure_dir, fname+'.png'), bbox_inches='tight')
|
| 267 |
+
|
| 268 |
+
plt.close()
|
| 269 |
+
return 0
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def plot_array(config, data, pred, label=None,
|
| 273 |
+
itp=None, its=None, itps=None,
|
| 274 |
+
itp_pred=None, its_pred=None, itps_pred=None,
|
| 275 |
+
fname=None, figure_dir="./", epoch=0):
|
| 276 |
+
|
| 277 |
+
dt = config.dt if hasattr(config, "dt") else 1.0
|
| 278 |
+
t = np.arange(0, pred.shape[1]) * dt
|
| 279 |
+
box = dict(boxstyle='round', facecolor='white', alpha=1)
|
| 280 |
+
text_loc = [0.05, 0.95]
|
| 281 |
+
if fname is None:
|
| 282 |
+
fname = [f"{epoch:03d}_{i:03d}" for i in range(len(data))]
|
| 283 |
+
else:
|
| 284 |
+
fname = [fname[i].decode().rstrip(".npz") for i in range(len(fname))]
|
| 285 |
+
|
| 286 |
+
for i in range(len(data)):
|
| 287 |
+
plt.figure(i, figsize=(10, 5))
|
| 288 |
+
plt.clf()
|
| 289 |
+
|
| 290 |
+
plt.subplot(121)
|
| 291 |
+
for j in range(data.shape[-2]):
|
| 292 |
+
plt.plot(t, data[i, :, j, 0]/10 + j, 'k', label='E', linewidth=0.5)
|
| 293 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
| 294 |
+
tmp_min = np.min(data[i, :, 0, 0])
|
| 295 |
+
tmp_max = np.max(data[i, :, 0, 0])
|
| 296 |
+
plt.xlabel('Time (s)')
|
| 297 |
+
plt.ylabel('Amplitude')
|
| 298 |
+
# plt.legend(loc='upper right', fontsize='small')
|
| 299 |
+
# plt.gca().set_xticklabels([])
|
| 300 |
+
plt.text(text_loc[0], text_loc[1], '(i)', horizontalalignment='center', verticalalignment="top",
|
| 301 |
+
transform=plt.gca().transAxes, fontsize="large", fontweight="normal", bbox=box)
|
| 302 |
+
|
| 303 |
+
plt.subplot(122)
|
| 304 |
+
for j in range(pred.shape[-2]):
|
| 305 |
+
if label is not None:
|
| 306 |
+
plt.plot(t, label[i, :, j, 1]+j, 'C2', label='P', linewidth=0.5)
|
| 307 |
+
plt.plot(t, label[i, :, j, 2]+j, 'C3', label='S', linewidth=0.5)
|
| 308 |
+
# plt.plot(t, label[i, :, j, 0]+j, 'C4', label='N', linewidth=0.5)
|
| 309 |
+
plt.plot(t, pred[i, :, j, 1]+j, 'C0', label='$\hat{P}$', linewidth=1)
|
| 310 |
+
plt.plot(t, pred[i, :, j, 2]+j, 'C1', label='$\hat{S}$', linewidth=1)
|
| 311 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
| 312 |
+
if (itp_pred is not None) and (its_pred is not None) and (itps_pred is not None):
|
| 313 |
+
for j in range(len(itp_pred)):
|
| 314 |
+
plt.plot([itp_pred[j]*dt, itp_pred[j]*dt], [-0.1, 1.1], '--C0', linewidth=1)
|
| 315 |
+
for j in range(len(its_pred)):
|
| 316 |
+
plt.plot([its_pred[j]*dt, its_pred[j]*dt], [-0.1, 1.1], '--C1', linewidth=1)
|
| 317 |
+
for j in range(len(itps_pred)):
|
| 318 |
+
plt.plot([itps_pred[j]*dt, itps_pred[j]*dt], [-0.1, 1.1], '--C2', linewidth=1)
|
| 319 |
+
# plt.ylim([-0.05, 1.05])
|
| 320 |
+
plt.text(text_loc[0], text_loc[1], '(ii)', horizontalalignment='center', verticalalignment="top",
|
| 321 |
+
transform=plt.gca().transAxes, fontsize="large", fontweight="normal", bbox=box)
|
| 322 |
+
# plt.legend(loc='upper right', fontsize='small', ncol=2)
|
| 323 |
+
plt.xlabel('Time (s)')
|
| 324 |
+
plt.ylabel('Probability')
|
| 325 |
+
plt.tight_layout()
|
| 326 |
+
plt.gcf().align_labels()
|
| 327 |
+
|
| 328 |
+
try:
|
| 329 |
+
plt.savefig(os.path.join(figure_dir, fname[i]+'.png'), bbox_inches='tight')
|
| 330 |
+
except FileNotFoundError:
|
| 331 |
+
os.makedirs(os.path.dirname(os.path.join(figure_dir, fname[i])), exist_ok=True)
|
| 332 |
+
plt.savefig(os.path.join(figure_dir, fname[i]+'.png'), bbox_inches='tight')
|
| 333 |
+
|
| 334 |
+
plt.close(i)
|
| 335 |
+
return 0
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def plot_spectrogram(config, data, pred, label=None,
|
| 339 |
+
itp=None, its=None, itps=None,
|
| 340 |
+
itp_pred=None, its_pred=None, itps_pred=None,
|
| 341 |
+
time=None, freq=None,
|
| 342 |
+
fname=None, figure_dir="./", epoch=0):
|
| 343 |
+
|
| 344 |
+
# dt = config.dt
|
| 345 |
+
# df = config.df
|
| 346 |
+
# t = np.arange(0, data.shape[1]) * dt
|
| 347 |
+
# f = np.arange(0, data.shape[2]) * df
|
| 348 |
+
t, f = time, freq
|
| 349 |
+
dt = t[1] - t[0]
|
| 350 |
+
box = dict(boxstyle='round', facecolor='white', alpha=1)
|
| 351 |
+
text_loc = [0.05, 0.75]
|
| 352 |
+
if fname is None:
|
| 353 |
+
fname = [f"{i:03d}" for i in range(len(data))]
|
| 354 |
+
elif type(fname[0]) is bytes:
|
| 355 |
+
fname = [f.decode() for f in fname]
|
| 356 |
+
|
| 357 |
+
numbers = ["(i)", "(ii)", "(iii)", "(iv)"]
|
| 358 |
+
for i in range(len(data)):
|
| 359 |
+
fig = plt.figure(i)
|
| 360 |
+
# gs = fig.add_gridspec(4, 1)
|
| 361 |
+
|
| 362 |
+
for j in range(3):
|
| 363 |
+
# fig.add_subplot(gs[j, 0])
|
| 364 |
+
plt.subplot(4,1,j+1)
|
| 365 |
+
plt.pcolormesh(t, f, np.abs(data[i, :, :, j]+1j*data[i, :, :, j+3]).T, vmax=2*np.std(data[i, :, :, j]+1j*data[i, :, :, j+3]), cmap="jet", shading='auto')
|
| 366 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
| 367 |
+
plt.gca().set_xticklabels([])
|
| 368 |
+
if j == 1:
|
| 369 |
+
plt.ylabel('Frequency (Hz)')
|
| 370 |
+
plt.text(text_loc[0], text_loc[1], numbers[j], horizontalalignment='center',
|
| 371 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 372 |
+
|
| 373 |
+
# fig.add_subplot(gs[-1, 0])
|
| 374 |
+
plt.subplot(4,1,4)
|
| 375 |
+
if label is not None:
|
| 376 |
+
plt.plot(t, label[i, :, 0, 1], '--C0', linewidth=1)
|
| 377 |
+
plt.plot(t, label[i, :, 0, 2], '--C3', linewidth=1)
|
| 378 |
+
plt.plot(t, label[i, :, 0, 3], '--C1', linewidth=1)
|
| 379 |
+
plt.plot(t, pred[i, :, 0, 1], 'C0', label='P', linewidth=1)
|
| 380 |
+
plt.plot(t, pred[i, :, 0, 2], 'C3', label='S', linewidth=1)
|
| 381 |
+
plt.plot(t, pred[i, :, 0, 3], 'C1', label='PS', linewidth=1)
|
| 382 |
+
plt.plot(t, t*0, 'k', linewidth=1)
|
| 383 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
| 384 |
+
if (itp_pred is not None) and (its_pred is not None) and (itps_pred is not None):
|
| 385 |
+
for j in range(len(itp_pred)):
|
| 386 |
+
plt.plot([itp_pred[j]*dt, itp_pred[j]*dt], [-0.1, 1.1], ':C3', linewidth=1)
|
| 387 |
+
for j in range(len(its_pred)):
|
| 388 |
+
plt.plot([its_pred[j]*dt, its_pred[j]*dt], [-0.1, 1.1], '-.C6', linewidth=1)
|
| 389 |
+
for j in range(len(itps_pred)):
|
| 390 |
+
plt.plot([itps_pred[j]*dt, itps_pred[j]*dt], [-0.1, 1.1], '--C8', linewidth=1)
|
| 391 |
+
plt.ylim([-0.05, 1.05])
|
| 392 |
+
plt.text(text_loc[0], text_loc[1], numbers[-1], horizontalalignment='center',
|
| 393 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 394 |
+
plt.legend(loc='upper right', fontsize='small', ncol=1)
|
| 395 |
+
plt.xlabel('Time (s)')
|
| 396 |
+
plt.ylabel('Probability')
|
| 397 |
+
# plt.tight_layout()
|
| 398 |
+
plt.gcf().align_labels()
|
| 399 |
+
|
| 400 |
+
try:
|
| 401 |
+
plt.savefig(os.path.join(figure_dir, f'{epoch:02d}_'+fname[i]+'.png'), bbox_inches='tight')
|
| 402 |
+
except FileNotFoundError:
|
| 403 |
+
os.makedirs(os.path.dirname(os.path.join(figure_dir, fname[i])), exist_ok=True)
|
| 404 |
+
plt.savefig(os.path.join(figure_dir, f'{epoch:02d}_'+fname[i]+'.png'), bbox_inches='tight')
|
| 405 |
+
|
| 406 |
+
plt.close(i)
|
| 407 |
+
return 0
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def plot_spectrogram_waveform(config, spectrogram, waveform, pred, label=None,
|
| 411 |
+
itp=None, its=None, itps=None, picks=None,
|
| 412 |
+
time=None, freq=None,
|
| 413 |
+
fname=None, figure_dir="./", epoch=0):
|
| 414 |
+
|
| 415 |
+
# dt = config.dt
|
| 416 |
+
# df = config.df
|
| 417 |
+
# t = np.arange(0, spectrogram.shape[1]) * dt
|
| 418 |
+
# f = np.arange(0, spectrogram.shape[2]) * df
|
| 419 |
+
t, f = time, freq
|
| 420 |
+
dt = t[1] - t[0]
|
| 421 |
+
box = dict(boxstyle='round', facecolor='white', alpha=1)
|
| 422 |
+
text_loc = [0.02, 0.90]
|
| 423 |
+
if fname is None:
|
| 424 |
+
fname = [f"{i:03d}" for i in range(len(spectrogram))]
|
| 425 |
+
elif type(fname[0]) is bytes:
|
| 426 |
+
fname = [f.decode() for f in fname]
|
| 427 |
+
|
| 428 |
+
numbers = ["(i)", "(ii)", "(iii)", "(iv)", "(v)", "(vi)", "(vii)"]
|
| 429 |
+
for i in range(len(spectrogram)):
|
| 430 |
+
fig = plt.figure(i, figsize=(6.4, 10))
|
| 431 |
+
# gs = fig.add_gridspec(4, 1)
|
| 432 |
+
|
| 433 |
+
for j in range(3):
|
| 434 |
+
# fig.add_subplot(gs[j, 0])
|
| 435 |
+
plt.subplot(7,1,j*2+1)
|
| 436 |
+
plt.plot(waveform[i,:,j], 'k', linewidth=0.5)
|
| 437 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
| 438 |
+
plt.gca().set_xticklabels([])
|
| 439 |
+
plt.ylabel('')
|
| 440 |
+
plt.text(text_loc[0], text_loc[1], numbers[j*2], horizontalalignment='left', verticalalignment='top',
|
| 441 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 442 |
+
|
| 443 |
+
for j in range(3):
|
| 444 |
+
# fig.add_subplot(gs[j, 0])
|
| 445 |
+
plt.subplot(7,1,j*2+2)
|
| 446 |
+
plt.pcolormesh(t, f, np.abs(spectrogram[i, :, :, j]+1j*spectrogram[i, :, :, j+3]).T, vmax=2*np.std(spectrogram[i, :, :, j]+1j*spectrogram[i, :, :, j+3]), cmap="jet", shading='auto')
|
| 447 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
| 448 |
+
plt.gca().set_xticklabels([])
|
| 449 |
+
if j == 1:
|
| 450 |
+
plt.ylabel('Frequency (Hz) or Amplitude')
|
| 451 |
+
plt.text(text_loc[0], text_loc[1], numbers[j*2+1], horizontalalignment='left', verticalalignment='top',
|
| 452 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 453 |
+
|
| 454 |
+
# fig.add_subplot(gs[-1, 0])
|
| 455 |
+
plt.subplot(7,1,7)
|
| 456 |
+
if label is not None:
|
| 457 |
+
plt.plot(t, label[i, :, 0, 1], '--C0', linewidth=1)
|
| 458 |
+
plt.plot(t, label[i, :, 0, 2], '--C3', linewidth=1)
|
| 459 |
+
plt.plot(t, label[i, :, 0, 3], '--C1', linewidth=1)
|
| 460 |
+
plt.plot(t, pred[i, :, 0, 1], 'C0', label='P', linewidth=1)
|
| 461 |
+
plt.plot(t, pred[i, :, 0, 2], 'C3', label='S', linewidth=1)
|
| 462 |
+
plt.plot(t, pred[i, :, 0, 3], 'C1', label='PS', linewidth=1)
|
| 463 |
+
plt.plot(t, t*0, 'k', linewidth=1)
|
| 464 |
+
plt.autoscale(enable=True, axis='x', tight=True)
|
| 465 |
+
plt.ylim([-0.05, 1.05])
|
| 466 |
+
plt.text(text_loc[0], text_loc[1], numbers[-1], horizontalalignment='left', verticalalignment='top',
|
| 467 |
+
transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box)
|
| 468 |
+
plt.legend(loc='upper right', fontsize='small', ncol=1)
|
| 469 |
+
plt.xlabel('Time (s)')
|
| 470 |
+
plt.ylabel('Probability')
|
| 471 |
+
# plt.tight_layout()
|
| 472 |
+
plt.gcf().align_labels()
|
| 473 |
+
|
| 474 |
+
try:
|
| 475 |
+
plt.savefig(os.path.join(figure_dir, f'{epoch:02d}_'+fname[i]+'.png'), bbox_inches='tight')
|
| 476 |
+
except FileNotFoundError:
|
| 477 |
+
os.makedirs(os.path.dirname(os.path.join(figure_dir, fname[i])), exist_ok=True)
|
| 478 |
+
plt.savefig(os.path.join(figure_dir, f'{epoch:02d}_'+fname[i]+'.png'), bbox_inches='tight')
|
| 479 |
+
|
| 480 |
+
plt.close(i)
|
| 481 |
+
return 0
|
pipeline.py
CHANGED
|
@@ -2,6 +2,12 @@ from typing import Dict, List
|
|
| 2 |
import numpy as np
|
| 3 |
import tensorflow as tf
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
class PreTrainedPipeline():
|
| 6 |
def __init__(self, path=""):
|
| 7 |
# IMPLEMENT_THIS
|
|
@@ -11,7 +17,23 @@ class PreTrainedPipeline():
|
|
| 11 |
# raise NotImplementedError(
|
| 12 |
# "Please implement PreTrainedPipeline __init__ function"
|
| 13 |
# )
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
def __call__(self, inputs: str) -> List[List[Dict[str, float]]]:
|
| 17 |
"""
|
|
@@ -27,4 +49,20 @@ class PreTrainedPipeline():
|
|
| 27 |
# raise NotImplementedError(
|
| 28 |
# "Please implement PreTrainedPipeline __call__ function"
|
| 29 |
# )
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
import tensorflow as tf
|
| 4 |
|
| 5 |
+
from phasenet.model import ModelConfig, UNet
|
| 6 |
+
from phasenet.postprocess import extract_picks
|
| 7 |
+
|
| 8 |
+
tf.compat.v1.disable_eager_execution()
|
| 9 |
+
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
| 10 |
+
|
| 11 |
class PreTrainedPipeline():
|
| 12 |
def __init__(self, path=""):
|
| 13 |
# IMPLEMENT_THIS
|
|
|
|
| 17 |
# raise NotImplementedError(
|
| 18 |
# "Please implement PreTrainedPipeline __init__ function"
|
| 19 |
# )
|
| 20 |
+
|
| 21 |
+
## load model
|
| 22 |
+
model = UNet(mode="pred")
|
| 23 |
+
sess_config = tf.compat.v1.ConfigProto()
|
| 24 |
+
sess_config.gpu_options.allow_growth = True
|
| 25 |
+
|
| 26 |
+
sess = tf.compat.v1.Session(config=sess_config)
|
| 27 |
+
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())
|
| 28 |
+
init = tf.compat.v1.global_variables_initializer()
|
| 29 |
+
sess.run(init)
|
| 30 |
+
latest_check_point = tf.train.latest_checkpoint(f"model/190703-214543")
|
| 31 |
+
print(f"restoring model {latest_check_point}")
|
| 32 |
+
saver.restore(sess, latest_check_point)
|
| 33 |
+
|
| 34 |
+
##
|
| 35 |
+
self.sess = sess
|
| 36 |
+
self.model = model
|
| 37 |
|
| 38 |
def __call__(self, inputs: str) -> List[List[Dict[str, float]]]:
|
| 39 |
"""
|
|
|
|
| 49 |
# raise NotImplementedError(
|
| 50 |
# "Please implement PreTrainedPipeline __call__ function"
|
| 51 |
# )
|
| 52 |
+
|
| 53 |
+
vec = np.array(inputs)[np.newaxis, :, np.newaxis, :]
|
| 54 |
+
|
| 55 |
+
feed = {self.model.X: vec, self.model.drop_rate: 0, self.model.is_training: False}
|
| 56 |
+
preds = self.sess.run(self.model.preds, feed_dict=feed)
|
| 57 |
+
|
| 58 |
+
picks = extract_picks(preds)#, station_ids=data.id, begin_times=data.timestamp, waveforms=vec_raw)
|
| 59 |
+
|
| 60 |
+
# picks = [{k: v for k, v in pick.items() if k in ["station_id", "phase_time", "phase_score", "phase_type", "dt"]} for pick in picks]
|
| 61 |
+
|
| 62 |
+
return picks
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
if __name__ == "__main__":
|
| 66 |
+
pipeline = PreTrainedPipeline()
|
| 67 |
+
inputs = np.random.rand(1000, 3).tolist()
|
| 68 |
+
picks = pipeline(inputs)
|
requirements.txt
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
tensorflow
|
|
|
|
| 1 |
+
tensorflow
|