gflownet is a library built upon PyTorch to easily train and extend GFlowNets, also known as GFN or generative flow networks. GFlowNets are a machine learning framework for probabilistic and generative modelling, with a wide range of applications, especially in scientific discovery problems.
In a nutshell, GFlowNets can be regarded as a generative model designed to sample objects
GFlowNets rely on the principle of compositionality to generate samples. A meaningful decomposition of samples
Consider the problem of generating Tetris-like boards. A natural decomposition of the sample generation process would be to add one piece at a time, starting from an empty board. For any state representing a board with pieces, we could identify its valid parents and children, as illustrated in the figure below.
We could define a reward function
The GFlowNet library comprises four core components: environment, proxy, policy models (forward and backward), and GFlowNet agent.
The environment defines the state space
The Scrabble environment simulates a simple letter arrangement game where words are constructed by adding one letter at a time, up to a maximum sequence length (typically 7). Therefore, the action space is the set of all English letters plus a special end-of-sequence (EOS) action; and the state space is the set of all possible words with up to 7 letters. We can represent each state
as a list of indices corresponding to the letters, padded with zeroes to the maximum length. For example, the state for the word "CAT" would be represented as [3, 1, 20, 0, 0, 0, 0]
. Actions in the Scrabble environment are single-element tuples containing the index of the letter, plus the end-of-sequence (EOS) action (-1,)
.
Using the gflownet library for a new task will typically require implementing your own environment. The library is particularly designed to make such extensions as easy as possible. In the documentation, we show how to do it step by step. You can also watch this live-coding tutorial on how to code the Scrabble environment.
We use the term "proxy" to refer to the function or model that provides the rewards for the states of an environment. In other words, In the context of GFlowNets, the proxy can be thought of as a function ScrabbleScorer
proxy computes the sum of the score of each letter of a word. For the word "CAT" that is
Adapting the gflownet library for a new task will also likely require implementing your own proxy, which is usually fairly simple, as illustrated in the documentation.
The policy models are neural networks that model the forward and backward transitions between states,
The GFlowNet Agent is the central component that ties all others together. It orchestrates the interaction between the environment, policies, and proxy, as well as other auxiliary components such as the Evaluator and the Logger. The GFlowNet can construct training batches by sampling trajectories, optimise the policy models via gradient descent, compute evaluation metrics, log data to Weights & Biases, etc. The agent can be configured to optimise any of the following loss functions implemented in the library: flow matching (FM), trajectory balance (TB), detailed balance (TB) and forward-looking (FL).
If you simply want to install everything, clone the repo and run setup_all.sh
:
git clone git@github.com:alexhernandezgarcia/gflownet.git
cd gflownet
./setup_all.sh
- This project requires
python 3.10
andcuda 11.8
. - Setup is currently only supported on Ubuntu. It should also work on OSX, but you will need to handle the package dependencies.
- The recommend installation is as follows:
python3.10 -m venv ~/envs/gflownet # Initalize your virtual env.
source ~/envs/gflownet/bin/activate # Activate your environment.
./prereq_ubuntu.sh # Installs some packages required by dependencies.
./prereq_python.sh # Installs python packages with specific wheels.
./prereq_geometric.sh # OPTIONAL - for the molecule environment.
pip install .[all] # Install the remaining elements of this package.
Aside from the base packages, you can optionally install dev
tools using this tag, materials
dependencies using this tag, or molecules
packages using this tag. The simplest option is to use the all
tag, as above, which installs all dependencies.
The gflownet library uses Hydra to handle configuration files. This allows, for instance, to easily train a GFlowNet with the configuration of a specific YAML file. For example, to train a GFlowNet with a 10x10 Grid environment and the corners proxy, with the configuration from ./config/experiments/grid/corners.yaml
, we can simply run:
python main.py +experiments=grid/corners
Alternatively, we can explicitly indicate the environment and the proxy as follows:
python main.py env=grid proxy=box/corners
The above command will train a GFlowNet with the default configuration, except for the environment, which will use ./config/env/grid.yaml
; and the proxy, which will use ./config/proxy/box/corners.yaml
.
A typical use case of the gflownet library is to extend it with a new environment and a new proxy to fit your purposes. In that case, you could create their respective configuration files ./config/env/myenv.yaml
and ./config/proxy/myproxy.yaml
and run
python main.py env=myenv proxy=myproxy
The objective function to optimise is selected directly via the gflownet
configuration. The following GFlowNet objectives are supported:
- Flow-matching (FM):
gflownet=flowmatch
- Trajectory balance (TB):
gflownet=trajectorybalance
- Detailed balance (DB):
gflownet=detailedbalance
- Forward-looking (FL):
gflownet=forwardlooking
All other configurable options are handled similarly. For example, we recommend creating a user configuration file in ./config/user/myusername.yaml
specifying the directory for the log files in logdir.root
. Then, it can be included in the command with user=myusername
or user=$USER
if the name of the YAML file matches our system username.
As another example, you may also want to configure the functionality of the Logger, the class which helps manage logging to Weights & Biases during the training and evaluation of the model. Logging to WandB is disabled by default. In order to enable it, make sure to set up your WandB API key and set the configuration variable logger.do.online
to True
in your experiment config file or via the command line:
python main.py +experiments=grid/corners logger.do.online=True
Finally, also note that by default, PyTorch will operate on the CPU because we have not observed performance improvements by running on the GPU. You may run on GPU with device=cuda
.
To better understand the functionality and implementation of GFlowNet environments, let us explore the Scrabble environment in more detail.
- Instantiating a Scrabble environment
from gflownet.envs.scrabble import Scrabble
env = Scrabble()
- Checking the initial (source) state
Every environment has a state
attribute, which gets updated as actions are performed. The initial state correspond to the source
state:
env.state
>>> [0, 0, 0, 0, 0, 0, 0]
env.equal(env.state, env.source)
>>> True
In the Scrabble environment, the state is represented by a list of letter indices, padded by 0's up to the maximum word length (7 by default).
- Checking the action space
The actions of every environment are represented by tuples, and the set of all possible actions makes the action space:
env.action_space
>>> [(1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,), (9,), (10,), (11,), (12,), (13,), (14,), (15,), (16,), (17,), (18,), (19,), (20,), (21,), (22,), (23,), (24,), (25,), (26,), (-1,)]
In the Scrabble environment, the actions to append a letter from the English alphabet is represented by a single-element tuple with the letter index, from 1 to 26. The action space also contains (-1,) which represents the end-of-sequence (EOS) action, indicating the termination of word formation.
env.eos
>>> (-1,)
- Performing a step
We can apply one action from the action space to perform a state transition via the step()
method:
action = (1,) # Action to add 'A'
new_state, performed_action, is_valid = env.step(action)
print("Updated state:", new_state)
print("Performed action:", performed_action)
print("Action was valid:", is_valid)
>>> Updated state: [1, 0, 0, 0, 0, 0, 0]
>>> Performed action: (1,)
>>> Action was valid: True
env.equal(env.state, new_state)
>>> True
This function randomly selects a valid action (adding a letter or ending the sequence) and applies it to the environment. The output shows the new state, the action taken, and whether the action was valid.
- Performing a random step
We can also use the method step_random()
to perform a randomly sampled action:
new_state, performed_action, is_valid = env.step_random()
print("Updated state:", new_state)
print("Performed action:", performed_action)
print("Action was valid:", is_valid)
>>> Updated state: [1, 24, 0, 0, 0, 0, 0]
>>> Performed action: (24,)
>>> Action was valid: True
- Unfolding a full random trajectory
Similarly, we can also unfold a complete random trajectory, that is a sequence of actions terminated by the EOS action:
final_state, trajectory_actions = env.trajectory_random()
print("Final state:", final_state)
print("Sequence of actions:", trajectory_actions)
print("Trajectory is done:", env.done)
>>> Final state: [1, 24, 10, 6, 4, 21, 21]
>>> Sequence of actions: [(1,), (24,), (10,), (6,), (4,), (21,), (21,), (-1,)]
>>> Trajectory is done: True
- Displaying the state as a human readable string
env.state2readable()
>>> 'A X J F D U U'
- Reset enviroment
env.reset()
env.state
>>> [0, 0, 0, 0, 0, 0, 0]
So far, we've seen how to manually set actions or use random actions in the GFlowNet environment. This approach is useful for testing or understanding the basic mechanics of the environment. However, in practice, the goal of a GFlowNet agent is to adjust the parameters of the policy model to sample actions that result in trajectories with likelihoods proportional to the reward.
As the agent interacts with the environment, it collects data about the outcomes of its actions. This data is used to train the policy networks, which model the probability of state transitions given the current state.
- Sample a batch of trajectories from a trained agent
batch, _ = gflownet.sample_batch(n_forward=3, train=False)
batch.states
>>> [[20, 20, 21, 3, 0, 0, 0], [12, 16, 8, 6, 14, 11, 20], [17, 17, 16, 23, 20, 16, 24]]
We can convert the first state to human readable:
env.state2readable(batch.states[0])
>>> 'T T U C'
We can also compute the rewards and the proxy for all states or single state.
proxy(env.states2proxy(batch.states))
>>> tensor([ 6., 19., 39.])
Or single state
proxy(env.state2proxy(batch.states[0]))
>>> tensor([6.])
The state2proxy
and states2proxy
are helper functions that transform the input to appropriate format. For example to tensor.
We can also compute the rewards, and since our transformation function g
is the identity, the rewards should be equal to the proxy directly.
proxy.rewards(env.states2proxy(batch.states))
>>> tensor([ 6., 19., 39.])
Many wonderful scientists and developers have contributed to this repository: Alex Hernandez-Garcia, Nikita Saxena, Alexandra Volokhova, Michał Koziarski, Divya Sharma, Pierre Luc Carrier and Victor Schmidt.
This repository has been used in at least the following research articles:
- Lahlou et al. A theory of continuous generative flow networks. ICML, 2023.
- Hernandez-Garcia, Saxena et al. Multi-fidelity active learning with GFlowNets. RealML at NeurIPS 2023.
- Mila AI4Science et al. Crystal-GFN: sampling crystals with desirable properties and constraints. AI4Mat at NeurIPS 2023 (spotlight).
- Volokhova, Koziarski et al. Towards equilibrium molecular conformation generation with GFlowNets. AI4Mat at NeurIPS 2023.
Bibtex Format
@misc{hernandez-garcia2024,
author = {Hernandez-Garcia, Alex and Saxena, Nikita and Volokhova, Alexandra and Koziarski, Michał and Sharma, Divya and Viviano, Joseph D and Carrier, Pierre Luc and Schmidt, Victor},
title = {gflownet},
url = {https://github.com/alexhernandezgarcia/gflownet},
year = {2024},
}
Or CFF file