From 23c16bbcd4fa213196312c2a3b9fcb26dc166edd Mon Sep 17 00:00:00 2001 From: "Chris (Someguy123)" Date: Sat, 26 Sep 2020 05:29:17 +0100 Subject: [PATCH] 3.0.0 - Overhaul 'net', new object cleaner, class generation/mocking + more **Key Additions and Changes** - `privex.helpers.common` - Added `strip_null` - very simple helper function to strip both `\00` and white space from a string - with 2 cycles for good measure. - `privex.helpers.types` - Added `AUTO` / `AUTOMATIC` / `AUTO_DETECTED` dummy type, for use as the default value of function/method parameters, signalling to users that a parameter is auto-populated from another data source (e.g. instance/class attribute) if not specified. - `privex.helpers.collections` - Added `copy_class_simple` (alternative to `copy_class`) - Added `copy_func` for copying functions, methods, and classmethods - Improved `_q_copy` to handle copying functions, methods and classmethods - Added `generate_class` + `generate_class_kw` - Added `Mocker.make_mock_module` - Added `Mocker.add_mock_modules` - Added `Mocker.__dir__` to track the available mock attributes and modules - Added `dataclasses_mock` - a `Mocker` instance which emulates `dataclasses` as a drop-in partially functional dummy for Python 3.6 when the `dataclasses` backport package isn't installed. - Various changes to `Mocker.make_mock_class` - potentially breaking, see the **BREAKING CHANGES** section. - Added `DictObject.__dir__` + `OrderedDictObject.__dir__` to enable proper tracking of dictionary keys as attributes - `privex.helpers.net` - This module has now been converted into a folder-based module. Imports in `__init__.py` have been carefully setup to ensure that existing import statements should still work as normal - Added new `SocketWrapper` and `AsyncSocketWrapper` classes, which are powerful wrapper classes for working with Python `socket.socket` objects, including support for SSL/TLS, partial support for running socket servers, and\ making basic HTTP requests - **Many, many new functions and classes!** There's too many to list, and due to the conversion into a module folder instead of a singular file, it's difficult to track which functions/classes are new, and which existed before. If you really want to know what's new, just take a look around the `privex/helpers/net` module. - `privex.helpers.converters` - Added `clean_obj` - which is a function that recursively "cleans" any arbitrary object, as to make it safe to convert into JSON and other common serialisation formats. It supports `dict`'s, `list`'s, [attrs](https://attrs.org) objects, native Python `dataclass`'s, `Decimal`, and many other types of objects. - Added `clean_dict` (used by `clean_obj`, usually no need to call it directly) - Added `clean_list` (used by `clean_obj`, usually no need to call it directly) - Added `privex.helpers.mockers` module, which contains pre-made `Mocker` objects that are designed to stand-in for certain libraries / classes as partially functional dummies, if the real module(s) are unavailable for whatever reason. - **And probably some other small additions / changes** **BREAKING CHANGES** - Both `_copy_class_dict` and `_copy_class_slotted` now check each attribute name against a blacklist (default: `COPY_CLASS_BLACKLIST`), and the default blacklist contains `__dict__`, `__slots__` and `__weakref__`, as the first 2 can't be directly copied (but we copy their contents by iteration), and weakref simply can't be deep copied (and it probably isn't a good idea to copy it anyway). - `_copy_class_dict` (used by `copy_class`) no longer breaks the attribute copy loop if `deep_copy=False` - `Mocker.make_mock_class` now returns a cloned `Mocker` class or instance by default, instead of a barebones class / instance of a barebones class. This was done simply because a Mocker class/instance is designed to handle being instantiated with any combination of constructor arguments, and have arbitrary attributes be retrieved / methods called without raising errors. If you absolutely require a plain, simple, empty class to be generated, you may pass the parameter `simple=True` to generate a bare class instead of a clone of Mocker (similar to the old behaviour). Unlike the old version of this method, you can now specify attributes as a dictionary to make your barebones mock class act similar to the class it's mocking. - Many things in `privex.helpers.net` such as `check_host` / `check_host_async` have been improved in various ways, however there may be some breaking changes with certain `privex.helpers.net` functions/classes in certain usecases. - Due to the high risk of bugs with certain networking functions that have been completely revamped, the older, simpler versions of various networking functions are available under `privex.helpers.net.base` with their original names. Because of the naming conflicts, to use the legacy functions/classes from `base`, you must import them directly from `privex.helpers.net.base` like so: ``` # Option 1: import the base module itself, with an alias to prevent naming conflicts (and make it more # clear what you're referencing) from privex.helpers.net import base as netbase if netbase.check_host('google.com', 80): print('google.com is up') # Option 2: import the required legacy functions directly (optionally, you can alias them as needed) # You could also alias the newer overhauled functions while testing them in small portions # of your application. from privex.helpers.net.base import check_host from privex.helpers.net import check_host as new_check_host if check_host('www.privex.io', 443, http_test=True, use_ssl=True): print('[old check_host] https://www.privex.io is up') if new_check_host('files.privex.io', 443, http_test=True, use_ssl=True): print('[new check_host] https://files.privex.io is up') ``` --- CHANGELOG.md | 103 ++ privex/helpers/__init__.py | 2 +- privex/helpers/collections.py | 656 ++++++--- privex/helpers/common.py | 16 +- privex/helpers/converters.py | 217 ++- privex/helpers/exceptions.py | 4 + privex/helpers/mockers.py | 40 + privex/helpers/net/__init__.py | 30 + privex/helpers/net/base.py | 197 +++ privex/helpers/net/common.py | 513 +++++++ privex/helpers/{net.py => net/dns.py} | 285 +--- privex/helpers/net/socket.py | 1796 +++++++++++++++++++++++++ privex/helpers/net/util.py | 172 +++ privex/helpers/settings.py | 87 ++ privex/helpers/types.py | 11 +- setup.py | 2 +- tests/test_cache.py | 5 +- tests/test_collections.py | 8 +- tests/test_converters.py | 120 ++ tests/test_crypto.py | 3 +- tests/test_geoip.py | 3 +- tests/test_net.py | 66 +- 22 files changed, 3892 insertions(+), 444 deletions(-) create mode 100644 privex/helpers/mockers.py create mode 100644 privex/helpers/net/__init__.py create mode 100644 privex/helpers/net/base.py create mode 100644 privex/helpers/net/common.py rename privex/helpers/{net.py => net/dns.py} (73%) create mode 100644 privex/helpers/net/socket.py create mode 100644 privex/helpers/net/util.py diff --git a/CHANGELOG.md b/CHANGELOG.md index b587380..e6e671a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,108 @@ ----------------------------------------------------------------------------------------------------------------------- +3.0.0 - Overhauled net module, new object cleaner for easier serialisation, improved class generation/mocking + more +==================================================================================================================== + +----------------------------------------------------------------------------------------------------------------------- + +- `privex.helpers.common` + - Added `strip_null` - very simple helper function to strip both `\00` and white space + from a string - with 2 cycles for good measure. + +- `privex.helpers.types` + - Added `AUTO` / `AUTOMATIC` / `AUTO_DETECTED` dummy type, for use as the default value + of function/method parameters, signalling to users that a parameter is auto-populated + from another data source (e.g. instance/class attribute) if not specified. + +- `privex.helpers.collections` + - Added `copy_class_simple` (alternative to `copy_class`) + - Added `copy_func` for copying functions, methods, and classmethods + - Improved `_q_copy` to handle copying functions, methods and classmethods + - Added `generate_class` + `generate_class_kw` + - Added `Mocker.make_mock_module` + - Added `Mocker.add_mock_modules` + - Added `Mocker.__dir__` to track the available mock attributes and modules + - Added `dataclasses_mock` - a `Mocker` instance which emulates `dataclasses` as a drop-in + partially functional dummy for Python 3.6 when the `dataclasses` backport package isn't installed. + - Various changes to `Mocker.make_mock_class` - potentially breaking, see + the **BREAKING CHANGES** section. + - Added `DictObject.__dir__` + `OrderedDictObject.__dir__` to enable proper tracking of dictionary keys as attributes + +- `privex.helpers.net` + - This module has now been converted into a folder-based module. Imports in `__init__.py` have been carefully + setup to ensure that existing import statements should still work as normal + - Added new `SocketWrapper` and `AsyncSocketWrapper` classes, which are powerful wrapper classes for working with + Python `socket.socket` objects, including support for SSL/TLS, partial support for running socket servers, and\ + making basic HTTP requests + - **Many, many new functions and classes!** There's too many to list, and due to the conversion into a module folder + instead of a singular file, it's difficult to track which functions/classes are new, and which existed before. + + If you really want to know what's new, just take a look around the `privex/helpers/net` module. + +- `privex.helpers.converters` + - Added `clean_obj` - which is a function that recursively "cleans" any arbitrary object, as to make it safe to convert + into JSON and other common serialisation formats. It supports `dict`'s, `list`'s, [attrs](https://attrs.org) + objects, native Python `dataclass`'s, `Decimal`, and many other types of objects. + - Added `clean_dict` (used by `clean_obj`, usually no need to call it directly) + - Added `clean_list` (used by `clean_obj`, usually no need to call it directly) + +- Added `privex.helpers.mockers` module, which contains pre-made `Mocker` objects that are designed to stand-in + for certain libraries / classes as partially functional dummies, if the real module(s) are unavailable for whatever reason. + +- **And probably some other small additions / changes** + + +**BREAKING CHANGES** + +- Both `_copy_class_dict` and `_copy_class_slotted` now check each attribute name + against a blacklist (default: `COPY_CLASS_BLACKLIST`), and the default blacklist + contains `__dict__`, `__slots__` and `__weakref__`, as the first 2 can't be directly + copied (but we copy their contents by iteration), and weakref simply can't be deep copied + (and it probably isn't a good idea to copy it anyway). +- `_copy_class_dict` (used by `copy_class`) no longer breaks the attribute copy loop if `deep_copy=False` + +- `Mocker.make_mock_class` now returns a cloned `Mocker` class or instance by default, instead of + a barebones class / instance of a barebones class. + + This was done simply because a Mocker class/instance is designed to handle being + instantiated with any combination of constructor arguments, and have arbitrary + attributes be retrieved / methods called without raising errors. + + If you absolutely require a plain, simple, empty class to be generated, you may + pass the parameter `simple=True` to generate a bare class instead of a clone of Mocker + (similar to the old behaviour). Unlike the old version of this method, you can now specify attributes + as a dictionary to make your barebones mock class act similar to the class it's mocking. + +- Many things in `privex.helpers.net` such as `check_host` / `check_host_async` have been improved in various ways, however + there may be some breaking changes with certain `privex.helpers.net` functions/classes in certain usecases. + - Due to the high risk of bugs with certain networking functions that have been completely revamped, the + older, simpler versions of various networking functions are available under `privex.helpers.net.base` + with their original names. + + Because of the naming conflicts, to use the legacy functions/classes from `base`, you must import them + directly from `privex.helpers.net.base` like so: + + ``` + # Option 1: import the base module itself, with an alias to prevent naming conflicts (and make it more + # clear what you're referencing) + from privex.helpers.net import base as netbase + if netbase.check_host('google.com', 80): + print('google.com is up') + # Option 2: import the required legacy functions directly (optionally, you can alias them as needed) + # You could also alias the newer overhauled functions while testing them in small portions + # of your application. + from privex.helpers.net.base import check_host + from privex.helpers.net import check_host as new_check_host + + if check_host('www.privex.io', 443, http_test=True, use_ssl=True): + print('[old check_host] https://www.privex.io is up') + if new_check_host('files.privex.io', 443, http_test=True, use_ssl=True): + print('[new check_host] https://files.privex.io is up') + ``` + + +----------------------------------------------------------------------------------------------------------------------- + 2.8.0 - Refactoring, bug fixes + new loop_run function =============================================================================================== diff --git a/privex/helpers/__init__.py b/privex/helpers/__init__.py index cfb4c19..24d6afe 100644 --- a/privex/helpers/__init__.py +++ b/privex/helpers/__init__.py @@ -148,7 +148,7 @@ def _setup_logging(level=logging.WARNING): log = _setup_logging() name = 'helpers' -VERSION = '2.19.0' +VERSION = '3.0.0rc1' diff --git a/privex/helpers/collections.py b/privex/helpers/collections.py index 7ef846d..018b389 100644 --- a/privex/helpers/collections.py +++ b/privex/helpers/collections.py @@ -190,19 +190,372 @@ """ import copy +import functools import inspect +import os import sys +import types +from os.path import dirname, abspath from collections import namedtuple, OrderedDict from json import JSONDecodeError from types import MemberDescriptorType -from typing import Dict, Optional, NamedTuple, Union, Type, List, Generator, Iterable, TypeVar -from privex.helpers.types import T, K +from typing import Any, Callable, Dict, Optional, NamedTuple, Union, Type, List, Generator, Iterable, TypeVar + +# from privex.helpers.decorators import mock_decorator + +from privex.helpers.types import AUTO, T, K import logging import warnings log = logging.getLogger(__name__) +def _mock_decorator(*dec_args, **dec_kwargs): + def _decorator(f): + @functools.wraps(f) + def wrapper(*args, **kwargs): + return f(*args, **kwargs) + return wrapper + return _decorator + + +def generate_class( + name: str, qualname: str = None, module: str = None, bases: Union[tuple, list] = None, + attributes: Dict[str, Any] = None, **kwargs +) -> Any: + """ + A small helper function for dynamically generating classes / types. + + **Basic usage** + + Generating a simple class, with an instance constructor, a basic instance method, and an instance factory classmethod:: + + >>> import random + >>> from privex.helpers.collections import generate_class + >>> def hello_init(self, example: int): + ... self.example = example + ... + >>> Hello = generate_class( + ... 'Hello', module='hello', + ... attributes=dict( + ... __init__=hello_init, lorem=lambda self: self.example * 10, + ... make_hello=classmethod(lambda cls: cls(random.randint(1, 100))) + ... ) + ... ) + ... + >>> h = Hello(123) + >>> h.lorem() + 1230 + >>> j = Hello.make_hello() + >>> j.example + 77 + >>> j.lorem() + 770 + + Generating a child class which inherits from an existing class (the parent(s) can also be a generated classes):: + + >>> World = generate_class( + ... 'World', module='hello', bases=(Hello,), attributes=dict(ipsum=lambda self: float(self.example) / 3) + ... ) + >>> w = World(130) + >>> w.lorem() + 1300 + >>> w.ipsum() + 43.333333333333336 + + :param str name: The name of the class, e.g. ``Hello`` + :param str qualname: (Optional) The qualified name of the class, e.g. for nested classes ``A -> B -> C``, class ``C`` + would have the ``__name__``: ``C`` and ``__qualname__``: ``A.B.C`` + :param str module: (Optional) The module the class should appear to belong to (sets ``__module__``) + :param tuple|list bases: (Optional) A tuple or list of "base" / "parent" classes for inheritance. + :param dict attributes: (Optional) A dictionary of attributes to add to the class. (can include constructor + methods) + :param kwargs: + :return: + """ + qualname = name if qualname is None else qualname + bases = (object,) if bases is None else (tuple(bases) if not isinstance(bases, tuple) else bases) + attributes = {} if attributes is None else attributes + attributes['__module__'] = attributes.get('__module__', module) + # kwargs = dict(kwargs) + + x = type(name, bases, attributes) + x.__name__ = name + x.__qualname__ = qualname + x.__module__ = module + + return x + + +def generate_class_kw(name: str, qualname: str = None, module: str = None, bases: Union[tuple, list] = None, **kwargs) -> Type: + """ + Same as :func:`.generate_class`, but instead of a :class:`dict` ``attributes`` parameter - all additional keyword arguments + will be used for ``attributes`` + + **Example**:: + + >>> def lorem_init(self, ipsum=None): + ... self._ipsum = ipsum + ... + >>> Lorem = generate_class_kw('Lorem', + ... __init__=lorem_init, hello=staticmethod(lambda: 'world'), + ... ipsum=property(lambda self: 0 if self._ipsum is None else self._ipsum) + ... ) + >>> l = Lorem() + >>> l.ipsum() + 0 + >>> l.hello() + 'world' + + """ + return generate_class(name, qualname, module, bases, attributes=dict(kwargs)) + + +def copy_func(f: Callable, rewrap_classmethod=True, name=None, qualname=None, module=AUTO, **kwargs) -> Union[Callable, classmethod]: + """Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)""" + if isinstance(f, classmethod): + fn = copy_func(f.__func__) + return classmethod(fn) if rewrap_classmethod else fn + g = types.FunctionType(f.__code__, f.__globals__, name=name if name is not None else f.__name__, + argdefs=f.__defaults__, + closure=f.__closure__) + g = functools.update_wrapper(g, f) + g.__kwdefaults__ = f.__kwdefaults__ + g.__qualname__ = g.__name__ if qualname is None else qualname + g.__module__ = getattr(f, '__module__') if module is AUTO else module + return g + + +def _q_copy(obj: K, key: str = None, deep_private: bool = False, quiet: bool = False, fail: bool = False, **kwargs) -> K: + should_copy = kwargs.pop('should_copy', True) + if not should_copy: + log.debug("Not copying object '%s' (key '%s') as should_copy is False", repr(obj), key) + return obj + # By default, deep_private is false, which means we avoid deep copying any attributes which have keys starting with __ + # This is because they're usually objects/types we simply can't deepcopy without issues. + if not deep_private and key is not None and key.startswith('__'): + log.debug("Not deep copying key '%s' as deep_private is true and key begins with __", key) + return obj + use_copy_func = kwargs.get('use_copy_func', True) + try: + if use_copy_func and any([inspect.isfunction(obj), inspect.ismethod(obj), isinstance(obj, classmethod)]): + copied = copy_func(obj) + else: + copied = copy.deepcopy(obj) + return copied + except Exception as ex: + if fail: + raise ex + log_args = "Exception while deep copying object %s ( %s ) - using normal ref. Ex: %s %s", key, obj, type(ex), str(ex) + if quiet: + log.debug(*log_args) + else: + log.warning(*log_args) + return obj + + +COPY_CLASS_BLACKLIST = [ + '__dict__', '__slots__', '__weakref__' +] + + +def _copy_class_dict(obj: Type[T], name, deep_copy=True, deep_private=False, **kwargs) -> Union[Type[T], type]: + """ + Internal function used by :func:`.copy_class` + + Make a deep copy of the :class:`.type` / class ``obj`` (for standard classes, which use ``__dict__``) + """ + orig_dict, filt_dict = dict(obj.__dict__), {} + attr_blacklist = kwargs.get('blacklist', COPY_CLASS_BLACKLIST) + # Try to deep copy each attribute's value from the original __dict__ in 'orig_dict' into 'filt_dict'. + # Attributes that are private (start with '__') will use standard references by default (if deep_private is False), + # along with any attributes that fail to be deep copied. + for k, v in orig_dict.items(): + # if not deep_copy: break + if k in attr_blacklist: continue + filt_dict[k] = v if not deep_copy else _q_copy(v, key=k, deep_private=deep_private, **kwargs) + + bases = kwargs.get('bases', obj.__bases__ if kwargs.get('use_bases', True) else (object,)) + return type(name, bases, filt_dict if deep_copy else orig_dict) + + +def _copy_class_slotted(obj: Type[T], name, deep_copy=True, deep_private=False, **kwargs) -> Union[Type[T], type]: + """ + Internal function used by :func:`.copy_class` + + Make a deep copy of the :class:`.type` / class ``obj`` (for slotted classes, i.e. those which use ``__slots__``) + + Based on a StackOverflow answer by user ``nkpro``: https://stackoverflow.com/a/61823543/2648583 + """ + slots = obj.__slots__ if type(obj.__slots__) != str else (obj.__slots__,) + orig_dict, slotted_members = {}, {} + attr_blacklist = kwargs.get('blacklist', COPY_CLASS_BLACKLIST) + + for k, v in obj.__dict__.items(): + if k in attr_blacklist: continue + dcval = v if not deep_copy else _q_copy(v, k, deep_private=deep_private, **kwargs) + if k not in slots: + orig_dict[k] = dcval + elif type(v) != MemberDescriptorType: + slotted_members[k] = dcval + + bases = kwargs.get('bases', obj.__bases__ if kwargs.get('use_bases', True) else (object,)) + new_obj = type(name, bases, orig_dict) + for k, v in slotted_members.items(): + setattr(new_obj, k, v) + + return new_obj + + +DEFAULT_ALLOWED_DUPE = [ + '__annotations__', '__doc__', '__init__', '__getattr__', '__setattr__', '__len__', '__sizeof__', + '__getattribute__', '__getitem__', '__setitem__', '__delitem__', '__str__', '__repr__', + '__del__', '__delattr__', '__enter__', '__exit__', '__aenter__', '__aexit__', + '__next__', '__iter__', '__hash__', '__call__', '__dir__', '__get__', '__set__', '__contains__' + '__add__', '__sub__', '__mul__', '__floordiv__', '__div__', '__mod__', '__pow__', + '__eq__', '__ne__', '__lt__', '__gt__', '__le__', '__ge__', '__cmp__', + '__copy__', '__deepcopy__', '__getstate__', '__setstate__', '__new__', +] + + +# +# class Asdf: +# def __ + +def copy_class_simple(obj: Type[T], name=None, qualname=None, module=AUTO, allow_attrs: list = None, ban_attrs: list = None, **kwargs): + """ + This is an alternative to :func:`.copy_class` which simply creates a blank class, then iterates over ``obj.__dict__``, + using :func:`setattr` to copy each attribute over to the cloned class. + + It uses :func:`._q_copy` to safely deep copy any attributes which are object references, and thus need their reference pointers + severed, to avoid edits to the copy affecting the original (and vice versa). + + + :param obj: The class to duplicate + :param str name: The class name to set on the duplicate. If left as ``None``, the duplicate will retain the original ``obj`` name. + :param str qualname: The qualified class name to set on the duplicate. + :param Optional[str] module: The module path to set on the duplicate, e.g. ``privex.helpers.common`` + :param list allow_attrs: Optionally, you may specify additional private attributes (ones which start with ``__``) that + are allowed to be copied from the original class to the duplicated class. + :param list ban_attrs: Optionally, you may blacklist certain attributes from being copied from the original class to the duplicate. + + Blacklisted attributes take priority over whitelisted attributes, so you may use this to cancel out + any attributes in the default attribute whitelist :attr:`.DEFAULT_ALLOWED_DUPE` which you don't want + to be copied to the duplicated class. + :param kwargs: + :keyword tuple|list bases: If specified, overrides the default inherited classes (``obj.__bases__``) which would be + set on the duplicated class's ``__bases__``. + :return: + """ + kwargs = dict(kwargs) + + name = name if name is not None else obj.__name__ + qualname = qualname if qualname is not None else name + module = module if module is not AUTO else obj.__module__ + + mkr = generate_class( + name, qualname, module, bases=kwargs.pop('bases', getattr(obj, '__bases__')), attributes=kwargs.pop('dict_attrs', None) + ) + allow_attrs = [] if allow_attrs is None else allow_attrs + ban_attrs = [] if ban_attrs is None else ban_attrs + allow_attrs += [d for d in DEFAULT_ALLOWED_DUPE if d not in allow_attrs] + + for k, v in obj.__dict__.items(): + if k in ban_attrs or (k.startswith('__') and k not in allow_attrs): continue + try: + setattr(mkr, k, _q_copy(v, **kwargs)) + except Exception: + log.exception(f"Failed to copy attribute {k} (value: {v}) to duplicated {obj.__name__} class.") + # if kwargs.get('bases') is not None: + # bases = kwargs.pop('bases') + # mkr.__bases__ = bases if isinstance(bases, tuple) else tuple(bases) + if kwargs.get('str_func') is not None: + str_func = kwargs.pop('str_func') + mkr.__str__ = (lambda self: str_func) if isinstance(str_func, str) else str_func + + # mkr.__name__ = name if name is not None else obj.__name__ + # mkr.__qualname__ = qualname if qualname is not None else mkr.__name__ + # mkr.__module__ = module if module is not None else obj.__module__ + return mkr + + +def copy_class(obj: Type[T], name=None, deep_copy=True, deep_private=False, **kwargs) -> Union[Type[T], type]: + """ + Attempts to create a full copy of a :class:`.type` or class, severing most object pointers such as attributes containing a + :class:`.dict` / :class:`.list`, along with classes or instances of classes. + + Example:: + + >>> class SomeClass: + >>> example = 'lorem ipsum' + >>> data = ['hello', 'world'] + >>> testing = 123 + >>> + >>> from privex.helpers import copy_class + >>> OtherClass = copy_class(SomeClass, name='OtherClass') + + If you then append to the :class:`.list` attribute ``data`` on both SomeClass and OtherClass - with a different item + appended to each class, you'll see that the added item was only added to ``data`` for that class, and not to the other class, + proving the original and the copy are independent from each other:: + + >>> SomeClass.data.append('lorem') + >>> OtherClass.data.append('ipsum') + >>> SomeClass.data + ['hello', 'world', 'lorem'] + >>> OtherClass.data + ['hello', 'world', 'ipsum'] + + + :param Type[T] obj: A :class:`.type` / class to attempt to duplicate, deep copying each individual object in the class, to + avoid any object pointers shared between the original and the copy. + :param str|None name: The class name to use for the copy of ``obj``. If not specified, defaults to the original class name from ``obj`` + :param bool deep_copy: (Default: ``True``) If True, uses :func:`copy.deepcopy` to deep copy each attribute in ``obj`` to the copy. + If False, then standard references will be used, which may result in object pointers being copied. + :param bool deep_private: (Default: ``False``) If True, :func:`copy.deepcopy` will be used on "private" class attributes, + i.e. ones that start with ``__``. If False, attributes starting with ``__`` will not be deep copied, + only a standard assignment/reference will be used. + :param kwargs: Additional advanced settings (see ``keyword`` pydoc entries for this function) + + :keyword bool use_bases: (Default: ``True``) If True, copy the inheritance (bases) from ``obj`` into the class copy. + :keyword bool quiet: (Default ``False``) If True, log deep copy errors as ``debug`` level (usually silent in production apps) + instead of the louder ``warning``. + :keyword tuple bases: A :class:`.tuple` of classes to use as "bases" (inheritance) for the class copy. If not specified, + copies ``__bases__`` from the original class. + :keyword str module: If specified, overrides the module ``__module__`` in the class copy with this string, instead of copying from + the original class. + :return Type[T] obj_copy: A deep copy of the original ``obj`` + """ + # If no class name was passed as an attribute, then we copy the name from the original class. + if not name: + name = obj.__name__ + + # Depending on whether 'obj' is a normal class using __dict__, or a slotted class using __slots__, we need to handle the + # deep copying of the class differently. + if hasattr(obj, '__slots__'): + new_obj = _copy_class_slotted(obj, name=name, deep_copy=deep_copy, deep_private=deep_private, **kwargs) + else: + new_obj = _copy_class_dict(obj, name=name, deep_copy=deep_copy, deep_private=deep_private, **kwargs) + + # Override the module path string if the user specified 'module' as a kwarg + module = kwargs.get('module') + if module is not None: + new_obj.__module__ = module + return new_obj + + +def _create_mocker_copy(name=None, **kwargs) -> Union[type, Type["Mocker"]]: + kwargs['dict_attrs'] = kwargs.get('dict_attrs', {}) + + def n(self): return self.__class__.__name__ + def m(self): return self.__class__.__module__ + + nm = copy_func(n, name='__name__', qualname=f"{name}.__name__", module=kwargs.get('module', AUTO)) + # mm = copy_func(m, name='__module__', qualname=f"{name}.__module__", module=kwargs.get('module', AUTO)) + kwargs['dict_attrs']['__name__'] = property(nm) + # kwargs['dict_attrs']['__module__'] = property(mm) + + return copy_class_simple(Mocker, name=name, **kwargs) + + class Mocker(object): """ This mock class is designed to be used either to act as a stand-in "noop" (no operation) object, which @@ -299,15 +652,42 @@ class Mocker(object): """ + # _ALLOWED_DUPE = [ + # '__module__', '__annotations__', '__doc__', '__init__', '__getattr__', '__setattr__', '__getitem__', + # '__setitem__' + # ] mock_modules: dict mock_attrs: dict - def __init__(self, modules: dict = None, attributes: dict = None): + def __init__(self, modules: dict = None, attributes: dict = None, *args, **kwargs): self.mock_attrs = {} if attributes is None else attributes self.mock_modules = {} if modules is None else modules @classmethod - def make_mock_class(cls, name='Mocker', instance=True, **kwargs): + def make_mock_module(cls, mod_name: str, attributes: dict = None, modules: dict = None, built_in=False, **kwargs): + mod_base = kwargs.pop('module_base', _module_dir() if built_in else dirname(dirname(dirname(abspath(__file__))))) + mod_file = os.path.join(mod_base, kwargs.pop('module_file', os.path.join(*mod_name.split('.'), '__init__.py'))) + fix_funcs = kwargs.get('fix_funcs', True) + + attributes = {} if not attributes else attributes + + if fix_funcs: + for k, v in attributes.items(): + if inspect.isfunction(v) or inspect.ismethod(v): + v.__module__ = mod_name + modrep = f"" + def x(self): return modrep + + # copy_func(x, name='__str__', qualname='__str__', module=mod_name) + attributes['__repr__'] = attributes.get('__repr__', copy_func(x, name='__repr__', qualname='__repr__', module=mod_name)) + attributes['__str__'] = attributes.get('__str__', copy_func(x, name='__str__', qualname='__str__', module=mod_name)) + attributes['__file__'] = attributes.get('__file__', mod_file) + return cls.make_mock_class('module', attributes=attributes, modules=modules, **kwargs) + + @classmethod + def make_mock_class( + cls, name='Mocker', instance=True, simple=False, attributes: dict = None, modules: dict = None, **kwargs + ) -> Union[Any, "Mocker", Type["Mocker"]]: """ Return a customized mock class or create an instance which appears to be named ``name`` @@ -339,8 +719,17 @@ def make_mock_class(cls, name='Mocker', instance=True, **kwargs): :param name: The name to write onto the mock class's ``__name__`` (and ``__qualname__`` if not specified) :param bool instance: If ``True`` then the disguised mock class will be returned as an instance. Otherwise the raw class itself will be returned for you to instantiate yourself. + :param bool simple: When ``True``, generates a very basic, new class - not based on :class:`.Mocker`, which contains + the attributes/methods defined in the param ``attributes``. :param kwargs: All kwargs (other than ``qualname``) are forwarded to ``__init__`` of the disguised class if ``instance`` is True. + :param dict attributes: If ``simple`` is True, then this dictionary of attributes is used to generate the class's + attributes, methods, and/or constructor. + If ``simple`` is False, and ``instance`` is True, these attributes are passed to the constructor + of the :class:`.Mocker` clone that was generated. + :param dict modules: If ``simple`` is False, and ``instance`` is True, this dict of modules are passed to the constructor + of the :class:`.Mocker` clone that was generated. + :key str qualname: Optionally specify the "qualified name" to insert into ``__qualname__``. If this isn't specified, then ``name`` is used for qualname, which is fine for most cases anyway. :key str module: Optionally override the module namespace that the class is supposedly from. If not specified, @@ -348,18 +737,26 @@ def make_mock_class(cls, name='Mocker', instance=True, **kwargs): :return: """ qualname = kwargs.pop('qualname', name) + mod_name = kwargs.pop('module', __name__) + attributes = {} if attributes is None else attributes + modules = {} if modules is None else modules + if simple: + c = generate_class( + name, qualname=qualname, module=mod_name, bases=kwargs.pop('bases', None), attributes=attributes + ) + return c(**kwargs) if instance else c + c = _create_mocker_copy(name, deep_private=True, module=mod_name, qualname=qualname) + + # OuterMocker.__name__ = name - class OuterMocker(cls): - pass - - OuterMocker.__name__ = name - OuterMocker.__qualname__ = qualname + # if mod_name is not None: + # OuterMocker.__module__ = kwargs['module'] + str_func = attributes.pop('__str__', None) + repr_func = attributes.pop('__repr__', None) + if str_func is not None: c.__str__ = str_func + if repr_func is not None: c.__repr__ = repr_func + return c(modules=modules, attributes=attributes, **kwargs) if instance else c - if 'module' in kwargs: - OuterMocker.__module__ = kwargs['module'] - - return OuterMocker() if instance else OuterMocker - def add_mock_module(self, name: str, value=None, mock_attrs: dict = None, mock_modules: dict = None): """ Add a fake sub-module to this Mocker instance. @@ -385,18 +782,68 @@ def add_mock_module(self, name: str, value=None, mock_attrs: dict = None, mock_m self.mock_modules[name] = Mocker(modules=mock_modules, attributes=mock_attrs) if value is None else value - def __getattribute__(self, item): + def add_mock_modules(self, *module_list, _dict_to_attrs=True, _parse_dict=True, **module_map): + """ + + >>> hello = Mocker.make_mock_class('Hello') + >>> hello.add_mock_modules( + ... world={ + ... 'lorem': 'ipsum', + ... 'dolor': 123, + ... } + ... ) + + :param module_list: + :param _parse_dict: + :param _dict_to_attrs: + :param module_map: + :return: + """ + module_map = dict(module_map) + for m in module_list: + log.debug("Adding simple mock module from module_list: %s", m) + self.add_mock_module(m) + for k, v in module_map.items(): + m_val, m_attrs, m_modules = v, {}, {} + if isinstance(v, dict): + if _parse_dict: + _m_val = None + if 'value' in v: _m_val = v['value'] + if 'attrs' in v: + log.debug("Popping 'attrs' from kwarg '%s' value as attributes for module: %s", k, v) + m_attrs = {**m_attrs, **v.pop('attrs')} + if 'modules' in v: + log.debug("Popping 'modules' from kwarg '%s' value as attributes for module: %s", k, v) + m_modules = {**m_modules, **v.pop('modules')} + if not _dict_to_attrs or not all([_m_val is None, m_attrs is None, m_modules is None]): + log.debug("Setting module value to value of kwarg '%s': %s", k, v) + m_val = _m_val if m_attrs is None and m_modules is None else None + if _dict_to_attrs: + log.debug("Importing kwarg '%s' value as attributes for module: %s", k, v) + m_attrs = {**m_attrs, **v} + + self.add_mock_module(k, m_val, mock_attrs=m_attrs, mock_modules=m_modules) + + @classmethod + def _duplicate_cls(cls, name=None, qualname=None, module=None, **kwargs) -> Type["Mocker"]: + return _create_mocker_copy(name=name, qualname=qualname, module=module, **kwargs) + + def _duplicate_ins(self, name=None, qualname=None, module=None, **kwargs) -> "Mocker": + mkr = _create_mocker_copy(name=name, qualname=qualname, module=module, **kwargs) + return mkr(modules=self.mock_modules, attributes=self.mock_attrs) + + def __getattr__(self, item): try: - return super().__getattribute__(item) + return object.__getattribute__(self, item) except AttributeError: pass try: - if item in super().__getattribute__('mock_modules'): + if item in object.__getattribute__(self, 'mock_modules'): return self.mock_modules[item] except AttributeError: pass try: - if item in super().__getattribute__('mock_attrs'): + if item in object.__getattribute__(self, 'mock_attrs'): return self.mock_attrs[item] except AttributeError: pass @@ -405,13 +852,13 @@ def __getattribute__(self, item): def __setattr__(self, key, value): if key in ['mock_attrs', 'mock_modules']: - return super().__setattr__(key, value) - m = super().__getattribute__('mock_attrs') + return object.__setattr__(self, key, value) + m = object.__getattribute__(self, 'mock_attrs') m[key] = value def __getitem__(self, item): try: - return self.__getattribute__(item) + return self.__getattr__(item) except AttributeError as ex: raise KeyError(str(ex)) @@ -425,6 +872,32 @@ def __setitem__(self, key, value): def __name__(self): return self.__class__.__name__ + def __dir__(self) -> Iterable[str]: + base_attrs = list(object.__dir__(self)) + extra_attrs = list(self.mock_attrs.keys()) + list(self.mock_modules.keys()) + return base_attrs + extra_attrs + + +def _module_dir(): + import collections + col_dir = dirname(abspath(collections.__file__)) + return dirname(col_dir) + + +dataclasses_mock = Mocker.make_mock_module( + 'dataclasses', + attributes=dict( + dataclass=_mock_decorator, + asdict=lambda obj, dict_factory=dict: dict_factory(obj), + astuple=lambda obj, tuple_factory=tuple: tuple_factory(obj), + is_dataclass=lambda obj: False, + field=lambda *args, **kwargs: kwargs.get('default', kwargs.get('default_factory', lambda: None)()), + ), built_in=True +) +""" +This is a :class:`.Mocker` instance which somewhat emulates the Python 3.7+ :mod:`dataclasses` module, +including the :func:`dataclasses.dataclass` decorator. +""" try: # noinspection PyCompatibility @@ -441,10 +914,10 @@ def __name__(self): # To avoid a severe syntax error caused by the missing dataclass types, we generate a dummy dataclasses module, along with a # dummy dataclass and field class so that type annotations such as Type[dataclass] don't cause the module to throw a syntax error. # noinspection PyTypeHints - dataclasses = Mocker() + dataclasses = dataclasses_mock # noinspection PyTypeHints - dataclass = Mocker.make_mock_class(name='dataclass', instance=False) - field = Mocker.make_mock_class(name='field', instance=False) + dataclass = dataclasses.dataclass + field = dataclasses.field class DictObject(dict): @@ -498,6 +971,9 @@ def __setattr__(self, key, value): except KeyError as ex: raise AttributeError(str(ex)) + def __dir__(self) -> Iterable[str]: + return list(dict.__dir__(self)) + list(self.keys()) + class OrderedDictObject(OrderedDict): """ @@ -522,6 +998,9 @@ def __setattr__(self, key, value): except KeyError as ex: raise AttributeError(str(ex)) + def __dir__(self) -> Iterable[str]: + return list(OrderedDict.__dir__(self)) + list(self.keys()) + class MockDictObj(DictObject): """ @@ -540,135 +1019,6 @@ class MockDictObj(DictObject): MockDictObj.__module__ = 'builtins' -def _q_copy(obj: K, key: str = None, deep_private: bool = False, quiet: bool = False, fail: bool = False, **kwargs) -> K: - # By default, deep_private is false, which means we avoid deep copying any attributes which have keys starting with __ - # This is because they're usually objects/types we simply can't deepcopy without issues. - if not deep_private and key.startswith('__'): - log.debug("Not deep copying key '%s' as deep_private is true and key begins with __", key) - return obj - - try: - copied = copy.deepcopy(obj) - return copied - except Exception as ex: - if fail: - raise ex - log_args = "Exception while deep copying object %s ( %s ) - using normal ref. Ex: %s %s", key, obj, type(ex), str(ex) - if quiet: - log.debug(*log_args) - else: - log.warning(*log_args) - return obj - - -def _copy_class_dict(obj: Type[T], name, deep_copy=True, deep_private=False, **kwargs) -> Union[Type[T], type]: - """ - Internal function used by :func:`.copy_class` - - Make a deep copy of the :class:`.type` / class ``obj`` (for standard classes, which use ``__dict__``) - """ - orig_dict, filt_dict = dict(obj.__dict__), {} - - # Try to deep copy each attribute's value from the original __dict__ in 'orig_dict' into 'filt_dict'. - # Attributes that are private (start with '__') will use standard references by default (if deep_private is False), - # along with any attributes that fail to be deep copied. - for k, v in orig_dict.items(): - if not deep_copy: break - filt_dict[k] = _q_copy(v, key=k, deep_private=deep_private, **kwargs) - - bases = kwargs.get('bases', obj.__bases__ if kwargs.get('use_bases', True) else (object,)) - return type(name, bases, filt_dict if deep_copy else orig_dict) - - -def _copy_class_slotted(obj: Type[T], name, deep_copy=True, deep_private=False, **kwargs) -> Union[Type[T], type]: - """ - Internal function used by :func:`.copy_class` - - Make a deep copy of the :class:`.type` / class ``obj`` (for slotted classes, i.e. those which use ``__slots__``) - - Based on a StackOverflow answer by user ``nkpro``: https://stackoverflow.com/a/61823543/2648583 - """ - slots = obj.__slots__ if type(obj.__slots__) != str else (obj.__slots__,) - orig_dict, slotted_members = {}, {} - - for k, v in obj.__dict__.items(): - dcval = v if not deep_copy else _q_copy(v, k, deep_private=deep_private, **kwargs) - if k not in slots: - orig_dict[k] = dcval - elif type(v) != MemberDescriptorType: - slotted_members[k] = dcval - - bases = kwargs.get('bases', obj.__bases__ if kwargs.get('use_bases', True) else (object,)) - new_obj = type(name, bases, orig_dict) - for k, v in slotted_members.items(): - setattr(new_obj, k, v) - - return new_obj - - -def copy_class(obj: Type[T], name=None, deep_copy=True, deep_private=False, **kwargs) -> Union[Type[T], type]: - """ - Attempts to create a full copy of a :class:`.type` or class, severing most object pointers such as attributes containing a - :class:`.dict` / :class:`.list`, along with classes or instances of classes. - - Example:: - - >>> class SomeClass: - >>> example = 'lorem ipsum' - >>> data = ['hello', 'world'] - >>> testing = 123 - >>> - >>> from privex.helpers import copy_class - >>> OtherClass = copy_class(SomeClass, name='OtherClass') - - If you then append to the :class:`.list` attribute ``data`` on both SomeClass and OtherClass - with a different item - appended to each class, you'll see that the added item was only added to ``data`` for that class, and not to the other class, - proving the original and the copy are independent from each other:: - - >>> SomeClass.data.append('lorem') - >>> OtherClass.data.append('ipsum') - >>> SomeClass.data - ['hello', 'world', 'lorem'] - >>> OtherClass.data - ['hello', 'world', 'ipsum'] - - - :param Type[T] obj: A :class:`.type` / class to attempt to duplicate, deep copying each individual object in the class, to - avoid any object pointers shared between the original and the copy. - :param str|None name: The class name to use for the copy of ``obj``. If not specified, defaults to the original class name from ``obj`` - :param bool deep_copy: (Default: ``True``) If True, uses :func:`copy.deepcopy` to deep copy each attribute in ``obj`` to the copy. - If False, then standard references will be used, which may result in object pointers being copied. - :param bool deep_private: (Default: ``False``) If True, :func:`copy.deepcopy` will be used on "private" class attributes, - i.e. ones that start with ``__``. If False, attributes starting with ``__`` will not be deep copied, - only a standard assignment/reference will be used. - :param kwargs: Additional advanced settings (see ``keyword`` pydoc entries for this function) - - :keyword bool use_bases: (Default: ``True``) If True, copy the inheritance (bases) from ``obj`` into the class copy. - :keyword bool quiet: (Default ``False``) If True, log deep copy errors as ``debug`` level (usually silent in production apps) - instead of the louder ``warning``. - :keyword tuple bases: A :class:`.tuple` of classes to use as "bases" (inheritance) for the class copy. If not specified, - copies ``__bases__`` from the original class. - :keyword str module: If specified, overrides the module ``__module__`` in the class copy with this string, instead of copying from - the original class. - :return Type[T] obj_copy: A deep copy of the original ``obj`` - """ - # If no class name was passed as an attribute, then we copy the name from the original class. - if not name: - name = obj.__name__ - - # Depending on whether 'obj' is a normal class using __dict__, or a slotted class using __slots__, we need to handle the - # deep copying of the class differently. - if hasattr(obj, '__slots__'): - new_obj = _copy_class_slotted(obj, name=name, deep_copy=deep_copy, deep_private=deep_private, **kwargs) - else: - new_obj = _copy_class_dict(obj, name=name, deep_copy=deep_copy, deep_private=deep_private, **kwargs) - - # Override the module path string if the user specified 'module' as a kwarg - module = kwargs.get('module') - if module is not None: - new_obj.__module__ = module - return new_obj - def is_namedtuple(*objs) -> bool: """ diff --git a/privex/helpers/common.py b/privex/helpers/common.py index bc5291d..095e4f9 100644 --- a/privex/helpers/common.py +++ b/privex/helpers/common.py @@ -36,7 +36,7 @@ from decimal import Decimal, getcontext from os import getenv as env from subprocess import PIPE, STDOUT -from typing import Sequence, List, Union, Tuple, Type, Dict, Any, Iterable, Optional, BinaryIO, Generator, Mapping +from typing import Callable, Sequence, List, Union, Tuple, Type, Dict, Any, Iterable, Optional, BinaryIO, Generator, Mapping from privex.helpers import settings from privex.helpers.collections import DictObject, OrderedDictObject from privex.helpers.types import T, K, V, C, USE_ORIG_VAR, STRBYTES, NumberStr @@ -1840,4 +1840,16 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> Any: return await self.aexit(exc_type, exc_val, exc_tb) - +def strip_null(value: Union[str, bytes], conv: Callable[[str], Union[str, bytes, T]] = stringify, nullc="\00") -> Union[str, bytes, T]: + """ + Small convenience function which :func:`.stringify`'s ``value`` then strips it of whitespace and null bytes, with + two passes for good measure. + + :param str|bytes value: The value to clean whitespace/null bytes out of + :param callable conv: (Default :func:`.stringify`) Optionally, you can override the casting function used after + the stripping is completed + :param str nullc: (Default: ``\00``) Null characters to remove + :return str|bytes|T cleaned: The cleaned up ``value`` + """ + value = stringify(value).strip().strip(nullc).strip().strip(nullc) + return conv(value) diff --git a/privex/helpers/converters.py b/privex/helpers/converters.py index 7cff4e2..48c192e 100644 --- a/privex/helpers/converters.py +++ b/privex/helpers/converters.py @@ -22,10 +22,59 @@ """ +import warnings from datetime import datetime, date from decimal import Decimal -from typing import Optional, Union, AnyStr +from typing import Any, Optional, Union, AnyStr +from privex.helpers.exceptions import ValidatorNotMatched +from privex.helpers.types import T +from privex.helpers.collections import Mocker + +try: + from privex.helpers.extras.attrs import AttribDictable +except ImportError as e: + warnings.warn(f"Failed to import privex.helpers.extras.attrs.AttribDictable - falling back to placeholder type") + AttribDictable = Mocker.make_mock_class('AttribDictable', instance=False) + +try: + from privex.helpers.collections import Dictable +except ImportError as e: + warnings.warn(f"Failed to import privex.helpers.collections.Dictable - falling back to placeholder type") + Dictable = Mocker.make_mock_class('Dictable', instance=False) + +try: + from privex.helpers.collections import DictDataClass +except ImportError as e: + warnings.warn(f"Failed to import privex.helpers.collections.DictDataClass - falling back to placeholder type") + DictDataClass = Mocker.make_mock_class('DictDataClass', instance=False) + + +try: + import dataclasses +except ImportError as e: + warnings.warn(f"Failed to import dataclasses - falling back to placeholder type") + from privex.helpers.mockers import dataclasses + +try: + import attr + from attr.exceptions import NotAnAttrsClassError +except ImportError as e: + warnings.warn(f"Failed to import attr - falling back to placeholder type") + from privex.helpers.mockers import attr + + + class NotAnAttrsClassError(Exception): + pass + # + # attr = Mocker( + # attributes=dict( + # s=mock_decorator, + # asdict=lambda obj, dict_factory=dict: dict_factory(obj), + # astuple=lambda obj, tuple_factory=tuple: tuple_factory(obj), + # validate=lambda obj: False + # ) + # ) from privex.helpers.common import empty, is_true, stringify import logging @@ -174,9 +223,175 @@ def convert_int_bool(d, if_empty=False, fail_empty=False) -> bool: return is_true(d) +DICT_TYPES = (dict, AttribDictable, Dictable, DictDataClass) + +FLOAT_TYPES = (float, Decimal) +INTEGER_TYPES = (int,) +NUMBER_TYPES = FLOAT_TYPES + INTEGER_TYPES + +LIST_TYPES = (list, set, tuple) + +SIMPLE_TYPES = Union[list, dict, str, float, int] +SIMPLE_TYPES_TUPLE = (list, dict, str, float, int) + + +def _clean_attrs_matcher(ob): + try: + attr.validate(ob) + return True + # return clean_dict(attr.asdict(ob)) + except NotAnAttrsClassError: + return False + + +_clean_floats = lambda ob, number_str=False, **kwargs: str(ob) if number_str else float(ob) +_clean_ints = lambda ob, number_str=False, **kwargs: str(ob) if number_str else int(ob) + + +def _clean_strs(ob, **kwargs): + try: + return stringify(ob) + except Exception: + return str(repr(ob)) + + +def clean_obj(ob: Any, number_str: bool = False, fail=False, fallback: T = None) -> Union[SIMPLE_TYPES, T]: + """ + Cleans an object by converting it / it's contents into basic, simple, JSON-compatible types. + + For example, :class:`.Decimal`'s will become :class:`.float`'s (or :class:`str`'s if ``number_str=True``), + :class:`bytes` will be decoded into a :class:`str` if possible, + :param Any ob: An object to clean - making it safe for use with JSON/YAML etc. + :param bool number_str: (Default: ``False``) When set to ``True``, numbers will be converted to strings instead of int/float. + :param bool fail: (Default: ``False``) When set to ``True``, will raise the exception thrown by the fallback converter + if an error occurs, instead of returning ``fallback`` + :param Any fallback: (Default: ``None``) The value to return if all matchers/converters fail to handle the object, + only used when ``fail=False`` (the default) + + :return SIMPLE_TYPES|T res: A clean version of the object for serialisation - or ``fallback`` if something went wrong. + """ + # if isinstance(ob, FLOAT_TYPES): return str(ob) if number_str else float(ob) + # if isinstance(ob, INTEGER_TYPES): return str(ob) if number_str else int(ob) + # if isinstance(ob, NUMBER_TYPES): return str(ob) if number_str else float(ob) + # + # if isinstance(ob, (str, bytes)): + # try: + # return stringify(ob) + # except Exception: + # return str(repr(ob)) + # + # if isinstance(ob, DICT_TYPES): return clean_dict(dict(ob)) + # if isinstance(ob, LIST_TYPES): return clean_list(list(ob)) + # if dataclasses.is_dataclass(ob): return dataclasses.asdict(ob) + + matched = False + for matcher, convt in CLEAN_OBJ_VALIDATORS.items(): + try: + log.debug("Checking matcher: %s - against object: %s", matcher, ob) + if isinstance(matcher, (list, set)): matcher = tuple(matcher) + if isinstance(matcher, tuple): + if not isinstance(ob, matcher): continue + matched = True + if not matched and callable(matcher): + if not matcher(ob): continue + matched = True + if not matched: + if type(ob) is not type(matcher): continue + matched = True + log.debug("Matched %s has matched against object. Running converter. Object is: %s", matcher, ob) + + res = convt(ob, number_str=number_str) + return res + except ValidatorNotMatched: + log.info("Matcher %s raised ValidatorNotMatched for object '%s' - continuing.", matcher, ob) + continue + except Exception as e: + log.error("Matcher %s raised %s for object '%s' - continuing. Message was: %s", matcher, type(e), ob, str(e)) + continue + + log.warning( + "All %s matchers failed to match against object '%s' - using fallback converter: %s", + len(CLEAN_OBJ_VALIDATORS), ob, CLEAN_OBJ_FALLBACK + ) + try: + res = CLEAN_OBJ_FALLBACK(ob, number_str=number_str) + return res + except Exception as e: + log.exception("Fallback matcher failed to convert object '%s' ...", ob) + if fail: + raise e + return fallback + + +def clean_list(ld: list, **kwargs) -> list: + ld = list(ld) + nl = [] + for d in ld: + try: + x = clean_obj(d, **kwargs) + nl.append(x) + # if isinstance(d, (int, float, str)): + # nl.append(d) + # continue + # if isinstance(d, LIST_TYPES): + # nl.append(clean_list(list(d))) + # continue + # if isinstance(d, LIST_TYPES): + # nl.append(clean_list(list(d))) + # continue + # if isinstance(d, DICT_TYPES): + # nl.append(clean_dict(dict(d))) + # continue + # nl.append(str(d)) + except Exception: + log.exception("Error while cleaning list item: %s", d) + return nl + + +def clean_dict(data: dict, **kwargs) -> dict: + data = dict(data) + cleaned = {} + for k, v in data.items(): + try: + n = clean_obj(v, **kwargs) + cleaned[k] = n + # if isinstance(v, (dict, AttribDictable)): + # n = clean_dict(dict(v)) + # cleaned[k] = n + # continue + # if isinstance(v, list): + # n = clean_list(list(v)) + # cleaned[k] = n + # continue + # if isinstance(v, (int, float, str)): + # cleaned[k] = v + # continue + # cleaned[k] = str(v) + except Exception: + log.exception("Error while cleaning dict item: %s = %s", k, v) + + return cleaned + + +CLEAN_OBJ_VALIDATORS = { + FLOAT_TYPES: _clean_floats, + INTEGER_TYPES: _clean_ints, + NUMBER_TYPES: _clean_floats, + (str, bytes): _clean_strs, + DICT_TYPES: clean_dict, + LIST_TYPES: clean_list, + _clean_attrs_matcher: lambda ob, **kwargs: clean_obj(attr.asdict(ob), **kwargs), + dataclasses.is_dataclass: lambda ob, **kwargs: clean_obj(dataclasses.asdict(ob), **kwargs), +} + +CLEAN_OBJ_FALLBACK = lambda ob, **kwargs: str(ob) + __all__ = [ 'convert_datetime', 'convert_unixtime_datetime', 'convert_bool_int', 'convert_int_bool', 'parse_date', 'parse_datetime', 'parse_epoch', 'parse_unixtime', 'convert_epoch_datetime', + 'DICT_TYPES', 'FLOAT_TYPES', 'INTEGER_TYPES', 'NUMBER_TYPES', 'LIST_TYPES', + 'clean_obj', 'clean_list', 'clean_dict', 'CLEAN_OBJ_FALLBACK', 'CLEAN_OBJ_VALIDATORS', 'MINUTE', 'HOUR', 'DAY', 'MONTH', 'YEAR', 'DECADE', + ] diff --git a/privex/helpers/exceptions.py b/privex/helpers/exceptions.py index 76086c9..e7ea5d4 100644 --- a/privex/helpers/exceptions.py +++ b/privex/helpers/exceptions.py @@ -170,3 +170,7 @@ class EventWaitTimeout(PrivexException): Raised when a timeout has been reached while waiting for an event (:class:`threading.Event`) to be signalled. """ + +class ValidatorNotMatched(PrivexException): + pass + diff --git a/privex/helpers/mockers.py b/privex/helpers/mockers.py new file mode 100644 index 0000000..ea19be8 --- /dev/null +++ b/privex/helpers/mockers.py @@ -0,0 +1,40 @@ +import warnings + +from privex.helpers.decorators import mock_decorator +from privex.helpers.collections import Mocker, dataclasses_mock + +module = Mocker.make_mock_class('module') + + +def mkclass(name: str = 'module', instance: bool = True, **kwargs): + return Mocker.make_mock_class(name, instance=instance, **kwargs) + + +def mkmodule(mod_name: str, attributes: dict = None, modules: dict = None, **kwargs): + return Mocker.make_mock_module(mod_name, attributes, modules, **kwargs) + + +dataclasses = dataclasses_mock +dataclass, field = dataclasses.dataclass, dataclasses.field + +pytest = mkmodule( + 'pytest', + dict( + skip=lambda msg, allow_module_level=True: warnings.warn(msg), + mark=Mocker.make_mock_class( + '_pytest.mark.structures.MarkGenerator', + attributes=dict(skip=mock_decorator, skipif=mock_decorator()) + ) + ) +) + +attr = Mocker( + attributes=dict( + s=mock_decorator, + asdict=lambda obj, dict_factory=dict: dict_factory(obj), + astuple=lambda obj, tuple_factory=tuple: tuple_factory(obj), + validate=lambda obj: False + ) +) + + diff --git a/privex/helpers/net/__init__.py b/privex/helpers/net/__init__.py new file mode 100644 index 0000000..f667dc8 --- /dev/null +++ b/privex/helpers/net/__init__.py @@ -0,0 +1,30 @@ +""" +Network related helper code + +**Copyright**:: + + +===================================================+ + | © 2020 Privex Inc. | + | https://www.privex.io | + +===================================================+ + | | + | Originally Developed by Privex Inc. | + | License: X11 / MIT | + | | + | Core Developer(s): | + | | + | (+) Chris (@someguy123) [Privex] | + | (+) Kale (@kryogenic) [Privex] | + | | + +===================================================+ + + Copyright 2019 Privex Inc. ( https://www.privex.io ) + +""" + +from privex.helpers.exceptions import BoundaryException, NetworkUnreachable + +from privex.helpers.net.dns import * +from privex.helpers.net.util import * +from privex.helpers.net.common import * +from privex.helpers.net.socket import * diff --git a/privex/helpers/net/base.py b/privex/helpers/net/base.py new file mode 100644 index 0000000..e04f275 --- /dev/null +++ b/privex/helpers/net/base.py @@ -0,0 +1,197 @@ +import asyncio +import socket +import ssl + +from privex.helpers import settings +from privex.helpers.common import byteify, empty + +from privex.helpers.net.util import generate_http_request, get_ssl_context, ip_is_v6, sock_ver + +from privex.helpers.types import AnyNum, IP_OR_STR +from privex.helpers.net.dns import resolve_ip, resolve_ip_async +import logging + +log = logging.getLogger(__name__) + + +def _ssl_context(kwargs, ssl_params=None): + if not ssl_params: + ssl_params = kwargs.pop( + 'ssl_params', dict(verify_cert=settings.SSL_VERIFY_CERT, check_hostname=settings.SSL_VERIFY_HOSTNAME) + ) + return get_ssl_context(**ssl_params) + + +def _wrap_socket(s: socket.socket, kwargs: dict, host=None, wrap_params=None, ssl_params=None, ) -> ssl.SSLSocket: + if not wrap_params: + wrap_params = kwargs.pop('wrap_params', dict( + server_hostname=kwargs.get('server_hostname', host), + session=kwargs.get('session'), + do_handshake_on_connect=kwargs.get('do_handshake_on_connect', True) + )) + ctx = _ssl_context(kwargs, ssl_params) + return ctx.wrap_socket(s, **wrap_params) + + +def check_host(host: IP_OR_STR, port: AnyNum, version='any', throw=False, **kwargs) -> bool: + """ + Test if the service on port ``port`` for host ``host`` is working. AsyncIO version: :func:`.check_host_async` + + Basic usage (services which send the client data immediately after connecting):: + + >>> check_host('hiveseed-se.privex.io', 2001) + True + >>> check_host('hiveseed-se.privex.io', 9991) + False + + For some services, such as HTTP - it's necessary to transmit some data to the host before it will + send a response. Using the ``send`` kwarg, you can transmit an arbitrary string/bytes upon connection. + + Sending data to ``host`` after connecting:: + + >>> check_host('files.privex.io', 80, send=b"GET / HTTP/1.1\\n\\n") + True + + + :param str|IPv4Address|IPv6Address host: Hostname or IP to test + :param int|str port: Port number on ``host`` to connect to + :param str|int version: When connecting to a hostname, this can be set to ``'v4'``, ``'v6'`` or similar + to ensure the connection is via that IP version + + :param bool throw: (default: ``False``) When ``True``, will raise exceptions instead of returning ``False`` + :param kwargs: Additional configuration options (see below) + + :keyword int receive: (default: ``100``) Amount of bytes to attempt to receive from the server (``0`` to disable) + :keyword bytes|str send: If ``send`` is specified, the data in ``send`` will be transmitted to the server before receiving. + :keyword int stype: Socket type, e.g. :attr:`socket.SOCK_STREAM` + + :keyword float|int timeout: Socket timeout. If not passed, uses the default from :func:`socket.getdefaulttimeout`. + If the global default timeout is ``None``, then falls back to ``5.0`` + + :raises socket.timeout: When ``throw=True`` and a timeout occurs. + :raises socket.gaierror: When ``throw=True`` and various errors occur + :raises ConnectionRefusedError: When ``throw=True`` and the connection was refused + :raises ConnectionResetError: When ``throw=True`` and the connection was reset + + :return bool success: ``True`` if successfully connected + sent/received data. Otherwise ``False``. + """ + kwargs = dict(kwargs) + receive, stype = int(kwargs.get('receive', 100)), kwargs.get('stype', socket.SOCK_STREAM) + timeout, send, use_ssl = kwargs.get('timeout', 'n/a'), kwargs.get('send'), kwargs.get('ssl', kwargs.get('use_ssl')) + http_test, hostname = kwargs.get('http_test', False), kwargs.get('hostname', host) + + # ssl_params = kwargs.get('ssl_params', dict(verify_cert=False, check_hostname=False)) + if timeout == 'n/a': + t = socket.getdefaulttimeout() + timeout = settings.DEFAULT_SOCKET_TIMEOUT if not t else t + if http_test: + send = generate_http_request(url=kwargs.get('url', '/'), host=hostname) + try: + s_ver = socket.AF_INET + ip = resolve_ip(host, version) + + if ip_is_v6(ip): s_ver = socket.AF_INET6 + + if port == 443 and use_ssl is None: + log.warning("check_host: automatically setting use_ssl=True as port is 443 and use_ssl was not specified.") + use_ssl = True + with socket.socket(s_ver, stype) as s: + if use_ssl: s = _wrap_socket(s, kwargs, host) + if timeout: s.settimeout(float(timeout)) + + s.connect((ip, int(port))) + if not empty(send): + s.sendall(byteify(send)) + if receive > 0: + s.recv(int(receive)) + if use_ssl: + s.close() + return True + except (socket.timeout, TimeoutError, ConnectionRefusedError, ConnectionResetError, socket.gaierror) as e: + if throw: + raise e + return False + + +async def check_host_async(host: IP_OR_STR, port: AnyNum, version='any', throw=False, **kwargs) -> bool: + """ + AsyncIO version of :func:`.check_host`. Test if the service on port ``port`` for host ``host`` is working. + + Basic usage (services which send the client data immediately after connecting):: + + >>> await check_host_async('hiveseed-se.privex.io', 2001) + True + >>> await check_host_async('hiveseed-se.privex.io', 9991) + False + + For some services, such as HTTP - it's necessary to transmit some data to the host before it will + send a response. Using the ``send`` kwarg, you can transmit an arbitrary string/bytes upon connection. + + Sending data to ``host`` after connecting:: + + >>> await check_host_async('files.privex.io', 80, send=b"GET / HTTP/1.1\\n\\n") + True + + + :param str|IPv4Address|IPv6Address host: Hostname or IP to test + :param int|str port: Port number on ``host`` to connect to + :param str|int version: When connecting to a hostname, this can be set to ``'v4'``, ``'v6'`` or similar + to ensure the connection is via that IP version + + :param bool throw: (default: ``False``) When ``True``, will raise exceptions instead of returning ``False`` + :param kwargs: Additional configuration options (see below) + + :keyword int receive: (default: ``100``) Amount of bytes to attempt to receive from the server (``0`` to disable) + :keyword bytes|str send: If ``send`` is specified, the data in ``send`` will be transmitted to the server before receiving. + :keyword int stype: Socket type, e.g. :attr:`socket.SOCK_STREAM` + + :keyword float|int timeout: Socket timeout. If not passed, uses the default from :func:`socket.getdefaulttimeout`. + If the global default timeout is ``None``, then falls back to ``5.0`` + + :raises socket.timeout: When ``throw=True`` and a timeout occurs. + :raises socket.gaierror: When ``throw=True`` and various errors occur + :raises ConnectionRefusedError: When ``throw=True`` and the connection was refused + :raises ConnectionResetError: When ``throw=True`` and the connection was reset + + :return bool success: ``True`` if successfully connected + sent/received data. Otherwise ``False``. + """ + kwargs = dict(kwargs) + receive, stype = int(kwargs.get('receive', 100)), kwargs.get('stype', socket.SOCK_STREAM) + timeout, send, use_ssl = kwargs.get('timeout', 'n/a'), kwargs.get('send'), kwargs.get('ssl', kwargs.get('use_ssl')) + http_test, hostname = kwargs.get('http_test', False), kwargs.get('hostname', host) + + if timeout == 'n/a': + t = socket.getdefaulttimeout() + timeout = 10.0 if not t else t + + loop = asyncio.get_event_loop() + if http_test: + send = generate_http_request(url=kwargs.get('url', '/'), host=hostname) + if port == 443 and use_ssl is None: + log.warning("check_host_async: automatically setting use_ssl=True as port is 443 and use_ssl was not specified.") + use_ssl = True + try: + if sock_ver(version) is None: + s_ver = socket.AF_INET + host = await resolve_ip_async(host, version) + if ip_is_v6(host): s_ver = socket.AF_INET6 + else: + s_ver = sock_ver(version) + + with socket.socket(s_ver, stype) as s: + if use_ssl: s = _wrap_socket(s, kwargs, host) + if timeout: + s.settimeout(float(timeout)) + await asyncio.wait_for(loop.sock_connect(s, (host, int(port))), timeout) + else: + await loop.sock_connect(s, (host, int(port))) + + if not empty(send): + await loop.sock_sendall(s, byteify(send)) + if receive > 0: + await loop.sock_recv(s, int(receive)) + return True + except (socket.timeout, TimeoutError, ConnectionRefusedError, ConnectionResetError, socket.gaierror) as e: + if throw: + raise e + return False diff --git a/privex/helpers/net/common.py b/privex/helpers/net/common.py new file mode 100644 index 0000000..5f21d75 --- /dev/null +++ b/privex/helpers/net/common.py @@ -0,0 +1,513 @@ +""" +General uncategorised functions/classes for network related helper code + +**Copyright**:: + + +===================================================+ + | © 2019 Privex Inc. | + | https://www.privex.io | + +===================================================+ + | | + | Originally Developed by Privex Inc. | + | License: X11 / MIT | + | | + | Core Developer(s): | + | | + | (+) Chris (@someguy123) [Privex] | + | (+) Kale (@kryogenic) [Privex] | + | | + +===================================================+ + + Copyright 2019 Privex Inc. ( https://www.privex.io ) + +""" +import asyncio +import logging +import random +import socket +from datetime import datetime +from math import ceil +from typing import List, Tuple + +from privex.helpers.decorators import r_cache, r_cache_async + +from privex.helpers import settings +from privex.helpers.common import byteify, empty, empty_if, is_true +from privex.helpers.asyncx import run_coro_thread_async +from privex.helpers.net import base as netbase +from privex.helpers.net.dns import resolve_ip, resolve_ip_async +from privex.helpers.net.socket import AsyncSocketWrapper +from privex.helpers.net.util import get_ssl_context, ip_is_v6 +from privex.helpers.types import AUTO, AnyNum, IP_OR_STR + +log = logging.getLogger(__name__) + +__all__ = [ + 'check_host', 'check_host_async', 'check_host_http', 'check_host_http_async', 'test_hosts_async', + 'test_hosts', 'check_v4', 'check_v6', 'check_v4_async', 'check_v6_async' +] + + +def check_host(host: IP_OR_STR, port: AnyNum, version='any', throw=False, **kwargs) -> bool: + """ + Test if the service on port ``port`` for host ``host`` is working. AsyncIO version: :func:`.check_host_async` + + Basic usage (services which send the client data immediately after connecting):: + + >>> check_host('hiveseed-se.privex.io', 2001) + True + >>> check_host('hiveseed-se.privex.io', 9991) + False + + For some services, such as HTTP - it's necessary to transmit some data to the host before it will + send a response. Using the ``send`` kwarg, you can transmit an arbitrary string/bytes upon connection. + + Sending data to ``host`` after connecting:: + + >>> check_host('files.privex.io', 80, send=b"GET / HTTP/1.1\\n\\n") + True + + + :param str|IPv4Address|IPv6Address host: Hostname or IP to test + :param int|str port: Port number on ``host`` to connect to + :param str|int version: When connecting to a hostname, this can be set to ``'v4'``, ``'v6'`` or similar + to ensure the connection is via that IP version + + :param bool throw: (default: ``False``) When ``True``, will raise exceptions instead of returning ``False`` + :param kwargs: Additional configuration options (see below) + + :keyword int receive: (default: ``100``) Amount of bytes to attempt to receive from the server (``0`` to disable) + :keyword bytes|str send: If ``send`` is specified, the data in ``send`` will be transmitted to the server before receiving. + :keyword int stype: Socket type, e.g. :attr:`socket.SOCK_STREAM` + + :keyword float|int timeout: Socket timeout. If not passed, uses the default from :func:`socket.getdefaulttimeout`. + If the global default timeout is ``None``, then falls back to ``5.0`` + + :raises socket.timeout: When ``throw=True`` and a timeout occurs. + :raises socket.gaierror: When ``throw=True`` and various errors occur + :raises ConnectionRefusedError: When ``throw=True`` and the connection was refused + :raises ConnectionResetError: When ``throw=True`` and the connection was reset + + :return bool success: ``True`` if successfully connected + sent/received data. Otherwise ``False``. + """ + receive, stype = int(kwargs.get('receive', 100)), kwargs.get('stype', socket.SOCK_STREAM) + timeout, send, use_ssl = kwargs.get('timeout', 'n/a'), kwargs.get('send'), kwargs.get('ssl', kwargs.get('use_ssl')) + ssl_params = kwargs.get('ssl_params', dict(verify_cert=False, check_hostname=False)) + if timeout == 'n/a': + t = socket.getdefaulttimeout() + timeout = 10.0 if not t else t + + try: + s_ver = socket.AF_INET + ip = resolve_ip(host, version) + + if ip_is_v6(ip): s_ver = socket.AF_INET6 + + if port == 443 and use_ssl is None: + log.warning("check_host: automatically setting use_ssl=True as port is 443 and use_ssl was not specified.") + use_ssl = True + with socket.socket(s_ver, stype) as s: + orig_sock = s + if timeout: s.settimeout(float(timeout)) + if use_ssl: + ctx = get_ssl_context(**ssl_params) + s = ctx.wrap_socket( + s, + server_hostname=kwargs.get('server_hostname'), + session=kwargs.get('session'), + do_handshake_on_connect=kwargs.get('do_handshake_on_connect', True), + ) + + s.connect((ip, int(port))) + if not empty(send): + s.sendall(byteify(send)) + if receive > 0: + s.recv(int(receive)) + if use_ssl: + s.close() + return True + except (socket.timeout, TimeoutError, ConnectionRefusedError, ConnectionResetError, socket.gaierror) as e: + if throw: + raise e + return False + + +async def check_host_async(host: IP_OR_STR, port: AnyNum, version='any', throw=False, **kwargs) -> bool: + """ + AsyncIO version of :func:`.check_host`. Test if the service on port ``port`` for host ``host`` is working. + + Basic usage (services which send the client data immediately after connecting):: + + >>> await check_host_async('hiveseed-se.privex.io', 2001) + True + >>> await check_host_async('hiveseed-se.privex.io', 9991) + False + + For some services, such as HTTP - it's necessary to transmit some data to the host before it will + send a response. Using the ``send`` kwarg, you can transmit an arbitrary string/bytes upon connection. + + Sending data to ``host`` after connecting:: + + >>> await check_host_async('files.privex.io', 80, send=b"GET / HTTP/1.1\\n\\n") + True + + + :param str|IPv4Address|IPv6Address host: Hostname or IP to test + :param int|str port: Port number on ``host`` to connect to + :param str|int version: When connecting to a hostname, this can be set to ``'v4'``, ``'v6'`` or similar + to ensure the connection is via that IP version + + :param bool throw: (default: ``False``) When ``True``, will raise exceptions instead of returning ``False`` + :param kwargs: Additional configuration options (see below) + + :keyword int receive: (default: ``100``) Amount of bytes to attempt to receive from the server (``0`` to disable) + :keyword bytes|str send: If ``send`` is specified, the data in ``send`` will be transmitted to the server before receiving. + :keyword int stype: Socket type, e.g. :attr:`socket.SOCK_STREAM` + + :keyword float|int timeout: Socket timeout. If not passed, uses the default from :func:`socket.getdefaulttimeout`. + If the global default timeout is ``None``, then falls back to ``5.0`` + + :raises socket.timeout: When ``throw=True`` and a timeout occurs. + :raises socket.gaierror: When ``throw=True`` and various errors occur + :raises ConnectionRefusedError: When ``throw=True`` and the connection was refused + :raises ConnectionResetError: When ``throw=True`` and the connection was reset + + :return bool success: ``True`` if successfully connected + sent/received data. Otherwise ``False``. + """ + receive, stype = int(kwargs.get('receive', 16)), kwargs.get('stype', socket.SOCK_STREAM) + timeout, send = kwargs.get('timeout', 'n/a'), kwargs.get('send') + http_test, use_ssl = kwargs.get('http_test', False), kwargs.get('use_ssl', False) + if timeout == 'n/a': + t = socket.getdefaulttimeout() + timeout = settings.DEFAULT_SOCKET_TIMEOUT if not t else t + + # loop = asyncio.get_event_loop() + s_ver = socket.AF_INET + ip = await resolve_ip_async(host, version) + + if ip_is_v6(ip): s_ver = socket.AF_INET6 + + try: + aw = AsyncSocketWrapper(host, int(port), family=s_ver, use_ssl=use_ssl, timeout=timeout) + await aw.connect() + if http_test: + log.info("Sending HTTP request to %s", host) + log.info("Response from %s : %s", host, await aw.http_request()) + + elif not empty(send) and receive > 0: + log.info("Sending query data '%s' and trying to receive data from %s", send, host) + log.info("Response from %s : %s", host, await aw.query(send, receive, read_timeout=kwargs.get('read_timeout', AUTO))) + + elif not empty(send): + log.info("Sending query data '%s' to %s", send, host) + await aw.sendall(send) + else: + log.info("Receiving data from %s", host) + + log.info("Response from %s : %s", host, await aw.read_eof( + receive, strip=False, read_timeout=kwargs.get('read_timeout', AUTO), + )) + + # with socket.socket(s_ver, stype) as s: + # if timeout: s.settimeout(float(timeout)) + # await loop.sock_connect(s, (ip, int(port))) + # if not empty(send): + # await loop.sock_sendall(s, byteify(send)) + # if receive > 0: + # await loop.sock_recv(s, int(receive)) + return True + except (socket.timeout, TimeoutError, ConnectionRefusedError, ConnectionResetError, socket.gaierror) as e: + if throw: + raise e + return False + + +def check_host_http(host: IP_OR_STR, port: AnyNum = 80, version='any', throw=False, **kwargs) -> bool: + return netbase.check_host(host, port, version, throw=throw, http_test=True, **kwargs) + + +async def check_host_http_async( + host: IP_OR_STR, port: AnyNum = 80, version='any', throw=False, send=b"GET / HTTP/1.1\\n\\n", **kwargs + ) -> bool: + # return await check_host_async(host, port, version, throw=throw, send=send, **kwargs) + return await netbase.check_host_async(host, port, version, throw=throw, http_test=True, **kwargs) + + +async def test_hosts_async(hosts: List[str] = None, ipver: str = 'any', timeout: AnyNum = None, **kwargs) -> bool: + randomise = is_true(kwargs.get('randomise', True)) + max_hosts = kwargs.get('max_hosts', settings.NET_CHECK_HOST_COUNT_TRY) + if max_hosts is not None: max_hosts = int(max_hosts) + timeout = empty_if(timeout, empty_if(socket.getdefaulttimeout(), 4, zero=True), zero=True) + + v4h, v6h = list(settings.V4_TEST_HOSTS), list(settings.V6_TEST_HOSTS) + if randomise: random.shuffle(v4h) + if randomise: random.shuffle(v6h) + + if empty(hosts, True, True): + # if empty(ipver, True, True) or ipver in ['any', 'all', 'both', 10, '10', '46', 46]: + # settings.V4_CHECKED_AT + if isinstance(ipver, str): ipver = ipver.lower() + if ipver in [4, '4', 'v4', 'ipv4']: + hosts = v4h + ipver = 4 + elif ipver in [6, '6', 'v6', 'ipv6']: + hosts = v6h + ipver = 6 + else: + ipver = 'any' + if max_hosts: + hosts = v4h[:int(ceil(max_hosts / 2))] + v6h[:int(ceil(max_hosts / 2))] + else: + hosts = v4h + v6h + + if max_hosts: hosts = hosts[:max_hosts] + + # st4_empty = any([empty(settings.HAS_WORKING_V4, True, True), empty(settings.V4_CHECKED_AT, True, True)]) + # st6_empty = any([empty(settings.HAS_WORKING_V6, True, True), empty(settings.V6_CHECKED_AT, True, True)]) + + # if ipver == 6 and not st6_empty and settings.V6_CHECKED_AT > datetime.utcnow(): + # # if settings.V6_CHECKED_AT > datetime.utcnow() + # log.debug("Returning cached IPv6 status: working = %s", settings.HAS_WORKING_V6) + # return settings.HAS_WORKING_V6 + # if ipver == 4 and not st4_empty and settings.V4_CHECKED_AT > datetime.utcnow(): + # # if settings.V6_CHECKED_AT > datetime.utcnow() + # log.debug("Returning cached IPv4 status: working = %s", settings.HAS_WORKING_V4) + # return settings.HAS_WORKING_V4 + # + # if ipver == 'any' and any([not st4_empty, not st6_empty]) and settings.V4_CHECKED_AT > datetime.utcnow(): + # # if settings.V6_CHECKED_AT > datetime.utcnow() + # if st4_empty: + # log.debug("test_hosts being requested for 'any' ip ver. IPv6 status cached, but not IPv4 status. Checking IPv4 status...") + # await check_v4_async() + # if st6_empty: + # log.debug("test_hosts being requested for 'any' ip ver. IPv4 status cached, but not IPv6 status. Checking IPv6 status...") + # await check_v6_async(hosts) + # # if not st4_empty and not st6_empty: + # log.debug( + # "Returning status %s based on: Working IPv4 = %s || Working IPv6 = %s", + # settings.HAS_WORKING_V4 or settings.HAS_WORKING_V6, settings.HAS_WORKING_V4, settings.HAS_WORKING_V6 + # ) + # return settings.HAS_WORKING_V4 or settings.HAS_WORKING_V6 + + # max_hosts = int(kwargs.get('max_hosts', settings.NET_CHECK_HOST_COUNT_TRY)) + min_hosts_pos = int(kwargs.get('required_positive', settings.NET_CHECK_HOST_COUNT)) + + # hosts = empty_if(hosts, settings.V4_TEST_HOSTS, itr=True) + hosts = [x for x in hosts] + + if randomise: random.shuffle(hosts) + + if len(hosts) > max_hosts: hosts = hosts[:max_hosts] + + # port = empty_if(port, 80, zero=True) + + total_hosts = len(hosts) + total_working, total_broken = 0, 0 + working_list, broken_list = [], [] + log.debug("Testing %s hosts with IP version '%s' - timeout: %s", total_hosts, ipver, timeout) + + host_checks = [] + host_checks_hosts = [] + for h in hosts: + # host_checks.append( + # asyncio.create_task(_test_host_async(h, ipver=ipver, timeout=timeout)) + # ) + host_checks.append( + asyncio.create_task( + run_coro_thread_async(_test_host_async, h, ipver=ipver, timeout=timeout) + ) + ) + host_checks_hosts.append(h) + + host_checks_res = await asyncio.gather(*host_checks, return_exceptions=True) + for i, _res in enumerate(host_checks_res): + h = host_checks_hosts[i] + if isinstance(_res, Exception): + log.warning("Exception while checking host %s", h) + total_broken += 1 + continue + + res, h, port = _res + + if res: + total_working += 1 + working_list.append(f"{h}:{port}") + log.debug("check_host for %s (port %s) came back True (WORKING). incremented working hosts: %s", h, port, total_working) + else: + total_broken += 1 + broken_list.append(f"{h}:{port}") + log.debug("check_host for %s (port %s) came back False (! BROKEN !). incremented broken hosts: %s", h, port, total_broken) + + # port = 80 + # for h in hosts: + # try: + # h, port, res = await _test_host_async(h, ipver, timeout) + # if res: + # total_working += 1 + # log.debug("check_host for %s came back true. incremented working hosts: %s", h, total_working) + # else: + # total_broken += 1 + # log.debug("check_host for %s came back false. incremented broken hosts: %s", h, total_broken) + # + # except Exception as e: + # log.warning("Exception while checking host %s port %s", h, port) + + working = total_working >= min_hosts_pos + + log.info("test_hosts - proto: %s - protocol working? %s || total hosts: %s || working hosts: %s || broken hosts: %s", + ipver, working, total_hosts, total_working, total_broken) + log.debug("working hosts: %s", working_list) + log.debug("broken hosts: %s", broken_list) + + return working + + +async def _test_host_async(host, ipver: str = 'any', timeout: AnyNum = None) -> Tuple[bool, str, int]: + nh = host.split(':') + if len(nh) > 1: + port = int(nh[-1]) + host = ':'.join(nh[:-1]) + else: + host = ':'.join(nh) + log.warning("Host is missing port: %s - falling back to port 80") + port = 80 + log.debug("Checking host %s via port %s + IP version '%s'", host, port, ipver) + if port == 80: + res = await check_host_http_async(host, port, ipver, throw=False, timeout=timeout) + elif port == 53: + res = await netbase.check_host_async(host, port, ipver, throw=False, timeout=timeout, send="hello\nworld\n") + else: + res = await netbase.check_host_async(host, port, ipver, throw=False, timeout=timeout) + return res, host, port + + +def test_hosts(hosts: List[str] = None, ipver: str = 'any', timeout: AnyNum = None, **kwargs) -> bool: + randomise = is_true(kwargs.get('randomise', True)) + max_hosts = kwargs.get('max_hosts', settings.NET_CHECK_HOST_COUNT_TRY) + if max_hosts is not None: max_hosts = int(max_hosts) + timeout = empty_if(timeout, empty_if(socket.getdefaulttimeout(), 4, zero=True), zero=True) + + v4h, v6h = list(settings.V4_TEST_HOSTS), list(settings.V6_TEST_HOSTS) + if randomise: random.shuffle(v4h) + if randomise: random.shuffle(v6h) + + if empty(hosts, True, True): + # if empty(ipver, True, True) or ipver in ['any', 'all', 'both', 10, '10', '46', 46]: + # settings.V4_CHECKED_AT + if isinstance(ipver, str): ipver = ipver.lower() + if ipver in [4, '4', 'v4', 'ipv4']: + hosts = v4h + ipver = 4 + elif ipver in [6, '6', 'v6', 'ipv6']: + hosts = v6h + ipver = 6 + else: + ipver = 'any' + if max_hosts: + hosts = v4h[:int(ceil(max_hosts / 2))] + v6h[:int(ceil(max_hosts / 2))] + else: + hosts = v4h + v6h + + if max_hosts: hosts = hosts[:max_hosts] + + # st4_empty = any([empty(settings.HAS_WORKING_V4, True, True), empty(settings.V4_CHECKED_AT, True, True)]) + # st6_empty = any([empty(settings.HAS_WORKING_V6, True, True), empty(settings.V6_CHECKED_AT, True, True)]) + + # if ipver == 6 and not st6_empty and settings.V6_CHECKED_AT > datetime.utcnow(): + # # if settings.V6_CHECKED_AT > datetime.utcnow() + # log.debug("Returning cached IPv6 status: working = %s", settings.HAS_WORKING_V6) + # return settings.HAS_WORKING_V6 + # if ipver == 4 and not st4_empty and settings.V4_CHECKED_AT > datetime.utcnow(): + # # if settings.V6_CHECKED_AT > datetime.utcnow() + # log.debug("Returning cached IPv4 status: working = %s", settings.HAS_WORKING_V4) + # return settings.HAS_WORKING_V4 + + # if ipver == 'any' and any([not st4_empty, not st6_empty]) and settings.V4_CHECKED_AT > datetime.utcnow(): + # # if settings.V6_CHECKED_AT > datetime.utcnow() + # if st4_empty: + # log.debug("test_hosts being requested for 'any' ip ver. IPv6 status cached, but not IPv4 status. Checking IPv4 status...") + # check_v4() + # if st6_empty: + # log.debug("test_hosts being requested for 'any' ip ver. IPv4 status cached, but not IPv6 status. Checking IPv6 status...") + # check_v6() + # # if not st4_empty and not st6_empty: + # log.debug( + # "Returning status %s based on: Working IPv4 = %s || Working IPv6 = %s", + # settings.HAS_WORKING_V4 or settings.HAS_WORKING_V6, settings.HAS_WORKING_V4, settings.HAS_WORKING_V6 + # ) + # return settings.HAS_WORKING_V4 or settings.HAS_WORKING_V6 + + # max_hosts = int(kwargs.get('max_hosts', settings.NET_CHECK_HOST_COUNT_TRY)) + min_hosts_pos = int(kwargs.get('required_positive', settings.NET_CHECK_HOST_COUNT)) + + # hosts = empty_if(hosts, settings.V4_TEST_HOSTS, itr=True) + hosts = [x for x in hosts] + + if randomise: random.shuffle(hosts) + + if len(hosts) > max_hosts: hosts = hosts[:max_hosts] + + + total_hosts = len(hosts) + total_working, total_broken = 0, 0 + + log.debug("Testing %s hosts with IP version '%s' - timeout: %s", total_hosts, ipver, timeout) + port = 80 + + for h in hosts: + try: + nh = h.split(':') + if len(nh) > 1: + port = int(nh[-1]) + h = ':'.join(nh[:-1]) + else: + h = ':'.join(nh) + log.warning("Host is missing port: %s - falling back to port 80") + port = 80 + + log.debug("Checking host %s via port %s + IP version '%s'", h, port, ipver) + + if port == 80: + res = check_host_http(h, port, ipver, throw=False, timeout=timeout) + else: + res = check_host(h, port, ipver, throw=False, timeout=timeout) + if res: + total_working += 1 + log.debug("check_host for %s came back true. incremented working hosts: %s", h, total_working) + else: + total_broken += 1 + log.debug("check_host for %s came back false. incremented broken hosts: %s", h, total_broken) + + except Exception as e: + log.warning("Exception while checking host %s port %s", h, port) + + working = total_working >= min_hosts_pos + + log.info("test_hosts - proto: %s - protocol working? %s || total hosts: %s || working hosts: %s || broken hosts: %s", + ipver, working, total_hosts, total_working, total_broken) + + return working + + +@r_cache("pvxhelpers:check_v4", settings.NET_CHECK_TIMEOUT) +def check_v4(hosts: List[str] = None, *args, **kwargs) -> bool: + """Check and cache whether IPv4 is functional by testing a handful of IPv4 hosts""" + return test_hosts(hosts, ipver='v4', *args, **kwargs) + + +@r_cache("pvxhelpers:check_v6", settings.NET_CHECK_TIMEOUT) +def check_v6(hosts: List[str] = None, *args, **kwargs) -> bool: + """Check and cache whether IPv6 is functional by testing a handful of IPv6 hosts""" + return test_hosts(hosts, ipver='v6', *args, **kwargs) + + +@r_cache_async("pvxhelpers:check_v4", settings.NET_CHECK_TIMEOUT) +async def check_v4_async(hosts: List[str] = None, *args, **kwargs) -> bool: + """(Async ver of :func:`.check_v4`) Check and cache whether IPv4 is functional by testing a handful of IPv4 hosts""" + return await test_hosts_async(hosts, ipver='v4', *args, **kwargs) + + +@r_cache_async("pvxhelpers:check_v6", settings.NET_CHECK_TIMEOUT) +async def check_v6_async(hosts: List[str] = None, *args, **kwargs) -> bool: + """(Async ver of :func:`.check_v6`) Check and cache whether IPv6 is functional by testing a handful of IPv6 hosts""" + return await test_hosts_async(hosts, ipver='v6', *args, **kwargs) diff --git a/privex/helpers/net.py b/privex/helpers/net/dns.py similarity index 73% rename from privex/helpers/net.py rename to privex/helpers/net/dns.py index ee48a08..569e760 100644 --- a/privex/helpers/net.py +++ b/privex/helpers/net/dns.py @@ -1,5 +1,5 @@ """ -Network related helper code +Functions/classes related to hostnames/domains/reverse DNS etc. - network related helper code **Copyright**:: @@ -21,32 +21,33 @@ Copyright 2019 Privex Inc. ( https://www.privex.io ) """ + import asyncio -import logging -import platform -import subprocess import socket +from ipaddress import IPv4Address, IPv6Address, ip_address -from privex.helpers.common import empty_if, empty, byteify +from typing import AsyncGenerator, Generator, List, Optional, Tuple, Union -from privex.helpers.exceptions import BoundaryException, NetworkUnreachable, ReverseDNSNotFound, InvalidHost from privex.helpers import plugin -from ipaddress import ip_address, IPv4Address, IPv6Address -from typing import Union, Optional, List, Dict, Generator, Tuple, AsyncGenerator +from privex.helpers.common import empty +from privex.helpers.exceptions import BoundaryException, InvalidHost, ReverseDNSNotFound +import logging -from privex.helpers.types import IP_OR_STR, AnyNum +from privex.helpers.net.util import is_ip, sock_ver +from privex.helpers.types import IP_OR_STR log = logging.getLogger(__name__) __all__ = [ - 'ip_to_rdns', '_check_boundaries', 'ip4_to_rdns', 'ip6_to_rdns', 'ip_is_v4', 'ip_is_v6', - 'ping', 'resolve_ip', 'resolve_ips', 'resolve_ips_multi', 'resolve_ip_async', 'resolve_ips_async', 'resolve_ips_multi_async', - 'get_rdns', 'get_rdns_multi', 'get_rdns_async', 'check_host', 'check_host_async', 'BoundaryException', 'NetworkUnreachable' + 'ip_to_rdns', 'ip4_to_rdns', 'ip6_to_rdns', 'resolve_ips_async', 'resolve_ip_async', 'resolve_ips_multi_async', + 'resolve_ips', 'resolve_ip', 'resolve_ips_multi', 'get_rdns_async', 'get_rdns', 'get_rdns_multi' ] + try: from dns.resolver import Resolver, NoAnswer, NXDOMAIN + def asn_to_name(as_number: Union[int, str], quiet: bool = True) -> str: """ Look up an integer Autonomous System Number and return the human readable @@ -61,12 +62,12 @@ def asn_to_name(as_number: Union[int, str], quiet: bool = True) -> str: This helper function requires ``dnspython>=1.16.0``, it will not be visible unless you install the dnspython package in your virtualenv, or systemwide:: - + pip3 install dnspython - + :param int/str as_number: The AS number as a string or integer, e.g. 210083 or '210083' - :param bool quiet: (default True) If True, returns 'Unknown ASN' if a lookup fails. + :param bool quiet: (default True) If True, returns 'Unknown ASN' if a lookup fails. If False, raises a KeyError if no results are found. :raises KeyError: Raised when a lookup returns no results, and ``quiet`` is set to False. :return str as_name: The name and country code of the ASN, e.g. 'PRIVEX, SE' @@ -85,9 +86,10 @@ def asn_to_name(as_number: Union[int, str], quiet: bool = True) -> str: return 'Unknown ASN' raise KeyError('ASN {} was not found, or server did not respond.'.format(as_number)) + __all__ += ['asn_to_name'] plugin.HAS_DNSPYTHON = True - + except ImportError: log.debug('privex.helpers.net failed to import "dns.resolver" (pypi package "dnspython"), skipping some helpers') pass @@ -158,7 +160,7 @@ def _check_boundaries(v4_boundary: int, v6_boundary: int): def ip4_to_rdns(ip_obj: IPv4Address, v4_boundary: int = 24, boundary: bool = False) -> str: """ - Internal function for getting the rDNS domain for a given v4 address. Use :py:func:`.ip_to_rdns` unless + Internal function for getting the rDNS domain for a given v4 address. Use :py:func:`.ip_to_rdns` unless you have a specific need for this one. :param IPv4Address ip_obj: An IPv4 ip_address() object to get the rDNS domain for @@ -179,7 +181,7 @@ def ip4_to_rdns(ip_obj: IPv4Address, v4_boundary: int = 24, boundary: bool = Fal def ip6_to_rdns(ip_obj: IPv6Address, v6_boundary: int = 32, boundary: bool = False) -> str: """ - Internal function for getting the rDNS domain for a given v6 address. Use :py:func:`.ip_to_rdns` unless + Internal function for getting the rDNS domain for a given v6 address. Use :py:func:`.ip_to_rdns` unless you have a specific need for this one. :param IPv6Address ip_obj: An IPv4 ip_address() object to get the rDNS domain for @@ -201,82 +203,6 @@ def ip6_to_rdns(ip_obj: IPv6Address, v6_boundary: int = 32, boundary: bool = Fal return addr_joined + '.ip6.arpa' # and finally, return the completed string, a.f.0.0.1.0.0.2.ip6.arpa -def ip_is_v4(ip: str) -> bool: - """ - Determines whether an IP address is IPv4 or not - - :param str ip: An IP address as a string, e.g. 192.168.1.1 - :raises ValueError: When the given IP address ``ip`` is invalid - :return bool: True if IPv6, False if not (i.e. probably IPv4) - """ - return type(ip_address(ip)) == IPv4Address - - -def ip_is_v6(ip: str) -> bool: - """ - Determines whether an IP address is IPv6 or not - - :param str ip: An IP address as a string, e.g. 192.168.1.1 - :raises ValueError: When the given IP address ``ip`` is invalid - :return bool: True if IPv6, False if not (i.e. probably IPv4) - """ - return type(ip_address(ip)) == IPv6Address - - -def ping(ip: str, timeout: int = 30) -> bool: - """ - Sends a ping to a given IPv4 / IPv6 address. Tested with IPv4+IPv6 using ``iputils-ping`` on Linux, as well as the - default IPv4 ``ping`` utility on Mac OSX (Mojave, 10.14.6). - - Fully supported when using Linux with the ``iputils-ping`` package. Only IPv4 support on Mac OSX. - - **Example Usage**:: - - >>> from privex.helpers import ping - >>> if ping('127.0.0.1', 5) and ping('::1', 10): - ... print('Both 127.0.0.1 and ::1 are up') - ... else: - ... print('127.0.0.1 or ::1 failed to respond to a ping within the given timeout.') - - **Known Incompatibilities**: - - * NOT compatible with IPv6 addresses on OSX due to the lack of a timeout argument with ``ping6`` - * NOT compatible with IPv6 addresses when using ``inetutils-ping`` on Linux due to separate ``ping6`` command - - :param str ip: An IP address as a string, e.g. ``192.168.1.1`` or ``2a07:e00::1`` - :param int timeout: (Default: 30) Number of seconds to wait for a response from the ping before timing out - :raises ValueError: When the given IP address ``ip`` is invalid or ``timeout`` < 1 - :return bool: ``True`` if ping got a response from the given IP, ``False`` if not - """ - ip_obj = ip_address(ip) # verify IP is valid (this will throw if it isn't) - if timeout < 1: - raise ValueError('timeout value cannot be less than 1 second') - opts4 = { - 'Linux': ["/bin/ping", "-c1", f"-w{timeout}"], - 'Darwin': ["/sbin/ping", "-c1", f"-t{timeout}"] - } - opts6 = {'Linux': ["/bin/ping", "-c1", f"-w{timeout}"]} - opts = opts4 if ip_is_v4(ip_obj) else opts6 - if platform.system() not in opts: - raise NotImplementedError(f"{__name__}.ping is not fully supported on platform '{platform.system()}'...") - - with subprocess.Popen(opts[platform.system()] + [ip], stdout=subprocess.PIPE, stderr=subprocess.PIPE) as proc: - out, err = proc.communicate() - err = err.decode('utf-8') - if 'network is unreachable' in err.lower(): - raise NetworkUnreachable(f'Got error from ping: "{err}"') - - return 'bytes from {}'.format(ip) in out.decode('utf-8') - - -def _sock_ver(version): - version = empty_if(version, 'any', zero=True, itr=True) - version = version.lower() if isinstance(version, str) and version not in [socket.AF_INET, socket.AF_INET6] else version - if version in [4, 'v4', '4', 'ipv4', 'inet', 'inet4']: version = socket.AF_INET - if version in [6, 'v6', '6', 'ipv6', 'inet6']: version = socket.AF_INET6 - return version - - async def resolve_ips_async(addr: IP_OR_STR, version: Union[str, int] = 'any', v4_convert=False) -> List[str]: """ AsyncIO version of :func:`.resolve_ips_async` - resolves the IPv4/v6 addresses for a given host (``addr``) @@ -303,9 +229,9 @@ async def resolve_ips_async(addr: IP_OR_STR, version: Union[str, int] = 'any', v :return List[str] ips: Zero or more IP addresses in a list of :class:`str`'s """ loop = asyncio.get_event_loop() - addr, version = str(addr), _sock_ver(version) + addr, version = str(addr), sock_ver(version) ips = [] - ip = _is_ip(addr, version) + ip = is_ip(addr, version) if ip: return [str(ip)] try: if version in [socket.AF_INET, socket.AF_INET6]: @@ -394,28 +320,6 @@ async def resolve_ips_multi_async(*addr: IP_OR_STR, version: Union[str, int] = ' yield (a, None) -def _is_ip(addr: str, version: int = None): - try: - res = _sock_validate_ip(addr, version=version) - return res - except AttributeError as e: - raise e - except ValueError: - return False - - -def _sock_validate_ip(addr: IP_OR_STR, version: int, throw=True) -> Optional[Union[IPv4Address, IPv4Address]]: - ip = ip_address(addr) - ver = "v4" if ip_is_v4(ip) else "v6" - if version == socket.AF_INET and ver != 'v4': - if not throw: return None - raise AttributeError(f"Passed address '{addr}' was an IPv6 address, but 'version' requested an IPv4 address.") - if version == socket.AF_INET6 and ver != 'v6': - if not throw: return None - raise AttributeError(f"Passed address '{addr}' was an IPv4 address, but 'version' requested an IPv6 address.") - return ip - - def resolve_ips(addr: IP_OR_STR, version: Union[str, int] = 'any', v4_convert=False) -> List[str]: """ With just a single hostname argument, both IPv4 and IPv6 addresses will be returned as strings:: @@ -486,8 +390,8 @@ def resolve_ips(addr: IP_OR_STR, version: Union[str, int] = 'any', v4_convert=Fa :return List[str] ips: Zero or more IP addresses in a list of :class:`str`'s """ - addr, version, ips = str(addr), _sock_ver(version), [] - ip = _is_ip(addr, version) + addr, version, ips = str(addr), sock_ver(version), [] + ip = is_ip(addr, version) if ip: return [str(ip)] try: if version in [socket.AF_INET, socket.AF_INET6]: @@ -628,7 +532,7 @@ async def get_rdns_async(host: IP_OR_STR, throw=True, version='any', name_port=8 loop = asyncio.get_event_loop() host = str(host) try: - if not _is_ip(host): + if not is_ip(host): orig_host = host host = await resolve_ip_async(host, version=version) if empty(host): @@ -638,7 +542,7 @@ async def get_rdns_async(host: IP_OR_STR, throw=True, version='any', name_port=8 res = await loop.getnameinfo((host, name_port)) rdns = res[0] - if _is_ip(rdns): + if is_ip(rdns): if throw: raise ReverseDNSNotFound(f"No reverse DNS records found for host '{host}' - result was: {rdns}") return None return rdns @@ -745,140 +649,3 @@ def get_rdns_multi(*hosts: IP_OR_STR, throw=False) -> Generator[Tuple[str, Optio """ for h in hosts: yield (str(h), get_rdns(str(h), throw=throw)) - - -def check_host(host: IP_OR_STR, port: AnyNum, version='any', throw=False, **kwargs) -> bool: - """ - Test if the service on port ``port`` for host ``host`` is working. AsyncIO version: :func:`.check_host_async` - - Basic usage (services which send the client data immediately after connecting):: - - >>> check_host('hiveseed-se.privex.io', 2001) - True - >>> check_host('hiveseed-se.privex.io', 9991) - False - - For some services, such as HTTP - it's necessary to transmit some data to the host before it will - send a response. Using the ``send`` kwarg, you can transmit an arbitrary string/bytes upon connection. - - Sending data to ``host`` after connecting:: - - >>> check_host('files.privex.io', 80, send=b"GET / HTTP/1.1\\n\\n") - True - - - :param str|IPv4Address|IPv6Address host: Hostname or IP to test - :param int|str port: Port number on ``host`` to connect to - :param str|int version: When connecting to a hostname, this can be set to ``'v4'``, ``'v6'`` or similar - to ensure the connection is via that IP version - - :param bool throw: (default: ``False``) When ``True``, will raise exceptions instead of returning ``False`` - :param kwargs: Additional configuration options (see below) - - :keyword int receive: (default: ``100``) Amount of bytes to attempt to receive from the server (``0`` to disable) - :keyword bytes|str send: If ``send`` is specified, the data in ``send`` will be transmitted to the server before receiving. - :keyword int stype: Socket type, e.g. :attr:`socket.SOCK_STREAM` - - :keyword float|int timeout: Socket timeout. If not passed, uses the default from :func:`socket.getdefaulttimeout`. - If the global default timeout is ``None``, then falls back to ``5.0`` - - :raises socket.timeout: When ``throw=True`` and a timeout occurs. - :raises socket.gaierror: When ``throw=True`` and various errors occur - :raises ConnectionRefusedError: When ``throw=True`` and the connection was refused - :raises ConnectionResetError: When ``throw=True`` and the connection was reset - - :return bool success: ``True`` if successfully connected + sent/received data. Otherwise ``False``. - """ - receive, stype = int(kwargs.get('receive', 100)), kwargs.get('stype', socket.SOCK_STREAM) - timeout, send = kwargs.get('timeout', 'n/a'), kwargs.get('send') - if timeout == 'n/a': - t = socket.getdefaulttimeout() - timeout = 10.0 if not t else t - - s_ver = socket.AF_INET - ip = resolve_ip(host, version) - - if ip_is_v6(ip): s_ver = socket.AF_INET6 - - try: - with socket.socket(s_ver, stype) as s: - if timeout: s.settimeout(float(timeout)) - s.connect((ip, int(port))) - if not empty(send): - s.sendall(byteify(send)) - if receive > 0: - s.recv(int(receive)) - return True - except (socket.timeout, ConnectionRefusedError, ConnectionResetError, socket.gaierror) as e: - if throw: - raise e - return False - - -async def check_host_async(host: IP_OR_STR, port: AnyNum, version='any', throw=False, **kwargs) -> bool: - """ - AsyncIO version of :func:`.check_host`. Test if the service on port ``port`` for host ``host`` is working. - - Basic usage (services which send the client data immediately after connecting):: - - >>> await check_host_async('hiveseed-se.privex.io', 2001) - True - >>> await check_host_async('hiveseed-se.privex.io', 9991) - False - - For some services, such as HTTP - it's necessary to transmit some data to the host before it will - send a response. Using the ``send`` kwarg, you can transmit an arbitrary string/bytes upon connection. - - Sending data to ``host`` after connecting:: - - >>> await check_host_async('files.privex.io', 80, send=b"GET / HTTP/1.1\\n\\n") - True - - - :param str|IPv4Address|IPv6Address host: Hostname or IP to test - :param int|str port: Port number on ``host`` to connect to - :param str|int version: When connecting to a hostname, this can be set to ``'v4'``, ``'v6'`` or similar - to ensure the connection is via that IP version - - :param bool throw: (default: ``False``) When ``True``, will raise exceptions instead of returning ``False`` - :param kwargs: Additional configuration options (see below) - - :keyword int receive: (default: ``100``) Amount of bytes to attempt to receive from the server (``0`` to disable) - :keyword bytes|str send: If ``send`` is specified, the data in ``send`` will be transmitted to the server before receiving. - :keyword int stype: Socket type, e.g. :attr:`socket.SOCK_STREAM` - - :keyword float|int timeout: Socket timeout. If not passed, uses the default from :func:`socket.getdefaulttimeout`. - If the global default timeout is ``None``, then falls back to ``5.0`` - - :raises socket.timeout: When ``throw=True`` and a timeout occurs. - :raises socket.gaierror: When ``throw=True`` and various errors occur - :raises ConnectionRefusedError: When ``throw=True`` and the connection was refused - :raises ConnectionResetError: When ``throw=True`` and the connection was reset - - :return bool success: ``True`` if successfully connected + sent/received data. Otherwise ``False``. - """ - receive, stype = int(kwargs.get('receive', 100)), kwargs.get('stype', socket.SOCK_STREAM) - timeout, send = kwargs.get('timeout', 'n/a'), kwargs.get('send') - if timeout == 'n/a': - t = socket.getdefaulttimeout() - timeout = 10.0 if not t else t - - loop = asyncio.get_event_loop() - s_ver = socket.AF_INET - ip = await resolve_ip_async(host, version) - - if ip_is_v6(ip): s_ver = socket.AF_INET6 - - try: - with socket.socket(s_ver, stype) as s: - if timeout: s.settimeout(float(timeout)) - await loop.sock_connect(s, (ip, int(port))) - if not empty(send): - await loop.sock_sendall(s, byteify(send)) - if receive > 0: - await loop.sock_recv(s, int(receive)) - return True - except (socket.timeout, ConnectionRefusedError, ConnectionResetError, socket.gaierror) as e: - if throw: - raise e - return False diff --git a/privex/helpers/net/socket.py b/privex/helpers/net/socket.py new file mode 100644 index 0000000..1980dea --- /dev/null +++ b/privex/helpers/net/socket.py @@ -0,0 +1,1796 @@ +""" +Various wrapper functions/classes which use :mod:`socket` or are strongly tied to functions in this file +which use :mod:`socket`. Part of :mod:`privex.helpers.net` - network related helper code. + +**Copyright**:: + + +===================================================+ + | © 2019 Privex Inc. | + | https://www.privex.io | + +===================================================+ + | | + | Originally Developed by Privex Inc. | + | License: X11 / MIT | + | | + | Core Developer(s): | + | | + | (+) Chris (@someguy123) [Privex] | + | (+) Kale (@kryogenic) [Privex] | + | | + +===================================================+ + + Copyright 2019 Privex Inc. ( https://www.privex.io ) + +""" +import asyncio +import functools +import socket +import ssl +import time +from ipaddress import ip_network +from typing import Any, Callable, Generator, IO, Iterable, List, Optional, Tuple, Union + +import attr + +from privex.helpers import settings +from privex.helpers.common import LayeredContext, byteify, empty, empty_if, is_true, stringify, strip_null +from privex.helpers.thread import SafeLoopThread +from privex.helpers.asyncx import await_if_needed, run_coro_thread +from privex.helpers.net.util import generate_http_request, get_ssl_context, ip_is_v6, ip_sock_ver, is_ip +from privex.helpers.net.dns import resolve_ip, resolve_ip_async +from privex.helpers.types import AUTO, AUTO_DETECTED, AnyNum, STRBYTES, T + +import logging + +log = logging.getLogger(__name__) + +__all__ = [ + 'AnySocket', 'OpAnySocket', 'SocketContextManager', + 'StopLoopOnMatch', 'SocketWrapper', 'AsyncSocketWrapper', 'send_data_async', 'send_data', 'upload_termbin', + 'upload_termbin_file', 'upload_termbin_async', 'upload_termbin_file_async' +] + +AnySocket = Union[ssl.SSLSocket, "socket.socket"] +OpAnySocket = Optional[Union[ssl.SSLSocket, "socket.socket"]] + + +class SocketContextManager: + parent_class: Union["SocketWrapper", "AsyncSocketWrapper"] + + def __init__(self, parent_class: Union["SocketWrapper", "AsyncSocketWrapper"]): + self.parent_class = parent_class + + def __enter__(self) -> "SocketWrapper": + log.debug("Entering SocketContextManager") + self.parent_class.reconnect() + return self.parent_class + + def __exit__(self, exc_type, exc_val, exc_tb): + log.debug("Exiting SocketContextManager") + self.parent_class.close() + + async def __aenter__(self) -> "AsyncSocketWrapper": + log.debug("[async] Entering SocketContextManager") + await self.parent_class.reconnect() + return self.parent_class + + async def __aexit__(self, exc_type, exc_val, exc_tb): + log.debug("[async] Exiting SocketContextManager") + self.parent_class.close() + + +class StopLoopOnMatch(Exception): + def __init__(self, message: str, match: Any = None, compare: str = None, compare_lower: bool = True, **extra): + self.message = message + self.match = match + self.compare = compare + self.compare_lower = compare_lower + self.extra = extra + super().__init__(message) + + +def _sockwrapper_auto_connect(new_sock: bool = False): + def _decorator(f): + @functools.wraps(f) + def wrapper(self: Union["SocketWrapper"], *args, _sock_tries=0, **kwargs): + kwargs = dict(kwargs) + gensock = None + if kwargs.pop('new_sock', new_sock): + log.debug("new_sock is true for call to function %s - generating socket to kwarg 'sock'...", f.__name__) + # kwargs['sock'] = self._select_socket(new_sock=True) + gensock = SocketTracker.duplicate(self.tracker) + elif 'sock' in kwargs and kwargs['sock'] not in [None, False, '']: + gensock = kwargs.pop('sock') + + if gensock not in [None, False, '']: + + log.debug("'sock' is present for call to function %s...", f.__name__) + with gensock as sck: + log.debug('ensuring socket is open (inside with). now connecting socket.') + try: + # self.connect(host=kwargs.get('host'), port=kwargs.get('port'), sock=sck) + kwargs['sock'] = sck + except OSError as e: + if 'already connected' in str(e): + log.debug('socket already connected. continuing.') + log.debug('socket should now be connected. calling function %s', f.__name__) + return f(self, *args, **kwargs) + + if not self.connected: + log.debug('instance socket is not connected ( calling function %s )', f.__name__) + + if not self.auto_connect: + raise ConnectionError( + "Would've auto-connected SocketWrapper, but self.auto_connect is False. Please call connect before " + "interacting with the socket." + ) + if any([empty(self.host, zero=True), empty(self.port, zero=True)]): + raise ConnectionError("Tried to auto-connect SocketWrapper, but self.host and/or self.port are empty!") + log.debug('connecting instance socket ( calling function %s )', f.__name__) + # self.connect(self.host, self.port) + self.tracker.connect() + try: + _sock_tries += 1 + return f(self, *args, **kwargs) + except (BrokenPipeError, ConnectionResetError, ConnectionAbortedError) as e: + if self.error_reconnect and _sock_tries < 3: + log.error("The socket appears to have broken. Resetting and trying again. Error was: %s - %s", type(e), str(e)) + self.tracker.reconnect() + return wrapper(self, *args, _sock_tries=_sock_tries, **kwargs) + raise e + return wrapper + return _decorator + + +def _async_sockwrapper_auto_connect(): + def _decorator(f): + @functools.wraps(f) + async def wrapper(self: Union["AsyncSocketWrapper"], *args, _sock_tries=0, **kwargs): + if not self.tracker.connected: + if not self.auto_connect: + raise ConnectionError( + "Would've auto-connected AsyncSocketWrapper, but self.auto_connect is False. Please call connect before " + "interacting with the socket." + ) + + if any([empty(self.host, zero=True), empty(self.port, zero=True)]): + raise ConnectionError("Tried to auto-connect AsyncSocketWrapper, but self.host and/or self.port are empty!") + # await self.connect(self.host, self.port) + await self.tracker.connect_async() + try: + _sock_tries += 1 + return await f(self, *args, **kwargs) + except (BrokenPipeError, ConnectionResetError, ConnectionAbortedError) as e: + log.error("The socket appears to have broken. Error was: %s - %s", type(e), str(e)) + if self.error_reconnect and _sock_tries < 3: + log.error("Resetting the connection and trying again...") + await self.tracker.reconnect_async() + return await wrapper(self, *args, _sock_tries=_sock_tries, **kwargs) + raise e + return wrapper + return _decorator + + +class MockContext: + def __enter__(self): + # return self.auto_socket + return "yes" + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + async def __aenter__(self): + return "yes" + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + +@attr.s +class SocketTracker: + """ + Data class used by :class:`.SocketWrapper` / :class:`.AsyncSocketWrapper` for managing sockets + """ + host: str = attr.ib() + port: int = attr.ib(converter=int) + timeout: Union[int, float] = attr.ib(factory=lambda: settings.DEFAULT_SOCKET_TIMEOUT) + server: bool = attr.ib(default=False, converter=is_true) + connected: bool = attr.ib(default=False, converter=is_true) + binded: bool = attr.ib(default=False, converter=is_true) + listening: bool = attr.ib(default=False, converter=is_true) + use_ssl: bool = attr.ib(default=False, converter=is_true) + socket_conf: dict = attr.ib(factory=dict) + ssl_conf: dict = attr.ib(factory=dict) + ssl_wrap_conf: dict = attr.ib(factory=dict) + hostname: str = attr.ib(default=None) + _ssl_context: ssl.SSLContext = attr.ib(default=None) + _ssl_socket: ssl.SSLSocket = attr.ib(default=None) + _loop: asyncio.AbstractEventLoop = attr.ib(default=None) + _socket: AnySocket = attr.ib(default=None) + _socket_layer_ctx = attr.ib(default=None) + + _host_v4: Optional[str] = attr.ib(default=None) + _host_v6: Optional[str] = attr.ib(default=None) + + _host_v4_resolved: bool = attr.ib(default=False) + _host_v6_resolved: bool = attr.ib(default=False) + + def __attrs_post_init__(self): + self.hostname = empty_if(self.hostname, self.host, zero=True) + + @property + def family(self) -> int: + return self.socket_conf.get('family', -1) + + @family.setter + def family(self, value: int): + self.socket_conf['family'] = value + + @property + def host_v4(self) -> Optional[str]: + if not self._host_v4_resolved: + self._host_v4 = resolve_ip(self.host, 'v4') + self._host_v4_resolved = True + return self._host_v4 + + @property + def host_v6(self) -> Optional[str]: + if not self._host_v6_resolved: + self._host_v6 = resolve_ip(self.host, 'v6') + self._host_v6_resolved = True + return self._host_v6 + + @property + def socket(self): + if not self._socket: + self._socket = socket.socket(**self.socket_conf) + return self._socket + + @socket.setter + def socket(self, value): + pass + + @property + def socket_layer_ctx(self): + if not self._socket_layer_ctx: + self._socket_layer_ctx = LayeredContext(MockContext()) + return self._socket_layer_ctx + + @socket_layer_ctx.setter + def socket_layer_ctx(self, value): + self._socket_layer_ctx = value + + def _make_context(self, **kwargs) -> ssl.SSLContext: + cnf = {**self.ssl_conf, **kwargs} + return get_ssl_context(**cnf) + + @property + def ssl_context(self): + if not self._ssl_context: + self._ssl_context = self._make_context() + return self._ssl_context + + @ssl_context.setter + def ssl_context(self, value): + self._ssl_context = value + + @property + def ssl_socket(self): + if not self._ssl_socket: + self._ssl_socket = self.ssl_context.wrap_socket(self.socket, **self.ssl_wrap_conf) + return self._ssl_socket + + @ssl_socket.setter + def ssl_socket(self, value): + self._ssl_socket = value + + @property + def loop(self) -> asyncio.AbstractEventLoop: + if not self._loop: + self._loop = asyncio.get_event_loop() + return self._loop + + @property + def _auto_socket(self): + return self.ssl_socket if self.use_ssl else self.socket + + @property + def auto_socket(self) -> AnySocket: + if not self.connected: self.connect() + return self._auto_socket + + @property + def ip_address(self): + try: + if empty(self._auto_socket) or empty(self._auto_socket.getpeername()): + return None + except Exception as e: + log.warning("Error while getting peername: %s %s", type(e), str(e)) + return None + return self._auto_socket.getpeername()[0] + connected_ip = ip_address + + @property + def connected_port(self): + try: + if empty(self._auto_socket) or empty(self._auto_socket.getpeername()): + return None + except Exception as e: + log.warning("Error while getting peername: %s %s", type(e), str(e)) + return None + return self._auto_socket.getpeername()[1] + + def bind(self, address: Tuple[str, AnyNum] = None, force=False, **kwargs): + if self.binded and not force: + return self.auto_socket + self.auto_socket.bind(address) + self.binded = True + return self.auto_socket + + def listen(self, backlog: int = 10, force=False, **kwargs): + s = self.auto_socket + if self.listening and not force: + return s + self.auto_socket.listen(backlog) + self.listening = True + return self.auto_socket + + def post_connect(self, sock: AnySocket): + log.debug("[%s.%s] Connected to host: %s", __name__, self.__class__.__name__, sock.getpeername()) + sock.settimeout(self.timeout) + return sock + + def v6_fallback(self, ex: Exception = None) -> bool: + ip = self.ip_address + if self.family == socket.AF_INET6 or (self.family != socket.AF_INET and not empty(ip) and ip_is_v6(ip)): + if self.host_v4: + if ex: + log.warning( + "[%s.%s] Error while using IPv6. Falling back to v4. %s %s", + __name__, self.__class__.__name__, type(ex), str(ex) + ) + self.family = socket.AF_INET + return True + return False + + def connect(self, force=False, override_ssl=None, _conn_tries=0) -> AnySocket: + if not self.connected or force: + sock = self.socket + if self.use_ssl and override_ssl in [None, True]: + sock = self.ssl_socket + log.debug("[%s.%s] Connecting to host %s on port %s", __name__, self.__class__.__name__, self.host, self.port) + + # log.debug("Connecting to host %s on port %s", self.host, self.port) + try: + _conn_tries += 1 + sock.settimeout(self.timeout) + sock.connect((self.host, self.port)) + except OSError as e: + if 'already connected' in str(e): + log.debug("[%s.%s] Got OSError. Already connected. %s - %s", __name__, self.__class__.__name__, type(e), str(e)) + self.connected = True + return self.post_connect(self.auto_socket) + if _conn_tries >= 3: + raise e + if not self.v6_fallback(e): + log.warning("[%s.%s] Got OSError. Resetting. %s - %s", __name__, self.__class__.__name__, type(e), str(e)) + # self._socket = None + return self.reconnect(force=True, override_ssl=override_ssl, _conn_tries=_conn_tries) + # sock.settimeout(self.timeout) + self.connected = True + if self.use_ssl: + self.ssl_socket = sock + else: + self.socket = sock + return self.post_connect(sock) + sock = self.ssl_socket if self.use_ssl and override_ssl in [None, True] else self.socket + return self.post_connect(sock) + + def reconnect(self, force=True, override_ssl=None, _conn_tries=0) -> AnySocket: + if self.connected or force: + self.disconnect() + return self.connect(force=True, override_ssl=override_ssl, _conn_tries=_conn_tries) + + async def reconnect_async(self, force=True, override_ssl=None, _conn_tries=0) -> AnySocket: + if self.connected or force: + self.disconnect() + return await self.connect_async(force=True, override_ssl=override_ssl, _conn_tries=_conn_tries) + + async def connect_async(self, force=False, override_ssl=None, _conn_tries=0) -> AnySocket: + if not self.connected or force: + sock = self.socket + if self.use_ssl and override_ssl in [None, True]: + sock = self.ssl_socket + log.debug("[async] [%s.%s] Connecting to host %s on port %s (timeout: %s)", __name__, self.__class__.__name__, + self.host, self.port, self.timeout) + try: + _conn_tries += 1 + sock.settimeout(self.timeout) + await asyncio.wait_for(self.loop.sock_connect(sock, (self.host, self.port)), self.timeout + 0.1) + except (OSError, asyncio.TimeoutError) as e: + if 'already connected' in str(e): + log.debug("[%s.%s] Got OSError. Already connected. %s - %s", __name__, self.__class__.__name__, type(e), str(e)) + self.connected = True + return self.post_connect(self.auto_socket) + if _conn_tries >= 3: + raise e + if not self.v6_fallback(e): + log.warning("[%s.%s] Got OSError. Resetting. %s - %s", __name__, self.__class__.__name__, type(e), str(e)) + # self._socket = None + return await self.reconnect_async(force=True, override_ssl=override_ssl, _conn_tries=_conn_tries) + # sock.settimeout(self.timeout) + self.connected = True + sock = self.ssl_socket if self.use_ssl and override_ssl in [None, True] else self.socket + return self.post_connect(sock) + + def _shutdown(self, sck): + try: + sck.shutdown(socket.SHUT_RDWR) + except OSError as e: + if 'not connected' in str(e): return + log.warning("OSError while shutting down socket: %s %s", type(e), str(e)) + except Exception as e: + log.warning("Exception while shutting down socket: %s %s", type(e), str(e)) + + def _close(self, sck): + try: + sck.close() + except OSError as e: + log.warning("OSError while closing socket: %s %s", type(e), str(e)) + except Exception as e: + log.warning("Exception while closing socket: %s %s", type(e), str(e)) + + def disconnect(self): + self.connected, self.binded, self.listening = False, False, False + try: + log.debug("[%s.%s] Disconnecting socket for host %s on port %s", __name__, self.__class__.__name__, self.host, self.port) + + # log.debug() + if self._socket: + self._shutdown(self._socket) + self._close(self._socket) + # try: + # self._socket.shutdown(socket.SHUT_RDWR) + # except OSError as e: + # log.warning("OSError while shutting down socket: %s %s", type(e), str(e)) + # except Exception as e: + # log.warning("Exception while shutting down socket: %s %s", type(e), str(e)) + # self._socket.close() + self._socket = None + if self._ssl_socket: + self._shutdown(self._ssl_socket) + self._close(self._ssl_socket) + # self._ssl_socket.shutdown(socket.SHUT_RDWR) + # self._ssl_socket.close() + self._ssl_socket = None + return True + except Exception: + log.exception("error while closing socket") + return False + + @classmethod + def duplicate(cls, inst: "SocketTracker", **kwargs) -> "SocketTracker": + cfg = dict( + host=inst.host, port=inst.port, timeout=inst.timeout, server=inst.server, use_ssl=inst.use_ssl, + socket_conf=inst.socket_conf, ssl_conf=inst.ssl_conf, ssl_wrap_conf=inst.ssl_wrap_conf + ) + cfg = {**cfg, **kwargs} + return cls(**cfg) + + def __enter__(self): + if self.socket_layer_ctx.virtual_layer == 0: + self._socket_layer_ctx = None + if self.connected: + self.reconnect() + elif not self.connected: self.connect() + # return self.auto_socket + self.socket_layer_ctx.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.socket_layer_ctx.virtual_layer <= 1: + self.disconnect() + self.socket_layer_ctx.__exit__(exc_type, exc_val, exc_tb) + # return self.auto_socket + + async def __aenter__(self): + if self.socket_layer_ctx.virtual_layer == 0: + self._socket_layer_ctx = None + if self.connected: + await self.reconnect_async() + elif not self.connected: await self.connect_async() + # return self.auto_socket + await self.socket_layer_ctx.__aenter__() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + # self.disconnect() + if self.socket_layer_ctx.virtual_layer <= 1: + self.disconnect() + await self.socket_layer_ctx.__aexit__(exc_type, exc_val, exc_tb) + + def __getattr__(self, item): + try: + return super().__getattribute__(item) + except AttributeError: + pass + sock: AnySocket = super().__getattribute__('auto_socket') + return getattr(sock, item) + + +class SocketWrapper(object): + """ + A wrapper class to make working with :class:`socket.socket` much simpler. + + .. NOTE:: For AsyncIO, use :class:`.AsyncSocketWrapper` instead. + + **Features** + + * Automatic address family detection - detects whether you have working IPv4 / IPv6, and decides the best way + to connect to a host, depending on what IP versions that host supports + * ``Happy Eyeballs`` for IPv6. If something goes wrong with an IPv6 connection, it will fallback to IPv4 if the + host has it available (i.e. a domain with both ``A`` and ``AAAA`` records) + * Easy to use SSL, which works with HTTPS and other SSL-secured protocols. Just pass ``use_ssl=True`` in the constructor. + * Many wrapper methods such as :meth:`.recv_eof`, :meth:`.query`, and :meth:`.http_request` to make working + with sockets much easier. + + + **Examples** + + Send a string of bytes / text to a server, and then read until EOF:: + + >>> sw = SocketWrapper('icanhazip.org', 80) + >>> res = sw.query("GET / HTTP/1.1\\nHost: icanhazip.com\\n\\n") + >>> print(res) + HTTP/1.1 200 OK + Server: nginx + Content-Type: text/plain; charset=UTF-8 + Content-Length: 17 + x-rtfm: Learn about this site at http://bit.ly/icanhazip-faq and do not abuse the service. + + 2a07:e00::abc + + For basic HTTP requests, you can use :meth:`.http_request`, which will automatically send ``Host`` (based on the host you passed), + and ``User-Agent``. SSL works too, just set ``use_ssl=True``:: + + >>> sw = SocketWrapper('myip.privex.io', 443, use_ssl=True) + >>> res = sw.http_request('/?format=json') + >>> print(res) + HTTP/1.1 200 OK + Server: nginx + Date: Tue, 22 Sep 2020 03:40:48 GMT + Content-Type: application/json + Content-Length: 301 + Connection: close + Access-Control-Allow-Origin: * + {"error":false,"geo":{"as_name":"Privex Inc.","as_number":210083,"city":"Stockholm","country":"Sweden", + "country_code":"SE","error":false,"zip":"173 11"},"ip":"2a07:e00::abc","ip_type":"ipv6","ip_valid":true, + "messages":[], "ua":"Python Privex Helpers ( https://github.com/Privex/python-helpers )"} + + Standard low-level sending and receiving data:: + + >>> sw = SocketWrapper('127.0.0.1', 8888) + >>> sw.sendall(b"hello world") # Send the text 'hello world' + >>> sw.recv(64) # read up to 64 bytes of data from the socket + b"lorem ipsum\n" + + """ + DEFAULT_TIMEOUT = empty_if(socket.getdefaulttimeout(), settings.DEFAULT_SOCKET_TIMEOUT, zero=True) + + _context: Optional[ssl.SSLContext] + _socket: OpAnySocket + _base_socket: Optional[socket.socket] + _ssl_socket: Optional[ssl.SSLSocket] + _layer_context: Optional[LayeredContext] + _socket_ctx_mgr: SocketContextManager + # connected: bool + auto_connect: bool + auto_listen: bool + listen_backlog: int + tracker: SocketTracker + + def __init__( + self, host: str, port: int, server=False, family=-1, type=socket.SOCK_STREAM, proto=-1, fileno=None, + timeout=DEFAULT_TIMEOUT, use_ssl=False, verify_cert=False, **kwargs + ): + self.host, self.port = host, int(port) + self.server = is_true(server) + # if self.server and (empty(type) or type == -1): + # type = socket.SOCK_STREAM + # self._socket = kwargs.get('socket', None) + # self._base_socket = kwargs.get('base_socket', None) + # self._ssl_socket = kwargs.get('ssl_socket', None) + _context = kwargs.get('ssl_context', None) + # self.connected = not (self._socket is None) + binded, listening = kwargs.get('binded', False), kwargs.get('listening', False) + check_connectivity = kwargs.get('check_connectivity', settings.CHECK_CONNECTIVITY) + self.auto_connect = kwargs.get('auto_connect', True) + self.error_reconnect = kwargs.get('error_reconnect', True) + self.auto_listen = kwargs.get('auto_listen', True) + self.listen_backlog = kwargs.get('listen_backlog', 10) + self.read_timeout = kwargs.get('read_timeout', settings.DEFAULT_READ_TIMEOUT) + self.send_timeout = kwargs.get('send_timeout', settings.DEFAULT_WRITE_TIMEOUT) + + from privex.helpers.net.common import check_v4_async, check_v6_async + + if family == -1 and is_ip(host): + log.debug("Host '%s' appears to be an IP. Automatically setting address family based on IP.", host) + family = ip_sock_ver(host) + + if family == -1 and check_connectivity: + host_v4 = resolve_ip(host, 'v4') + host_v6 = resolve_ip(host, 'v6') + + if host_v6 is not None and run_coro_thread(check_v6_async): + log.debug("Domain %s has one or more IPv6 addresses, and current system appears to have IPv6 connectivity. " + "Using domain's IPv6 address: %s", host, host_v6) + family = socket.AF_INET6 + elif host_v4 is not None and run_coro_thread(check_v4_async): + log.debug("Domain %s has one or more IPv4 addresses, and current system appears to have IPv4 connectivity. " + "Using domain's IPv4 address: %s", host, host_v4) + family = socket.AF_INET + + # self.use_ssl = use_ssl + # self.socket_conf = dict(family=family, type=type, proto=proto, fileno=fileno) + # self.ssl_wrap_conf = dict( + # server_hostname=kwargs.get('server_hostname'), + # session=kwargs.get('session'), + # do_handshake_on_connect=kwargs.get('do_handshake_on_connect', True) + # ) + # self.ssl_conf = dict( + # verify_cert=verify_cert, + # check_hostname=kwargs.get('check_hostname'), + # verify_mode=kwargs.get('verify_mode') + # ) + # sck = self._socket if self._socket is not None else socket.socket(**self.socket_conf) + self.tracker = SocketTracker( + self.host, self.port, + timeout=timeout, server=server, binded=binded, connected=kwargs.get('connected', False), + listening=listening, use_ssl=use_ssl, + socket_conf=dict(family=family, type=type, proto=proto, fileno=fileno), + ssl_conf=dict( + verify_cert=verify_cert, + check_hostname=kwargs.get('check_hostname'), + verify_mode=kwargs.get('verify_mode') + ), + ssl_wrap_conf=dict( + server_hostname=kwargs.get('server_hostname'), + session=kwargs.get('session'), + do_handshake_on_connect=kwargs.get('do_handshake_on_connect', True) + ), hostname=kwargs.get('hostname', None) + ) + + _socket = kwargs.get('socket', None) + _base_socket = kwargs.get('base_socket', None) + _ssl_socket = kwargs.get('ssl_socket', None) + + if _context is not None: self.tracker.ssl_context = _context + if _socket is not None: self.tracker.socket = _socket + if _base_socket is not None: self.tracker.socket = _base_socket + if _ssl_socket is not None: self.tracker.ssl_socket = _ssl_socket + + # self._timeout = float(timeout) + self._layer_context = None + self._socket_ctx_mgr = SocketContextManager(self) + # if use_ssl: + # ctx = get_ssl_context(**ssl_params) + # s = ctx.wrap_socket( + # server_hostname=kwargs.get('server_hostname'), + # session=kwargs.get('session'), + # do_handshake_on_connect=kwargs.get('do_handshake_on_connect', True), + # ) + + @property + def ssl_conf(self) -> dict: + return self.tracker.ssl_conf + + @ssl_conf.setter + def ssl_conf(self, value): + self.tracker.ssl_conf = value + + @property + def ssl_wrap_conf(self) -> dict: + return self.tracker.ssl_wrap_conf + + @ssl_wrap_conf.setter + def ssl_wrap_conf(self, value): + self.tracker.ssl_wrap_conf = value + + @property + def socket_conf(self) -> dict: + return self.tracker.socket_conf + + @socket_conf.setter + def socket_conf(self, value): + self.tracker.socket_conf = value + + @property + def timeout(self): + return self.tracker.timeout + + @property + def _auto_socket(self): + return self.tracker._auto_socket + + @timeout.setter + def timeout(self, value): + self.socket.settimeout(value) + self.tracker.timeout = value + # self.base_socket.settimeout(value) + # if self._socket: + # self._socket.settimeout(value) + # self._timeout = value + + def _make_context(self, **kwargs) -> ssl.SSLContext: + cnf = {**self.ssl_conf, **kwargs} + return get_ssl_context(**cnf) + + def _make_socket(self, **kwargs) -> socket.socket: + cnf = {**self.socket_conf, **kwargs} + # if self.server: + # if 'family' in cnf: del cnf['family'] + # if 'type' in cnf: del cnf['type'] + # if 'proto' in cnf: del cnf['proto'] + # if 'fileno' in cnf: del cnf['fileno'] + # log.info("socket host: %s || port: %s", self.host, self.port) + # log.info("socket extra config: %s", cnf) + # return socket.create_server((self.host, self.port), **cnf) + return socket.socket(**cnf) + + def _ssl_wrap_socket(self, sock: socket.socket = None, ctx: ssl.SSLContext = None, **kwargs) -> ssl.SSLSocket: + cnf = {**self.ssl_wrap_conf, **kwargs} + ctx = empty_if(ctx, self.context, itr=True, zero=True) + sock = empty_if(sock, self.base_socket, itr=True, zero=True) + return ctx.wrap_socket(sock, **cnf) + + def _select_socket(self, new_sock=False, **kwargs) -> Union[ssl.SSLSocket, "socket.socket"]: + if new_sock: + sock = self._make_socket() + if kwargs.get('use_ssl', self.use_ssl): + sock = self._ssl_wrap_socket(sock, **kwargs) + return sock + if self.use_ssl: + return self._ssl_wrap_socket(**kwargs) + return self.base_socket + + @property + def hostname(self): + return self.tracker.hostname + + @hostname.setter + def hostname(self, value): + self.tracker.hostname = value + + @property + def context(self) -> ssl.SSLContext: + # if not self._context: + # self._context = self._make_context() + # return self._context + if not self.tracker.ssl_context: + self.tracker.ssl_context = self._make_context() + return self.tracker.ssl_context + + ssl_context = context + + @property + def base_socket(self) -> socket.socket: + if not self.tracker.socket: + self.tracker.socket = self._make_socket() + return self.tracker.socket + + @base_socket.setter + def base_socket(self, value: socket.socket): + self.tracker.socket = value + + @property + def socket(self) -> AnySocket: + # if not self._socket: + # self._socket = self._select_socket() + # if not self.server: self._socket.settimeout(self.timeout) + return self.tracker.auto_socket + + @socket.setter + def socket(self, value: AnySocket): + if self.tracker.use_ssl: + self.tracker.ssl_socket = value + else: + self.tracker.socket = value + # self._socket = value + + @property + def connected(self): + return self.tracker.connected + + # @connected.setter + # def connected(self, value): + # self.tracker.connected = value + + def _connect_sanity(self, host, port, sock: OpAnySocket = None, **kwargs): + port = int(port) + sck = self.socket if sock is None else sock + + if sock is None and self.connected and self.socket is not None: + if host != self.host or port != int(self.port): + log.debug(f"Already connected, but {self.__class__.__name__}.connect called with different host/port than stored. " + f"Trigerring a reconnect.") + return self.reconnect(host, port, sock=sck) + log.debug(f"Already connected, {self.__class__.__name__}.connect called with same details as previously. " + f"Returning existing socket.") + return sck + if empty(port, True, True): + raise ValueError(f"{self.__class__.__name__}.connect requires a port. Either connect(host, port) or connect( (host,port) )") + return True + + def _connect(self, host: str, port: AnyNum, sock: OpAnySocket = None, **kwargs) -> AnySocket: + port = int(port) + sck = self.tracker if sock is None else sock + if self.server: + log.debug("Binding to host '%s' on port %s", host, port) + self.bind(host, port, sock=sock) + log.debug("Successfully binded to host '%s' on port %s", host, port) + if self.auto_listen: + log.debug("Auto-listen is enabled. Calling %s.listen(%s)", self.__class__.__name__, self.listen_backlog) + self.listen(self.listen_backlog, sock=sock) + log.debug("%s is now listening on host(s) '%s' on port %s", self.__class__.__name__, host, port) + # if sock is None: self.host, self.port, self.connected = host, port, True + return sck + + log.debug("[%s.%s] Connecting to host %s on port %s", self.__class__.__name__, __name__, host, port) + sck.connect((host, port)) + + # if sock is None: self.host, self.port, self.connected = host, port, True + return sck + + def _get_addr(self, host: Union[str, Tuple[str, AnyNum]] = None, port: AnyNum = None) -> Tuple[str, int]: + csn = self.__class__.__name__ + + if host is None: + if self.host is None: raise ValueError(f"No host specified to {csn}.reconnect(host, port) - and no host in {csn}.host") + host = self.host + if port is None: + if self.port is None: raise ValueError(f"No port specified to {csn}.connect(host, port) - and no port in {csn}.port") + port = self.port + if isinstance(host, (list, set)): host = tuple(host) + if isinstance(host, tuple): host, port = host + + return host, int(port) + + def bind(self, host: Union[str, Tuple[str, AnyNum]] = None, port: AnyNum = None, sock: OpAnySocket = None, **kwargs): + sck = self.socket if sock is None else sock + if sock is None and self.binded: + return + sck.bind(self._get_addr(host, port)) + if sock is None: self.binded = True + return True + + def connect(self, host: Union[str, Tuple[str, AnyNum]] = None, port: AnyNum = None, sock: OpAnySocket = None, **kwargs) -> AnySocket: + # csn = self.__class__.__name__ + # + # if host is None: + # if self.host is None: raise ValueError(f"No host specified to {csn}.reconnect(host, port) - and no host in {csn}.host") + # host = self.host + # if port is None: + # if self.port is None: raise ValueError(f"No port specified to {csn}.connect(host, port) - and no port in {csn}.port") + # port = self.port + # if isinstance(host, (list, set)): host = tuple(host) + # if isinstance(host, tuple): host, port = host + + host, port = self._get_addr(host, port) + sanity = self._connect_sanity(host, port, sock=sock) + if sanity is not True: return sanity + return self._connect(host, port, sock=sock) + + def reconnect(self, host: Union[str, Tuple[str, AnyNum]] = None, port: AnyNum = None, sock: OpAnySocket = None, **kwargs): + csn = self.__class__.__name__ + + # self.close(sock=sock) + if host is None: + if port is not None: + if self.host is None: + raise ValueError(f"No host specified to {csn}.reconnect(host, port) - and no host in {csn}.host") + # return self.connect(self.host, port, sock=sock, **kwargs) + host = self.host + # self.tracker.host, self.tracker.port = self.host, port + # self.tracker.reconnect() + # return self.tracker + # if all([self.host is not None, self.port is not None]): + # return self.connect(self.host, self.port, sock=sock, **kwargs) + # self.tracker.host, self.tracker.port = self.host, + + # self.tracker.connect() + # return self.tracker + elif port is None: + port = self.port + # self.tracker.host, self.tracker.port = host, port + # else: + self.tracker.host, self.tracker.port = host, port + + # return self.connect(host, port, sock=sock, **kwargs) + self.tracker.reconnect() + return self.tracker + # return self.connect(host, self.port, sock=sock, **kwargs) + + def listen(self, backlog=10, sock: OpAnySocket = None, **kwargs): + if self.listening: + return True + (self.socket if sock is None else sock).listen(backlog) + if sock is None: self.listening = True + return True + + @_sockwrapper_auto_connect() + def accept(self, sock: OpAnySocket = None, **kwargs) -> Tuple[AnySocket, Tuple[str, int]]: + return (self.socket if sock is None else sock).accept() + + @_sockwrapper_auto_connect() + def settimeout(self, value, sock: OpAnySocket = None, **kwargs): + return (self.socket if sock is None else sock).settimeout(value) + + def close(self, sock: OpAnySocket = None): + log.debug("Closing socket connection to host: %s || port: %s", self.host, self.port) + if sock is not None: + log.debug(" !! sock was specified. only closing sock.") + try: + sock.close() + log.debug("Closed sock.") + except Exception: + log.exception("error while closing sock") + return + self.tracker.disconnect() + # try: + # if self._socket is not None: + # log.debug("closing self.socket") + # self.socket.close() + # except Exception: + # log.exception("error while closing self.socket") + # try: + # if self._base_socket is not None: + # log.debug("closing self.base_socket") + # self.base_socket.close() + # except Exception: + # log.exception("error while closing self.base_socket") + # + # try: + # if self._ssl_socket is not None: + # self._ssl_socket.close() + # log.debug("closing self._ssl_socket") + # except Exception: + # log.exception("error while closing self._ssl_socket") + # self.connected = False + # log.debug("setting socket instance attributes to None") + # self._socket, self._ssl_socket, self._base_socket = None, None, None + + @_sockwrapper_auto_connect() + def recv(self, bufsize: int, flags: int = None, sock: OpAnySocket = None, **kwargs) -> bytes: + if flags is None: return (self.socket if sock is None else sock).recv(bufsize) + return (self.socket if sock is None else sock).recv(bufsize, flags) + + @_sockwrapper_auto_connect() + def recvfrom(self, bufsize: int, flags: int = None, sock: OpAnySocket = None, **kwargs) -> Tuple[bytes, Any]: + if flags is None: return (self.socket if sock is None else sock).recvfrom(bufsize) + return (self.socket if sock is None else sock).recvfrom(bufsize, flags) + + @_sockwrapper_auto_connect() + def recvmsg( + self, bufsize: int, ancbufsize:int = None, flags: int = None, sock: OpAnySocket = None, **kwargs + ) -> Tuple[bytes, List[Tuple[int, int, bytes]], int, Any]: + args = [bufsize] + if ancbufsize is not None: args.append(ancbufsize) + if flags is not None: args.append(flags) + return (self.socket if sock is None else sock).recvmsg(*args) + + @_sockwrapper_auto_connect() + def read_eof( + self, bufsize: int = 256, eof_timeout: AnyNum = 120, flags: int = None, timeout_fail=False, strip=True, + conv: Optional[Callable[[Union[bytes, str]], T]] = stringify, sock: OpAnySocket = None, **kwargs + ) -> Union[bytes, str, T]: + strip_func = kwargs.get('strip_func', lambda d: strip_null(d, conv=conv)) + data = b'' + total_time = 0.0 + + while True: + st_time = time.time() + chunk = self.recv(bufsize, flags, sock=sock) + if not chunk: + log.debug("Finished reading until EOF") + break + e_time = time.time() + total_time += (e_time - st_time) + data += chunk + if total_time > eof_timeout: + log.error("Giving up, spent over %f seconds (%f) reading until EOF for host %s", eof_timeout, total_time, self.host) + if timeout_fail: + raise TimeoutError(f"Giving up, spent over {eof_timeout} seconds ({total_time}) reading until EOF for host {self.host}") + break + + return strip_func(data) if strip else data + + @_sockwrapper_auto_connect() + def shutdown(self, how: int = None, sock: OpAnySocket = None, **kwargs): + how = empty_if(how, socket.SHUT_RDWR, itr=True) + return (self.socket if sock is None else sock).shutdown(how) + + @_sockwrapper_auto_connect() + def send(self, data: Union[str, bytes], flags: int = None, sock: OpAnySocket = None, **kwargs): + a = [byteify(data)] + if not empty(flags): a.append(flags) + return (self.socket if sock is None else sock).send(*a) + + @_sockwrapper_auto_connect() + def sendall(self, data: Union[str, bytes], flags: int = None, sock: OpAnySocket = None, **kwargs): + a = [byteify(data)] + if not empty(flags): a.append(flags) + return (self.socket if sock is None else sock).sendall(*a) + + @_sockwrapper_auto_connect() + def sendto(self, data: Union[str, bytes], *args, sock: OpAnySocket = None, **kwargs): + return (self.socket if sock is None else sock).sendto(byteify(data), *args, **kwargs) + + @_sockwrapper_auto_connect() + def send_chunks(self, gen: Union[Iterable, Generator], flags: int = None, sock: OpAnySocket = None, **kwargs): + results = [] + for c in gen: + results.append(self.send(c, flags, sock=sock, **kwargs)) + return results + + # @_sockwrapper_auto_connect() + # def query(self, data: Union[str, bytes], bufsize: int = 32, eof_timeout=30, **kwargs): + # timeout_fail, send_flags = kwargs.get('timeout_fail'), kwargs.get('send_flags', kwargs.get('flags', None)) + # recv_flags = kwargs.get('recv_flags', kwargs.get('flags', None)) + # log.debug(" >> Sending %s bytes to %s:%s", len(data), self.host, self.port) + # self.sendall(byteify(data), flags=send_flags) + # log.debug(" >> Reading %s bytes per chunk from %s:%s", bufsize, self.host, self.port) + # return self.read_eof(bufsize, eof_timeout=eof_timeout, flags=recv_flags, timeout_fail=timeout_fail) + + # @_sockwrapper_auto_connect() + # def http_request( + # self, url="/", host=AUTO_DETECTED, method="GET", user_agent=DEFAULT_USER_AGENT, extra_data: Union[STRBYTES, List[str]] = None, + # body: STRBYTES = None, eof_timeout=30, **kwargs + # ) -> Union[bytes, Awaitable[bytes]]: + # bufsize, flags, timeout_fail = kwargs.pop('bufsize', 256), kwargs.pop('flags', None), kwargs.pop('timeout_fail', False) + # data = self._http_request(url, host=host, method=method, user_agent=user_agent, extra=extra_data, body=body, **kwargs) + # self.sendall(data, flags=flags) + # return self.read_eof(bufsize, eof_timeout=eof_timeout, flags=flags, timeout_fail=timeout_fail) + + def _http_request(self, url, host: str, method: str, user_agent: str = settings.DEFAULT_USER_AGENT, extra=None, **kwargs) -> bytes: + host = self.hostname if host == AUTO_DETECTED else host + return generate_http_request(url, host, method=method, user_agent=user_agent, extra_data=extra, **kwargs) + + @_sockwrapper_auto_connect() + def query(self, data: Union[str, bytes], bufsize: int = 32, eof_timeout=30, sock: OpAnySocket = None, **kwargs): + timeout_fail, send_flags = kwargs.pop('timeout_fail', False), kwargs.pop('send_flags', kwargs.get('flags', None)) + recv_flags = kwargs.pop('recv_flags', kwargs.pop('flags', None)) + log.debug(" >> Sending %s bytes to %s:%s", len(data), self.host, self.port) + self.sendall(byteify(data), flags=send_flags, sock=sock) + log.debug(" >> Reading %s bytes per chunk from %s:%s", bufsize, self.host, self.port) + return self.read_eof(bufsize, eof_timeout=eof_timeout, flags=recv_flags, timeout_fail=timeout_fail, sock=sock, **kwargs) + + @_sockwrapper_auto_connect() + def http_request( + self, url="/", host=AUTO_DETECTED, method="GET", user_agent=settings.DEFAULT_USER_AGENT, + extra_data: Union[STRBYTES, List[str]] = None, body: STRBYTES = None, eof_timeout=30, bufsize: int = 256, + conv: Optional[Callable[[Union[bytes, str]], T]] = stringify, sock: OpAnySocket = None, **kwargs + ) -> Union[str, bytes, T]: + + data = self._http_request(url, host=host, method=method, user_agent=user_agent, extra=extra_data, body=body, **kwargs) + kargs = dict(data=data, bufsize=bufsize, eof_timeout=eof_timeout, timeout_fail=kwargs.get('timeout_fail', False), conv=conv, + sock=sock, **kwargs) + if sock is not None: return self.query(**kargs) + # with self: + with self.tracker: + return self.query(**kargs) + + @_sockwrapper_auto_connect() + def setblocking(self, flag: bool, sock: OpAnySocket = None, **kwargs): + return (self.socket if sock is None else sock).setblocking(flag) + + def handle_connection( + self, sock: AnySocket, addr: Tuple[str, int], callback: Callable[["SocketWrapper", Tuple[str, int]], Any], + stop_return: Union[str, bytes] = None, + **kwargs + ): + stop_compare, stop_compare_lower = kwargs.get('stop_compare', 'equal'), kwargs.get('stop_compare_lower', True) + if stop_return is not None: stop_return = stringify(stop_return) + log.info("NEW CONNECTION: %s || %s", sock, addr) + log.info("Running callback: %s(%s, %s)\n", callback.__name__, sock, addr) + orig_cres = callback(self.from_socket(sock), addr) + cres = stringify(orig_cres) + log.info("Callback return data: %s\n\n", cres) + if stop_return is not None: + if stop_compare_lower: stop_return, cres = stop_return.lower(), cres.lower() + + if stop_compare.lower() in ['in', 'contain', 'contains', 'contained', 'within', 'inside']: + if stop_return in cres or strip_null(stop_return) in strip_null(cres): + raise StopLoopOnMatch("Matched stop_return. Parent should stop loop.", cres, stop_compare, stop_compare_lower) + + if cres == stop_return or strip_null(cres) == strip_null(stop_return): + raise StopLoopOnMatch("Matched stop_return. Parent should stop loop.", cres, stop_compare, stop_compare_lower) + return orig_cres + + @_sockwrapper_auto_connect() + def on_connect( + self, callback: Callable[["SocketWrapper", Tuple[str, int]], Any], timeout: AnyNum = None, + stop_return: Union[str, bytes] = None, **kwargs + ): + if not self.server: + raise ValueError("This SocketWrapper has 'server' set to False. Can't handle incoming connections.") + if not self.binded: self.bind() + if not self.listening: self.listen(self.listen_backlog) + stop_return_match = None + + while self.connected and stop_return_match is None: + log.info("Waiting for incoming connection ( %s:%s || %s ) ...", self.host, self.port, self.socket.getsockname()) + sock, addr = self.accept() + try: + self.handle_connection(sock, addr, callback, stop_return, **kwargs) + except StopLoopOnMatch as e: + log.info(" !!! Stopping on_connect as 'stop_return' has been matched: %s", stop_return) + log.info(" !!! The matching message was: %s", e.match) + break + + # if stop_return_match is not None: + # log.info(" !!! Stopping on_connect as 'stop_return' has been matched: %s", stop_return) + # log.info(" !!! The matching message was: %s", stop_return_match) + + log.info(" !!! Disconnected. Stopping on_connect.") + + class SocketWrapperThread(SafeLoopThread): + def __init__(self, *args, parent_instance: "SocketWrapper", callback, stop_return, conn_kwargs: dict = None, **kwargs): + kwargs = dict(kwargs) + self.parent_instance = parent_instance + self.callback = callback + self.conn_kwargs = empty_if(conn_kwargs, {}, itr=True) + self.stop_return = stop_return + self.stop_compare = kwargs.pop('stop_compare', 'equal') + self.stop_compare_lower = kwargs.pop('stop_compare_lower', True) + super().__init__(*args, **kwargs) + + def loop(self): + pi = self.parent_instance + log.info("Waiting for incoming connection ( %s:%s || %s ) ...", pi.host, pi.port, pi.socket.getsockname()) + sock, addr = pi.accept() + try: + pi.handle_connection( + sock, addr, self.callback, self.stop_return, + stop_compare=self.stop_compare, stop_compare_lower=self.stop_compare_lower, **self.conn_kwargs + ) + except StopLoopOnMatch as e: + log.info(" !!! Stopping on_connect as 'stop_return' has been matched: %s", self.stop_return) + log.info(" !!! The matching message was: %s", e.match) + self.emit_stop() + + def run(self): + self.parent_instance.reconnect() + return super().run() + + def on_connect_thread( + self, callback: Callable[["SocketWrapper", Tuple[str, int]], Any], timeout: AnyNum = None, + stop_return: Union[str, bytes] = None, daemon=True, auto_start=True, **kwargs + ) -> SocketWrapperThread: + t = self.SocketWrapperThread(parent_instance=self, callback=callback, stop_return=stop_return, **kwargs) + t.setDaemon(daemon) + if auto_start: + t.start() + return t + + @classmethod + def from_socket(cls, sock: AnySocket, server=False, **kwargs) -> Union["SocketWrapper", "AsyncSocketWrapper"]: + sock_host, sock_port = sock.getsockname() + cfg = dict( + family=sock.family, proto=sock.proto, type=sock.type, fileno=sock.fileno(), + host=sock_host, port=sock_port, server=server, socket=sock, base_socket=sock + ) + cfg = {**cfg, **kwargs} + return cls(**cfg) + + def __getattribute__(self, item): + try: + return super().__getattribute__(item) + except AttributeError: + pass + sock: AnySocket = super().__getattribute__('socket') + return getattr(sock, item) + + def __enter__(self): + # if not self._socket_ctx_mgr: + # self._socket_ctx_mgr = SocketContextManager(self) + # if not self._layer_context: + # self._layer_context = LayeredContext(self._socket_ctx_mgr, max_layers=1) + # return self._layer_context.__enter__() + self.tracker.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + # return self.tracker.__aexit__() + + # return self._layer_context.__exit__(exc_type, exc_val, exc_tb) + return self.tracker.__exit__(exc_type, exc_val, exc_tb) + + +class AsyncSocketWrapper(SocketWrapper): + """ + + >>> from privex.helpers import AsyncSocketWrapper + >>> sw = AsyncSocketWrapper('termbin.com', 9999) + >>> url = await sw.query("HELLO world\\n\\nThis is a test\\nusing async sockets\\n\\nwith Python") + 'https://termbin.com/lsd93' + >>> url = await sw.read_eof() + """ + _loop: Optional[asyncio.AbstractEventLoop] + DEFAULT_TIMEOUT = empty_if(socket.getdefaulttimeout(), settings.DEFAULT_SOCKET_TIMEOUT, zero=True) + + def __init__( + self, host: str, port: int, server=False, family=-1, type=socket.SOCK_STREAM, proto=-1, fileno=None, + timeout=DEFAULT_TIMEOUT, use_ssl=False, verify_cert=False, loop=None, **kwargs + ): + self._loop = loop + super().__init__( + host=host, port=port, server=server, family=family, type=type, proto=proto, fileno=fileno, timeout=timeout, + use_ssl=use_ssl, verify_cert=verify_cert, **kwargs + ) + + @property + def loop(self) -> asyncio.AbstractEventLoop: + if not self._loop: + self._loop = asyncio.get_event_loop() + return self._loop + + async def _connect(self, host: str, port: AnyNum, sock: OpAnySocket = None, **kwargs) -> AnySocket: + port = int(port) + # sck = self.socket if sock is None else sock + + if self.server: + log.debug("Binding to host '%s' on port %s", host, port) + self.bind(host, port, sock=sock) + log.debug("Successfully binded to host '%s' on port %s", host, port) + if self.auto_listen: + log.debug("Auto-listen is enabled. Calling %s.listen(%s)", self.__class__.__name__, self.listen_backlog) + self.listen(self.listen_backlog, sock=sock) + log.debug("%s is now listening on host(s) '%s' on port %s", self.__class__.__name__, host, port) + # if sock is None: self.host, self.port, self.connected = host, port, True + return self.socket + log.debug("Connecting to host %s on port %s", host, port) + if sock: + await self.loop.sock_connect(sock, (host, port)) + return sock + await self.tracker.connect_async() + return self.tracker.auto_socket + # self.loop.soc + # if sock is None: self.host, self.port, self.connected = host, port, True + + async def connect( + self, host: Union[str, Tuple[str, AnyNum]] = None, port: AnyNum = None, sock: OpAnySocket = None, **kwargs) -> AnySocket: + + host, port = self._get_addr(host, port) + sanity = await self._connect_sanity(host, port) + if sanity is not True: return sanity + return await self._connect(host, port, sock=sock) + + async def _connect_sanity(self, host, port, sock: OpAnySocket = None, **kwargs): + port = int(port) + # sock = self._auto_socket if sock is None else sock + + if sock is not None or (self.connected and self._auto_socket is not None): + if host != self.host or port != int(self.port): + log.debug(f"Already connected, but {self.__class__.__name__}.connect called with different host/port than stored. " + f"Trigerring a reconnect.") + return await self.reconnect(host, port, sock=sock) + log.debug(f"Already connected, {self.__class__.__name__}.connect called with same details as previously. " + f"Returning existing socket.") + return sock + if empty(port, True, True): + raise ValueError(f"{self.__class__.__name__}.connect requires a port. Either connect(host, port) or connect( (host,port) )") + return True + + async def reconnect(self, host: Union[str, Tuple[str, AnyNum]] = None, port: AnyNum = None, sock: OpAnySocket = None, **kwargs): + csn = self.__class__.__name__ + self.close() + if host is None: + if port is not None: + if self.host is None: + raise ValueError(f"No host specified to {csn}.reconnect(host, port) - and no host in {csn}.host") + return await self.connect(self.host, port, sock=sock) + if all([self.host is not None, self.port is not None]): + return await self.connect(self.host, self.port, sock=sock) + if port is not None: + return await self.connect(host, port, sock=sock) + return await self.connect(host, self.port, sock=sock) + + @_async_sockwrapper_auto_connect() + async def read_eof( + self, bufsize: int = 256, eof_timeout: AnyNum = 120, flags: int = None, timeout_fail=False, strip=True, + conv: Optional[Callable[[Union[bytes, str]], T]] = stringify, sock: OpAnySocket = None, **kwargs + ) -> Union[str, bytes, T]: + strip_func = kwargs.get('strip_func', lambda d: strip_null(d, conv=conv)) + data, total_time = b'', 0.0 + + while True: + st_time = time.time() + chunk = await self.recv(bufsize, flags, timeout=kwargs.get('read_timeout', AUTO), sock=sock) + if not chunk: + log.debug("Finished reading until EOF") + break + e_time = time.time() + total_time += (e_time - st_time) + data += chunk + if not empty(eof_timeout, True) and total_time > eof_timeout: + log.error("Giving up, spent over %f seconds (%f) reading until EOF for host %s", eof_timeout, total_time, self.host) + if timeout_fail: + raise TimeoutError(f"Giving up, spent over {eof_timeout} seconds ({total_time}) reading until EOF for host {self.host}") + break + + return strip_func(data) if strip else data + + @_async_sockwrapper_auto_connect() + async def recv(self, bufsize: int, flags: int = None, sock: OpAnySocket = None, timeout: Union[float, int] = AUTO, **kwargs) -> bytes: + timeout, sck = self.read_timeout if timeout is AUTO else timeout, self.socket if sock is None else sock + if timeout not in [None, False]: + return await asyncio.wait_for(self.loop.sock_recv(sck, bufsize), timeout) + return await self.loop.sock_recv(sck, bufsize) + + @_async_sockwrapper_auto_connect() + async def recv_into(self, buf: bytearray, sock: OpAnySocket = None, **kwargs) -> int: + return await self.loop.sock_recv_into(self.socket if sock is None else sock, buf) + + @_async_sockwrapper_auto_connect() + async def send(self, data: Union[str, bytes], flags: int = None, sock: OpAnySocket = None, timeout: Union[float, int] = AUTO, **kwargs): + return await self.send_timeout(data, flags, sock, timeout, **kwargs) + + @_async_sockwrapper_auto_connect() + async def sendall( + self, data: Union[str, bytes], flags: int = None, sock: OpAnySocket = None, timeout: Union[float, int] = AUTO, **kwargs + ): + timeout, sck = self.send_timeout if timeout is AUTO else timeout, self.socket if sock is None else sock + if timeout not in [None, False]: + return await asyncio.wait_for(self.loop.sock_sendall(sck, byteify(data)), timeout) + return await self.loop.sock_sendall(sck, byteify(data)) + + @_async_sockwrapper_auto_connect() + async def sendfile( + self, file: IO[bytes], offset: int = None, count: int = None, fallback: bool = True, sock: OpAnySocket = None, + timeout: Union[float, int] = AUTO, **kwargs + ): + timeout, sck = self.send_timeout if timeout is AUTO else timeout, self.socket if sock is None else sock + if timeout not in [None, False]: + return await asyncio.wait_for(self.loop.sock_sendfile(sck, file, offset=offset, count=count, fallback=fallback), timeout) + return await self.loop.sock_sendfile(sck, file, offset=offset, count=count, fallback=fallback) + + @_async_sockwrapper_auto_connect() + async def query(self, data: Union[str, bytes], bufsize: int = 32, eof_timeout=30, sock: OpAnySocket = None, **kwargs): + timeout_fail, send_flags = kwargs.pop('timeout_fail', False), kwargs.pop('send_flags', kwargs.get('flags', None)) + recv_flags = kwargs.pop('recv_flags', kwargs.pop('flags', None)) + shared_timeout = kwargs.pop('timeout', AUTO) + log.debug(" >> Sending %s bytes to %s:%s", len(data), self.host, self.port) + snd_tmout, rcv_tmout = kwargs.pop('send_timeout', shared_timeout), kwargs.pop('read_timeout', shared_timeout) + await self.sendall(byteify(data), flags=send_flags, sock=self.socket if sock is None else sock, timeout=snd_tmout) + log.debug(" >> Reading %s bytes per chunk from %s:%s", bufsize, self.host, self.port) + return await self.read_eof( + bufsize, eof_timeout=eof_timeout, flags=recv_flags, timeout_fail=timeout_fail, + sock=self.socket if sock is None else sock, read_timeout=rcv_tmout, **kwargs + ) + + @_async_sockwrapper_auto_connect() + async def http_request( + self, url="/", host=AUTO_DETECTED, method="GET", user_agent=settings.DEFAULT_USER_AGENT, + extra_data: Union[STRBYTES, List[str]] = None, body: STRBYTES = None, eof_timeout=30, bufsize: int = 256, + conv: Optional[Callable[[Union[bytes, str]], T]] = stringify, sock: OpAnySocket = None, **kwargs + ) -> Union[str, bytes, T]: + async with self: + data = self._http_request( + url, host=host, method=method, user_agent=user_agent, extra=extra_data, body=body, sock=sock, **kwargs + ) + + # await self.sendall(data) + return await self.query( + data, bufsize, eof_timeout=eof_timeout, timeout_fail=kwargs.get('timeout_fail', False), conv=conv, sock=sock, **kwargs + ) + # return await super().http_request( + # url, host=host, method=method, user_agent=user_agent, extra=extra_data, body=body, eof_timeout=eof_timeout, **kwargs + # ) + + async def accept(self, sock: OpAnySocket = None, **kwargs) -> Tuple[AnySocket, Tuple[str, int]]: + return await self.loop.sock_accept(self.socket if sock is None else sock) + + async def handle_connection( + self, sock: AnySocket, addr: Tuple[str, int], callback: Callable[["AsyncSocketWrapper", Tuple[str, int]], Any], + stop_return: Union[str, bytes] = None, + **kwargs + ): + stop_compare, stop_compare_lower = kwargs.get('stop_compare', 'equal'), kwargs.get('stop_compare_lower', True) + if stop_return is not None: stop_return = stringify(stop_return) + log.info("[async] NEW CONNECTION: %s || %s", sock, addr) + log.info("[async] Running callback: %s(%s, %s)\n", callback.__name__, sock, addr) + orig_cres = await await_if_needed(callback(self.from_socket(sock), addr)) + cres = stringify(orig_cres) + log.info("[async] Callback return data: %s\n\n", cres) + if stop_return is not None: + if stop_compare_lower: stop_return, cres = stop_return.lower(), cres.lower() + + if stop_compare.lower() in ['in', 'contain', 'contains', 'contained', 'within', 'inside']: + if stop_return in cres or strip_null(stop_return) in strip_null(cres): + raise StopLoopOnMatch("[async] Matched stop_return. Parent should stop loop.", cres, stop_compare, stop_compare_lower) + + if cres == stop_return or strip_null(cres) == strip_null(stop_return): + raise StopLoopOnMatch("[async] Matched stop_return. Parent should stop loop.", cres, stop_compare, stop_compare_lower) + return orig_cres + + @_sockwrapper_auto_connect() + async def on_connect( + self, callback: Callable[["AsyncSocketWrapper", Tuple[str, int]], Any], timeout: AnyNum = None, + stop_return: Union[str, bytes] = None, sock: OpAnySocket = None, **kwargs + ): + if not self.server: + raise ValueError("This AsyncSocketWrapper has 'server' set to False. Can't handle incoming connections.") + if not self.binded: self.bind(sock=self.socket if sock is None else sock) + if not self.listening: self.listen(self.listen_backlog, sock=self.socket if sock is None else sock) + # if stop_return is not None: stop_return = stringify(stop_return) + # stop_compare, stop_compare_lower = kwargs.get('stop_compare', 'equal'), kwargs.get('stop_compare_lower', True) + stop_return_match = None + + while self.connected and stop_return_match is None: + log.info("[async] Waiting for incoming connection ( %s:%s || %s ) ...", self.host, self.port, self.socket.getsockname()) + sock, addr = await self.accept() + try: + with sock: + await self.handle_connection(sock, addr, callback, stop_return, **kwargs) + except StopLoopOnMatch as e: + log.info(" !!! Stopping on_connect as 'stop_return' has been matched: %s", stop_return) + log.info(" !!! The matching message was: %s", e.match) + break + # log.info("[async] NEW CONNECTION: %s || %s", sock, addr) + # log.info("[async] Running callback: %s(%s, %s)\n", callback.__name__, sock, addr) + # cres = await await_if_needed(callback(self.from_socket(sock), addr)) + # cres = stringify(cres) + # log.info("[async] Callback return data: %s\n\n", cres) + # if stop_return is not None: + # if stop_compare_lower: + # stop_return, cres = stop_return.lower(), cres.lower() + # if stop_compare.lower() in ['in', 'contain', 'contains', 'contained', 'within', 'inside']: + # if stop_return in cres or strip_null(stop_return) in strip_null(cres): + # stop_return_match = cres + # break + # if cres == stop_return or strip_null(cres) == strip_null(stop_return): + # stop_return_match = cres + # break + # if stop_return_match is not None: + # log.info(" !!! Stopping on_connect as 'stop_return' has been matched: %s", stop_return) + # log.info(" !!! The matching message was: %s", stop_return_match) + + log.info(" !!! Disconnected. Stopping on_connect.") + + async def __aenter__(self): + # if not self._socket_ctx_mgr: + # self._socket_ctx_mgr = SocketContextManager(self) + # if not self._layer_context: + # self._layer_context = LayeredContext(self._socket_ctx_mgr, max_layers=1) + # return await self._layer_context.__aenter__() + return await self.tracker.__aenter__() + + async def __aexit__(self, exc_type, exc_val, exc_tb): + # return await self._layer_context.__aexit__(exc_type, exc_val, exc_tb) + return await self.tracker.__aexit__(exc_type, exc_val, exc_tb) + + +async def send_data_async( + host: str, port: int, data: Union[bytes, str, Iterable], timeout: AnyNum = None, **kwargs +) -> Optional[Union[str, bytes]]: + """ + + >>> await send_data_async('termbin.com', 9999, "hello world\\nthis is a test\\n\\nlorem ipsum dolor\\n") + 'https://termbin.com/oi07' + + :param host: + :param port: + :param data: + :param timeout: + :param kwargs: + :return: + """ + fhost = f"({host}):{port}" + chunk_size = int(kwargs.get('chunk', kwargs.get('chunk_size', 64))) + string_result = is_true(kwargs.get('string_result', True)) + strip_result = is_true(kwargs.get('strip_result', True)) + fail = is_true(kwargs.get('fail', True)) + ip_version = kwargs.get('ip_version', 'any') + timeout = empty_if(timeout, empty_if(socket.getdefaulttimeout(), 15, zero=True), zero=True) + + is_iter, data_iter = False, None + + if data is not None: + if isinstance(data, (str, bytes, int, float)): + data = byteify(data) + else: + try: + data_iter = iter(data) + is_iter = True + except TypeError: + # noinspection PyTypeChecker + data = byteify(data) + + loop = asyncio.get_event_loop() + try: + s_ver = socket.AF_INET + ip = await resolve_ip_async(host, ip_version) + + if ip_is_v6(ip): s_ver = socket.AF_INET6 + + fhost += f" (IP: {ip})" + + with socket.socket(s_ver, socket.SOCK_STREAM) as s: + s.settimeout(float(timeout)) + log.debug(" [...] Connecting to host: %s", fhost) + await loop.sock_connect(s, (ip, port)) + log.debug(" [+++] Connected to %s\n", fhost) + + if data is None: + log.debug(" [!!!] 'data' is None. Not transmitting any data to the host.") + elif is_iter: + i = 1 + for c in data_iter: + log.debug(" [...] Sending %s byte chunk (%s)\n", len(c), i) + await loop.sock_sendall(s, c) + else: + # We use 'sendall' to reliably send the entire contents of 'data' to the service we're connected to. + log.debug(" [...] Sending %s bytes to %s ...\n", len(data), fhost) + await loop.sock_sendall(s, data) + # s.sendall(data) + log.debug(" >> Reading response ...") + res = b'' + i = 1 + while True: + chunk = await loop.sock_recv(s, chunk_size) + if not chunk: break + res += chunk + log.debug(" [...] Read %s byte chunk (%s)\n", len(chunk), i) + i += 1 + if string_result: + res = stringify(res) + if strip_result: res = res.strip("\x00").strip().strip("\x00").strip() + log.debug(" [+++] Got result ( %s bytes ) \n", len(res)) + except (socket.timeout, ConnectionRefusedError, ConnectionResetError, socket.gaierror) as e: + if fail: + raise e + log.warning("Exception while connecting + sending data to: %s - reason: %s %s", fhost, type(e), str(e)) + return None + return res + + +def send_data( + host: str, port: int, data: Optional[Union[bytes, str, Iterable]] = None, timeout: Union[int, float] = None, **kwargs +) -> Optional[Union[str, bytes]]: + """ + >>> from privex.helpers import send_data + >>> send_data('termbin.com', 9999, "hello world\\nthis is a test\\n\\nlorem ipsum dolor\\n") + 'https://termbin.com/oi07' + + :param str host: The hostname or IPv4/v6 address to connect to + :param port: The port number to connect to on ``host`` + :param bytes|str|iter data: The data to send to ``host:port`` via a TCP socket. Generally :class:`bytes` / :class:`str`. + Can be an iterator/generator to send data in chunks. Can be ``None`` to disable sending data, instead + only receiving and returning data. + :param float|int timeout: Socket timeout. If not passed, uses the default from :func:`socket.getdefaulttimeout`. + If the global default timeout is ``None``, then falls back to ``15`` + :param kwargs: + :keyword int chunk: (Default: ``64``) Maximum number of bytes to read into buffer per socket receive call. + :keyword bool string_result: (Default: ``True``) If ``True``, the response sent by the server will be casted into a :class:`str` + before returning it. + :keyword bool strip_result: (Default: ``True``) This argument only works if ``string_result`` is also True. + If both ``string_result`` and ``strip_result`` are ``True``, the response sent by the server will + have whitespace, newlines, and null bytes trimmed from the start and end after it's casted into a string. + :keyword bool fail: (Default: ``True``) If ``True``, will raise exceptions when connection errors occur. When ``False``, will simply + ``None`` if there are connection exceptions raised during this function's execution. + :keyword str|int ip_version: (Default: ``any``) + :return: + """ + fhost = f"({host}):{port}" + chunk_size = int(kwargs.get('chunk', kwargs.get('chunk_size', 64))) + string_result = is_true(kwargs.get('string_result', True)) + strip_result = is_true(kwargs.get('strip_result', True)) + fail = is_true(kwargs.get('fail', True)) + ip_version = kwargs.get('ip_version', 'any') + timeout = empty_if(timeout, empty_if(socket.getdefaulttimeout(), 15, zero=True), zero=True) + + is_iter, data_iter, is_v6, v4_address, host_is_ip = False, None, False, None, False + + if data is not None: + if isinstance(data, (str, bytes, int, float)): + data = byteify(data) + else: + try: + data_iter = iter(data) + is_iter = True + except TypeError: + # noinspection PyTypeChecker + data = byteify(data) + + try: + ip_network(host) + host_is_ip = True + except (TypeError, ValueError) as e: + host_is_ip = False + + try: + # First we resolve the IP address of 'host', so we can detect whether we're connecting to an IPv4 or IPv6 host, + # letting us adjust the AF_INET variable accordingly. + s_ver = socket.AF_INET + ip = resolve_ip(host, ip_version) + + if ip_is_v6(ip): + s_ver, is_v6 = socket.AF_INET6, True + if not host_is_ip: + try: + v4_address = resolve_ip(host, 'v4') + except (socket.timeout, ConnectionRefusedError, ConnectionResetError, socket.gaierror, AttributeError) as e: + log.warning( + "Warning: failed to resolve IPv4 address for %s (to be used as a backup if IPv6 is broken). Reason: %s %s ", + type(e), str(e) + ) + + fhost += f" (IP: {ip})" + + except (socket.timeout, ConnectionRefusedError, ConnectionResetError, socket.gaierror) as e: + if fail: + raise e + log.warning("Exception while connecting + sending data to: %s - reason: %s %s", fhost, type(e), str(e)) + return None + + try: + with socket.socket(s_ver, socket.SOCK_STREAM) as s: + # Once we have our socket object, we set the timeout (by default it could hang forever), and open the connection. + s.settimeout(timeout) + log.debug(" [...] Connecting to host: %s", fhost) + s.connect((ip, port)) + log.debug(" [+++] Connected to %s\n", fhost) + + if data is None: + log.debug(" [!!!] 'data' is None. Not transmitting any data to the host.") + elif is_iter: + i = 1 + for c in data_iter: + log.debug(" [...] Sending %s byte chunk (%s)\n", len(c), i) + s.sock_sendall(c) + else: + # We use 'sendall' to reliably send the entire contents of 'data' to the service we're connected to. + log.debug(" [...] Sending %s bytes to %s ...\n", len(data), fhost) + s.sendall(data) + # Once we've sent 'data', + log.debug(" >> Reading response ...") + res = b'' + i = 1 + while True: + chunk = s.recv(chunk_size) + if not chunk: break + res += chunk + log.debug(" [...] Read %s byte chunk (%s)\n", len(chunk), i) + i += 1 + if string_result: + res = stringify(res) + if strip_result: res = res.strip("\x00").strip().strip("\x00").strip() + log.debug(" [+++] Got result ( %s bytes ) \n", len(res)) + except (socket.timeout, ConnectionRefusedError, ConnectionResetError, socket.gaierror) as e: + log.warning("Exception while connecting + sending data to: %s - reason: %s %s", fhost, type(e), str(e)) + if is_v6 and not empty(v4_address): + log.warning( + "Retrying connection to %s over IPv4 instead of IPv6. || IPv6 address: %s || IPv4 address: %s ", + fhost, ip, v4_address + ) + return send_data(host, port, data, timeout=timeout, **kwargs) + + if fail: + raise e + return None + return res + + +def upload_termbin(data: Union[bytes, str], timeout: Union[int, float] = None, **kwargs) -> str: + """ + Upload the :class:`bytes` / :class:`string` ``data`` to the pastebin service `TermBin`_ , + using the hostname and port defined in :attr:`privex.helpers.settings.TERMBIN_HOST` + and :attr:`privex.helpers.settings.TERMBIN_PORT` + + NOTE - An AsyncIO version of this function is available: :func:`.upload_termbin_async` + + Returns the `TermBin`_ URL as a string - which is a raw download / viewing link for the paste. + + .. _TermBin: https://termbin.com + + >>> my_data = "hello world\\nthis is a test\\n\\nlorem ipsum dolor\\n" + >>> upload_termbin(my_data) + 'https://termbin.com/kerjk' + + :param bytes|str data: The data to upload to `TermBin`_ - as either :class:`str` or :class:`bytes` + :param float|int timeout: Socket timeout. If not passed, uses the default from :func:`socket.getdefaulttimeout`. + If the global default timeout is ``None``, then falls back to ``15`` + :return str url: The `TermBin`_ URL to your paste as a string - which is a raw download / viewing link for the paste. + """ + data = byteify(data) + log.info(" [...] Uploading %s bytes to termbin ...\n", len(data)) + res = send_data(settings.TERMBIN_HOST, settings.TERMBIN_PORT, data, timeout=timeout, **kwargs) + log.info(" [+++] Got termbin link: %s \n", res) + + return res + + +def upload_termbin_file(filename: str, timeout: int = 15, **kwargs) -> str: + """ + Uploads the file ``filename`` to `TermBin`_ and returns the paste URL as a string. + + .. NOTE:: An AsyncIO version of this function is available: :func:`.upload_termbin_file_async` + + .. NOTE:: If the data you want to upload is already loaded into a variable - you can use :func:`.upload_termbin` instead, + which accepts your data directly - through a :class:`str` or :class:`bytes` parameter + + .. _TermBin: https://termbin.com + + :param str filename: The path (absolute or relative) to the file you want to upload to `TermBin`_ - as a :class:`str` + :param float|int timeout: Socket timeout. If not passed, uses the default from :func:`socket.getdefaulttimeout`. + If the global default timeout is ``None``, then falls back to ``15`` + :return str url: The `TermBin`_ URL to your paste as a string - which is a raw download / viewing link for the paste. + """ + log.info(" >> Uploading file '%s' to termbin", filename) + + with open(filename, 'rb') as fh: + log.debug(" [...] Opened file %s - reading contents into RAM...", filename) + data = fh.read() + log.debug(" [+++] Loaded file into RAM. Total size: %s bytes", len(data)) + + res = upload_termbin(data, timeout=timeout, **kwargs) + log.info(" [+++] Uploaded file %s to termbin. Got termbin link: %s \n", filename, res) + return res + + +async def upload_termbin_async(data: Union[bytes, str], timeout: Union[int, float] = None) -> str: + """ + Upload the :class:`bytes` / :class:`string` ``data`` to the pastebin service `TermBin`_ , + using the hostname and port defined in :attr:`privex.helpers.settings.TERMBIN_HOST` + and :attr:`privex.helpers.settings.TERMBIN_PORT` + + NOTE - A synchronous (non-async) version of this function is available: :func:`.upload_termbin` + + Returns the `TermBin`_ URL as a string - which is a raw download / viewing link for the paste. + + .. _TermBin: https://termbin.com + + >>> my_data = "hello world\\nthis is a test\\n\\nlorem ipsum dolor\\n" + >>> await upload_termbin_async(my_data) + 'https://termbin.com/kerjk' + + :param bytes|str data: The data to upload to `TermBin`_ - as either :class:`str` or :class:`bytes` + :param float|int timeout: Socket timeout. If not passed, uses the default from :func:`socket.getdefaulttimeout`. + If the global default timeout is ``None``, then falls back to ``15`` + :return str url: The `TermBin`_ URL to your paste as a string - which is a raw download / viewing link for the paste. + """ + data = byteify(data) + log.info(" [...] Uploading %s bytes to termbin ...\n", len(data)) + res = await send_data_async(settings.TERMBIN_HOST, settings.TERMBIN_PORT, data, timeout=timeout) + log.info(" [+++] Got termbin link: %s \n", res) + + return res + + +async def upload_termbin_file_async(filename: str, timeout: int = 15) -> str: + """ + Uploads the file ``filename`` to `TermBin`_ and returns the paste URL as a string. + + .. NOTE:: A synchronous (non-async) version of this function is available: :func:`.upload_termbin_file` + + .. NOTE:: If the data you want to upload is already loaded into a variable - you can use :func:`.upload_termbin_async` instead, + which accepts your data directly - through a :class:`str` or :class:`bytes` parameter + + + .. _TermBin: https://termbin.com + + :param str filename: The path (absolute or relative) to the file you want to upload to `TermBin`_ - as a :class:`str` + :param float|int timeout: Socket timeout. If not passed, uses the default from :func:`socket.getdefaulttimeout`. + If the global default timeout is ``None``, then falls back to ``15`` + :return str url: The `TermBin`_ URL to your paste as a string - which is a raw download / viewing link for the paste. + """ + log.info(" >> Uploading file '%s' to termbin", filename) + + with open(filename, 'rb') as fh: + log.debug(" [...] Opened file %s - reading contents into RAM...", filename) + data = fh.read() + log.debug(" [+++] Loaded file into RAM. Total size: %s bytes", len(data)) + + res = await upload_termbin_async(data, timeout=timeout) + log.info(" [+++] Uploaded file %s to termbin. Got termbin link: %s \n", filename, res) + return res diff --git a/privex/helpers/net/util.py b/privex/helpers/net/util.py new file mode 100644 index 0000000..8f10a44 --- /dev/null +++ b/privex/helpers/net/util.py @@ -0,0 +1,172 @@ +import platform +import socket +import ssl +import subprocess +from ipaddress import IPv4Address, IPv4Network, IPv6Address, IPv6Network, ip_address, ip_network +from typing import List, Optional, Union +from privex.helpers import settings +from privex.helpers.common import empty_if, byteify, is_true, stringify +from privex.helpers.exceptions import NetworkUnreachable +from privex.helpers.types import IP_OR_STR, STRBYTES + +__all__ = [ + 'ip_is_v4', 'ip_is_v6', 'ping', 'IPV4_ALIASES', 'IPV6_ALIASES', 'ip_ver_to_int', 'ip_ver_to_sock', + 'sock_ver', 'is_ip', 'sock_validate_ip' +] + + +def ip_is_v4(ip: str) -> bool: + """ + Determines whether an IP address is IPv4 or not + + :param str ip: An IP address as a string, e.g. 192.168.1.1 + :raises ValueError: When the given IP address ``ip`` is invalid + :return bool: True if IPv6, False if not (i.e. probably IPv4) + """ + return type(ip_address(ip)) == IPv4Address + + +def ip_is_v6(ip: str) -> bool: + """ + Determines whether an IP address is IPv6 or not + + :param str ip: An IP address as a string, e.g. 192.168.1.1 + :raises ValueError: When the given IP address ``ip`` is invalid + :return bool: True if IPv6, False if not (i.e. probably IPv4) + """ + return type(ip_address(ip)) == IPv6Address + + +def ping(ip: str, timeout: int = 30) -> bool: + """ + Sends a ping to a given IPv4 / IPv6 address. Tested with IPv4+IPv6 using ``iputils-ping`` on Linux, as well as the + default IPv4 ``ping`` utility on Mac OSX (Mojave, 10.14.6). + + Fully supported when using Linux with the ``iputils-ping`` package. Only IPv4 support on Mac OSX. + + **Example Usage**:: + + >>> from privex.helpers import ping + >>> if ping('127.0.0.1', 5) and ping('::1', 10): + ... print('Both 127.0.0.1 and ::1 are up') + ... else: + ... print('127.0.0.1 or ::1 failed to respond to a ping within the given timeout.') + + **Known Incompatibilities**: + + * NOT compatible with IPv6 addresses on OSX due to the lack of a timeout argument with ``ping6`` + * NOT compatible with IPv6 addresses when using ``inetutils-ping`` on Linux due to separate ``ping6`` command + + :param str ip: An IP address as a string, e.g. ``192.168.1.1`` or ``2a07:e00::1`` + :param int timeout: (Default: 30) Number of seconds to wait for a response from the ping before timing out + :raises ValueError: When the given IP address ``ip`` is invalid or ``timeout`` < 1 + :return bool: ``True`` if ping got a response from the given IP, ``False`` if not + """ + ip_obj = ip_address(ip) # verify IP is valid (this will throw if it isn't) + if timeout < 1: + raise ValueError('timeout value cannot be less than 1 second') + opts4 = { + 'Linux': ["/bin/ping", "-c1", f"-w{timeout}"], + 'Darwin': ["/sbin/ping", "-c1", f"-t{timeout}"] + } + opts6 = {'Linux': ["/bin/ping", "-c1", f"-w{timeout}"]} + opts = opts4 if ip_is_v4(ip_obj) else opts6 + if platform.system() not in opts: + raise NotImplementedError(f"{__name__}.ping is not fully supported on platform '{platform.system()}'...") + + with subprocess.Popen(opts[platform.system()] + [ip], stdout=subprocess.PIPE, stderr=subprocess.PIPE) as proc: + out, err = proc.communicate() + err = err.decode('utf-8') + if 'network is unreachable' in err.lower(): + raise NetworkUnreachable(f'Got error from ping: "{err}"') + + return 'bytes from {}'.format(ip) in out.decode('utf-8') + + +IPV4_ALIASES = [4, 'v4', '4', 'ipv4', 'ip4', 'inet', 'inet4', socket.AF_INET, str(socket.AF_INET)] +IPV6_ALIASES = [6, 'v6', '6', 'ipv6', 'ip6', 'inet6', socket.AF_INET6, str(socket.AF_INET6)] + + +def ip_ver_to_int(ver: Union[str, int]) -> int: + ver = str(ver).lower() + if ver in IPV4_ALIASES: return 4 + if ver in IPV6_ALIASES: return 6 + return 0 + + +def sock_ver(version) -> Optional[int]: + version = empty_if(version, 'any', zero=True, itr=True) + version = str(version).lower() + if ip_ver_to_int(version) == 4: return socket.AF_INET + if ip_ver_to_int(version) == 6: return socket.AF_INET6 + return None + + +ip_ver_to_sock = sock_ver + + +def ip_sock_ver(ip_addr) -> Optional[int]: + a = ip_network(ip_addr, strict=False) + if isinstance(a, (IPv4Address, IPv4Network)): return socket.AF_INET + if isinstance(a, (IPv6Address, IPv6Network)): return socket.AF_INET6 + return None + + +def is_ip(addr: str, version: int = None): + try: + res = sock_validate_ip(addr, version=version) + return res + except AttributeError as e: + raise e + except ValueError: + return False + + +def sock_validate_ip(addr: IP_OR_STR, version: int, throw=True) -> Optional[Union[IPv4Address, IPv4Address]]: + ip = ip_address(addr) + ver = "v4" if ip_is_v4(ip) else "v6" + if version == socket.AF_INET and ver != 'v4': + if not throw: return None + raise AttributeError(f"Passed address '{addr}' was an IPv6 address, but 'version' requested an IPv4 address.") + if version == socket.AF_INET6 and ver != 'v6': + if not throw: return None + raise AttributeError(f"Passed address '{addr}' was an IPv4 address, but 'version' requested an IPv6 address.") + return ip + + +def get_ssl_context( + verify_cert: bool = False, check_hostname: Optional[bool] = None, verify_mode: Optional[int] = None, **kwargs + ) -> ssl.SSLContext: + check_hostname = empty_if(check_hostname, is_true(verify_cert)) + verify_mode = empty_if(verify_mode, ssl.CERT_REQUIRED if verify_cert else ssl.CERT_NONE) + + ctx = ssl.create_default_context() + ctx.check_hostname = check_hostname + ctx.verify_mode = verify_mode + return ctx + + +def generate_http_request( + url="/", host=None, method="GET", user_agent=settings.DEFAULT_USER_AGENT, extra_data: Union[STRBYTES, List[str]] = None, + body: STRBYTES = None, **kwargs +) -> bytes: + method, url = stringify(method), stringify(url) + http_ver = stringify(kwargs.get('http_ver', '1.0')) + data = f"{method.upper()} {url} HTTP/{http_ver}\n" + if host is not None: data += f"Host: {stringify(host)}\n" + if user_agent is not None: data += f"User-Agent: {stringify(user_agent)}\n" + data = byteify(data) + + if extra_data is not None: + if isinstance(extra_data, list): + extra_data = [byteify(x) for x in extra_data] + data += b"\n".join(extra_data) + else: + data += byteify(extra_data) + if not data.endswith(b"\n"): data += b"\n" + + if body is not None: + data += byteify(body) + if not data.endswith(b"\n"): data += b"\n" + if not data.endswith(b"\n\n"): data += b"\n" + return data diff --git a/privex/helpers/settings.py b/privex/helpers/settings.py index 8f28dc7..98e274e 100644 --- a/privex/helpers/settings.py +++ b/privex/helpers/settings.py @@ -29,8 +29,11 @@ """ +import random +from datetime import datetime from os import getcwd, getenv as env from os.path import dirname, abspath, join, expanduser +from typing import Optional BASE_DIR = dirname(dirname(dirname(abspath(__file__)))) """The root folder of this project (i.e. where setup.py is)""" @@ -41,6 +44,17 @@ EXTRAS_FOLDER = 'extras' """Folder where additional requirements files can be found for :py:func:`privex.helpers.setuppy.common.extras`""" + +def _is_true(v): + return (v.lower() if type(v) is str else v) in [True, 'true', 'yes', 'y', '1', 1] + + +def _env_bool(v, d) -> bool: return _is_true(env(v, d)) + + +def _env_int(v, d) -> int: return int(env(v, d)) + + ######################################## # # # Cache Module Settings # @@ -106,4 +120,77 @@ GEOCITY, GEOASN, GEOCOUNTRY = join(GEOIP_DIR, GEOCITY_NAME), join(GEOIP_DIR, GEOASN_NAME), join(GEOIP_DIR, GEOCOUNTRY_NAME) +TERMBIN_HOST, TERMBIN_PORT = 'termbin.com', 9999 + +CHECK_CONNECTIVITY: bool = _env_bool('CHECK_CONNECTIVITY', True) + +HAS_WORKING_V4: Optional[bool] = None +""" +This is a storage variable - becomes either ``True`` or ``False`` after :func:`.check_v4` has been ran. + + * ``None`` - The connectivity checking function has never been ran - unsure where this IP version works or not. + * ``True`` - This IP version appears to be fully functional - at least it was the last time the IP connectivity checking function was ran + * ``False`` - This IP version appears to be broken - at least it was the last time the IP connectivity checking function was ran + +""" +HAS_WORKING_V6: Optional[bool] = None +""" +This is a storage variable - becomes either ``True`` or ``False`` after :func:`.check_v6` has been ran. + + * ``None`` - The connectivity checking function has never been ran - unsure where this IP version works or not. + * ``True`` - This IP version appears to be fully functional - at least it was the last time the IP connectivity checking function was ran + * ``False`` - This IP version appears to be broken - at least it was the last time the IP connectivity checking function was ran + +""" + +SSL_VERIFY_CERT: bool = _env_bool('SSL_VERIFY_CERT', True) +SSL_VERIFY_HOSTNAME: bool = _env_bool('SSL_VERIFY_HOSTNAME', True) + +DEFAULT_USER_AGENT = "Python Privex Helpers ( https://github.com/Privex/python-helpers )" + +# V4_CHECKED_AT: Optional[datetime] = None +# """ +# This is a storage variable - used by :func:`.check_v4` to determine how long it's been since the host's IPv4 was tested. +# """ +# +# V6_CHECKED_AT: Optional[datetime] = None +# """ +# This is a storage variable - used by :func:`.check_v6` to determine how long it's been since the host's IPv6 was tested. +# """ + +NET_CHECK_TIMEOUT: int = _env_int('NET_CHECK_TIMEOUT', 3600) +""" +Number of seconds to cache the functional status of an IP version (caching applies to both positive and negative test results). +""" + +NET_CHECK_HOST_COUNT: int = _env_int('NET_CHECK_HOST_COUNT', 3) +""" +Number of hosts in :attr:`.V4_TEST_HOSTS` / :attr:`.V6_TEST_HOSTS` that must be accessible - before that IP protocol +is considered functional. +""" + +NET_CHECK_HOST_COUNT_TRY: int = _env_int('NET_CHECK_HOST_COUNT', 8) +""" +Maximum number of hosts in :attr:`.V4_TEST_HOSTS` / :attr:`.V6_TEST_HOSTS` that will be tested by :func:`.check_v4` / :func:`.check_v6` +""" + +V4_TEST_HOSTS = [ + '185.130.44.10:80', '8.8.4.4:53', '1.1.1.1:53', '185.130.44.20:53', 'privex.io:80', 'files.privex.io:80', + 'google.com:80', 'www.microsoft.com:80', 'facebook.com:80', 'python.org:80' +] + +V6_TEST_HOSTS = [ + '2a07:e00::333:53', '2001:4860:4860::8888:53', '2606:4700:4700::1111:53', '2a07:e00::abc:80', + 'privex.io:80', 'files.privex.io:80', 'google.com:80', 'facebook.com:80', 'bitbucket.org:80' +] + +random.shuffle(V4_TEST_HOSTS) +random.shuffle(V6_TEST_HOSTS) + +DEFAULT_SOCKET_TIMEOUT = 45 + +DEFAULT_READ_TIMEOUT = _env_int('DEFAULT_READ_TIMEOUT', 60) +DEFAULT_WRITE_TIMEOUT = _env_int('DEFAULT_WRITE_TIMEOUT', DEFAULT_READ_TIMEOUT) +DEFAULT_READ_TIMEOUT = None if DEFAULT_READ_TIMEOUT == 0 else DEFAULT_READ_TIMEOUT +DEFAULT_WRITE_TIMEOUT = None if DEFAULT_WRITE_TIMEOUT == 0 else DEFAULT_WRITE_TIMEOUT diff --git a/privex/helpers/types.py b/privex/helpers/types.py index 4961a88..72c8693 100644 --- a/privex/helpers/types.py +++ b/privex/helpers/types.py @@ -24,9 +24,9 @@ """Plain generic type variable for use in helper functions""" V = TypeVar('V') """Plain generic type variable for use in helper functions""" -C = TypeVar('C', type, callable, Callable) +C = TypeVar('C', type, Callable) """Generic type variable constrained to :class:`type` / :class:`typing.Callable` for use in helper functions""" -CL = TypeVar('CL', type, callable, Callable) +CL = TypeVar('CL', type, Callable) """Generic type variable constrained to :class:`type` / :class:`typing.Callable` for use in helper functions""" @@ -59,3 +59,10 @@ STRBYTES = Union[bytes, str] """Shorter alias for ``Union[bytes, str]``""" + +AUTO = AUTOMATIC = AUTO_DETECTED = type('AutoDetected', (), {}) +""" +Another functionless type, intended to stand-in as the default value for a parameter, with the +meaning "automatically populate this parameter from another source" e.g. instance state attributes +""" + diff --git a/setup.py b/setup.py index 70bd24d..64de4f8 100755 --- a/setup.py +++ b/setup.py @@ -72,7 +72,7 @@ license='MIT', install_requires=[ - 'privex-loghelper>=1.0.4', 'python-dateutil', 'sniffio', 'async-property', + 'privex-loghelper>=1.0.4', 'python-dateutil', 'sniffio', 'async-property', 'attrs' ], cmdclass=extra_commands, extras_require=extras_require(extensions), diff --git a/tests/test_cache.py b/tests/test_cache.py index a072f7b..13dcd8c 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -50,10 +50,7 @@ HAS_PYTEST = True except ImportError: warnings.warn('WARNING: Could not import pytest. You should run "pip3 install pytest" to ensure tests work best') - pytest = helpers.Mocker.make_mock_class('module') - pytest.skip = lambda msg, allow_module_level=True: warnings.warn(msg) - pytest.add_mock_module('mark') - pytest.mark.skip, pytest.mark.skipif = helpers.mock_decorator, helpers.mock_decorator + from privex.helpers.mockers import pytest HAS_PYTEST = False HAS_REDIS = plugin.HAS_REDIS diff --git a/tests/test_collections.py b/tests/test_collections.py index 8529d85..432c561 100644 --- a/tests/test_collections.py +++ b/tests/test_collections.py @@ -56,10 +56,7 @@ HAS_PYTEST = True except ImportError: warnings.warn('WARNING: Could not import pytest. You should run "pip3 install pytest" to ensure tests work best') - pytest = helpers.Mocker.make_mock_class('module') - pytest.skip = lambda msg, allow_module_level=True: warnings.warn(msg) - pytest.add_mock_module('mark') - pytest.mark.skip, pytest.mark.skipif = helpers.mock_decorator, helpers.mock_decorator + from privex.helpers.mockers import pytest HAS_PYTEST = False try: @@ -74,8 +71,7 @@ ) # To avoid a severe syntax error caused by the missing dataclass types, we generate a dummy dataclass and field class # so that type annotations such as Type[dataclass] don't break the test before it can be skipped. - dataclass = Mocker.make_mock_class(name='dataclass', instance=False) - field = Mocker.make_mock_class(name='field', instance=False) + from privex.helpers.mockers import dataclass, field, dataclasses log = logging.getLogger(__name__) diff --git a/tests/test_converters.py b/tests/test_converters.py index e34756a..3346d93 100644 --- a/tests/test_converters.py +++ b/tests/test_converters.py @@ -1,4 +1,7 @@ +import warnings +import pytest from datetime import datetime, date +from decimal import Decimal from dateutil.tz import tzutc @@ -127,3 +130,120 @@ def test_convert_int_bool_empty_cust(self): def test_convert_int_bool_empty_fail(self): with self.assertRaises(AttributeError): helpers.convert_int_bool(None, fail_empty=True) + + +try: + from dataclasses import dataclass, field + + HAS_DATACLASSES = True +except ImportError: + HAS_DATACLASSES = False + warnings.warn( + 'WARNING: Could not import dataclasses module (Python older than 3.7?). ' + 'For older python versions such as 3.6, you can run "pip3 install dataclasses" to install the dataclasses ' + 'backport library, which emulates Py3.7+ dataclasses using older syntax.', category=ImportWarning + ) + # To avoid a severe syntax error caused by the missing dataclass types, we generate a dummy dataclass and field class + # so that type annotations such as Type[dataclass] don't break the test before it can be skipped. + from privex.helpers.mockers import dataclass, field + +try: + import attr + + HAS_ATTRS = True +except ImportError: + HAS_ATTRS = False + warnings.warn( + 'WARNING: Could not import "attr" module. Please run "pip3 install attrs" to install the attrs module.', category=ImportWarning + ) + # To avoid a severe syntax error caused by the missing dataclass types, we generate a dummy dataclass and field class + # so that type annotations such as Type[dataclass] don't break the test before it can be skipped. + from privex.helpers.mockers import attr + +EXAMP_DEC = Decimal('1.2345') +EXAMP_DEC_STR = '1.2345' +EXAMP_DEC_FLOAT = 1.2345 + +EXAMP_INT, EXAMP_INT_STR = 123123, '123123' +EXAMP_FLOAT, EXAMP_FLOAT_STR = 543.12342, '543.12342' +EXAMP_STR = 'hello world' + +EXAMP_LIST = [1, 2, Decimal('3.123'), ('a', 'b', b'c')] +EXAMP_LIST_CLEAN_NUM = [1, 2, 3.123, ['a', 'b', 'c']] +EXAMP_LIST_CLEAN_STR = ['1', '2', '3.123', ['a', 'b', 'c']] + +EXAMP_DICT = dict(lorem='ipsum', dolor=3, world=list(EXAMP_LIST), example=Decimal(EXAMP_DEC)) +EXAMP_DICT_CLEAN_STR = dict(lorem='ipsum', dolor='3', world=list(EXAMP_LIST_CLEAN_STR), example=EXAMP_DEC_STR) +EXAMP_DICT_CLEAN_NUM = dict(lorem='ipsum', dolor=3, world=list(EXAMP_LIST_CLEAN_NUM), example=EXAMP_DEC_FLOAT) + + +@attr.s +class ExampleAttrs: + hello = attr.ib(default='world') + lorem = attr.ib(factory=lambda: Decimal(EXAMP_DEC)) + ipsum = attr.ib(factory=lambda: list(EXAMP_LIST)) + dolor = attr.ib(factory=lambda: dict(EXAMP_DICT)) + + +@dataclass +class ExampleDataClass: + hello: str = 'world' + lorem: Decimal = field(default_factory=lambda: Decimal(EXAMP_DEC)) + ipsum: list = field(default_factory=lambda: list(EXAMP_LIST)) + dolor: dict = field(default_factory=lambda: dict(EXAMP_DICT)) + + +class TestCleanData(PrivexBaseCase): + def test_clean_obj_decimal(self): + self.assertEqual(helpers.clean_obj(EXAMP_DEC), EXAMP_DEC_FLOAT) + self.assertEqual(helpers.clean_obj(EXAMP_DEC, number_str=True), EXAMP_DEC_STR) + + def test_clean_obj_float(self): + self.assertEqual(helpers.clean_obj(EXAMP_FLOAT), EXAMP_FLOAT) + self.assertEqual(helpers.clean_obj(EXAMP_FLOAT, number_str=True), EXAMP_FLOAT_STR) + + def test_clean_obj_int(self): + self.assertEqual(helpers.clean_obj(EXAMP_INT), EXAMP_INT) + self.assertEqual(helpers.clean_obj(EXAMP_INT, number_str=True), EXAMP_INT_STR) + + def test_clean_obj_list(self): + self.assertEqual(helpers.clean_obj(EXAMP_LIST), EXAMP_LIST_CLEAN_NUM) + self.assertEqual(helpers.clean_obj(EXAMP_LIST, number_str=True), EXAMP_LIST_CLEAN_STR) + + def test_clean_obj_dict(self): + self.assertEqual(helpers.clean_obj(EXAMP_DICT), EXAMP_DICT_CLEAN_NUM) + self.assertEqual(helpers.clean_obj(EXAMP_DICT, number_str=True), EXAMP_DICT_CLEAN_STR) + + @pytest.mark.skipif(HAS_ATTRS is False, reason='HAS_ATTRS is False (must install attrs: pip3 install attrs)') + def test_clean_obj_attrs(self): + o = ExampleAttrs() + c = helpers.clean_obj(o) + self.assertIsInstance(c, dict) + self.assertEqual(c['hello'], 'world') + self.assertEqual(c['lorem'], EXAMP_DEC_FLOAT) + self.assertListEqual(c['ipsum'], EXAMP_LIST_CLEAN_NUM) + self.assertDictEqual(c['dolor'], EXAMP_DICT_CLEAN_NUM) + + cs = helpers.clean_obj(o, True) + self.assertIsInstance(cs, dict) + self.assertEqual(cs['hello'], 'world') + self.assertEqual(cs['lorem'], EXAMP_DEC_STR) + self.assertListEqual(cs['ipsum'], EXAMP_LIST_CLEAN_STR) + self.assertDictEqual(cs['dolor'], EXAMP_DICT_CLEAN_STR) + + @pytest.mark.skipif(HAS_DATACLASSES is False, reason='HAS_DATACLASSES is False (Python older than 3.7?)') + def test_clean_obj_dataclass(self): + o = ExampleDataClass() + c = helpers.clean_obj(o) + self.assertIsInstance(c, dict) + self.assertEqual(c['hello'], 'world') + self.assertEqual(c['lorem'], EXAMP_DEC_FLOAT) + self.assertListEqual(c['ipsum'], EXAMP_LIST_CLEAN_NUM) + self.assertDictEqual(c['dolor'], EXAMP_DICT_CLEAN_NUM) + + cs = helpers.clean_obj(o, True) + self.assertIsInstance(cs, dict) + self.assertEqual(cs['hello'], 'world') + self.assertEqual(cs['lorem'], EXAMP_DEC_STR) + self.assertListEqual(cs['ipsum'], EXAMP_LIST_CLEAN_STR) + self.assertDictEqual(cs['dolor'], EXAMP_DICT_CLEAN_STR) diff --git a/tests/test_crypto.py b/tests/test_crypto.py index a8fabeb..964f5fc 100644 --- a/tests/test_crypto.py +++ b/tests/test_crypto.py @@ -54,8 +54,7 @@ HAS_PYTEST = True except ImportError: warnings.warn('WARNING: Could not import pytest. You should run "pip3 install pytest" to ensure tests work best') - pytest = Mocker.make_mock_class('module') - pytest.skip = lambda msg, allow_module_level=True: warnings.warn(msg) + from privex.helpers.mockers import pytest HAS_PYTEST = False if plugin.HAS_CRYPTO: diff --git a/tests/test_geoip.py b/tests/test_geoip.py index b790deb..e7aed2e 100644 --- a/tests/test_geoip.py +++ b/tests/test_geoip.py @@ -57,8 +57,7 @@ HAS_PYTEST = True except ImportError: warnings.warn('WARNING: Could not import pytest. You should run "pip3 install pytest" to ensure tests work best') - pytest = Mocker.make_mock_class('module') - pytest.skip = lambda msg, allow_module_level=True: warnings.warn(msg) + from privex.helpers.mockers import pytest HAS_PYTEST = False if plugin.HAS_GEOIP: diff --git a/tests/test_net.py b/tests/test_net.py index 5ed6001..fb4dd0c 100644 --- a/tests/test_net.py +++ b/tests/test_net.py @@ -4,9 +4,11 @@ import socket import warnings -from privex.helpers import loop_run +from privex.helpers import loop_run, settings from tests import PrivexBaseCase from privex import helpers +from privex.helpers import run_coro_thread +from privex.helpers.net import base as netbase try: import pytest @@ -14,10 +16,7 @@ HAS_PYTEST = True except ImportError: warnings.warn('WARNING: Could not import pytest. You should run "pip3 install pytest" to ensure tests work best') - pytest = helpers.Mocker.make_mock_class('module') - pytest.skip = lambda msg, allow_module_level=True: warnings.warn(msg) - pytest.add_mock_module('mark') - pytest.mark.skip, pytest.mark.skipif = helpers.mock_decorator, helpers.mock_decorator + from privex.helpers.mockers import pytest HAS_PYTEST = False HAS_DNSPYTHON = helpers.plugin.HAS_DNSPYTHON @@ -26,6 +25,13 @@ class TestNet(PrivexBaseCase): """Test cases related to :py:mod:`privex.helpers.net` or generally network related functions""" + def __init__(self, *args, **kwargs): + settings.CHECK_CONNECTIVITY = False + settings.DEFAULT_SOCKET_TIMEOUT = 5 + settings.DEFAULT_READ_TIMEOUT = 7 + settings.DEFAULT_WRITE_TIMEOUT = 3 + super().__init__(*args, **kwargs) + def test_ping(self): """Test success & failure cases for ping function with IPv4, as well as input validation""" try: @@ -138,6 +144,14 @@ def test_check_host_send(self): self.assertTrue(helpers.check_host('files.privex.io', 80, send=http_req)) self.assertFalse(helpers.check_host('files.privex.io', 9991)) + @pytest.mark.xfail() + def test_check_host_http(self): + self.assertTrue(helpers.check_host_http('files.privex.io', 80)) + + @pytest.mark.xfail() + def test_check_host_http_ssl(self): + self.assertTrue(helpers.check_host_http('www.privex.io', 443, use_ssl=True)) + @pytest.mark.xfail() def test_check_host_throw(self): with self.assertRaises(ConnectionRefusedError): @@ -397,6 +411,13 @@ def test_resolve_ips_multi_v6(self): class TestAsyncNet(PrivexBaseCase): + def __init__(self, *args, **kwargs): + settings.CHECK_CONNECTIVITY = False + settings.DEFAULT_SOCKET_TIMEOUT = 5 + settings.DEFAULT_READ_TIMEOUT = 7 + settings.DEFAULT_WRITE_TIMEOUT = 3 + super().__init__(*args, **kwargs) + def test_get_rdns_privex_ns1_ip(self): """Test resolving IPv4 and IPv6 addresses into ns1.privex.io""" self.assertEqual(loop_run(helpers.get_rdns_async('2a07:e00::100')), 'ns1.privex.io') @@ -417,20 +438,43 @@ def test_get_rdns_no_rdns_records(self): with self.assertRaises(helpers.ReverseDNSNotFound): loop_run(helpers.get_rdns_async('192.168.5.1')) + @pytest.mark.xfail() + def test_base_check_host_async(self): + self.assertTrue(loop_run(netbase.check_host_async, 'hiveseed-se.privex.io', 2001)) + self.assertFalse(loop_run(netbase.check_host_async, 'hiveseed-se.privex.io', 9991)) + + @pytest.mark.xfail() + def test_base_check_host_async_send(self): + http_req = b"GET / HTTP/1.1\n\n" + self.assertTrue(loop_run(netbase.check_host_async, 'files.privex.io', 80, send=http_req)) + self.assertFalse(loop_run(netbase.check_host_async, 'files.privex.io', 9991)) + + @pytest.mark.xfail() + def test_base_check_host_async_throw(self): + with self.assertRaises(ConnectionRefusedError): + loop_run(netbase.check_host_async, 'files.privex.io', 9991, throw=True) + @pytest.mark.xfail() def test_check_host_async(self): - self.assertTrue(loop_run(helpers.check_host_async('hiveseed-se.privex.io', 2001))) - self.assertFalse(loop_run(helpers.check_host_async('hiveseed-se.privex.io', 9991))) + self.assertTrue(run_coro_thread(helpers.check_host_async, 'hiveseed-se.privex.io', 2015, throw=True)) + self.assertFalse(run_coro_thread(helpers.check_host_async, 'hiveseed-se.privex.io', 9991, timeout=2)) @pytest.mark.xfail() def test_check_host_async_send(self): http_req = b"GET / HTTP/1.1\n\n" - self.assertTrue(loop_run(helpers.check_host_async('files.privex.io', 80, send=http_req))) - self.assertFalse(loop_run(helpers.check_host_async('files.privex.io', 9991))) + self.assertTrue(run_coro_thread(helpers.check_host_async, 'files.privex.io', 80, send=http_req)) + self.assertFalse(run_coro_thread(helpers.check_host_async, 'files.privex.io', 9991, timeout=2)) + + @pytest.mark.xfail() + def test_check_host_async_http(self): + self.assertTrue(run_coro_thread(helpers.check_host_http_async, 'files.privex.io', 80)) + + @pytest.mark.xfail() + def test_check_host_async_http_ssl(self): + self.assertTrue(run_coro_thread(helpers.check_host_http_async, 'www.privex.io', 443, use_ssl=True)) @pytest.mark.xfail() def test_check_host_async_throw(self): with self.assertRaises(ConnectionRefusedError): - loop_run(helpers.check_host_async('files.privex.io', 9991, throw=True)) - + run_coro_thread(helpers.check_host_async, 'files.privex.io', 9991, timeout=5, throw=True)