How to use callback? #1086
-
I am trying to stop the training when the total training loss is less than 1e-2. I could not find an inbuilt function to do this. So, I thought I would use the callback I tried this implementation. class LossConvergence(dde.callbacks.Callback):
def __init__(self):
super().__init__()
def on_epoch_begin(self):
print('Training loss',sum(self.model.train_state.loss_train))
if sum(model.train_state.loss_train) < 1e-2:
break
losshistory, train_state = model.train(epochs=100, display_every=1,callbacks=[LossConvergence]) But I get the following error. ---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/home/hell/Desktop/PhD/PhD work/PINNs/16. December 2022/3. Week 3/1. UKACM abstract work/2. deepxde solution/2. nice framework/2D heat conduction.ipynb Cell 16 in <cell line: 2>()
[1](vscode-notebook-cell:/home/hell/Desktop/PhD/PhD%20work/PINNs/16.%20December%202022/3.%20Week%203/1.%20UKACM%20abstract%20work/2.%20deepxde%20solution/2.%20nice%20%20framework/2D%20heat%20conduction.ipynb#X20sZmlsZQ%3D%3D?line=0) # Training with Adam
----> [2](vscode-notebook-cell:/home/hell/Desktop/PhD/PhD%20work/PINNs/16.%20December%202022/3.%20Week%203/1.%20UKACM%20abstract%20work/2.%20deepxde%20solution/2.%20nice%20%20framework/2D%20heat%20conduction.ipynb#X20sZmlsZQ%3D%3D?line=1) losshistory, train_state = model.train(epochs=100, display_every=1,callbacks=[LossConvergence])
File ~/anaconda3/lib/python3.9/site-packages/deepxde/utils/internal.py:22, in timing.<locals>.wrapper(*args, **kwargs)
19 @wraps(f)
20 def wrapper(*args, **kwargs):
21 ts = timeit.default_timer()
---> 22 result = f(*args, **kwargs)
23 te = timeit.default_timer()
24 print("%r took %f s\n" % (f.__name__, te - ts))
File ~/anaconda3/lib/python3.9/site-packages/deepxde/model.py:543, in Model.train(self, iterations, batch_size, display_every, disregard_previous_best, callbacks, model_restore_path, model_save_path, epochs)
541 self.batch_size = batch_size
542 self.callbacks = CallbackList(callbacks=callbacks)
--> 543 self.callbacks.set_model(self)
544 if disregard_previous_best:
545 self.train_state.disregard_best()
File ~/anaconda3/lib/python3.9/site-packages/deepxde/callbacks.py:70, in CallbackList.set_model(self, model)
68 self.model = model
69 for callback in self.callbacks:
---> 70 callback.set_model(model)
TypeError: set_model() missing 1 required positional argument: 'model' How do I correctly implement this callback? Is there an easier way than modifying the source code? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Ok, I solved this issue. I am silly I was using the class definition in # https://github.com/lululxvi/deepxde/discussions/1086#discussion-4657568
class LossConvergence(dde.callbacks.Callback):
def __init__(self):
super().__init__()
def on_epoch_begin(self):
print('Training loss : ',sum(self.model.train_state.loss_train))
if sum(model.train_state.loss_train) < 0.5:
print('Loss converged')
self.model.stop_training = True
loss_early_stop = LossConvergence()
losshistory, train_state = model.train(epochs=100, display_every=1000,callbacks=[loss_early_stop]) The output looks like this: Warning: epochs is deprecated and will be removed in a future version. Use iterations instead.
Training model...
Step Train loss Test loss Test metric
200 [9.69e-03, 3.57e-02, 3.81e-02, 3.79e-02, 4.14e-02] [9.69e-03, 3.57e-02, 3.81e-02, 3.79e-02, 4.14e-02] []
Training loss : 0.1628757631406188
Loss converged
Best model at step 200:
train loss: 1.63e-01
test loss: 1.63e-01
test metric: []
'train' took 0.076761 s |
Beta Was this translation helpful? Give feedback.
Ok, I solved this issue. I am silly I was using the class definition in
model.train()
. I have to create the object from that class and than pass it to themodel.train()
. This is my final code.