Skip to content

How to use callback? #1086

Answered by praksharma
praksharma asked this question in Q&A
Discussion options

You must be logged in to vote

Ok, I solved this issue. I am silly I was using the class definition in model.train(). I have to create the object from that class and than pass it to the model.train(). This is my final code.

# https://github.com/lululxvi/deepxde/discussions/1086#discussion-4657568
class LossConvergence(dde.callbacks.Callback):
    def __init__(self):
        super().__init__()

    def on_epoch_begin(self):
        print('Training loss : ',sum(self.model.train_state.loss_train))
        if sum(model.train_state.loss_train) < 0.5:
            print('Loss converged')
            self.model.stop_training = True

loss_early_stop = LossConvergence()

losshistory, train_state = model.train(epochs=100, display…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by praksharma
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant