Skip to content

Commit

Permalink
Merge pull request #4 from jmitrevs/initialRecurrOneAPI
Browse files Browse the repository at this point in the history
initial state rnns for oneAPI
  • Loading branch information
JanFSchulte authored Jan 18, 2025
2 parents 75b5dca + 96b6903 commit ebd4de3
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 10 deletions.
33 changes: 29 additions & 4 deletions hls4ml/backends/oneapi/passes/recurrent_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@
}};\n'''

gru_function_template = 'nnet::gru<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {wr}, {b}, {br});'
gru_function_initial_state_template = (
'nnet::gru_init_state<{input_t}, {h_t}, {output_t}, {config}>({input}, {init_state}, {output}, {w}, {wr}, {b}, {br});'
)
gru_task_sequence_template = 'task_sequence<nnet::gru_stream<{input_pipe}, {output_pipe}, {config}>> {name};'
gru_stream_function_template = '{name}.async({w}, {wr}, {b}, {br});'

Expand Down Expand Up @@ -163,15 +166,23 @@ def format(self, node):
class GRUFunctionTemplate(FunctionCallTemplate):
def __init__(self):
super().__init__(GRU, include_header=recurrent_include_list)
self.template = gru_function_template

def format(self, node):
params = self._default_function_params(node)
if params['pass_initial_states'] == 'true':
params['h_t'] = node.get_input_variable(node.inputs[1]).type.name
params['init_state'] = node.get_input_variable(node.inputs[1]).name
params['w'] = node.get_weights('weight').name
params['b'] = node.get_weights('bias').name
params['wr'] = node.get_weights('recurrent_weight').name
params['br'] = node.get_weights('recurrent_bias').name
return self.template.format(**params)

if params['pass_initial_states'] == 'true':
template = gru_function_initial_state_template
else:
template = gru_function_template

return template.format(**params)


class GRUTaskSequenceTemplate(TaskSequenceTemplate):
Expand Down Expand Up @@ -235,6 +246,10 @@ def format(self, node):
}};\n"""

lstm_function_template = 'nnet::lstm<{input_t}, {output_t}, {config}>({input}, {output}, {weights});'
lstm_function_initial_state_template = (
'nnet::lstm_init_state<{input_t}, {h_t}, {hc_t}, {output_t}, {config}>'
'({input}, {init_state}, {init_cell}, {output}, {weights});'
)


class LSTMConfigTemplate(LayerConfigTemplate):
Expand Down Expand Up @@ -275,11 +290,16 @@ def format(self, node):
class LSTMFunctionTemplate(FunctionCallTemplate):
def __init__(self):
super().__init__(LSTM, include_header=recurrent_include_list)
self.template = lstm_function_template

def format(self, node):
params = self._default_function_params(node)

if params['pass_initial_states'] == 'true':
params['h_t'] = node.get_input_variable(node.inputs[1]).type.name
params['init_state'] = node.get_input_variable(node.inputs[1]).name
params['init_cell'] = node.get_input_variable(node.inputs[2]).name
params['hc_t'] = node.get_input_variable(node.inputs[2]).type.name

types = ['i', 'f', 'c', 'o']
params['weights'] = ''
for t in types:
Expand All @@ -289,7 +309,12 @@ def format(self, node):
for t in types:
params['weights'] += 'bias_{}_{}{}'.format(t, str(node.index), ',' if t != 'o' else '')

return self.template.format(**params)
if params['pass_initial_states'] == 'true':
template = lstm_function_initial_state_template
else:
template = lstm_function_template

return template.format(**params)


################################################
Expand Down
107 changes: 107 additions & 0 deletions hls4ml/templates/oneapi/firmware/nnet_utils/nnet_recurrent.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,41 @@ void gru(const data_T &data, res_T &res, const typename CONFIG_T::weight_t &weig
}
}

template <class data_T, class h_T, class res_T, typename CONFIG_T>
void gru_init_state(const data_T &data, const h_T &hin, res_T &res, const typename CONFIG_T::weight_t &weights,
const typename CONFIG_T::recurrent_weight_t &recurrent_weights, const typename CONFIG_T::bias_t &bias,
const typename CONFIG_T::recurrent_bias_t &recurrent_bias) {

[[intel::fpga_register]] data_T x;

[[intel::fpga_register]] h_T h = hin;

// Loop depedency - cannot pipeline
[[intel::disable_loop_pipelining]] for (int t = 0; t < CONFIG_T::n_timesteps; t++) {
// Get data at current time step
#pragma unroll
for (int j = 0; j < CONFIG_T::n_in; j++) {
x[j] = data[j + t * CONFIG_T::n_in];
}

nnet::gru_cell<data_T, h_T, CONFIG_T>(x, h, weights, recurrent_weights, bias, recurrent_bias);

if (CONFIG_T::return_sequences) {
#pragma unroll
for (int i = 0; i < CONFIG_T::n_units; i++) {
res[CONFIG_T::n_units * t + i] = h[i];
}
}
}

if (!CONFIG_T::return_sequences) {
#pragma unroll
for (int i = 0; i < (CONFIG_T::n_units); i++) {
res[i] = h[i];
}
}
}

//----------------------
// SimpleRNN
//----------------------
Expand Down Expand Up @@ -561,6 +596,78 @@ void lstm(const data_T &data, res_T &res, const typename CONFIG_T::weight_i_t &W
}
}

template <class data_T, class h_T, class hc_T, class res_T, class CONFIG_T>
void lstm_init_state(const data_T &data, const h_T &hidden_state_initial, const hc_T &cell_state_initial, res_T &res,
const typename CONFIG_T::weight_i_t &WI, const typename CONFIG_T::weight_f_t &WF,
const typename CONFIG_T::weight_c_t &WC, const typename CONFIG_T::weight_o_t &WO,
const typename CONFIG_T::recurrent_weight_i_t &RWI, const typename CONFIG_T::recurrent_weight_f_t &RWF,
const typename CONFIG_T::recurrent_weight_c_t &RWC, const typename CONFIG_T::recurrent_weight_o_t &RWO,
const typename CONFIG_T::bias_i_t &BI, const typename CONFIG_T::bias_f_t &BF,
const typename CONFIG_T::bias_c_t &BC, const typename CONFIG_T::bias_o_t &BO) {

// Note: currently this does not support recurrent bias

using in_T = array<typename data_T::value_type, CONFIG_T::n_in>;

[[intel::fpga_register]] h_T hidden_state[CONFIG_T::n_timesteps + 1];
[[intel::fpga_register]] h_T hidden_state_temp;
[[intel::fpga_register]] h_T cell_state[CONFIG_T::n_timesteps + 1];
[[intel::fpga_register]] h_T cell_state_temp; // should this be updated to a differnt type
[[intel::fpga_register]] h_T h;
[[intel::fpga_register]] h_T c;
[[intel::fpga_register]] in_T in;

// Set initially hidden state (output) to zero
INIT_LOOP:
#pragma unroll
for (int x = 0; x < CONFIG_T::n_out; x++) {
hidden_state[0][x] = hidden_state_initial[x];
cell_state[0][x] = cell_state_initial[x];
}

// Input dimension
[[intel::disable_loop_pipelining]] for (int i = 0; i < CONFIG_T::n_timesteps; i++) {
// Data at current time step
for (int x = 0; x < CONFIG_T::n_in; x++) {
in[x] = data[x + i * CONFIG_T::n_in];
}

// Hidden state at current time step
#pragma unroll
for (int x = 0; x < CONFIG_T::n_out; x++) {
hidden_state_temp[x] = hidden_state[i][x];
cell_state_temp[x] = cell_state[i][x];
}

// Do LSTM
lstm_cell<in_T, h_T, CONFIG_T>(in, hidden_state_temp, h, cell_state_temp, c, WI, WF, WC, WO, RWI, RWF, RWC, RWO, BI,
BF, BC, BO);

// Write result
#pragma unroll
for (int x = 0; x < CONFIG_T::n_out; x++) {
hidden_state[i + 1][x] = h[x];
cell_state[i + 1][x] = c[x];
}
}

if (CONFIG_T::return_sequences == 0) {
// Output when return_sequences is false
#pragma unroll
for (int x = 0; x < CONFIG_T::n_out; x++) {
res[x] = hidden_state[CONFIG_T::n_timesteps][x];
}
} else {
// Output when return_sequences is true
#pragma unroll
for (int x = 0; x < CONFIG_T::n_timesteps; x++) {
for (int h = 0; h < CONFIG_T::n_out; h++) {
res[x * CONFIG_T::n_out + h] = hidden_state[x + 1][h];
}
}
}
}

} // namespace nnet

#endif
12 changes: 6 additions & 6 deletions test/pytest/test_recurrent_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def forward(self, x):
return output


@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus'])
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'oneAPI'])
@pytest.mark.parametrize('io_type', ['io_parallel'])
def test_gru(backend, io_type):
model = GRUNet()
Expand All @@ -56,7 +56,7 @@ def test_gru(backend, io_type):
np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=0, atol=1e-1)


@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus'])
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'oneAPI'])
@pytest.mark.parametrize('io_type', ['io_stream'])
def test_gru_stream(backend, io_type):
model = GRUNetStream()
Expand Down Expand Up @@ -98,7 +98,7 @@ def forward(self, x):
return output


@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus'])
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'oneAPI'])
@pytest.mark.parametrize('io_type', ['io_parallel'])
def test_lstm(backend, io_type):
model = LSTM()
Expand Down Expand Up @@ -132,10 +132,10 @@ def test_lstm(backend, io_type):
np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=0, atol=1e-1)


@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus'])
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'oneAPI'])
@pytest.mark.parametrize('io_type', ['io_stream'])
def test_lstm_stream(backend, io_type):
if not (backend == "Quartus" and io_type == "io_stream"):
if not (backend in ('Quartus', 'oneAPI') and io_type == "io_stream"):
model = LSTMStream()
model.eval()

Expand Down Expand Up @@ -174,7 +174,7 @@ def forward(self, x, h0):
@pytest.mark.parametrize('backend', ['Quartus'])
@pytest.mark.parametrize('io_type', ['io_parallel'])
def test_rnn(backend, io_type):
if not (backend == "Quartus" and io_type == "io_stream"):
if not (backend in ('Quartus', 'oneAPI') and io_type == "io_stream"):
model = RNN()
model.eval()

Expand Down

0 comments on commit ebd4de3

Please sign in to comment.