Skip to content

Commit

Permalink
fix choice
Browse files Browse the repository at this point in the history
  • Loading branch information
ajbalogh committed Jan 22, 2021
1 parent 4b4002a commit 0a7d608
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 43 deletions.
33 changes: 30 additions & 3 deletions snappi/snappicommon.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,37 @@ def _decode(self, dict_object):

class SnappiObject(SnappiBase):
"""Base class for any /components/schemas object
Every SnappiObject is reuseable within the schema so it can
exist in multiple locations within the hierarchy.
That means it can exist in multiple locations as a
leaf, parent/choice or parent.
"""
__slots__ = ('_properties')
__slots__ = ('_properties', '_parent', '_choice')

def __init__(self):
def __init__(self, parent=None, choice=None):
super(SnappiObject, self).__init__()
self._parent = parent
self._choice = choice
self._properties = {}

@property
def parent(self):
return self._parent

def _get_property(self, name, default_value=None):
if name not in self._properties or self._properties[name] is None:
if isinstance(default_value, type) is True:
self._properties[name] = default_value()
else:
self._properties[name] = default_value
return self._properties[name]

def _set_property(self, name, value):
self._properties[name] = value
if self._parent is not None and self._choice is not None and value is not None:
self._parent._set_property('choice', self._choice)

def _encode(self):
"""Helper method for serialization
"""
Expand All @@ -122,7 +146,10 @@ def _decode(self, obj):
if property_name in snappi_names:
if isinstance(property_value, dict):
child = self._get_child_class(property_name)
property_value = child[1]()._decode(property_value)
if '_choice' in dir(child[1]) and '_parent' in dir(child[1]):
property_value = child[1](self, property_name)._decode(property_value)
else:
property_value = child[1]()._decode(property_value)
elif isinstance(property_value,
list) and property_name in self._TYPES:
child = self._get_child_class(property_name, True)
Expand Down
81 changes: 41 additions & 40 deletions snappi/snappigenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def _get_object_property_class_names(self, ref):
class_name = object_name.replace('.', '')
return (object_name, property_name, class_name)

def _write_snappi_object(self, ref):
def _write_snappi_object(self, ref, choice_method_name=None):
schema_object = self._get_object_from_ref(ref)
ref_name = ref.split('/')[-1]
class_name = ref_name.replace('.', '')
Expand All @@ -221,7 +221,10 @@ def _write_snappi_object(self, ref):
self._write()
self._write()
self._write(0, 'class %s(SnappiObject):' % class_name)
self._write(1, "__slots__ = ()")
slots = ''
# if choice_method_name is not None:
slots = "'_parent', '_choice'"
self._write(1, "__slots__ = (%s)" % slots)
self._write()

# write _TYPES definition
Expand All @@ -247,20 +250,27 @@ def _write_snappi_object(self, ref):

# write def __init__(self)
init_param_string = ''
# if choice_method_name is not None:
init_param_string = ", parent=None, choice=None" # everything will have a parent choice
for init_param in self._get_simple_type_names(schema_object):
init_param_string += ', %s=None' % (init_param)
self._write(1, 'def __init__(self%s):' % init_param_string)
self._write(2, 'super(%s, self).__init__()' % class_name)
# if choice_method_name is not None:
self._write(2, 'self._parent = parent')
self._write(2, 'self._choice = choice')
for init_param in self._get_simple_type_names(schema_object):
self._write(2, 'self.%s = %s' % (init_param, init_param))
# if len(parse('$..choice').find(schema_object)) > 0:
# self._write(2, 'self.choice = None')

# process properties - TBD use this one level up to process
# schema, in requestBody, Response and also
refs = self._process_properties(class_name, schema_object)
refs = self._process_properties(class_name, schema_object, choice_child=choice_method_name is not None)

# descend into child properties
for ref in refs:
self._write_snappi_object(ref[0])
self._write_snappi_object(ref[0], ref[3])
if ref[1] is True:
self._write_snappi_list(ref[0], ref[2])

Expand All @@ -284,7 +294,7 @@ def _get_choice_names(self, schema_object):
choice_names.append('choice')
return choice_names

def _process_properties(self, class_name=None, schema_object=None):
def _process_properties(self, class_name=None, schema_object=None, choice_child=False):
"""Process all properties of a /component/schema object
Write a factory method for all choice
If there are no properties then the schema_object is a primitive or array type
Expand All @@ -302,14 +312,13 @@ def _process_properties(self, class_name=None, schema_object=None):
if property_name in excluded_property_names:
continue
property = schema_object['properties'][property_name]
self._write_snappi_property(schema_object, property_name,
property)
self._write_snappi_property(schema_object, property_name, property, choice_child)
for property_name, property in schema_object['properties'].items():
ref = parse("$..'$ref'").find(property)
if len(ref) > 0:
restriction = self._get_type_restriction(property)
refs.append((ref[0].value, restriction.startswith('list['),
property_name))
choice_name = property_name if property_name in excluded_property_names else None
refs.append((ref[0].value, restriction.startswith('list['), property_name, choice_name))
return refs

def _write_snappi_list(self, ref, property_name):
Expand Down Expand Up @@ -353,11 +362,18 @@ def _write_snappi_list(self, ref, property_name):
self._write_snappilist_special_methods(contained_class_name)
# write factory method for the schema object in the list
self._write_factory_method(contained_class_name, ref_name.lower().split('.')[-1], ref, True, False)
# write choice factory methods if any
# write choice factory methods if the only properties are choice properties
write_factory_choice_methods = True
if 'properties' in yobject and 'choice' in yobject['properties']:
for property in yobject['properties']:
if property not in yobject['properties']['choice'][
'enum']:
if property not in yobject['properties']['choice']['enum']:
write_factory_choice_methods = False
break
else:
write_factory_choice_methods = False
if write_factory_choice_methods is True:
for property in yobject['properties']:
if property not in yobject['properties']['choice']['enum']:
continue
if '$ref' not in yobject['properties'][property]:
continue
Expand Down Expand Up @@ -406,23 +422,24 @@ def _write_factory_method(self,
self._write(2, '"""')
if choice_method is True:
self._write(2, 'item = %s()' % (contained_class_name))
self._write(2, 'item.%s' % (method_name))
self._write(2, "item.choice = '%s'" % (method_name))
self._write(2, "item.%s" % (method_name))
else:
self._write(2, 'item = %s(%s)' % (class_name, ', '.join(properties)))
params = []
for property in properties:
params.append('%s=%s' % (property, property))
self._write(2, 'item = %s(%s)' % (class_name, ', '.join(params)))
self._write(2, 'self._add(item)')
self._write(2, 'return self')
else:
self._write(1, '@property')
self._write(1, 'def %s(self):' % (method_name))
self._write(2, "# type: () -> %s" % (class_name))
self._write(2, '"""Factory method to create an instance of the %s class' % (class_name))
self._write(2, '"""Factory property that returns an instance of the %s class' % (class_name))
self._write()
self._write(2, '%s' % self._get_description(yobject))
self._write(2, '"""')
self._write(2, "if '%s' not in self._properties or self._properties['%s'] is None:" % (method_name, method_name))
self._write(3, "self._properties['%s'] = %s()" % (method_name, class_name))
self._write(2, 'self.choice = \'%s\'' % (method_name))
self._write(2, "return self._properties['%s']" % (method_name))
self._write(2, "return self._get_property('%s', %s(self, '%s'))" % (method_name, class_name, method_name))

def _get_property_param_string(self, yobject):
property_param_string = ''
Expand All @@ -443,7 +460,7 @@ def _get_property_param_string(self, yobject):
property_param_string += '=%s' % default
return (property_param_string, properties)

def _write_snappi_property(self, schema_object, name, property):
def _write_snappi_property(self, schema_object, name, property, write_set_choice=False):
ref = parse("$..'$ref'").find(property)
restriction = self._get_type_restriction(property)
if len(ref) > 0:
Expand All @@ -466,7 +483,7 @@ def _write_snappi_property(self, schema_object, name, property):
self._write(2, 'Returns: %s' % restriction)
self._write(2, '"""')
if len(parse("$..'type'").find(property)) > 0 and len(ref) == 0:
self._write(2, "return self._properties['%s']" % (name))
self._write(2, "return self._get_property('%s')" % (name))
self._write()
self._write(1, '@%s.setter' % name)
self._write(1, 'def %s(self, value):' % name)
Expand All @@ -476,28 +493,12 @@ def _write_snappi_property(self, schema_object, name, property):
self._write()
self._write(2, 'value: %s' % restriction)
self._write(2, '"""')
if name in self._get_choice_names(
schema_object) and name != 'choice':
self._write(2, "self._properties['choice'] = '%s'" % (name))
self._write(2, "self._properties['%s'] = value" % (name))
self._write(2, "self._set_property('%s', value)" % (name))
elif len(ref) > 0:
if restriction.startswith('list['):
self._write(
2,
"if '%s' not in self._properties or self._properties['%s'] is None:"
% (name, name))
self._write(
3,
"self._properties['%s'] = %sList()" % (name, class_name))
self._write(2, "return self._properties['%s']" % (name))
self._write(2, "return self._get_property('%s', %sList)" % (name, class_name))
else:
self._write(
2,
"if '%s' not in self._properties or self._properties['%s'] is None:"
% (name, name))
self._write(
3, "self._properties['%s'] = %s()" % (name, class_name))
self._write(2, "return self._properties['%s']" % (name))
self._write(2, "return self._get_property('%s', %s)" % (name, class_name))

def _get_description(self, yobject):
if 'description' not in yobject:
Expand Down
30 changes: 30 additions & 0 deletions snappi/tests/test_device_factory_methods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest


def test_device_factory_methods(api):
"""Test device factory methods
Device factory methods should populate saved structures
"""
config = api.config()

config = api.config()

param = ('name', 'container name', 11)
device = config.devices.device(name=param[0],
container_name=param[1],
device_count=param[2])[-1]
assert (device.name == param[0])
assert (device.container_name == param[1])
assert (device.device_count == param[2])

name = 'eth name'
eth = device.ethernet
eth.name = name
assert (eth.name == name)

print(config)


if __name__ == '__main__':
pytest.main(['-vv', '-s', __file__])

0 comments on commit 0a7d608

Please sign in to comment.