diff --git a/urnai/models/memory_representations/neural_network/pytorch.py b/urnai/models/memory_representations/neural_network/pytorch.py index 5d817be0..c1648a81 100644 --- a/urnai/models/memory_representations/neural_network/pytorch.py +++ b/urnai/models/memory_representations/neural_network/pytorch.py @@ -6,6 +6,43 @@ import numpy as np class PyTorchDeepNeuralNetwork(ABNeuralNetwork): + """ + Implementation of a Generic Deep Neural Network using PyTorch + + This class inherits from ABNeuralNetwork, so it already has all abstract methods + necessary for learning, predicting outputs, and building a model. All that this + class does is implement those abstract methods, and implement the methods necessary + for adding Neural Network layers, such as add_input_layer(), add_output_layer(), + add_fully_connected_layer(). + + Differently from KerasDeepNeuralNetwork, this class is not able to dynamically build + Neural Network architectures that use convolutional layers, due to the complexity + of PyTorch's initialization of convolutional layers and a general difficulty in + fitting that complexity to URNAI's achitecture. To use PyTorch with convolutional + layers, one has to inherit from this class and manually create your model, we recommend + you do that in self.make_model(). + + This class also implements the methodes necessary for saving and loading the model + from local memory. + + Parameters: + action_output_size: int + size of our output + state_input_shape: tuple + shape of our input + build_model: Python dict + A dict representing the NN's layers. Can be generated by the + ModelBuilder.get_model_layout() method from an instantiated ModelBuilder object. + gamma: Float + Gamma parameter for the Deep Q Learning algorithm + alpha: Float + This is the Learning Rate of the model + seed: Integer (default None) + Value to assing to random number generators in Python and our ML libraries to try + and create reproducible experiments + batch_size: Integer + Size of our learning batch to be passed to the Machine Learning library + """ def __init__(self, action_output_size, state_input_shape, build_model, gamma, alpha, seed = None, batch_size=32):