-
Notifications
You must be signed in to change notification settings - Fork 212
/
train_lstm.py
199 lines (174 loc) · 8.51 KB
/
train_lstm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from os import path
import numpy as np
import tensorflow as tf
from tensorflow.contrib.timeseries.python.timeseries import estimators as ts_estimators
from tensorflow.contrib.timeseries.python.timeseries import model as ts_model
from tensorflow.contrib.timeseries.python.timeseries import NumpyReader
import matplotlib
matplotlib.use("agg")
import matplotlib.pyplot as plt
class _LSTMModel(ts_model.SequentialTimeSeriesModel):
"""A time series model-building example using an RNNCell."""
def __init__(self, num_units, num_features, dtype=tf.float32):
"""Initialize/configure the model object.
Note that we do not start graph building here. Rather, this object is a
configurable factory for TensorFlow graphs which are run by an Estimator.
Args:
num_units: The number of units in the model's LSTMCell.
num_features: The dimensionality of the time series (features per
timestep).
dtype: The floating point data type to use.
"""
super(_LSTMModel, self).__init__(
# Pre-register the metrics we'll be outputting (just a mean here).
train_output_names=["mean"],
predict_output_names=["mean"],
num_features=num_features,
dtype=dtype)
self._num_units = num_units
# Filled in by initialize_graph()
self._lstm_cell = None
self._lstm_cell_run = None
self._predict_from_lstm_output = None
def initialize_graph(self, input_statistics):
"""Save templates for components, which can then be used repeatedly.
This method is called every time a new graph is created. It's safe to start
adding ops to the current default graph here, but the graph should be
constructed from scratch.
Args:
input_statistics: A math_utils.InputStatistics object.
"""
super(_LSTMModel, self).initialize_graph(input_statistics=input_statistics)
self._lstm_cell = tf.nn.rnn_cell.LSTMCell(num_units=self._num_units)
# Create templates so we don't have to worry about variable reuse.
self._lstm_cell_run = tf.make_template(
name_="lstm_cell",
func_=self._lstm_cell,
create_scope_now_=True)
# Transforms LSTM output into mean predictions.
self._predict_from_lstm_output = tf.make_template(
name_="predict_from_lstm_output",
func_=lambda inputs: tf.layers.dense(inputs=inputs, units=self.num_features),
create_scope_now_=True)
def get_start_state(self):
"""Return initial state for the time series model."""
return (
# Keeps track of the time associated with this state for error checking.
tf.zeros([], dtype=tf.int64),
# The previous observation or prediction.
tf.zeros([self.num_features], dtype=self.dtype),
# The state of the RNNCell (batch dimension removed since this parent
# class will broadcast).
[tf.squeeze(state_element, axis=0)
for state_element
in self._lstm_cell.zero_state(batch_size=1, dtype=self.dtype)])
def _transform(self, data):
"""Normalize data based on input statistics to encourage stable training."""
mean, variance = self._input_statistics.overall_feature_moments
return (data - mean) / variance
def _de_transform(self, data):
"""Transform data back to the input scale."""
mean, variance = self._input_statistics.overall_feature_moments
return data * variance + mean
def _filtering_step(self, current_times, current_values, state, predictions):
"""Update model state based on observations.
Note that we don't do much here aside from computing a loss. In this case
it's easier to update the RNN state in _prediction_step, since that covers
running the RNN both on observations (from this method) and our own
predictions. This distinction can be important for probabilistic models,
where repeatedly predicting without filtering should lead to low-confidence
predictions.
Args:
current_times: A [batch size] integer Tensor.
current_values: A [batch size, self.num_features] floating point Tensor
with new observations.
state: The model's state tuple.
predictions: The output of the previous `_prediction_step`.
Returns:
A tuple of new state and a predictions dictionary updated to include a
loss (note that we could also return other measures of goodness of fit,
although only "loss" will be optimized).
"""
state_from_time, prediction, lstm_state = state
with tf.control_dependencies(
[tf.assert_equal(current_times, state_from_time)]):
transformed_values = self._transform(current_values)
# Use mean squared error across features for the loss.
predictions["loss"] = tf.reduce_mean(
(prediction - transformed_values) ** 2, axis=-1)
# Keep track of the new observation in model state. It won't be run
# through the LSTM until the next _imputation_step.
new_state_tuple = (current_times, transformed_values, lstm_state)
return (new_state_tuple, predictions)
def _prediction_step(self, current_times, state):
"""Advance the RNN state using a previous observation or prediction."""
_, previous_observation_or_prediction, lstm_state = state
lstm_output, new_lstm_state = self._lstm_cell_run(
inputs=previous_observation_or_prediction, state=lstm_state)
next_prediction = self._predict_from_lstm_output(lstm_output)
new_state_tuple = (current_times, next_prediction, new_lstm_state)
return new_state_tuple, {"mean": self._de_transform(next_prediction)}
def _imputation_step(self, current_times, state):
"""Advance model state across a gap."""
# Does not do anything special if we're jumping across a gap. More advanced
# models, especially probabilistic ones, would want a special case that
# depends on the gap size.
return state
def _exogenous_input_step(
self, current_times, current_exogenous_regressors, state):
"""Update model state based on exogenous regressors."""
raise NotImplementedError(
"Exogenous inputs are not implemented for this example.")
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
x = np.array(range(1000))
noise = np.random.uniform(-0.2, 0.2, 1000)
y = np.sin(np.pi * x / 50 ) + np.cos(np.pi * x / 50) + np.sin(np.pi * x / 25) + noise
data = {
tf.contrib.timeseries.TrainEvalFeatures.TIMES: x,
tf.contrib.timeseries.TrainEvalFeatures.VALUES: y,
}
reader = NumpyReader(data)
train_input_fn = tf.contrib.timeseries.RandomWindowInputFn(
reader, batch_size=4, window_size=100)
estimator = ts_estimators.TimeSeriesRegressor(
model=_LSTMModel(num_features=1, num_units=128),
optimizer=tf.train.AdamOptimizer(0.001))
estimator.train(input_fn=train_input_fn, steps=2000)
evaluation_input_fn = tf.contrib.timeseries.WholeDatasetInputFn(reader)
evaluation = estimator.evaluate(input_fn=evaluation_input_fn, steps=1)
# Predict starting after the evaluation
(predictions,) = tuple(estimator.predict(
input_fn=tf.contrib.timeseries.predict_continuation_input_fn(
evaluation, steps=200)))
observed_times = evaluation["times"][0]
observed = evaluation["observed"][0, :, :]
evaluated_times = evaluation["times"][0]
evaluated = evaluation["mean"][0]
predicted_times = predictions['times']
predicted = predictions["mean"]
plt.figure(figsize=(15, 5))
plt.axvline(999, linestyle="dotted", linewidth=4, color='r')
observed_lines = plt.plot(observed_times, observed, label="observation", color="k")
evaluated_lines = plt.plot(evaluated_times, evaluated, label="evaluation", color="g")
predicted_lines = plt.plot(predicted_times, predicted, label="prediction", color="r")
plt.legend(handles=[observed_lines[0], evaluated_lines[0], predicted_lines[0]],
loc="upper left")
plt.savefig('predict_result.jpg')