Skip to content

Commit

Permalink
Merge pull request #773 from pnbabu/vectors_bugs
Browse files Browse the repository at this point in the history
Fix vector resize
  • Loading branch information
clinssen authored May 11, 2022
2 parents 89805da + 1982121 commit 1111bf9
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 22 deletions.
32 changes: 14 additions & 18 deletions pynestml/codegeneration/printers/nest_reference_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def convert_name_reference(self, variable: ASTVariable, prefix: str = '') -> str
return "((POST_NEURON_TYPE*)(__target))->get_" + _name + "(_tr_t)"

if variable.get_name() == PredefinedVariables.E_CONSTANT:
return 'numerics::e'
return "numerics::e"

symbol = variable.get_scope().resolve_to_symbol(variable.get_complete_name(), SymbolKind.VARIABLE)
if symbol is None:
Expand All @@ -176,10 +176,14 @@ def convert_name_reference(self, variable: ASTVariable, prefix: str = '') -> str
code, message = Messages.get_could_not_resolve(variable.get_name())
Logger.log_message(log_level=LoggingLevel.ERROR, code=code, message=message,
error_position=variable.get_source_position())
return ''
return ""

if symbol.is_local():
return variable.get_name() + ('[' + variable.get_vector_parameter() + ']' if symbol.has_vector_parameter() else '')
vector_param = ""
if symbol.has_vector_parameter():
vector_param = "[" + variable.get_vector_parameter() + "]"

# if symbol.is_local():
# return variable.get_name() + vector_param

if symbol.is_buffer():
if isinstance(symbol.get_type_symbol(), UnitTypeSymbol):
Expand All @@ -190,33 +194,25 @@ def convert_name_reference(self, variable: ASTVariable, prefix: str = '') -> str
if not units_conversion_factor == 1:
s += "(" + str(units_conversion_factor) + " * "
s += self.print_origin(symbol, prefix=prefix) + self.buffer_value(symbol)
if symbol.has_vector_parameter():
s += '[' + variable.get_vector_parameter() + ']'
s += vector_param
if not units_conversion_factor == 1:
s += ")"
return s

if symbol.is_inline_expression:
return 'get_' + variable.get_name() + '()' + ('[i]' if symbol.has_vector_parameter() else '')
return self.getter(symbol) + "()" + vector_param

assert not symbol.is_kernel(), "NEST reference converter cannot print kernel; kernel should have been converted during code generation"

if symbol.is_state():
temp = ""
temp += self.getter(symbol) + "()"
temp += ('[' + variable.get_vector_parameter() + ']' if symbol.has_vector_parameter() else '')
return temp
if symbol.is_state() or symbol.is_inline_expression:
return self.getter(symbol) + "()" + vector_param

variable_name = self.convert_to_cpp_name(variable.get_complete_name())
if symbol.is_local():
return variable_name + ('[i]' if symbol.has_vector_parameter() else '')

if symbol.is_inline_expression:
return 'get_' + variable_name + '()' + ('[i]' if symbol.has_vector_parameter() else '')
return variable_name + vector_param

return self.print_origin(symbol, prefix=prefix) + \
self.name(symbol) + \
('[' + variable.get_vector_parameter() + ']' if symbol.has_vector_parameter() else '')
self.name(symbol) + vector_param

def __get_unit_name(self, variable: ASTVariable):
assert variable.get_scope() is not None, "Undeclared variable: " + variable.get_complete_name()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
@param variable VariableSymbol
#}
{%- if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif %}
{%- if not variable.is_inline_expression and not variable.is_state() %}

{%- if not variable.is_inline_expression %}
{%- if not variable.is_state() %}
{{declarations.print_variable_type(variable)}} tmp_{{names.name(variable)}} = {{names.getter(variable)}}();
updateValue<{{declarations.print_variable_type(variable)}}>(__d, nest::{{names_namespace}}::_{{names.name(variable)}}, tmp_{{names.name(variable)}});

{%- if vector_symbols|length > 0 %}
// Resize vectors
if (tmp_{{names.name(variable)}} != {{names.getter(variable)}}())
Expand All @@ -19,9 +22,29 @@ if (tmp_{{names.name(variable)}} != {{names.getter(variable)}}())
{%- endfor %}
}
{%- endif %}
{%- elif not variable.is_inline_expression and variable.is_state() %}

{%- else %}
{{declarations.print_variable_type(variable)}} tmp_{{names.convert_to_cpp_name(variable.get_symbol_name())}} = {{names.getter(variable)}}();
updateValue<{{declarations.print_variable_type(variable)}}>(__d, nest::{{names_namespace}}::_{{variable.get_symbol_name()}}, tmp_{{names.convert_to_cpp_name(variable.get_symbol_name())}});
{%- endif %}

{%- if variable.has_vector_parameter() %}
{#
Typecast the vector parameter to an int. If the typecast fails with a return value of 0, the vector parameter is a
variable
#}
{%- set vector_size = variable.get_vector_parameter() | int %}
{%- if not vector_size %}
{%- set vector_size = "tmp_" + variable.get_vector_parameter() %}
{%- endif %}
// Check if the new vector size matches its original size
if ( tmp_{{names.name(variable)}}.size() != {{vector_size}} )
{
std::stringstream msg;
msg << "The vector \"{{names.name(variable)}}\" does not match its size: " << {{vector_size}};
throw nest::BadProperty(msg.str());
}
{%- endif %}
{%- else %}
// ignores '{{names.name(variable)}}' {{declarations.print_variable_type(variable)}}' since it is an function and setter isn't defined
{%- endif %}
28 changes: 26 additions & 2 deletions tests/nest_tests/nest_vectors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,16 @@
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.
import os
import unittest
import numpy as np

import nest
import pytest
from nest.lib.hl_api_exceptions import NESTErrors

from pynestml.frontend.pynestml_frontend import generate_nest_target


class NestVectorsIntegrationTest(unittest.TestCase):
class TestNestVectorsIntegration:
r"""
Tests the code generation and vector operations from NESTML to NEST.
"""
Expand Down Expand Up @@ -70,3 +71,26 @@ def test_vectors(self):
v_m = multimeter.get("events")["V_m"]
print("V_m: {}".format(v_m))
np.testing.assert_almost_equal(v_m[-1], -0.3)

@pytest.mark.xfail(strict=True, raises=NESTErrors.BadProperty)
def test_vectors_resize(self):
input_path = os.path.join(
os.path.realpath(os.path.join(os.path.dirname(__file__), "resources", "VectorsResize.nestml")))
target_path = "target"
logging_level = "INFO"
module_name = "vectorsmodule"
suffix = "_nestml"

generate_nest_target(input_path,
target_path=target_path,
logging_level=logging_level,
module_name=module_name,
suffix=suffix)
nest.set_verbosity("M_ALL")

nest.ResetKernel()
nest.Install(module_name)

neuron = nest.Create("vector_resize_nestml", params={"N": 200})
neuron.set(x=[1.0, 1.0, 4.0])
nest.Simulate(10)
52 changes: 52 additions & 0 deletions tests/nest_tests/resources/VectorsResize.nestml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
VectorsResize.nestml
####################


Description
+++++++++++

This model is used to test vector operations with NEST.


Copyright statement
+++++++++++++++++++

This file is part of NEST.

Copyright (C) 2004 The NEST Initiative

NEST is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 2 of the License, or
(at your option) any later version.

NEST is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with NEST. If not, see <http://www.gnu.org/licenses/>.
"""
neuron vector_resize:
state:
# accumulator for x values
y real = 0.

x[N] real = 1.
end

parameters:
N integer = 1 # array size
end

update:
j integer = 0
y = 0
for j in 0 ... N step 1:
y += x[j]
end
print ("y= {y}\n")
end
end

0 comments on commit 1111bf9

Please sign in to comment.