Skip to content

Commit

Permalink
Merge pull request #38 from puentesarrin/method_name_on_require
Browse files Browse the repository at this point in the history
 Add method name as argument for validating required fields with callable
  • Loading branch information
bodbdigr authored Jun 17, 2019
2 parents 05a47db + 17d7e19 commit e1a8bb9
Show file tree
Hide file tree
Showing 10 changed files with 62 additions and 41 deletions.
2 changes: 1 addition & 1 deletion restea/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.3.7'
__version__ = '0.3.9'
19 changes: 14 additions & 5 deletions restea/adapters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,17 @@ def prepare_response(self, content, status_code, content_type, headers):
'''
raise NotImplementedError

def get_original_request(self, *args, **kwargs):
def split_request_and_arguments(self, *args, **kwargs):
'''
Returns the original request object.
Hook to return the original request object and arguments.
This method receives all arguments that the `wrap_request` method
receives and return the first argument as is commonly received.
receives and return the first argument as the request object by
default which is commonly received in that order.
Override this method in your subclass wrapper if the behavior is
different for your framework.
'''
return args[0]
return args[0], args[1:], kwargs

def wrap_request(self, *args, **kwargs):
'''
Expand All @@ -65,7 +68,9 @@ def wrap_request(self, *args, **kwargs):
'''
data_format, kwargs = self._get_format_name(kwargs)
formatter = formats.get_formatter(data_format)
original_request = self.get_original_request(*args, **kwargs)
original_request, args, kwargs = self.split_request_and_arguments(
*args, **kwargs
)

if not self.request_wrapper_class:
raise RuntimeError(
Expand All @@ -78,6 +83,10 @@ def wrap_request(self, *args, **kwargs):
)
response_tuple = resource.dispatch(*args, **kwargs)

if len(response_tuple) == 3:
# For backward compatibility, it adds an empty dict as headers
response_tuple += ({},)

return self.prepare_response(*response_tuple)


Expand Down
2 changes: 1 addition & 1 deletion restea/adapters/djangowrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def prepare_response(self, content, status_code, content_type, headers):
response[name] = value
return response

def get_routes(self, path='', iden_format='(?P<iden>\w+)'):
def get_routes(self, path='', iden_format=r'(?P<iden>\w+)'):
'''
Prepare routes for the given REST resource
Expand Down
4 changes: 2 additions & 2 deletions restea/adapters/flaskwrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def app(self):
'''
return flask.current_app

def get_original_request(self, *args, **kwargs):
return flask.request
def split_request_and_arguments(self, *args, **kwargs):
return flask.request, args, kwargs

def prepare_response(self, content, status_code, content_type, headers):
response = flask.Response(
Expand Down
2 changes: 1 addition & 1 deletion restea/adapters/wheezywebwrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def prepare_response(self, content, status_code, content_type, headers):
response.headers.append((name, value))
return response

def get_routes(self, path='', iden_format='(?P<iden>\w+)'):
def get_routes(self, path='', iden_format=r'(?P<iden>\w+)'):
'''
Prepare routes for the given REST resource
Expand Down
19 changes: 13 additions & 6 deletions restea/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ def field_names(self):
'''
return set(self.fields.keys())

def get_required_field_names(self, data):
def get_required_field_names(self, method_name, data):
'''
Returns only required field names
:returns: required field names (from self.fields)
:rtype: set
'''
def is_required_field(field, data):
if callable(field.required):
return field.required(data)
return field.required(method_name, data)
else:
return field.required

Expand All @@ -56,9 +56,11 @@ def is_required_field(field, data):
if is_required_field(field, data)
)

def validate(self, data):
def validate(self, method_name, data):
'''
Validates payload input
:param method_name: name of the method
:type method_name: str
:param data: input playload data to be validated
:type data: dict
:raises restea.fields.FieldSet.Error: field validation failed
Expand All @@ -74,7 +76,10 @@ def validate(self, data):
continue
cleaned_data[name] = self.fields[name].validate(value)

for req_field in self.get_required_field_names(cleaned_data):
required_field_names = self.get_required_field_names(
method_name, cleaned_data
)
for req_field in required_field_names:
if req_field not in cleaned_data:
raise self.Error('Field "{}" is missing'.format(req_field))

Expand Down Expand Up @@ -277,8 +282,10 @@ class Email(String):
Email implements field validation for emails
'''
error_message = '"%s" is not a valid email'
pattern = r'^[_a-z0-9-]+(\.[_a-z0-9-]+)*@[a-z0-9-]+(\.[a-z0-9-]+)*' \
'(\.[a-z]{2,16})$'
pattern = (
r'^[_a-z0-9-]+(\.[_a-z0-9-]+)*@[a-z0-9-]+(\.[a-z0-9-]+)*'
r'(\.[a-z]{2,16})$'
)

def _validate_field(self, field_value):
if not re.match(self.pattern, field_value, re.IGNORECASE):
Expand Down
16 changes: 8 additions & 8 deletions restea/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,12 @@ def _get_method(self, method_name):
)
return getattr(type(self), method_name)

def _get_payload(self):
def _get_payload(self, method_name):
'''
Returns a validated and parsed payload data for request
:param method_name: name of the method
:type method_name: str
:raises restea.errors.BadRequestError: unparseable data
:raises restea.errors.BadRequestError: payload is not mappable
:raises restea.errors.BadRequestError: validation of fields not passed
Expand All @@ -197,7 +199,7 @@ def _get_payload(self):
)

try:
return self.fields.validate(payload_data)
return self.fields.validate(method_name, payload_data)
except fields.FieldSet.Error as e:
raise errors.BadRequestError(str(e))
except fields.FieldSet.ConfigurationError as e:
Expand All @@ -224,16 +226,13 @@ def process(self, *args, **kwargs):
if not self._is_valid_formatter:
raise errors.BadRequestError('Not recognizable format')

self.payload = self._get_payload()

self.prepare()

method_name = self._get_method_name(has_iden=bool(args or kwargs))
self.payload = self._get_payload(method_name)
method = self._get_method(method_name)
method = self._apply_decorators(method)

self.prepare()
response = method(self, *args, **kwargs)

response = self.finish(response)

try:
Expand All @@ -246,7 +245,8 @@ def dispatch(self, *args, **kwargs):
Dispatches the request and handles exception to return data, status
and content type
:returns: 3 element tuple: result, HTTP status code and content type
:returns: 4-element tuple: result, HTTP status code, content type, and
headers
:rtype: tuple
'''
try:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
setup(
name='restea',
packages=['restea', 'restea.adapters'],
version='0.3.8',
version='0.3.9',
description='Simple RESTful server toolkit',
long_description=readme_content,
author='Walery Jadlowski',
Expand Down
23 changes: 14 additions & 9 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,30 +49,35 @@ def test_field_set_required_fields():
fs, f1, f2 = create_field_set_helper()
f1.required = True
f2.required = False
assert fs.get_required_field_names({}) == set(['field1'])
assert fs.get_required_field_names('create', {}) == set(['field1'])


def test_field_set_required_fields_callable():
fs, f1, f2 = create_field_set_helper()

def foo(data):
def foo(method_name, data):
return data.get('field2') == 0
f1.required = foo
f2.required = False
assert fs.get_required_field_names({'field2': 0}) == set(['field1'])
assert fs.get_required_field_names({}) == set([])
required_field_names = fs.get_required_field_names('create', {'field2': 0})
assert required_field_names == set(['field1'])
required_field_names = fs.get_required_field_names('create', {})
assert required_field_names == set([])

f1.required = lambda data: data.get('field2') == 0
f1.required = lambda method_name, data: data.get('field2') == 0
f2.required = False
assert fs.get_required_field_names({'field2': 0}) == set(['field1'])
assert fs.get_required_field_names({}) == set([])
required_field_names = fs.get_required_field_names('create', {'field2': 0})
assert required_field_names == set(['field1'])
required_field_names = fs.get_required_field_names('create', {})
assert required_field_names == set([])


def test_field_set_validate():
fs, f1, f2 = create_field_set_helper()
f1.validate.return_value = 1
f2.validate.return_value = 2
res = fs.validate({'field1': '1', 'field2': '2', 'field3': 'wrong!'})
payload = {'field1': '1', 'field2': '2', 'field3': 'wrong!'}
res = fs.validate('create', payload)

assert res == {'field1': 1, 'field2': 2}
f1.validate.assert_called_with('1')
Expand All @@ -84,7 +89,7 @@ def test_feild_set_validate_requred_fields_missing():
f1.requred = True

with pytest.raises(FieldSet.Error) as e:
fs.validate({'field2': '2'})
fs.validate('create', {'field2': '2'})
assert 'Field "field1" is missing' in str(e)


Expand Down
14 changes: 7 additions & 7 deletions tests/test_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def test_get_payload_should_pass_validation():
resource.fields = mock.Mock()
resource.fields.validate.return_value = expected_data

assert resource._get_payload() == expected_data
assert resource._get_payload('edit') == expected_data


def test_get_payload_unexpected_data():
Expand All @@ -215,7 +215,7 @@ def test_get_payload_unexpected_data():
formatter_mock.unserialize.side_effect = formats.LoadError()

with pytest.raises(errors.BadRequestError) as e:
resource._get_payload()
resource._get_payload('edit')
assert 'Fail to load the data' in str(e)


Expand All @@ -226,7 +226,7 @@ def test_get_payload_not_mapable_payload():
formatter_mock.unserialize.return_value = ['item']

with pytest.raises(errors.BadRequestError) as e:
resource._get_payload()
resource._get_payload('edit')
assert 'Data should be key -> value structure' in str(e)


Expand All @@ -243,7 +243,7 @@ def test_get_payload_field_validation_fails():
)

with pytest.raises(errors.BadRequestError) as e:
resource._get_payload()
resource._get_payload('edit')
assert field_error_message in str(e)


Expand All @@ -261,21 +261,21 @@ def test_get_payload_field_misconfigured_fields_fails():
resource.fields.validate.side_effect = conf_error

with pytest.raises(errors.ServerError) as e:
resource._get_payload()
resource._get_payload('edit')
assert configuration_error_message in str(e)


def test_get_payload_field_validation_no_data_empty_payload():
resource, _, _ = create_resource_helper(method='POST')
assert {} == resource._get_payload()
assert {} == resource._get_payload('create')


def test_get_payload_validation_no_fields_case_empty_payload():
resource, _, formatter_mock = create_resource_helper(
method='PUT', data='data'
)
formatter_mock.unserialize.return_value = {'data': 'test'}
assert {} == resource._get_payload()
assert {} == resource._get_payload('edit')


@patch.object(formats.JsonFormat, 'serialize')
Expand Down

0 comments on commit e1a8bb9

Please sign in to comment.