Skip to content

Commit

Permalink
documentation for pytorch.py
Browse files Browse the repository at this point in the history
  • Loading branch information
lpdcalves committed May 13, 2021
1 parent 200f733 commit e6d78e8
Showing 1 changed file with 37 additions and 0 deletions.
37 changes: 37 additions & 0 deletions urnai/models/memory_representations/neural_network/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,43 @@
import numpy as np

class PyTorchDeepNeuralNetwork(ABNeuralNetwork):
"""
Implementation of a Generic Deep Neural Network using PyTorch
This class inherits from ABNeuralNetwork, so it already has all abstract methods
necessary for learning, predicting outputs, and building a model. All that this
class does is implement those abstract methods, and implement the methods necessary
for adding Neural Network layers, such as add_input_layer(), add_output_layer(),
add_fully_connected_layer().
Differently from KerasDeepNeuralNetwork, this class is not able to dynamically build
Neural Network architectures that use convolutional layers, due to the complexity
of PyTorch's initialization of convolutional layers and a general difficulty in
fitting that complexity to URNAI's achitecture. To use PyTorch with convolutional
layers, one has to inherit from this class and manually create your model, we recommend
you do that in self.make_model().
This class also implements the methodes necessary for saving and loading the model
from local memory.
Parameters:
action_output_size: int
size of our output
state_input_shape: tuple
shape of our input
build_model: Python dict
A dict representing the NN's layers. Can be generated by the
ModelBuilder.get_model_layout() method from an instantiated ModelBuilder object.
gamma: Float
Gamma parameter for the Deep Q Learning algorithm
alpha: Float
This is the Learning Rate of the model
seed: Integer (default None)
Value to assing to random number generators in Python and our ML libraries to try
and create reproducible experiments
batch_size: Integer
Size of our learning batch to be passed to the Machine Learning library
"""

def __init__(self, action_output_size, state_input_shape, build_model, gamma, alpha, seed = None, batch_size=32):

Expand Down

0 comments on commit e6d78e8

Please sign in to comment.