Skip to content
This repository has been archived by the owner on Oct 31, 2022. It is now read-only.

Commit

Permalink
Add tensor rematerialization.
Browse files Browse the repository at this point in the history
  • Loading branch information
nshepperd committed Mar 16, 2021
1 parent 2de5d1b commit ffc54c7
Show file tree
Hide file tree
Showing 21 changed files with 1,515 additions and 50 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@ __pycache__
.mypy_cache/
models/
checkpoint
samples
samples
dist-newstyle
bin
57 changes: 53 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,68 @@ PYTHONPATH=src ./encode.py <file|directory|glob> /path/to/encoded.npz
PYTHONPATH=src ./train.py --dataset /path/to/encoded.npz
```

Make sure `cudnn` is installed. [Some have reported](https://github.com/nshepperd/gpt-2/issues/8) that `train.py` runs without it but has worse memory usage and might OOM.
Make sure `cudnn` is installed. [Some have
reported](https://github.com/nshepperd/gpt-2/issues/8) that `train.py`
runs without it but has worse memory usage and might OOM.

### Tensor Rematerialization

Experimental: a rematerialization rewriter based on `Efficient
Rematerialization for Deep Networks`
<https://papers.nips.cc/paper/9653-efficient-rematerialization-for-deep-networks.pdf>,
which unlike gradient checkpointing works in tensorflow 2.0 and is
able to automatically select checkpoints in arbitrary graphs. Using
this I was able to finetune GPT-2 1.5B on a single graphics card using
slightly less than 12G of video ram with very little slowdown.

To use this is a little involved, because the graph optimization
algorithm is offloaded to an optimized Haskell program. First, go into
subdirectory `twremat`, and build it by invoking:

cabal v2-install --installdir=../bin

(You'll need to install cabal if you haven't already -- but setting up
ghc and haskell compilation is beyond the scope of this README.)

Then run `train.py` as normal, enabling `--twremat` and setting
`--twremat_memlimit` to an appropriate value -- this sets the amount
of memory assumed to be available for computation of gradients, so it
should be roughly the memory size of your graphics card minus whatever
is taken up by the gpt-2 weights, and any other bookkeeping
variables. You may need to experiment with the memlimit until you find
the largest value that doesn't OOM.

(You probably also want to use SGD as optimizer instead of Adam to
minimize those bookkeeping variables, of which Adam uses a lot).

### Gradient Checkpointing

https://github.com/openai/gradient-checkpointing is included to reduce the memory requirements of the model, and can be enabled by `--memory_saving_gradients`. The checkpoints are currently chosen manually (poorly) by just adding layer 10 to the 'checkpoints' collection in model.py. `--memory_saving_gradients` is enabled by default for training the 345M model.
https://github.com/openai/gradient-checkpointing is included to reduce
the memory requirements of the model, and can be enabled by
`--memory_saving_gradients`. The checkpoints are currently chosen
manually (poorly) by just adding layer 10 to the 'checkpoints'
collection in model.py.

Gradient checkpointing doesn't work in tensorflow v2.0 and later due
to the removal of tf.contrib. You should use tensor rematerialization
instead if possible.

### Validation loss

Set `--val_every` to a number of steps `N > 0`, and "validation" loss against a fixed sample of the dataset will be calculated every N steps to get a better sense of training progress. N around 200 suggested. You can set `--val_dataset` to choose a separate validation dataset, otherwise it defaults to a sample from the train dataset (so not a real cross-validation loss!).
Set `--val_every` to a number of steps `N > 0`, and "validation" loss
against a fixed sample of the dataset will be calculated every N steps
to get a better sense of training progress. N around 200
suggested. You can set `--val_dataset` to choose a separate validation
dataset, otherwise it defaults to a sample from the train dataset (so
not a real cross-validation loss!).

### Optimizer

You can use SGD instead of Adam with `--optimizer sgd`. This also helps conserve memory when training the 345M model. Note: the learning rate needs to be adjusted for SGD, due to not having Adam's gradient normalization (0.0006 seems to be a good number from some experiments).
You can use SGD instead of Adam with `--optimizer sgd`. This also
helps conserve memory when training larger models. Note: the learning
rate needs to be adjusted for SGD, due to not having Adam's gradient
normalization (0.0006 seems to be a good number from some
experiments).

# Original README

Expand Down
181 changes: 181 additions & 0 deletions src/tfremat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import random
import os
import tensorflow.compat.v1 as tf
import tempfile

import twremat

def splice_op(op, input_map, control_inputs=None):
g = op.graph
node_def = tf.NodeDef()
node_def.CopyFrom(op.node_def)
node_def.name = g.unique_name(op.name + '_copy')
inputs = [input_map.get(x, x) for x in op.inputs]
new_control_inputs = [input_map.get(x, x) for x in op.control_inputs]
if control_inputs:
new_control_inputs.extend([x for x in control_inputs if x is not None])
# new_control_inputs = control_inputs
output_types = [o.dtype for o in op.outputs]
op_def = op.op_def
return tf.Operation(node_def, g, inputs=inputs, output_types=output_types, op_def=op_def, control_inputs=new_control_inputs)

def splice_tensor(ten, new_op):
i = ten.op.outputs.index(ten)
return new_op.outputs[i]

def splice(obj, input_map, control_inputs=None):
if type(obj) is tf.Operation:
return splice_op(obj, input_map, control_inputs=control_inputs)
elif type(obj) is tf.Tensor:
return splice_tensor(obj, input_map.get(obj.op, obj.op))
elif type(obj) is tf.IndexedSlices:
return tf.IndexedSlices(values=input_map.get(obj.values, obj.values),
indices=input_map.get(obj.indices, obj.indices),
dense_shape=input_map.get(obj.dense_shape, obj.dense_shape))
else:
raise AssertionError(f'Could not get deps from{repr(type(obj))} {repr(obj)}')

def product(xs):
r = 1
for x in xs:
r *= x
return r

def shape_size(shape):
if shape.rank is None:
return 16
shape = shape.as_list()
for i in range(len(shape)):
if shape[i] is None and i == 0:
shape[i] = 1
elif shape[i] is None:
shape[i] = 1024
return product(shape)

def graph_from_dfs(deps, starts):
visited = set()
frontier = starts
while frontier:
x = frontier.pop()
if x in visited:
continue
visited.add(x)
frontier.extend(list(deps(x)))
return {x : list(deps(x)) for x in visited}

def get_deps(obj):
if type(obj) is tf.Operation:
return list(obj.inputs) + list(obj.control_inputs)
elif type(obj) is tf.Tensor:
return [obj.op]
elif type(obj) is tf.IndexedSlices:
return [obj.indices, obj.values, obj.dense_shape]
else:
raise AssertionError(f'Could not get deps from{repr(type(obj))} {repr(obj)}')


def tensor_graph(compute):
return graph_from_dfs(get_deps, list(compute))

def blacklist(obj):
if type(obj) is tf.Operation:
if 'Assign' in obj.type or 'Variable' in obj.type or 'Placeholder' in obj.type:
# TODO: Should we do special accounting for
# ReadVariableOp? Currently we forbid cloning altogether,
# but it's actually ok to clone this op as long as it
# doesn't float across an effectful op (Assign). Also
# currently we don't account for the memory used by
# ReadVariableOp (is it copy-on-write?).
# https://www.tensorflow.org/api_docs/python/tf/raw_ops/ReadVariableOp?hl=uk
return True
elif type(obj) is tf.Tensor:
return blacklist(obj.op)
return False

def estimate_cpu(op):
return sum(4 * shape_size(t.shape) for t in op.inputs if type(t) is tf.Tensor) + sum(4 * shape_size(t.shape) for t in op.outputs)

def estimate_mem(op):
return sum(4 * shape_size(t.shape) for t in op.outputs)

def info(op):
if blacklist(op):
return {'type': 'effectful'}
elif type(op) is tf.Operation:
if 'Reshape' in op.type:
return {'type': 'pointer'}
return {'type': 'normal',
'cpu': estimate_cpu(op),
'mem': estimate_mem(op)}
elif type(op) is tf.Tensor:
return {'type': 'pointer'}
elif type(op) is tf.IndexedSlices:
return {'type': 'pointer'}
else:
raise AssertionError(repr((type(op), op)))


# Helper functions to flatten and unflatten nested structures of
# tensors and ops so that tf_remat can be applied to structures
# without fiddly marshalling.
def get_ops(compute):
output = []
stack = [compute]
while stack:
top = stack.pop()
if type(top) is dict:
for v in top.values():
stack.append(v)
elif type(top) in (list, tuple):
stack.extend(top)
elif type(top) in (tf.Operation, tf.Tensor, tf.IndexedSlices):
output.append(top)
return output

def replace_ops(top, live):
if type(top) in (tf.Operation, tf.Tensor, tf.IndexedSlices):
return live[top]
elif type(top) is dict:
return {k : replace_ops(v, live) for (k,v) in top.items()}
elif type(top) is list:
return [replace_ops(v, live) for v in top]
elif type(top) is tuple:
return tuple(replace_ops(v, live) for v in top)
else:
return top


def tf_remat(compute, memlimit):
compute_ops = get_ops(compute)
tf_deps = tensor_graph(compute_ops)

# Relabel with integers
from_op = {op : i for (i, op) in enumerate(tf_deps.keys())}
from_node = {i : op for (op, i) in from_op.items()}
nodes = set(from_node.keys())
node_deps = {n : [from_op[d] for d in tf_deps[from_node[n]]] for n in nodes}

node_info = {}
for n in nodes:
node_info[n] = info(from_node[n])
node_info[n]['deps'] = [from_op[d] for d in tf_deps[from_node[n]]]

steps = twremat.runtwremat(node_info, memlimit, {from_op[c] for c in compute_ops})

print('Constructing tensorflow graph...')
live = {}
last_op = None
for (action, n) in steps:
base = from_node[n]
if action == 'compute':
input_map = {d : live[d] for d in tf_deps[base] if live[d] != d}
if blacklist(base) and not input_map:
live[base] = base
else:
live[base] = splice(base, input_map, control_inputs=[last_op])
if type(base) is tf.Operation:
last_op = live[base]
elif action == 'free':
del live[base]

return replace_ops(compute, live)
60 changes: 60 additions & 0 deletions src/twremat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from subprocess import Popen, PIPE
import random
import os
import sys
import tempfile
from tqdm import tqdm

BINDIR=os.path.join(os.path.dirname(sys.argv[0]), 'bin')
TWREMAT=os.path.join(BINDIR, 'twremat')

# Allow users to pass 'humanized' memlimit values as strings.
def parse_memlimit(memlimit):
if memlimit[-1] == 'K':
return int(memlimit[:-1]) * 1000
elif memlimit[-1] == 'M':
return int(memlimit[:-1]) * 1000000
elif memlimit[-1] == 'G':
return int(memlimit[:-1]) * 1000000000
else:
return int(memlimit)

def runtwremat(gr, memlimit, target):
if type(memlimit) is str:
memlimit = parse_memlimit(memlimit)

fname = tempfile.mktemp()
outname = tempfile.mktemp()
with open(fname, 'w') as fp:
print('p remat2', file=fp)
print(f'memlimit {memlimit}', file=fp)
for (n, info) in gr.items():
deps = ' '.join(str(d) for d in info['deps'])
if info['type'] == 'normal':
cpu = info['cpu']
mem = info['mem']
weight = f'cpu {cpu} mem {mem}'
elif info['type'] == 'effectful':
weight = 'effectful'
elif info['type'] == 'pointer':
weight = 'pointer'
if n in target:
tstr = 'target'
else:
tstr = ''
print(f'node {n} deps {deps} {weight} {tstr}', file=fp)
print(' '.join([TWREMAT, fname, outname]))
proc = Popen([TWREMAT, fname, outname])
assert proc.wait() == 0
out = []
with open(outname, 'r') as fp:
for line in fp:
line = line.split()
if line and line[0] == 'c':
out.append(('compute', int(line[1])))
elif line and line[0] == 'f':
out.append(('free', int(line[1])))
elif line:
print(line)
exit()
return out
Loading

0 comments on commit ffc54c7

Please sign in to comment.