Skip to content

Commit

Permalink
Update dlfilter.py
Browse files Browse the repository at this point in the history
  • Loading branch information
bbfrederick committed Oct 2, 2024
1 parent fcd9098 commit f51291a
Showing 1 changed file with 17 additions and 24 deletions.
41 changes: 17 additions & 24 deletions rapidtide/dlfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@

os.environ["TF_USE_LEGACY_KERAS"] = "1"

try:
"""try:
import tensorflow.compat.v1 as tf
# tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
Expand All @@ -72,28 +72,21 @@
raise ImportError("no backend found - exiting")
if tfversion == 2:
LGR.debug("using tensorflow v2x")
tf.disable_v2_behavior()
from tensorflow.keras.callbacks import ModelCheckpoint, TerminateOnNaN
from tensorflow.keras.layers import (
LSTM,
Activation,
BatchNormalization,
Bidirectional,
Convolution1D,
Dense,
Dropout,
GlobalMaxPool1D,
MaxPooling1D,
TimeDistributed,
UpSampling1D,
)
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.optimizers import RMSprop
if tfversion == 2:"""

LGR.debug(f"tensorflow version: >>>{tf.__version__}<<<")
elif tfversion == 1:
import tensorflow.compat.v1 as tf

LGR.debug("using tensorflow v2x")
# tf.disable_v2_behavior()
from tensorflow.keras.callbacks import ModelCheckpoint, TerminateOnNaN
from tensorflow.keras.layers import LSTM, Activation, BatchNormalization, Bidirectional, Convolution1D, Dense, Dropout, GlobalMaxPool1D, MaxPooling1D, TimeDistributed, UpSampling1D,

from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.optimizers import RMSprop

LGR.debug(f"tensorflow version: >>>{tf.__version__}<<<")

"""elif tfversion == 1:
LGR.debug("using tensorflow v1x")
from keras.callbacks import ModelCheckpoint, TerminateOnNaN
from keras.layers import (
Expand All @@ -118,7 +111,7 @@
elif tfversion == 0:
pass
else:
raise ImportError("could not find backend - exiting")
raise ImportError("could not find backend - exiting")"""


class DeepLearningFilter:
Expand Down Expand Up @@ -371,7 +364,6 @@ def loadmodel(self, modelname, usehdf=True, verbose=False):
self.model.load_weights(os.path.join(self.modelname, "model_weights.h5"))
if verbose:
self.model.summary()
LGR.info(f"{modelname} loaded")

# now load additional information
self.infodict = tide_io.readdictfromjson(
Expand All @@ -383,6 +375,7 @@ def loadmodel(self, modelname, usehdf=True, verbose=False):
# model is ready to use
self.initialized = True
self.trained = True
LGR.info(f"{modelname} loaded")

def initialize(self):
self.getname()
Expand Down

0 comments on commit f51291a

Please sign in to comment.