Skip to content

Commit

Permalink
Refactor: Replace casefold with lower and use setdefault
Browse files Browse the repository at this point in the history
  • Loading branch information
jicampos committed Jan 17, 2025
1 parent 91efc06 commit 72fea13
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 34 deletions.
4 changes: 2 additions & 2 deletions hls4ml/backends/catapult/passes/conv_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ class GenerateConvStreamingInstructions(OptimizerPass):
def match(self, node):
is_match = (
isinstance(node, (Conv1D, SeparableConv1D, Conv2D, SeparableConv2D))
and node.model.config.get_config_value('IOType').casefold() == 'io_stream'
and node.get_attr('implementation').casefold() == 'encoded'
and node.model.config.get_config_value('IOType').lower() == 'io_stream'
and node.get_attr('implementation').lower() == 'encoded'
)
return is_match

Expand Down
20 changes: 5 additions & 15 deletions hls4ml/backends/catapult/passes/convolution_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,8 @@ def format(self, node):
else:
params['fill_fn'] = 'FillConv1DBuffer'

if (
node.get_attr('implementation').casefold() == 'linebuffer'
or node.model.config.get_config_value('IOType').casefold() == 'io_parallel'
):
# these are unused; just put dummy values
params['min_width'] = node.get_attr('in_width')
params['instructions'] = '0'
params.setdefault('min_width', node.get_attr('in_width'))
params.setdefault('instructions', '0')

conv_config = self.template.format(**params)

Expand Down Expand Up @@ -218,14 +213,9 @@ def format(self, node):
else:
params['fill_fn'] = 'FillConv2DBuffer'

if (
node.get_attr('implementation').casefold() == 'linebuffer'
or node.model.config.get_config_value('IOType').casefold() == 'io_parallel'
):
# these are unused; just put dummy values
params['min_height'] = node.get_attr('in_height')
params['min_width'] = node.get_attr('in_width')
params['instructions'] = '0'
params.setdefault('min_height', node.get_attr('in_height'))
params.setdefault('min_width', node.get_attr('in_width'))
params.setdefault('instructions', '0')

conv_config = self.template.format(**params)

Expand Down
4 changes: 2 additions & 2 deletions hls4ml/backends/vivado/passes/conv_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ class GenerateConvStreamingInstructions(OptimizerPass):
def match(self, node):
is_match = (
isinstance(node, (Conv1D, SeparableConv1D, Conv2D, SeparableConv2D))
and node.model.config.get_config_value('IOType').casefold() == 'io_stream'
and node.get_attr('implementation').casefold() == 'encoded'
and node.model.config.get_config_value('IOType').lower() == 'io_stream'
and node.get_attr('implementation').lower() == 'encoded'
)
return is_match

Expand Down
20 changes: 5 additions & 15 deletions hls4ml/backends/vivado/passes/convolution_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,8 @@ def format(self, node):
else:
params['conv_fn'] = 'Conv1DResource'

if (
node.get_attr('implementation').casefold() == 'linebuffer'
or node.model.config.get_config_value('IOType').casefold() == 'io_parallel'
):
# these are unused; just put dummy values
params['min_width'] = node.get_attr('in_width')
params['instructions'] = '0'
params.setdefault('min_width', node.get_attr('in_width'))
params.setdefault('instructions', '0')

conv_config = self.template.format(**params)

Expand Down Expand Up @@ -247,14 +242,9 @@ def format(self, node):
else:
params['fill_fn'] = 'FillConv2DBuffer'

if (
node.get_attr('implementation').casefold() == 'linebuffer'
or node.model.config.get_config_value('IOType').casefold() == 'io_parallel'
):
# these are unused; just put dummy values
params['min_height'] = node.get_attr('in_height')
params['min_width'] = node.get_attr('in_width')
params['instructions'] = '0'
params.setdefault('min_height', node.get_attr('in_height'))
params.setdefault('min_width', node.get_attr('in_width'))
params.setdefault('instructions', '0')

conv_config = self.template.format(**params)

Expand Down

0 comments on commit 72fea13

Please sign in to comment.