diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index ea62fcce..826f2aa8 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -22,7 +22,7 @@ jobs: strategy: matrix: os: [ubuntu-latest, macos-latest, windows-latest] - python-version: ["3.10", "3.11"] + python-version: ["3.11"] fail-fast: false steps: diff --git a/.github/workflows/sphinx.yml b/.github/workflows/sphinx.yml index 6a227951..d974f1cf 100644 --- a/.github/workflows/sphinx.yml +++ b/.github/workflows/sphinx.yml @@ -8,7 +8,7 @@ jobs: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 with: - python-version: '3.10' + python-version: '3.11' # pip cache - name: pip cache diff --git a/.gitignore b/.gitignore index d80f36db..43c14080 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ __pycache__/ **/*.pyc *.py[cod] *$py.class +*-env/ # C extensions *.so diff --git a/dev-environment.yml b/dev-environment.yml index 3042da6f..4100690c 100644 --- a/dev-environment.yml +++ b/dev-environment.yml @@ -2,7 +2,7 @@ name: dev-vai-lab-env channels: - conda-forge dependencies: - - python=3.10 + - python=3.11 - c-compiler - pip - pip: diff --git a/environment.yml b/environment.yml index e65480a8..3b78aaa5 100644 --- a/environment.yml +++ b/environment.yml @@ -2,7 +2,7 @@ name: vai-lab-env channels: - conda-forge dependencies: - - python=3.10 + - python=3.11 - c-compiler - pip - pip: diff --git a/pyproject.toml b/pyproject.toml index f1c9a7c5..21e92a40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,10 +12,9 @@ authors = [ readme = "README.md" classifiers = [ "License :: OSI Approved :: MIT License", - 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11' ] -requires-python = ">=3.8" +requires-python = ">=3.11" version = "0.0.dev4" description = "AI assisted Virtual Laboratory framework." diff --git a/sphinx-environment.yml b/sphinx-environment.yml index 45f3b5ae..8950f8ab 100644 --- a/sphinx-environment.yml +++ b/sphinx-environment.yml @@ -2,7 +2,7 @@ name: sphinx-env channels: - conda-forge dependencies: - - python=3.10 + - python=3.11 - sphinx - sphinx-rtd-theme - myst-parser diff --git a/sphinx-requirements.txt b/sphinx-requirements.txt index 4cbfb8b7..56c8516d 100644 --- a/sphinx-requirements.txt +++ b/sphinx-requirements.txt @@ -1,4 +1,3 @@ -# requirements.txt for sphinx sphinx sphinx-rtd-theme myst-parser diff --git a/src/vai_lab/Core/vai_lab_core.py b/src/vai_lab/Core/vai_lab_core.py index a2aeca8b..ebe9ca5d 100644 --- a/src/vai_lab/Core/vai_lab_core.py +++ b/src/vai_lab/Core/vai_lab_core.py @@ -47,9 +47,9 @@ def load_config_file(self, filename: Union[str,List,Tuple]): self._xml_handler.load_XML(filedir) self._initialised = True - def _load_data(self, module = 'Initialiser') -> None: + def _load_data(self, specs, module = 'Initialiser') -> None: """Loads data from XML file into Data object""" - init_data_fn = self._xml_handler.data_to_load(module) + init_data_fn = self._xml_handler.data_to_load(modules=specs, module=module) if module not in self.data.keys(): self.data[module] = Data() if isinstance(init_data_fn, str): @@ -66,7 +66,7 @@ def _execute_module(self, specs): mod: ModuleInterface = import_module(globals(), specs["module_type"]).__call__() mod._debug = self._debug mod.set_avail_plugins(self._avail_plugins) - self._load_data(specs["name"]) + self._load_data(specs, specs["name"]) mod.set_data_in(self.data[specs["name"]]) mod.set_options(specs) print("\t"*self.loop_level @@ -108,6 +108,7 @@ def _execute_exit_point(self, specs): with open(rel_to_abs(specs['plugin']['options']['outpath']), 'wb') as handle: pickle.dump(data_out, handle, protocol=pickle.HIGHEST_PROTOCOL) + def _parse_loop_condition(self, condition): try: condition = int(condition) @@ -169,7 +170,6 @@ def _execute(self, specs): if not _tracker['terminate']: self.load_config_file(self._xml_handler.filename) - # pass else: print('Pipeline terminated') exit() @@ -185,7 +185,7 @@ def run(self): self._initialise_with_gui() print("Running pipeline...") if len(self._xml_handler.loaded_modules) > 0: - self._load_data() + self._load_data(self._xml_handler.loaded_modules) self._init_status(self._xml_handler.loaded_modules) self._execute(self._xml_handler.loaded_modules) diff --git a/src/vai_lab/Data/xml_handler.py b/src/vai_lab/Data/xml_handler.py index b3ebeae3..9c0a62c4 100644 --- a/src/vai_lab/Data/xml_handler.py +++ b/src/vai_lab/Data/xml_handler.py @@ -9,7 +9,7 @@ root_mod = path.dirname(path.dirname(path.dirname(__file__))) sys.path.append(root_mod) -from vai_lab._import_helper import get_lib_parent_dir +from vai_lab._import_helper import get_lib_parent_dir, rel_to_abs class XML_handler: @@ -30,6 +30,8 @@ def __init__(self, filename: str = None): "pipeline": "declaration", "relationships": "relationships", "plugin": "plugin", + "method": "method", + "options": "options", "coordinates": "list", "Initialiser": "entry_point", "inputdata": "data", @@ -91,9 +93,42 @@ def load_XML(self, filename: str) -> None: """ if filename != None: self.set_filename(filename) - self.tree = ET.parse(self.filename) + if hasattr(self, 'tree'): + prev_tree = self.tree + self.tree = self._combine_XML(prev_tree.getroot(), ET.parse(self.filename).getroot()) + else: + self.tree = ET.parse(self.filename) self._parse_XML() + def _combine_XML(self, tree1, tree2): + """ + This function recursively updates either the text or the children + of an element if another element is found in `tree1`, or adds it + from `tree2` if not found. + """ + # Create a mapping from tag name to element, as that's what we are fltering with + mapping = {el.tag: el for el in tree1} + for el in tree2: + if len(el) == 0: + # Not nested + try: + # Update the text + mapping[el.tag].text = el.text + except KeyError: + # An element with this name is not in the mapping + mapping[el.tag] = el + # Add it + tree1.append(el) + else : + try: + # Recursively process the element, and update it in the same way + self._combine_XML(mapping[el.tag], el) + except KeyError: + # Not in the mapping + mapping[el.tag] = el + tree1.append(el) + return ET.ElementTree(tree1) + def _parse_XML(self) -> None: self.root = self.tree.getroot() self._parse_tags(self.root, self.loaded_modules) @@ -144,18 +179,38 @@ def _load_plugin(self, element: ET.Element, parent: dict) -> None: """ parent["plugin"] = {} parent["plugin"]["plugin_name"] = element.attrib["type"] + parent["plugin"]["methods"] = {"_order" : []} parent["plugin"]["options"] = {} + self._parse_tags(element, parent["plugin"]) + + def _load_method(self, element: ET.Element, parent: dict) -> None: + """Parses tags associated with methods and appends to parent dict + :param elem: xml.etree.ElementTree.Element to be parsed + :param parent: dict or dict fragment parsed tags will be appened to + """ + for key in element.attrib: + parent["methods"]["_order"].append(element.attrib[key]) + parent["methods"][element.attrib[key]] = {'options': {}} + self._parse_tags(element, parent["methods"][element.attrib[key]]) + + def _load_options(self, element: ET.Element, parent: dict) -> None: + """Parses tags associated with options and appends to parent dict + :param elem: xml.etree.ElementTree.Element to be parsed + :param parent: dict or dict fragment parsed tags will be appened to + """ for child in element: if child.text is not None: - val = self._parse_text_to_list(child) - val = (val[0] if len(val) == 1 else val) - parent["plugin"]["options"][child.tag] = val + try: + parent["options"][child.tag] = literal_eval(child.text.strip()) + except Exception as exc: + val = self._parse_text_to_list(child) + parent["options"][child.tag] = (val[0] if len(val) == 1 else val) + for key in child.attrib: if key == "val": - parent["plugin"]["options"][child.tag] = child.attrib[key] + parent["options"][child.tag] = child.attrib[key] else: - parent["plugin"]["options"][child.tag] = { - key: child.attrib[key]} + parent["options"][child.tag] = {key: child.attrib[key]} def _load_entry_point(self, element: ET.Element, parent: dict) -> None: """Parses tags associated with initialiser and appends to parent dict @@ -453,30 +508,33 @@ def update_plugin_options(self, else: xml_parent = xml_parent_name plugin_elem: ET.Element = xml_parent.find("./plugin") - self._add_plugin_options(plugin_elem, options) + self._add_options(plugin_elem, options) self._parse_XML() if save_changes: self.write_to_XML() - def _add_plugin_options(self, + def _add_options(self, plugin_elem: ET.Element, options ): + opt_elem = plugin_elem.find("./options") + if opt_elem is None: + opt_elem = ET.SubElement(plugin_elem, "options") for key in options.keys(): if isinstance(options[key], list): - new_option = ET.SubElement(plugin_elem, key) + new_option = ET.SubElement(opt_elem, key) option_text = ("\n{}".format( "\n".join([*options[key]]))) new_option.text = option_text elif isinstance(options[key], (int, float, str)): - new_option = plugin_elem.find(str("./" + key)) + new_option = opt_elem.find(str("./" + key)) if new_option is None: - new_option = ET.SubElement(plugin_elem, key) + new_option = ET.SubElement(opt_elem, key) text_lead = "\n" if "\n" not in str(options[key]) else "" new_option.text = "{0} {1}".format( text_lead, str(options[key])) elif isinstance(options[key], (dict)): - self._add_plugin_options(plugin_elem, options[key]) + self._add_options(opt_elem, options[key]) def append_input_data(self, data_name: str, @@ -511,14 +569,44 @@ def append_input_data(self, if save_dir_as_relative: data_dir = data_dir.replace(self.lib_base_path, "./") data_dir = data_dir.replace("\\", "/") - if path.exists(path.dirname(data_dir)): + if path.exists(path.dirname(rel_to_abs(data_dir))): plugin_elem.set('file', data_dir) else: plugin_elem.set('module', data_dir) + def append_method_to_plugin(self, + method_type: str, + method_options: dict, + xml_parent: Union[ET.Element, str], + overwrite_existing: Union[bool, int] = False + ): + """Appened method as subelement to existing plugin element + + :param method_type: string type of method to be loaded into plugin + :param method_options: dict where keys & values are options & values + :param xml_parent: dict OR str. + If string given, parent elem is found via search, + Otherwise, method appeneded directly + """ + if isinstance(xml_parent, str): + xml_parent = self._get_element_from_name(xml_parent) + + method_elem = xml_parent.find("./method") + + if method_elem is not None and overwrite_existing: + if method_elem.attrib['type'] == method_type: + xml_parent.remove(method_elem) + method_elem = None + + if method_elem is None: + method_elem = ET.SubElement(xml_parent, "method") + method_elem.set('type', method_type) + self._add_options(method_elem, method_options) + def append_plugin_to_module(self, plugin_type: str, plugin_options: dict, + method_list: list, plugin_data: str, xml_parent: Union[ET.Element, str], overwrite_existing: Union[bool, int] = False @@ -546,7 +634,13 @@ def append_plugin_to_module(self, plugin_elem.set('type', plugin_type) if plugin_data is not None and len(plugin_data) > 0: self.append_input_data('X', plugin_data, xml_parent, False) - self._add_plugin_options(plugin_elem, plugin_options) + if '__init__' in plugin_options.keys(): + self._add_options(plugin_elem, plugin_options['__init__']) + for f in method_list: + self.append_method_to_plugin(f, + plugin_options[f], + plugin_elem, + overwrite_existing) def append_pipeline_module(self, module_type: str, @@ -577,6 +671,7 @@ def append_pipeline_module(self, if plugin_type != None: self.append_plugin_to_module(plugin_type, plugin_options, + [], parents[0], new_mod, 0 @@ -625,10 +720,15 @@ def append_pipeline_loop(self, xml_parent_element.append(new_loop) - def _get_data_structure(self, module) -> Dict[str, Any]: - data_struct = self._find_dict_with_key_val_pair( - self.loaded_modules[module], - "class", "data") + def _get_data_structure(self, modules, module) -> Dict[str, Any]: + try: + data_struct = self._find_dict_with_key_val_pair( + modules[module], + "class", "data") + except Exception as exc: + data_struct = self._find_dict_with_key_val_pair( + modules, + "class", "data") assert len(data_struct) < 2, \ "Multiple data with same ID, please check XML" @@ -641,8 +741,10 @@ def _get_data_structure(self, module) -> Dict[str, Any]: return out #@property - def data_to_load(self, module='Initialiser') -> Dict[str, str]: - return self._get_data_structure(module)["to_load"] + def data_to_load(self, modules=False, module='Initialiser') -> Dict[str, str]: + if not modules: + modules = self.loaded_modules + return self._get_data_structure(modules, module)["to_load"] # Use case examples: diff --git a/src/vai_lab/DataProcessing/DataProcessing_core.py b/src/vai_lab/DataProcessing/DataProcessing_core.py index 666e2a2d..ba3716d6 100644 --- a/src/vai_lab/DataProcessing/DataProcessing_core.py +++ b/src/vai_lab/DataProcessing/DataProcessing_core.py @@ -31,8 +31,14 @@ def set_options(self, module_config: dict) -> None: def launch(self) -> None: self._plugin.set_data_in(self._data_in) self._plugin.configure(self._module_config["plugin"]) - self._plugin.fit() - self.output_data = self._plugin.transform(self._data_in) + self._plugin.init() + for method in self._module_config["plugin"]["methods"]["_order"]: + if "options" in self._module_config["plugin"]["methods"][method].keys(): + getattr(self._plugin, "{}".format(method))(self._plugin._parse_options_dict(self._module_config["plugin"]["methods"][method]["options"])) + else: + getattr(self._plugin, "{}".format(method))() + + self.output_data = self._data_in.copy() def get_result(self) -> DataInterface: - return self.output_data + return self.output_data \ No newline at end of file diff --git a/src/vai_lab/DataProcessing/plugins/binarizer.py b/src/vai_lab/DataProcessing/plugins/binarizer.py index 14d0e285..71e334ed 100644 --- a/src/vai_lab/DataProcessing/plugins/binarizer.py +++ b/src/vai_lab/DataProcessing/plugins/binarizer.py @@ -5,7 +5,7 @@ _PLUGIN_READABLE_NAMES = {"Binarizer":"default","binarizer":"alias"} # type:ignore _PLUGIN_MODULE_OPTIONS = {"Type": "encoder"} # type:ignore -_PLUGIN_REQUIRED_SETTINGS = {"Data": "str"} # type:ignore +_PLUGIN_REQUIRED_SETTINGS = {} # type:ignore _PLUGIN_OPTIONAL_SETTINGS = {"threshold": "float"} # type:ignore _PLUGIN_REQUIRED_DATA = {} # type:ignore _PLUGIN_OPTIONAL_DATA = {"X","Y","X_tst", 'Y_tst'} # type:ignore @@ -20,4 +20,4 @@ def __init__(self): Passes `globals` dict of all current variables """ super().__init__(globals()) - self.proc = model() \ No newline at end of file + self.model = model() \ No newline at end of file diff --git a/src/vai_lab/DataProcessing/plugins/kbinsdiscretizer.py b/src/vai_lab/DataProcessing/plugins/kbinsdiscretizer.py index a844db5e..4fc7023e 100644 --- a/src/vai_lab/DataProcessing/plugins/kbinsdiscretizer.py +++ b/src/vai_lab/DataProcessing/plugins/kbinsdiscretizer.py @@ -5,10 +5,10 @@ _PLUGIN_READABLE_NAMES = {"KBinsDiscretizer":"default"} # type:ignore _PLUGIN_MODULE_OPTIONS = {"Type": "encoder"} # type:ignore -_PLUGIN_REQUIRED_SETTINGS = {"Data": "str"} # type:ignore -_PLUGIN_OPTIONAL_SETTINGS = {"n_bins": "int"} # type:ignore +_PLUGIN_REQUIRED_SETTINGS = {} # type:ignore +_PLUGIN_OPTIONAL_SETTINGS = {} # type:ignore _PLUGIN_REQUIRED_DATA = {} # type:ignore -_PLUGIN_OPTIONAL_DATA = {"X","Y","X_tst","Y_tst"} # type:ignore +_PLUGIN_OPTIONAL_DATA = {"X","Y","X_tst","Y_tst"} # type:ignore class KBinsDiscretizer(DataProcessingT): """ @@ -20,4 +20,4 @@ def __init__(self): Passes `globals` dict of all current variables """ super().__init__(globals()) - self.proc = model() \ No newline at end of file + self.model = model() \ No newline at end of file diff --git a/src/vai_lab/DataProcessing/plugins/labelbinarizer.py b/src/vai_lab/DataProcessing/plugins/labelbinarizer.py index b9858065..811e1774 100644 --- a/src/vai_lab/DataProcessing/plugins/labelbinarizer.py +++ b/src/vai_lab/DataProcessing/plugins/labelbinarizer.py @@ -4,7 +4,7 @@ _PLUGIN_READABLE_NAMES = {"LabelBinarizer":"default"} # type:ignore _PLUGIN_MODULE_OPTIONS = {"Type": "encoder"} # type:ignore -_PLUGIN_REQUIRED_SETTINGS = {"Data": "str"} # type:ignore +_PLUGIN_REQUIRED_SETTINGS = {} # type:ignore _PLUGIN_OPTIONAL_SETTINGS = {"neg_label": "int", "pos_label": "int"} # type:ignore _PLUGIN_REQUIRED_DATA = {} # type:ignore _PLUGIN_OPTIONAL_DATA = {"X","Y","X_tst", 'Y_tst'} # type:ignore @@ -18,4 +18,4 @@ def __init__(self): Passes `globals` dict of all current variables """ super().__init__(globals()) - self.proc = model() \ No newline at end of file + self.model = model() \ No newline at end of file diff --git a/src/vai_lab/DataProcessing/plugins/labelencoder.py b/src/vai_lab/DataProcessing/plugins/labelencoder.py index 84bed864..a7166062 100644 --- a/src/vai_lab/DataProcessing/plugins/labelencoder.py +++ b/src/vai_lab/DataProcessing/plugins/labelencoder.py @@ -4,7 +4,7 @@ _PLUGIN_READABLE_NAMES = {"LabelEncoder":"default","LE":"alias"} # type:ignore _PLUGIN_MODULE_OPTIONS = {"Type": "encoder"} # type:ignore -_PLUGIN_REQUIRED_SETTINGS = {"Data": "str"} # type:ignore +_PLUGIN_REQUIRED_SETTINGS = {} # type:ignore _PLUGIN_OPTIONAL_SETTINGS = {} # type:ignore _PLUGIN_REQUIRED_DATA = {} # type:ignore _PLUGIN_OPTIONAL_DATA = {"X","Y","X_tst", 'Y_tst'} # type:ignore @@ -19,4 +19,4 @@ def __init__(self): Passes `globals` dict of all current variables """ super().__init__(globals()) - self.proc = model() \ No newline at end of file + self.model = model() \ No newline at end of file diff --git a/src/vai_lab/DataProcessing/plugins/maxabsscaler.py b/src/vai_lab/DataProcessing/plugins/maxabsscaler.py index d8e700ce..58def4b4 100644 --- a/src/vai_lab/DataProcessing/plugins/maxabsscaler.py +++ b/src/vai_lab/DataProcessing/plugins/maxabsscaler.py @@ -4,7 +4,7 @@ _PLUGIN_READABLE_NAMES = {"MaxAbsScaler":"default"} # type:ignore _PLUGIN_MODULE_OPTIONS = {"Type": "scaler"} # type:ignore -_PLUGIN_REQUIRED_SETTINGS = {"Data": "str"} # type:ignore +_PLUGIN_REQUIRED_SETTINGS = {} # type:ignore _PLUGIN_OPTIONAL_SETTINGS = {} # type:ignore _PLUGIN_REQUIRED_DATA = {} # type:ignore _PLUGIN_OPTIONAL_DATA = {"X","Y","X_tst", 'Y_tst'} # type:ignore @@ -19,4 +19,4 @@ def __init__(self): Passes `globals` dict of all current variables """ super().__init__(globals()) - self.proc = model() \ No newline at end of file + self.model = model() \ No newline at end of file diff --git a/src/vai_lab/DataProcessing/plugins/minmaxscaler.py b/src/vai_lab/DataProcessing/plugins/minmaxscaler.py index d3010e3b..c7e2d170 100644 --- a/src/vai_lab/DataProcessing/plugins/minmaxscaler.py +++ b/src/vai_lab/DataProcessing/plugins/minmaxscaler.py @@ -4,8 +4,8 @@ _PLUGIN_READABLE_NAMES = {"MinMaxScaler":"default"} # type:ignore _PLUGIN_MODULE_OPTIONS = {"Type": "scaler"} # type:ignore -_PLUGIN_REQUIRED_SETTINGS = {"Data": "str"} # type:ignore -_PLUGIN_OPTIONAL_SETTINGS = {"feature_range": "tuple"} # type:ignore +_PLUGIN_REQUIRED_SETTINGS = {} # type:ignore +_PLUGIN_OPTIONAL_SETTINGS = {} # type:ignore _PLUGIN_REQUIRED_DATA = {} # type:ignore _PLUGIN_OPTIONAL_DATA = {"X","Y","X_tst", 'Y_tst'} # type:ignore @@ -20,4 +20,4 @@ def __init__(self): Passes `globals` dict of all current variables """ super().__init__(globals()) - self.proc = model() \ No newline at end of file + self.model = model() \ No newline at end of file diff --git a/src/vai_lab/DataProcessing/plugins/multilabelbinarizer.py b/src/vai_lab/DataProcessing/plugins/multilabelbinarizer.py index 24be8036..4be65da8 100644 --- a/src/vai_lab/DataProcessing/plugins/multilabelbinarizer.py +++ b/src/vai_lab/DataProcessing/plugins/multilabelbinarizer.py @@ -4,7 +4,7 @@ _PLUGIN_READABLE_NAMES = {"MultiLabelBinarizer":"default"} # type:ignore _PLUGIN_MODULE_OPTIONS = {"Type": "encoder"} # type:ignore -_PLUGIN_REQUIRED_SETTINGS = {"Data": "str"} # type:ignore +_PLUGIN_REQUIRED_SETTINGS = {} # type:ignore _PLUGIN_OPTIONAL_SETTINGS = {"classes": "array-like"} # type:ignore _PLUGIN_REQUIRED_DATA = {} # type:ignore _PLUGIN_OPTIONAL_DATA = {"X","Y","X_tst", 'Y_tst'} # type:ignore @@ -19,4 +19,4 @@ def __init__(self): Passes `globals` dict of all current variables """ super().__init__(globals()) - self.proc = model() \ No newline at end of file + self.model = model() \ No newline at end of file diff --git a/src/vai_lab/DataProcessing/plugins/normalizer.py b/src/vai_lab/DataProcessing/plugins/normalizer.py index 5b92cfe6..6c0cbb8c 100644 --- a/src/vai_lab/DataProcessing/plugins/normalizer.py +++ b/src/vai_lab/DataProcessing/plugins/normalizer.py @@ -7,8 +7,8 @@ _PLUGIN_READABLE_NAMES = {"Normalizer": "default", "Norm": "alias", "normalizer": "alias"} # type:ignore _PLUGIN_MODULE_OPTIONS = {"Type": "scaler"} # type:ignore -_PLUGIN_REQUIRED_SETTINGS = {"Data": "str"} # type:ignore -_PLUGIN_OPTIONAL_SETTINGS = {"norm": "str"} # type:ignore +_PLUGIN_REQUIRED_SETTINGS = {} # type:ignore +_PLUGIN_OPTIONAL_SETTINGS = {} # type:ignore _PLUGIN_REQUIRED_DATA = {} # type:ignore _PLUGIN_OPTIONAL_DATA = {"X", "Y", "X_tst", 'Y_tst'} # type:ignore @@ -23,4 +23,4 @@ def __init__(self): Passes `globals` dict of all current variables """ super().__init__(globals()) - self.proc = model() \ No newline at end of file + self.model = model() \ No newline at end of file diff --git a/src/vai_lab/DataProcessing/plugins/onehotencoder.py b/src/vai_lab/DataProcessing/plugins/onehotencoder.py index 259574bc..73d8eb2c 100644 --- a/src/vai_lab/DataProcessing/plugins/onehotencoder.py +++ b/src/vai_lab/DataProcessing/plugins/onehotencoder.py @@ -4,7 +4,7 @@ _PLUGIN_READABLE_NAMES = {"OneHotEncoder":"default"} # type:ignore _PLUGIN_MODULE_OPTIONS = {"Type": "encoder"} # type:ignore -_PLUGIN_REQUIRED_SETTINGS = {"Data": "str"} # type:ignore +_PLUGIN_REQUIRED_SETTINGS = {} # type:ignore _PLUGIN_OPTIONAL_SETTINGS = {"categories": "array-like"} # type:ignore _PLUGIN_REQUIRED_DATA = {} # type:ignore _PLUGIN_OPTIONAL_DATA = {"X","Y","X_tst", 'Y_tst'} # type:ignore @@ -19,4 +19,4 @@ def __init__(self): Passes `globals` dict of all current variables """ super().__init__(globals()) - self.proc = model() \ No newline at end of file + self.model = model() \ No newline at end of file diff --git a/src/vai_lab/DataProcessing/plugins/ordinalencoder.py b/src/vai_lab/DataProcessing/plugins/ordinalencoder.py index 558083cf..a9095ccf 100644 --- a/src/vai_lab/DataProcessing/plugins/ordinalencoder.py +++ b/src/vai_lab/DataProcessing/plugins/ordinalencoder.py @@ -4,7 +4,7 @@ _PLUGIN_READABLE_NAMES = {"OrdinalEncoder": "default"} # type:ignore _PLUGIN_MODULE_OPTIONS = {"Type": "encoder"} # type:ignore -_PLUGIN_REQUIRED_SETTINGS = {"Data": "str"} # type:ignore +_PLUGIN_REQUIRED_SETTINGS = {} # type:ignore _PLUGIN_OPTIONAL_SETTINGS = {"categories": "array-like"} # type:ignore _PLUGIN_REQUIRED_DATA = {} # type:ignore _PLUGIN_OPTIONAL_DATA = {"X", "Y", "X_tst", 'Y_tst'} # type:ignore @@ -20,4 +20,4 @@ def __init__(self): Passes `globals` dict of all current variables """ super().__init__(globals()) - self.proc = model() \ No newline at end of file + self.model = model() \ No newline at end of file diff --git a/src/vai_lab/DataProcessing/plugins/polynomialfeatures.py b/src/vai_lab/DataProcessing/plugins/polynomialfeatures.py index deae011c..c40c5ea7 100644 --- a/src/vai_lab/DataProcessing/plugins/polynomialfeatures.py +++ b/src/vai_lab/DataProcessing/plugins/polynomialfeatures.py @@ -6,9 +6,8 @@ "polyfeat": "alias", "polynomialfeatures": "alias"} # type:ignore _PLUGIN_MODULE_OPTIONS = {"Type": "Other"} # type:ignore -_PLUGIN_REQUIRED_SETTINGS = {"Data": "str"} # type:ignore -_PLUGIN_OPTIONAL_SETTINGS = {"degree": "int", - "interaction_only": "bool", +_PLUGIN_REQUIRED_SETTINGS = {} # type:ignore +_PLUGIN_OPTIONAL_SETTINGS = {"interaction_only": "bool", "include_bias": "bool"} # type:ignore _PLUGIN_REQUIRED_DATA = {} # type:ignore _PLUGIN_OPTIONAL_DATA = {"X", "Y", "X_tst", 'Y_tst'} # type:ignore @@ -24,4 +23,4 @@ def __init__(self): Passes `globals` dict of all current variables """ super().__init__(globals()) - self.proc = model() \ No newline at end of file + self.model = model() \ No newline at end of file diff --git a/src/vai_lab/DataProcessing/plugins/quantiletransformer.py b/src/vai_lab/DataProcessing/plugins/quantiletransformer.py index 1e6b1c5b..91085593 100644 --- a/src/vai_lab/DataProcessing/plugins/quantiletransformer.py +++ b/src/vai_lab/DataProcessing/plugins/quantiletransformer.py @@ -5,7 +5,7 @@ _PLUGIN_READABLE_NAMES = { "QuantileTransformer": "default", "Quantile": "alias"} # type:ignore _PLUGIN_MODULE_OPTIONS = {"Type": "encoder"} # type:ignore -_PLUGIN_REQUIRED_SETTINGS = {"Data": "str"} # type:ignore +_PLUGIN_REQUIRED_SETTINGS = {} # type:ignore _PLUGIN_OPTIONAL_SETTINGS = {"n_quantiles": "int"} # type:ignore _PLUGIN_REQUIRED_DATA = {} # type:ignore _PLUGIN_OPTIONAL_DATA = {"X", "Y", "X_tst", 'Y_tst'} # type:ignore @@ -21,4 +21,4 @@ def __init__(self): Passes `globals` dict of all current variables """ super().__init__(globals()) - self.proc = model() \ No newline at end of file + self.model = model() \ No newline at end of file diff --git a/src/vai_lab/DataProcessing/plugins/standardscaler.py b/src/vai_lab/DataProcessing/plugins/standardscaler.py index f44c7d7b..cda34552 100644 --- a/src/vai_lab/DataProcessing/plugins/standardscaler.py +++ b/src/vai_lab/DataProcessing/plugins/standardscaler.py @@ -5,7 +5,7 @@ _PLUGIN_READABLE_NAMES = { "StandardScaler": "default", "standardscaler": "alias"} # type:ignore _PLUGIN_MODULE_OPTIONS = {"Type": "scaler"} # type:ignore -_PLUGIN_REQUIRED_SETTINGS = {"Data": "str"} # type:ignore +_PLUGIN_REQUIRED_SETTINGS = {} # type:ignore _PLUGIN_OPTIONAL_SETTINGS = {"with_mean": "bool"} # type:ignore _PLUGIN_REQUIRED_DATA = {} # type:ignore _PLUGIN_OPTIONAL_DATA = {"X", "Y", "X_tst", 'Y_tst'} # type:ignore @@ -21,4 +21,4 @@ def __init__(self): Passes `globals` dict of all current variables """ super().__init__(globals()) - self.proc = model() \ No newline at end of file + self.model = model() \ No newline at end of file diff --git a/src/vai_lab/DecisionMaking/DecisionMaking_core.py b/src/vai_lab/DecisionMaking/DecisionMaking_core.py new file mode 100644 index 00000000..88c8ea44 --- /dev/null +++ b/src/vai_lab/DecisionMaking/DecisionMaking_core.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +from vai_lab._import_helper import import_plugin_absolute +class DecisionMaking(object): + def __init__(self): + self.output_data = None + + def set_avail_plugins(self,avail_plugins): + self._avail_plugins = avail_plugins + + def set_data_in(self,data_in): + self._data_in = data_in + + def _load_plugin(self, plugin_name:str): + avail_plugins = self._avail_plugins.find_from_readable_name(plugin_name) + self._plugin_name = plugin_name + self._plugin = import_plugin_absolute(globals(),\ + avail_plugins["_PLUGIN_PACKAGE"],\ + avail_plugins["_PLUGIN_CLASS_NAME"])\ + .__call__() + + def set_options(self, module_config: dict): + """Send configuration arguments to plugin + + :param module_config: dict of settings to configure the plugin + """ + self._module_config = module_config + self._load_plugin(self._module_config["plugin"]["plugin_name"]) + + def launch(self): + self._plugin.set_data_in(self._data_in) + self._plugin.configure(self._module_config["plugin"]) + # self._plugin.optimise() + self.output_data = self._plugin.suggest_locations() + + def get_result(self): + return self.output_data \ No newline at end of file diff --git a/src/vai_lab/DecisionMaking/plugins/BayesianOptimisation(GPy).py b/src/vai_lab/DecisionMaking/plugins/BayesianOptimisation(GPy).py new file mode 100644 index 00000000..bf9ff291 --- /dev/null +++ b/src/vai_lab/DecisionMaking/plugins/BayesianOptimisation(GPy).py @@ -0,0 +1,61 @@ +from vai_lab._plugin_templates import DecisionMakingPluginT +from GPyOpt.methods import BayesianOptimization as model +from typing import Dict + +_PLUGIN_READABLE_NAMES = {"GPyOpt": "default", + "BayesianOptimisation": "alias", + "BayesianOptimisation_GPy": "alias",} # type:ignore +_PLUGIN_MODULE_OPTIONS = {"Type": "decision making"} # type:ignore +_PLUGIN_REQUIRED_SETTINGS = {"f": "function"} # type:ignore +_PLUGIN_OPTIONAL_SETTINGS = {"domain": "list", + "constraints": "list", + "acquisition_type ": "str", + "files": "list", + "normalize_Y": "bool", + "evaluator_type": "str", + "batch_size": "int", + "acquisition_jitter": "float"} # type:ignore +_PLUGIN_REQUIRED_DATA = {} # type:ignore +_PLUGIN_OPTIONAL_DATA = {"X","Y"} # type:ignore + +class GPyOpt(DecisionMakingPluginT): + """ + Bayesian optimisation model using GPyOpt. Compatible with no objective function using tabular data. + """ + + def __init__(self): + """Initialises parent class. + Passes `globals` dict of all current variables + """ + super().__init__(globals()) + self.model = model + + def _parse_options_dict(self, options_dict:Dict): + super()._parse_options_dict(options_dict) + if self.X is not None: + options_dict['X'] = self.X + if self.Y is not None: + options_dict['Y'] = self.Y.reshape(-1,1) + return options_dict + + def optimise(self): + """Sends parameters to optimizer, then runs Bayesian Optimization for a number 'max_iter' of iterations""" + try: + self.BO.run_optimization() + except Exception as exc: + print('The plugin encountered an error when running the optimization ' + +str(list(self._PLUGIN_READABLE_NAMES.keys())[list(self._PLUGIN_READABLE_NAMES.values()).index('default')])+'.') + raise + + def suggest_locations(self): + """Run a single optimization step and return the next locations to evaluate the objective. + Number of suggested locations equals to batch_size. + :returns: array, shape (n_samples,) + Returns suggested values. + """ + try: + return self.BO.suggest_next_locations() + except Exception as exc: + print('The plugin encountered an error when suggesting points with ' + +str(list(self._PLUGIN_READABLE_NAMES.keys())[list(self._PLUGIN_READABLE_NAMES.values()).index('default')])+'.') + raise \ No newline at end of file diff --git a/src/vai_lab/DecisionMaking/plugins/BayesianOptimisation(bayes_opt).py b/src/vai_lab/DecisionMaking/plugins/BayesianOptimisation(bayes_opt).py new file mode 100644 index 00000000..544bcdc9 --- /dev/null +++ b/src/vai_lab/DecisionMaking/plugins/BayesianOptimisation(bayes_opt).py @@ -0,0 +1,49 @@ +from vai_lab._plugin_templates import DecisionMakingPluginT +from bayes_opt import BayesianOptimization as model + +_PLUGIN_READABLE_NAMES = {"bayes_opt": "default", + "BayesOpt": "alias",} # type:ignore +_PLUGIN_MODULE_OPTIONS = {"Type": "decision making"} # type:ignore +_PLUGIN_REQUIRED_SETTINGS = {"f": "function", + "pbounds": "dict"} # type:ignore +_PLUGIN_OPTIONAL_SETTINGS = { + # "constraint": "ConstraintModel", + "random_state ": "int", + "verbose": "bool", + # "bounds_transformer": "DomainTransformer", + "allow_duplicate_points": "str"} # type:ignore +_PLUGIN_REQUIRED_DATA = {} # type:ignore +_PLUGIN_OPTIONAL_DATA = {"X","Y"} # type:ignore + +class bayes_opt(DecisionMakingPluginT): + """ + Bayesian optimisation model using bayes_opt. + """ + + def __init__(self): + """Initialises parent class. + Passes `globals` dict of all current variables + """ + super().__init__(globals()) + self.model = model + + def optimise(self): + """Probes the target space to find the parameters that yield the maximum value for the given function.""" + try: + self.BO.maximize() + except Exception as exc: + print('The plugin encountered an error when running the optimization ' + +str(list(self._PLUGIN_READABLE_NAMES.keys())[list(self._PLUGIN_READABLE_NAMES.values()).index('default')])+'.') + raise + + def suggest_locations(self, utility_function): + """Run a single optimization step and return the next locations to evaluate the objective. + :returns: array, shape (n_samples,) + Returns suggested values. + """ + try: + return self.BO.suggest(utility_function) + except Exception as exc: + print('The plugin encountered an error when suggesting points with ' + +str(list(self._PLUGIN_READABLE_NAMES.keys())[list(self._PLUGIN_READABLE_NAMES.values()).index('default')])+'.') + raise \ No newline at end of file diff --git a/src/vai_lab/Modelling/Modelling_core.py b/src/vai_lab/Modelling/Modelling_core.py index 81695629..fa19e436 100644 --- a/src/vai_lab/Modelling/Modelling_core.py +++ b/src/vai_lab/Modelling/Modelling_core.py @@ -29,9 +29,15 @@ def set_options(self, module_config: dict): def launch(self): self._plugin.set_data_in(self._data_in) self._plugin.configure(self._module_config["plugin"]) - self._plugin.solve() + self._plugin.init() + for method in self._module_config["plugin"]["methods"]["_order"]: + if "options" in self._module_config["plugin"]["methods"][method].keys(): + getattr(self._plugin, "{}".format(method))(self._plugin._parse_options_dict(self._module_config["plugin"]["methods"][method]["options"])) + else: + getattr(self._plugin, "{}".format(method))() + self.output_data = self._data_in.copy() - # self.output_data = self._plugin._test(self.output_data) + self.output_data = self._plugin._test(self.output_data) def get_result(self): - return self.output_data + return self.output_data \ No newline at end of file diff --git a/src/vai_lab/Modelling/plugins/affinitypropagation.py b/src/vai_lab/Modelling/plugins/affinitypropagation.py index b3dbc666..648a2091 100644 --- a/src/vai_lab/Modelling/plugins/affinitypropagation.py +++ b/src/vai_lab/Modelling/plugins/affinitypropagation.py @@ -19,4 +19,4 @@ def __init__(self): Passes `globals` dict of all current variables """ super().__init__(globals()) - self.clf = model() \ No newline at end of file + self.model = model() \ No newline at end of file diff --git a/src/vai_lab/Modelling/plugins/bayesianridge.py b/src/vai_lab/Modelling/plugins/bayesianridge.py index 9844f610..c8c034f4 100644 --- a/src/vai_lab/Modelling/plugins/bayesianridge.py +++ b/src/vai_lab/Modelling/plugins/bayesianridge.py @@ -19,4 +19,4 @@ def __init__(self): Passes `globals` dict of all current variables """ super().__init__(globals()) - self.clf = model() \ No newline at end of file + self.model = model() \ No newline at end of file diff --git a/src/vai_lab/Modelling/plugins/birch.py b/src/vai_lab/Modelling/plugins/birch.py index c1bd4a01..86325a30 100644 --- a/src/vai_lab/Modelling/plugins/birch.py +++ b/src/vai_lab/Modelling/plugins/birch.py @@ -20,4 +20,4 @@ def __init__(self): Passes `globals` dict of all current variables """ super().__init__(globals()) - self.clf = model() \ No newline at end of file + self.model = model() \ No newline at end of file diff --git a/src/vai_lab/Modelling/plugins/decisiontreeclassifier.py b/src/vai_lab/Modelling/plugins/decisiontreeclassifier.py index 32a8b3cc..797de431 100644 --- a/src/vai_lab/Modelling/plugins/decisiontreeclassifier.py +++ b/src/vai_lab/Modelling/plugins/decisiontreeclassifier.py @@ -1,7 +1,7 @@ from vai_lab._plugin_templates import ModellingPluginTClass from sklearn.tree import DecisionTreeClassifier as model -_PLUGIN_READABLE_NAMES = {"DecissionTreeClassifier": "default", +_PLUGIN_READABLE_NAMES = {"DecisionTreeClassifier": "default", "DTClassifier": "alias"} # type:ignore _PLUGIN_MODULE_OPTIONS = {"Type": "classification"} # type:ignore _PLUGIN_REQUIRED_SETTINGS = {} # type:ignore @@ -20,4 +20,4 @@ def __init__(self): Passes `globals` dict of all current variables """ super().__init__(globals()) - self.clf = model() \ No newline at end of file + self.model = model() \ No newline at end of file diff --git a/src/vai_lab/Modelling/plugins/decisiontreeregressor.py b/src/vai_lab/Modelling/plugins/decisiontreeregressor.py index d3b071f6..4665cd54 100644 --- a/src/vai_lab/Modelling/plugins/decisiontreeregressor.py +++ b/src/vai_lab/Modelling/plugins/decisiontreeregressor.py @@ -20,4 +20,4 @@ def __init__(self): Passes `globals` dict of all current variables """ super().__init__(globals()) - self.clf = model() \ No newline at end of file + self.model = model() \ No newline at end of file diff --git a/src/vai_lab/Modelling/plugins/elasticnet.py b/src/vai_lab/Modelling/plugins/elasticnet.py index 0e85f49d..eabd33bf 100644 --- a/src/vai_lab/Modelling/plugins/elasticnet.py +++ b/src/vai_lab/Modelling/plugins/elasticnet.py @@ -4,8 +4,7 @@ _PLUGIN_READABLE_NAMES = {"ElasticNet": "default"} # type:ignore _PLUGIN_MODULE_OPTIONS = {"Type": "regression"} # type:ignore _PLUGIN_REQUIRED_SETTINGS = {} # type:ignore -_PLUGIN_OPTIONAL_SETTINGS = {"alpha": "float", - "l1_ratio": "float"} # type:ignore +_PLUGIN_OPTIONAL_SETTINGS = {"l1_ratio": "float"} # type:ignore _PLUGIN_REQUIRED_DATA = {"X", "Y"} # type:ignore _PLUGIN_OPTIONAL_DATA = {"X_tst", 'Y_tst'} # type:ignore @@ -20,4 +19,4 @@ def __init__(self): Passes `globals` dict of all current variables """ super().__init__(globals()) - self.clf = model() \ No newline at end of file + self.model = model() \ No newline at end of file diff --git a/src/vai_lab/Modelling/plugins/gpclassifier.py b/src/vai_lab/Modelling/plugins/gpclassifier.py index b332f673..cefcf61a 100644 --- a/src/vai_lab/Modelling/plugins/gpclassifier.py +++ b/src/vai_lab/Modelling/plugins/gpclassifier.py @@ -22,4 +22,4 @@ def __init__(self): Passes `globals` dict of all current variables """ super().__init__(globals()) - self.clf = model() \ No newline at end of file + self.model = model() \ No newline at end of file diff --git a/src/vai_lab/Modelling/plugins/gpregressor.py b/src/vai_lab/Modelling/plugins/gpregressor.py index edf23cb6..add31a24 100644 --- a/src/vai_lab/Modelling/plugins/gpregressor.py +++ b/src/vai_lab/Modelling/plugins/gpregressor.py @@ -22,4 +22,4 @@ def __init__(self): Passes `globals` dict of all current variables """ super().__init__(globals()) - self.clf = model() \ No newline at end of file + self.model = model() \ No newline at end of file diff --git a/src/vai_lab/Modelling/plugins/kernelridge.py b/src/vai_lab/Modelling/plugins/kernelridge.py index 95cbdb20..fcac937a 100644 --- a/src/vai_lab/Modelling/plugins/kernelridge.py +++ b/src/vai_lab/Modelling/plugins/kernelridge.py @@ -5,8 +5,7 @@ "KR": "alias"} # type:ignore _PLUGIN_MODULE_OPTIONS = {"Type": "regression"} # type:ignore _PLUGIN_REQUIRED_SETTINGS = {} # type:ignore -_PLUGIN_OPTIONAL_SETTINGS = {"alpha": "float", - "kernel": "str", +_PLUGIN_OPTIONAL_SETTINGS = {"kernel": "str", "gamma": "float", "degree": "int"} # type:ignore _PLUGIN_REQUIRED_DATA = {"X", "Y"} # type:ignore @@ -23,4 +22,4 @@ def __init__(self): Passes `globals` dict of all current variables """ super().__init__(globals()) - self.clf = model() \ No newline at end of file + self.model = model() \ No newline at end of file diff --git a/src/vai_lab/Modelling/plugins/kmeans.py b/src/vai_lab/Modelling/plugins/kmeans.py index 960512ce..11e3e71e 100644 --- a/src/vai_lab/Modelling/plugins/kmeans.py +++ b/src/vai_lab/Modelling/plugins/kmeans.py @@ -4,8 +4,7 @@ _PLUGIN_READABLE_NAMES = {"KMeans": "default"} # type:ignore _PLUGIN_MODULE_OPTIONS = {"Type": "clustering"} # type:ignore _PLUGIN_REQUIRED_SETTINGS = {} # type:ignore -_PLUGIN_OPTIONAL_SETTINGS = {"n_clusters": "int", - "n_init": "int"} # type:ignore +_PLUGIN_OPTIONAL_SETTINGS = {} # type:ignore _PLUGIN_REQUIRED_DATA = {"X"} # type:ignore _PLUGIN_OPTIONAL_DATA = {"Y", "X_tst", 'Y_tst'} # type:ignore @@ -20,4 +19,4 @@ def __init__(self): Passes `globals` dict of all current variables """ super().__init__(globals()) - self.clf = model() \ No newline at end of file + self.model = model() \ No newline at end of file diff --git a/src/vai_lab/Modelling/plugins/knnclassifier.py b/src/vai_lab/Modelling/plugins/knnclassifier.py index 3a56f8ec..2c880e95 100644 --- a/src/vai_lab/Modelling/plugins/knnclassifier.py +++ b/src/vai_lab/Modelling/plugins/knnclassifier.py @@ -5,8 +5,7 @@ "KNN-C": "alias"} # type:ignore _PLUGIN_MODULE_OPTIONS = {"Type": "classification"} # type:ignore _PLUGIN_REQUIRED_SETTINGS = {} # type:ignore -_PLUGIN_OPTIONAL_SETTINGS = {"n_neighbors": "int", - "weights": "str"} # type:ignore +_PLUGIN_OPTIONAL_SETTINGS = {"weights": "str"} # type:ignore _PLUGIN_REQUIRED_DATA = {"X", "Y"} # type:ignore _PLUGIN_OPTIONAL_DATA = {"X_tst", 'Y_tst'} # type:ignore @@ -21,4 +20,4 @@ def __init__(self): Passes `globals` dict of all current variables """ super().__init__(globals()) - self.clf = model() \ No newline at end of file + self.model = model() \ No newline at end of file diff --git a/src/vai_lab/Modelling/plugins/knnregressor.py b/src/vai_lab/Modelling/plugins/knnregressor.py index 8be1bcc8..64982701 100644 --- a/src/vai_lab/Modelling/plugins/knnregressor.py +++ b/src/vai_lab/Modelling/plugins/knnregressor.py @@ -5,8 +5,7 @@ "KNN-R": "alias"} # type:ignore _PLUGIN_MODULE_OPTIONS = {"Type": "regression"} # type:ignore _PLUGIN_REQUIRED_SETTINGS = {} # type:ignore -_PLUGIN_OPTIONAL_SETTINGS = {"n_neighbors": "int", - "weights": "str"} # type:ignore +_PLUGIN_OPTIONAL_SETTINGS = {"weights": "str"} # type:ignore _PLUGIN_REQUIRED_DATA = {"X", "Y"} # type:ignore _PLUGIN_OPTIONAL_DATA = {"X_tst", 'Y_tst'} # type:ignore @@ -21,4 +20,4 @@ def __init__(self): Passes `globals` dict of all current variables """ super().__init__(globals()) - self.clf = model() \ No newline at end of file + self.model = model() \ No newline at end of file diff --git a/src/vai_lab/Modelling/plugins/lasso.py b/src/vai_lab/Modelling/plugins/lasso.py index 10ccbfed..bfb391b4 100644 --- a/src/vai_lab/Modelling/plugins/lasso.py +++ b/src/vai_lab/Modelling/plugins/lasso.py @@ -18,4 +18,4 @@ def __init__(self): Passes `globals` dict of all current variables """ super().__init__(globals()) - self.clf = model() \ No newline at end of file + self.model = model() \ No newline at end of file diff --git a/src/vai_lab/Modelling/plugins/linearregression.py b/src/vai_lab/Modelling/plugins/linearregression.py index c25523cf..25facd00 100644 --- a/src/vai_lab/Modelling/plugins/linearregression.py +++ b/src/vai_lab/Modelling/plugins/linearregression.py @@ -20,4 +20,4 @@ def __init__(self): Passes `globals` dict of all current variables """ super().__init__(globals()) - self.clf = model() \ No newline at end of file + self.model = model() \ No newline at end of file diff --git a/src/vai_lab/Modelling/plugins/logisticregression.py b/src/vai_lab/Modelling/plugins/logisticregression.py index 84390752..25eed6c2 100644 --- a/src/vai_lab/Modelling/plugins/logisticregression.py +++ b/src/vai_lab/Modelling/plugins/logisticregression.py @@ -6,7 +6,7 @@ "MaxEnt": "alias"} # type:ignore _PLUGIN_MODULE_OPTIONS = {"Type": "classification"} # type:ignore _PLUGIN_REQUIRED_SETTINGS = {} # type:ignore -_PLUGIN_OPTIONAL_SETTINGS = {"penalty": "str", "C": "float"} # type:ignore +_PLUGIN_OPTIONAL_SETTINGS = {} # type:ignore _PLUGIN_REQUIRED_DATA = {"X", "Y"} # type:ignore _PLUGIN_OPTIONAL_DATA = {"X_tst", 'Y_tst'} # type:ignore @@ -21,4 +21,4 @@ def __init__(self): Passes `globals` dict of all current variables """ super().__init__(globals()) - self.clf = model() \ No newline at end of file + self.model = model() \ No newline at end of file diff --git a/src/vai_lab/Modelling/plugins/meanshift.py b/src/vai_lab/Modelling/plugins/meanshift.py index d3d8c207..28d1327b 100644 --- a/src/vai_lab/Modelling/plugins/meanshift.py +++ b/src/vai_lab/Modelling/plugins/meanshift.py @@ -18,4 +18,4 @@ def __init__(self): Passes `globals` dict of all current variables """ super().__init__(globals()) - self.clf = model() \ No newline at end of file + self.model = model() \ No newline at end of file diff --git a/src/vai_lab/Modelling/plugins/passiveaggressiveclassifier.py b/src/vai_lab/Modelling/plugins/passiveaggressiveclassifier.py index 772ef0e6..dc01447a 100644 --- a/src/vai_lab/Modelling/plugins/passiveaggressiveclassifier.py +++ b/src/vai_lab/Modelling/plugins/passiveaggressiveclassifier.py @@ -20,4 +20,4 @@ def __init__(self): Passes `globals` dict of all current variables """ super().__init__(globals()) - self.clf = model() \ No newline at end of file + self.model = model() \ No newline at end of file diff --git a/src/vai_lab/Modelling/plugins/perceptron.py b/src/vai_lab/Modelling/plugins/perceptron.py index 63155166..93f35f40 100644 --- a/src/vai_lab/Modelling/plugins/perceptron.py +++ b/src/vai_lab/Modelling/plugins/perceptron.py @@ -21,4 +21,4 @@ def __init__(self): Passes `globals` dict of all current variables """ super().__init__(globals()) - self.clf = model() \ No newline at end of file + self.model = model() \ No newline at end of file diff --git a/src/vai_lab/Modelling/plugins/randomforestclassifier.py b/src/vai_lab/Modelling/plugins/randomforestclassifier.py index 39850f41..a8dc60cd 100644 --- a/src/vai_lab/Modelling/plugins/randomforestclassifier.py +++ b/src/vai_lab/Modelling/plugins/randomforestclassifier.py @@ -6,8 +6,7 @@ "RFC": "alias"} # type:ignore _PLUGIN_MODULE_OPTIONS = {"Type": "classification"} # type:ignore _PLUGIN_REQUIRED_SETTINGS = {} # type:ignore -_PLUGIN_OPTIONAL_SETTINGS = {"max_depth": "int", - "n_estimators": "int"} # type:ignore +_PLUGIN_OPTIONAL_SETTINGS = {} # type:ignore _PLUGIN_REQUIRED_DATA = {"X", "Y"} # type:ignore _PLUGIN_OPTIONAL_DATA = {"X_tst", 'Y_tst'} # type:ignore @@ -22,4 +21,4 @@ def __init__(self): Passes `globals` dict of all current variables """ super().__init__(globals()) - self.clf = model() \ No newline at end of file + self.model = model() \ No newline at end of file diff --git a/src/vai_lab/Modelling/plugins/randomforestregressor.py b/src/vai_lab/Modelling/plugins/randomforestregressor.py index bea9acbb..85cf9c10 100644 --- a/src/vai_lab/Modelling/plugins/randomforestregressor.py +++ b/src/vai_lab/Modelling/plugins/randomforestregressor.py @@ -6,8 +6,7 @@ "RFR": "alias"} # type:ignore _PLUGIN_MODULE_OPTIONS = {"Type": "regression"} # type:ignore _PLUGIN_REQUIRED_SETTINGS = {} # type:ignore -_PLUGIN_OPTIONAL_SETTINGS = {"max_depth": "int", - "n_estimators": "int"} # type:ignore +_PLUGIN_OPTIONAL_SETTINGS = {} # type:ignore _PLUGIN_REQUIRED_DATA = {"X", "Y"} # type:ignore _PLUGIN_OPTIONAL_DATA = {"X_tst", 'Y_tst'} # type:ignore @@ -22,4 +21,4 @@ def __init__(self): Passes `globals` dict of all current variables """ super().__init__(globals()) - self.clf = model() \ No newline at end of file + self.model = model() \ No newline at end of file diff --git a/src/vai_lab/Modelling/plugins/ridgeregression.py b/src/vai_lab/Modelling/plugins/ridgeregression.py index 6578c83d..0ee2f965 100644 --- a/src/vai_lab/Modelling/plugins/ridgeregression.py +++ b/src/vai_lab/Modelling/plugins/ridgeregression.py @@ -2,10 +2,10 @@ from sklearn.linear_model import Ridge as model _PLUGIN_READABLE_NAMES = {"RidgeRegression": "default", - "Ridge": "alias"} # type:ignore + "Ridge": "alias"} # type:ignore _PLUGIN_MODULE_OPTIONS = {"Type": "regression"} # type:ignore _PLUGIN_REQUIRED_SETTINGS = {} # type:ignore -_PLUGIN_OPTIONAL_SETTINGS = {"alpha": "float"} # type:ignore +_PLUGIN_OPTIONAL_SETTINGS = {} # type:ignore _PLUGIN_REQUIRED_DATA = {"X", "Y"} # type:ignore _PLUGIN_OPTIONAL_DATA = {"X_tst", 'Y_tst'} # type:ignore @@ -20,4 +20,4 @@ def __init__(self): Passes `globals` dict of all current variables """ super().__init__(globals()) - self.clf = model() + self.model = model() diff --git a/src/vai_lab/Modelling/plugins/svc.py b/src/vai_lab/Modelling/plugins/svc.py index 7862ae70..7fd4d700 100644 --- a/src/vai_lab/Modelling/plugins/svc.py +++ b/src/vai_lab/Modelling/plugins/svc.py @@ -23,4 +23,4 @@ def __init__(self): Passes `globals` dict of all current variables """ super().__init__(globals()) - self.clf = model() \ No newline at end of file + self.model = model() \ No newline at end of file diff --git a/src/vai_lab/Modelling/plugins/svr.py b/src/vai_lab/Modelling/plugins/svr.py index 9519bf7e..282fb622 100644 --- a/src/vai_lab/Modelling/plugins/svr.py +++ b/src/vai_lab/Modelling/plugins/svr.py @@ -23,4 +23,4 @@ def __init__(self): Passes `globals` dict of all current variables """ super().__init__(globals()) - self.clf = model() \ No newline at end of file + self.model = model() \ No newline at end of file diff --git a/src/vai_lab/UserInteraction/plugins/OptimisationInput.py b/src/vai_lab/UserInteraction/plugins/OptimisationInput.py new file mode 100644 index 00000000..d0027488 --- /dev/null +++ b/src/vai_lab/UserInteraction/plugins/OptimisationInput.py @@ -0,0 +1,301 @@ +from vai_lab._plugin_templates import UI +from vai_lab._import_helper import get_lib_parent_dir +from vai_lab._types import DictT, DataInterface, GUICoreInterface + +import os +import numpy as np +import pandas as pd +from typing import Tuple, List, Union +from PIL import Image, ImageTk, PngImagePlugin +import matplotlib.pyplot as plt +from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg + +import tkinter as tk +from tkinter import messagebox, ttk +from tkinter.filedialog import asksaveasfile + +_PLUGIN_READABLE_NAMES = {"OptimisationInput": "default", + "BOUI": "alias", + "optimisationUI": "alias"} # type:ignore +_PLUGIN_MODULE_OPTIONS = {"layer_priority": 2, + "required_children": None} # type:ignore +_PLUGIN_REQUIRED_SETTINGS = {} # type:ignore +_PLUGIN_OPTIONAL_SETTINGS = {"Bounds": "list"} # type:ignore +_PLUGIN_REQUIRED_DATA = {"X"} # type:ignore + + +class OptimisationInput(tk.Frame, UI): # type:ignore + """Method of user interaction for optimisation problems""" + + def __init__(self, parent, controller, config: DictT): + self.parent = parent + super().__init__(parent, bg=self.parent['bg']) + self.controller: GUICoreInterface = controller + self.controller.title('Optimisation Interaction') + + self.dirpath = get_lib_parent_dir() + self.tk.call('wm', 'iconphoto', self.controller._w, ImageTk.PhotoImage( + file=os.path.join(os.path.join( + self.dirpath, + 'utils', + 'resources', + 'Assets', + 'VAILabsIcon.ico')))) + + self.assets_path = os.path.join(self.dirpath, 'utils', 'resources', 'Assets') + + self._data_in: DataInterface + self._config = config + self.save_path = '' + self.saved = True + + + def _load_values_from_data(self): + + self.frame1 = tk.Frame(self, bg=self.parent['bg']) + frame4 = tk.Frame(self, bg=self.parent['bg']) + frame5 = tk.Frame(self, bg=self.parent['bg']) + # frame6 = tk.Frame(self, bg=self.parent['bg']) + + self.opt_var = list(self._data_in["X"].columns.values) + if len(self.opt_var) < 3: + figure = plt.Figure(figsize=(5, 4), dpi=100) + self.ax = figure.add_subplot(111) + + self.plot_points(self._data_in["X"], self.opt_var) + + self.canvas = FigureCanvasTkAgg(figure, self.frame1) + plot_frame = self.canvas.get_tk_widget() + plot_frame.grid(column=0, row=0, pady=10, padx=10, sticky="nsew") + self.frame1.grid_rowconfigure(0, weight=1) + self.frame1.grid_columnconfigure(0, weight=1) + + # Inital window + self.N = len(self._data_in["X"]) + + # Buttons + self.button_save = tk.Button( + frame4, text='Save', fg='white', bg=self.parent['bg'], height=3, + width=20, command=self.save_file) + self.button_save.grid(column=0, row=0, sticky="news", pady=2, padx=[2,0]) + + tk.Button( + frame5, text="Done", + fg='white', bg=self.parent['bg'], + height=3, width=20, + command=self.check_quit).grid(column=0, row=0, sticky="news", pady=2, padx=[0,2]) + + self.frame1.grid(row=0, column=0, sticky="nsew") + frame4.grid(row=1, column=0, sticky="nsew") + frame5.grid(row=1, column=1, sticky="nsew") + + frame4.grid_columnconfigure(0, weight=1) + frame5.grid_columnconfigure(0, weight=1) + self.grid_rowconfigure(0, weight=1) + self.grid_columnconfigure(tuple(range(2)), weight=1) + + def plot_points(self, data, labels, x=[None]): + """Plots points in pre-existing axis. If some extra points are given, + these are plotted with a different colour. + :param data: dict, dictionary to be plotted + :param labels: list, column and axis labels + :param x: array, extra points to be plotted + """ + + self.ax.clear() # clear axes from previous plot + self.ax.scatter(data[labels[0]], data[labels[1]]) + self.ax.set_xlabel(labels[0]) + self.ax.set_ylabel(labels[1]) + # self.ax.set_xlim(min(data[labels[0]]), max(data[labels[0]])) + # self.ax.set_ylim(min(data[labels[0]]), max(data[labels[0]])) + self.ax.set_title('Suggested points') + if None not in x: + self.ax.scatter(x[0], x[1], color='r') + + def set_data_in(self, data_in): + req_check = [ + r for r in _PLUGIN_REQUIRED_DATA if r not in data_in.keys()] + if len(req_check) > 0: + raise Exception("Minimal Data Requirements not met" + + "\n\t{0} ".format("OptimisationInput") + + "requires data: {0}".format(_PLUGIN_REQUIRED_DATA) + + "\n\tThe following data is missing:" + + "\n\t\u2022 {}".format(",\n\t\u2022 ".join([*req_check]))) + self._data_in = data_in + self._load_values_from_data() + self._load_classes_from_data() + + def class_list(self): + """Getter for required _class_list variable + + :return: list of class labels + :rtype: list of strings + """ + return self._class_list + + def _load_classes_from_data(self): + """Setter for required _class_list variable + + :param value: class labels for binary classification + :type value: list of strings + """ + self.out_data = self._data_in["X"] + + self.button_cl = {} + + # Tree defintion. Output display + style = ttk.Style() + style.configure( + "Treeview", background='white', foreground='white', + rowheight=25, fieldbackground='white', + font=self.controller.pages_font) + style.configure("Treeview.Heading", font=self.controller.pages_font) + style.map('Treeview', background=[('selected', 'grey')]) + + tree_frame = tk.Frame(self) + if len(self.opt_var) < 3: + tree_frame.grid(row=0, column=1, sticky="nsew", pady=10, padx=10) + else: + tree_frame.grid(row=0, column=0, columnspan = 2, sticky="nsew", pady=10, padx=10) + + tree_scrollx = tk.Scrollbar(tree_frame, orient='horizontal') + tree_scrollx.pack(side=tk.BOTTOM, fill=tk.X) + tree_scrolly = tk.Scrollbar(tree_frame) + tree_scrolly.pack(side=tk.RIGHT, fill=tk.Y) + + self.tree = ttk.Treeview(tree_frame, + yscrollcommand=tree_scrolly.set, + xscrollcommand=tree_scrollx.set) + self.tree.pack(fill='both', expand=True) + + tree_scrollx.config(command=self.tree.xview) + tree_scrolly.config(command=self.tree.yview) + + self.tree['columns'] = self.opt_var + + # Format columns + self.tree.column("#0", width=80, + minwidth=50) + for n, cl in enumerate(self.opt_var): + self.tree.column( + cl, width=int(self.controller.pages_font.measure(str(cl)))+20, + minwidth=50, anchor=tk.CENTER) + # Headings + self.tree.heading("#0", text="Sample", anchor=tk.CENTER) + for cl in self.opt_var: + self.tree.heading(cl, text=cl, anchor=tk.CENTER) + self.tree.tag_configure('odd', foreground='black', + background='#E8E8E8') + self.tree.tag_configure('even', foreground='black', + background='#DFDFDF') + # Add data + for n, sample in enumerate(self.out_data.values): + if n % 2 == 0: + self.tree.insert(parent='', index='end', iid=n, text=n+1, + values=tuple(sample.astype(float)), tags=('even',)) + else: + self.tree.insert(parent='', index='end', iid=n, text=n+1, + values=tuple(sample.astype(float)), tags=('odd',)) + + # Select the current row + self.tree.selection_set(str(int(0))) + + # Define click on row action + if len(self.opt_var) < 3: + self.tree.bind('', self.OnClick) + + # Define double-click on row action + self.tree.bind("", self.OnDoubleClick) + + def OnClick(self, event): + "Displays the corresponding ." + + item = self.tree.selection()[0] + x = [float(i) for i in self.tree.item(item)['values']] + if len(self.opt_var) < 3: + self.plot_points(self.out_data, self.opt_var, x = x) + self.canvas.draw() + + def OnDoubleClick(self, event): + """ Executed when a row is double clicked. + Opens an entry box to edit a cell and updates the plot and the + stored data. """ + + self.treerow = int(self.tree.identify_row(event.y)) + self.treecol = self.tree.identify_column(event.x) + + # get column position info + x, y, width, height = self.tree.bbox(self.treerow, self.treecol) + + # y-axis offset + pady = height // 2 + # pady = 0 + + if hasattr(self, 'entry'): + self.entry.destroy() + + self.entry = tk.Entry(self.tree, justify='center') + + if int(self.treecol[1:]) > 0: + self.entry.insert( + 0, self.tree.item(self.treerow)['values'][int(str(self.treecol[1:]))-1]) + self.entry['exportselection'] = False + + self.entry.focus_force() + self.entry.bind("", self.OnReturn) + self.entry.bind("", lambda *ignore: self.entry.destroy()) + + self.entry.place(x=x, + y=y + pady, + anchor=tk.W, width=width) + + def OnReturn(self, event): + """ Updates the stored data with the values in the entry. """ + val = self.tree.item(self.treerow)['values'] + val = [float(i) for i in val] + val[int(self.treecol[1:])-1] = float(self.entry.get()) + self.tree.item(self.treerow, values=val) + self.entry.destroy() + self.saved = False + + self.out_data.loc[self.treerow] = val + + self.OnClick(0) + self.saved = False + + def check_quit(self): + + if not self.saved: + response = messagebox.askokcancel( + "Exit?", + "Do you want to leave the program without saving?") + if response: + self.controller.destroy() + else: + response = messagebox.askokcancel( + "Exit?", + "Are you sure you are finished?") + self.controller.destroy() + + def save_file_as(self): + + self.save_path = asksaveasfile(mode='w') + self.save_file() + + def save_file(self): + + if self.save_path == '': + self.save_path = asksaveasfile(defaultextension='.txt', + filetypes=[('Text file', '.txt'), + ('CSV file', '.csv'), + ('All Files', '*.*')]) + # asksaveasfile return `None` if dialog closed with "cancel". + if self.save_path is not None: + filedata = pd.DataFrame( + self.out_data, columns=self.opt_var).to_string() + self.save_path.seek(0) # Move to the first row to overwrite it + self.save_path.write(filedata) + self.save_path.flush() # Save without closing + # typically the above line would do. however this is used to ensure that the file is written + os.fsync(self.save_path.fileno()) + self.saved = True \ No newline at end of file diff --git a/src/vai_lab/_import_helper.py b/src/vai_lab/_import_helper.py index e56d2118..3095ddbc 100644 --- a/src/vai_lab/_import_helper.py +++ b/src/vai_lab/_import_helper.py @@ -43,4 +43,25 @@ def rel_to_abs(filename: str) -> str: filename = path.normpath(path.join(get_lib_parent_dir(), filename)) elif filename[0] == "/" or (filename[0].isalpha() and filename[0].isupper()): filename = filename - return filename \ No newline at end of file + return filename + +def abs_to_rel(filename: str) -> str: + """Checks if path is relative or absolute + If absolute, converts path to relative if possible + If relative, returns itself + """ + if filename[0] == ".": + #Relative path + return filename + elif filename[0].isalpha() and filename[0] != get_lib_parent_dir()[0]: + #Different drive -> Absolute path + return filename + else: + #Same drive not relative + _folder = path.relpath(filename, get_lib_parent_dir()) + if _folder[:2] == '..': + # Absolute path + return filename + else: + # Relative path + return path.join('.',_folder) diff --git a/src/vai_lab/_plugin_templates.py b/src/vai_lab/_plugin_templates.py index 98d289d4..437ff74c 100644 --- a/src/vai_lab/_plugin_templates.py +++ b/src/vai_lab/_plugin_templates.py @@ -18,8 +18,8 @@ def __init__(self, plugin_globals: dict) -> None: """ self.X = None self.Y = None - self.X_tst = None - self.Y_tst = None + self.X_test = None + self.Y_test = None self._PLUGIN_READABLE_NAMES: dict self._PLUGIN_MODULE_OPTIONS: dict @@ -78,9 +78,8 @@ def _parse_config(self): """Parse incoming data and args, sets them as class variables""" self.X = np.array(self._get_data_if_exist(self._data_in, "X")) self.Y = np.array(self._get_data_if_exist(self._data_in, "Y")).ravel() - self.X_tst = self._get_data_if_exist(self._data_in, "X_test") - self.Y_tst = np.array(self._get_data_if_exist( - self._data_in, "Y_test")).ravel() + self.X_test = self._get_data_if_exist(self._data_in, "X_test") + self.Y_test = np.array(self._get_data_if_exist(self._data_in, "Y_test")).ravel() self._clean_options() def _get_data_if_exist(self, data_dict: dict, key: str, default=None): @@ -105,20 +104,33 @@ def _reshape(self, data, shape): def _parse_options_dict(self,options_dict:Dict): for key, val in options_dict.items(): - if type(val) == str and val.replace('.', '').replace(',', '').isnumeric(): - cleaned_opts = [] - for el in val.split(","): - val = float(el) - if val.is_integer(): - val = int(val) - cleaned_opts.append(val) - options_dict[key] = cleaned_opts - elif type(val) == str and val.lower() in ('y', 'yes', 't', 'true', 'on'): - options_dict[key] = True - elif type(val) == str and val.lower() in ('n', 'no', 'f', 'false', 'off'): - options_dict[key] = False - elif type(val) == str and val.lower() in ('none'): - options_dict[key] = None + if type(val) == str: + if val.replace('.', '').replace(',', '').isnumeric(): + cleaned_opts = [] + for el in val.split(","): + val = float(el) + if val.is_integer(): + val = int(val) + cleaned_opts.append(val) + options_dict[key] = cleaned_opts + elif val.lower() in ('yes', 'true'): + options_dict[key] = True + elif val.lower() in ('no', 'false'): + options_dict[key] = False + elif val.lower() in ('none'): + options_dict[key] = None + elif val == 'X': + options_dict[key] = self.X + elif val == 'Y': + options_dict[key] = self.Y + elif val == 'X_test': + options_dict[key] = self.X_test + elif val == 'Y_test': + options_dict[key] = self.Y_test + elif key.lower() == 'x': + options_dict[key] = self.X + elif key.lower() == 'y': + options_dict[key] = self.Y return options_dict def _clean_options(self): @@ -134,26 +146,26 @@ def _test(self, data: DataInterface) -> DataInterface: """ if self._PLUGIN_MODULE_OPTIONS['Type'] == 'classification': print('Training accuracy: %.2f%%' % - (self.score(self.X, self.Y)*100)) # type: ignore - if self.Y_tst is not None: + (self.score([self.X, self.Y])*100)) # type: ignore + if self.Y_test is not None: print('Test accuracy: %.2f%%' % - (self.score(self.X_tst, self.Y_tst)*100)) - if self.X_tst is not None: - data.append_data_column("Y_pred", self.predict(self.X_tst)) + (self.score([self.X_test, self.Y_test])*100)) + if self.X_test is not None: + data.append_data_column("Y_pred", self.predict([self.X_test])) return data elif self._PLUGIN_MODULE_OPTIONS['Type'] == 'regression': print('Training R2 score: %.3f' % - (self.score(self.X, self.Y))) # type: ignore - if self.Y_tst is not None: - print('Training R2 score: %.3f' % - (self.score(self.X_tst, self.Y_tst))) - if self.X_tst is not None: - data.append_data_column("Y_pred", self.predict(self.X_tst)) + (self.score([self.X, self.Y]))) # type: ignore + if self.Y_test is not None: + print('Test R2 score: %.3f' % + (self.score([self.X_test, self.Y_test]))) + if self.X_test is not None: + data.append_data_column("Y_pred", self.predict([self.X_test])) return data elif self._PLUGIN_MODULE_OPTIONS['Type'] == 'clustering': print('Clustering completed') - if self.X_tst is not None: - data.append_data_column("Y_pred", self.predict(self.X_tst)) + if self.X_test is not None: + data.append_data_column("Y_pred", self.predict([self.X_test])) return data else: return data @@ -183,60 +195,88 @@ def _clean_solver_options(self): _cleaned[key] = True elif val == 'False': _cleaned[key] = False + elif val == 'X': + _cleaned[key] = self.X + elif val == 'Y': + _cleaned[key] = self.Y + elif val == 'X_test': + _cleaned[key] = self.X_test + elif val == 'Y_test': + _cleaned[key] = self.Y_test + elif key.lower() == 'x': + _cleaned[key] = self.X + elif key.lower() == 'y': + _cleaned[key] = self.Y return _cleaned - - def fit(self): - cleaned_options = self._clean_solver_options() + + def init(self): + """Sends params to model""" + try: + self.model.set_params(**self._config["options"]) + except Exception as exc: + print('The plugin encountered an error on the parameters of ' + +str(list(self._PLUGIN_READABLE_NAMES.keys())[list(self._PLUGIN_READABLE_NAMES.values()).index('default')])+': '+str(exc)+'.') + raise + + def fit(self, options={}): try: - self.proc.set_params(**cleaned_options) + if type(self._clean_solver_options()) == list: + self.model.set_params(*self._clean_solver_options()) + else: + self.model.set_params(**self._clean_solver_options()) except Exception as exc: print('The plugin encountered an error on the parameters of ' +str(list(self._PLUGIN_READABLE_NAMES.keys())[list(self._PLUGIN_READABLE_NAMES.values()).index('default')])+': '+str(exc)+'.') raise try: - self.proc.fit(self.X) + if type(options) == list: + return self.model.fit(*options) + else: + return self.model.fit(**options) except Exception as exc: print('The plugin encountered an error when fitting ' +str(list(self._PLUGIN_READABLE_NAMES.keys())[list(self._PLUGIN_READABLE_NAMES.values()).index('default')])+': '+str(exc)+'.') raise - def transform(self, data: DataInterface) -> DataInterface: + def transform(self, options={}) -> DataInterface: try: - data.append_data_column("X", pd.DataFrame(self.proc.transform(self.X))) + if type(options) == list: + return pd.DataFrame(self.model.transform(*options)) + else: + return pd.DataFrame(self.model.transform(**options)) except Exception as exc: print('The plugin encountered an error when transforming the data with ' +str(list(self._PLUGIN_READABLE_NAMES.keys())[list(self._PLUGIN_READABLE_NAMES.values()).index('default')])+': '+str(exc)+'.') raise - if self.X_tst is not None: - try: - data.append_data_column("X_test", pd.DataFrame(self.proc.transform(self.X_tst))) - except Exception as exc: - print('The plugin encountered an error when transforming the data with ' - +str(list(self._PLUGIN_READABLE_NAMES.keys())[list(self._PLUGIN_READABLE_NAMES.values()).index('default')])+': '+str(exc)+'.') - raise return data class ModellingPluginT(PluginTemplate, ABC): def __init__(self, plugin_globals: dict) -> None: super().__init__(plugin_globals) - - def solve(self): - """Sends params to solver, then runs solver""" + + def init(self): + """Sends params to model""" try: - self.clf.set_params(**self._config["options"]) + self.model.set_params(**self._config["options"]) except Exception as exc: print('The plugin encountered an error on the parameters of ' +str(list(self._PLUGIN_READABLE_NAMES.keys())[list(self._PLUGIN_READABLE_NAMES.values()).index('default')])+': '+str(exc)+'.') raise + + def fit(self, options={}): + """Sends params to fit, then runs fit""" try: - self.clf.fit(self.X, self.Y) + if type(options) == list: + return self.model.fit(*options) + else: + return self.model.fit(**options) except Exception as exc: print('The plugin encountered an error when fitting ' +str(list(self._PLUGIN_READABLE_NAMES.keys())[list(self._PLUGIN_READABLE_NAMES.values()).index('default')])+': '+str(exc)+'.') raise - def predict(self, data): + def predict(self, options={}): """Uses fitted model to predict output of a given Y :param data: array-like or sparse matrix, shape (n_samples, n_features) Samples @@ -245,23 +285,27 @@ def predict(self, data): Returns predicted values. """ try: - return self.clf.predict(data) + if type(options) == list: + return self.model.predict(*options) + else: + return self.model.predict(**options) except Exception as exc: print('The plugin encountered an error when predicting with ' +str(list(self._PLUGIN_READABLE_NAMES.keys())[list(self._PLUGIN_READABLE_NAMES.values()).index('default')])+': '+str(exc)+'.') raise - def score(self, X, Y, sample_weight): + def score(self, options={}): """Return the coefficient of determination :param X : array-like of shape (n_samples, n_features) :param Y : array-like of shape (n_samples,) or (n_samples, n_outputs) - :param sample_weight : array-like of shape (n_samples,), default=None - Sample weights. :returns: score : float of ``self.predict(X)`` wrt. `y`. """ try: - return self.clf.score(X, Y, sample_weight) + if type(options) == list: + return self.model.score(*options) + else: + return self.model.score(**options) except Exception as exc: print('The plugin encountered an error when calculating the score with ' +str(list(self._PLUGIN_READABLE_NAMES.keys())[list(self._PLUGIN_READABLE_NAMES.values()).index('default')])+'.') @@ -271,7 +315,7 @@ class ModellingPluginTClass(ModellingPluginT, ABC): def __init__(self, plugin_globals: dict) -> None: super().__init__(plugin_globals) - def predict_proba(self, data): + def predict_proba(self, options={}): """Uses fitted model to predict the probability of the output of a given Y :param data: array-like or sparse matrix, shape (n_samples, n_features) Samples @@ -280,7 +324,10 @@ def predict_proba(self, data): Returns predicted values. """ try: - return self.clf.predict_proba(data) + if type(options) == list: + return self.model.predict_proba(*options) + else: + return self.model.predict_proba(**options) except Exception as exc: print('The plugin encountered an error when predicting the probability with ' +str(list(self._PLUGIN_READABLE_NAMES.keys())[list(self._PLUGIN_READABLE_NAMES.values()).index('default')])+'.') @@ -294,7 +341,10 @@ def configure(self, config: dict): """Extended from PluginTemplate.configure""" super().configure(config) try: - self.BO = self.model(**self._clean_options()) + if type(self._clean_options()) == list: + self.BO = self.model(*self._clean_options()) + else: + self.BO = self.model(**self._clean_options()) except Exception as exc: print('The plugin encountered an error on the parameters of ' +str(list(self._PLUGIN_READABLE_NAMES.keys())[list(self._PLUGIN_READABLE_NAMES.values()).index('default')])+'.') diff --git a/src/vai_lab/examples/crystalDesign/phasestability/.DS_Store b/src/vai_lab/examples/crystalDesign/phasestability/.DS_Store new file mode 100644 index 00000000..0084bcb0 Binary files /dev/null and b/src/vai_lab/examples/crystalDesign/phasestability/.DS_Store differ diff --git a/src/vai_lab/examples/crystalDesign/phasestability/CsFA/.DS_Store b/src/vai_lab/examples/crystalDesign/phasestability/CsFA/.DS_Store new file mode 100644 index 00000000..3bf150c7 Binary files /dev/null and b/src/vai_lab/examples/crystalDesign/phasestability/CsFA/.DS_Store differ diff --git a/src/vai_lab/examples/crystalDesign/phasestability/CsMA/.DS_Store b/src/vai_lab/examples/crystalDesign/phasestability/CsMA/.DS_Store new file mode 100644 index 00000000..5bdbc9ea Binary files /dev/null and b/src/vai_lab/examples/crystalDesign/phasestability/CsMA/.DS_Store differ diff --git a/src/vai_lab/examples/crystalDesign/phasestability/FAMA/.DS_Store b/src/vai_lab/examples/crystalDesign/phasestability/FAMA/.DS_Store new file mode 100644 index 00000000..9ffa3a33 Binary files /dev/null and b/src/vai_lab/examples/crystalDesign/phasestability/FAMA/.DS_Store differ diff --git a/src/vai_lab/examples/optimisation/X.csv b/src/vai_lab/examples/optimisation/X.csv new file mode 100644 index 00000000..94e55801 --- /dev/null +++ b/src/vai_lab/examples/optimisation/X.csv @@ -0,0 +1,29 @@ +CsPbI,MAPbI,FAPbI +0.75,0.25,0.0 +0.0,0.0,1.0 +0.5,0.5,0.0 +1.0,0.0,0.0 +0.5,0.0,0.5 +0.0,0.0,1.0 +0.5,0.5,0.0 +1.0,0.0,0.0 +0.0,0.25,0.75 +0.0,0.25,0.75 +0.25,0.0,0.75 +0.75,0.0,0.25 +0.0,0.5,0.5 +0.0,0.5,0.5 +0.25,0.25,0.5 +0.75,0.25,0.0 +0.0,0.75,0.25 +0.0,0.75,0.25 +0.25,0.5,0.25 +0.5,0.0,0.5 +0.25,0.25,0.5 +0.0,1.0,0.0 +0.25,0.75,0.0 +0.5,0.25,0.25 +0.25,0.5,0.25 +0.0,1.0,0.0 +0.75,0.0,0.25 +0.5,0.25,0.25 diff --git a/src/vai_lab/examples/optimisation/Y.csv b/src/vai_lab/examples/optimisation/Y.csv new file mode 100644 index 00000000..d9c602b5 --- /dev/null +++ b/src/vai_lab/examples/optimisation/Y.csv @@ -0,0 +1,29 @@ +merit +102012.10929790093 +480185.0038224852 +1167260.6222070677 +144556.1215368709 +239851.85597959475 +505656.58303249744 +1182235.6440845595 +351227.01607136783 +914155.5189523202 +981850.240125919 +144073.96913849353 +273687.1193117654 +1103449.8761320617 +1155542.3150353911 +284710.437804268 +128970.90311931062 +1300295.164428879 +1281399.3586171553 +446746.6993871238 +416656.76497605036 +602576.1724037583 +1267615.3317977716 +876981.2028616397 +925192.9561762783 +533060.412747728 +1274992.873707038 +253423.44691605103 +714734.685825648 diff --git a/src/vai_lab/examples/results/output.pkl b/src/vai_lab/examples/results/output.pkl index c89ae4a9..37f85b48 100644 Binary files a/src/vai_lab/examples/results/output.pkl and b/src/vai_lab/examples/results/output.pkl differ diff --git a/src/vai_lab/examples/xml_files/KNN-classification_demo.xml b/src/vai_lab/examples/xml_files/KNN-classification_demo.xml index e6bc3fb5..7dc49229 100644 --- a/src/vai_lab/examples/xml_files/KNN-classification_demo.xml +++ b/src/vai_lab/examples/xml_files/KNN-classification_demo.xml @@ -4,7 +4,7 @@ - [(350.0,50),0,{}] + [(350.0, 50), 0, {}] @@ -14,32 +14,66 @@ + + + - [(350.0,350.0),2,{0:'d0-u2'}] + [(350.0, 350.0), 2, {0: 'd0-u2'}] - - - - + + + + distance + + + 7 + + + + + + X + + + Y + + + + + + + X_test + + + Y_test + + + + + + + - [(350.0,650),1,{2:'d2-u1'}] + [(350.0, 650), 1, {2: 'd2-u1'}] - - Modelling - - - .\examples\results\output.pkl - + + + Modelling + + + .\examples\results\output.pkl + + diff --git a/src/vai_lab/examples/xml_files/SVR_demo.xml b/src/vai_lab/examples/xml_files/SVR_demo.xml index f94fa31e..2e8ceb8e 100644 --- a/src/vai_lab/examples/xml_files/SVR_demo.xml +++ b/src/vai_lab/examples/xml_files/SVR_demo.xml @@ -4,7 +4,7 @@ - [(350.0,50),0,{}] + [(350.0, 50), 0, {}] @@ -14,39 +14,69 @@ + + + - [(350.0,350.0),2,{0:'d0-u2'}] + [(350.0, 350.0), 2, {0: 'd0-u2'}] - - 0.1 - - - linear - + + + 0.01 + + + linear + + + + + + X + + + Y + + + + + + + X + + + Y + + + + + + - [(350.0,650),1,{2:'d2-u1'}] + [(350.0, 650), 1, {2: 'd2-u1'}] - - Modelling - - - .\examples\results\output.pkl - + + + Modelling + + + .\examples\results\output.pkl + + diff --git a/src/vai_lab/examples/xml_files/canvas_demo.xml b/src/vai_lab/examples/xml_files/canvas_demo.xml index d9d4859d..afd2827b 100644 --- a/src/vai_lab/examples/xml_files/canvas_demo.xml +++ b/src/vai_lab/examples/xml_files/canvas_demo.xml @@ -1,44 +1,49 @@ - - - - + - [(350.0,50),0,{}] + [(350.0, 50), 0, {}] + + + - - - - + + + + - + - [(350.0,350.0),2,{0:'d0-u2'}] - + [(350.0, 350.0), 2, {0: 'd0-u2'}] + + + + + + - + - - - - - My First UserFeedback Module - - - .\examples\results\output.pkl - - - + - [(350.0,650),1,{2:'d2-u1'}] + [(350.0, 650), 1, {2: 'd2-u1'}] + + + + User Interaction + + + .\examples\results\output.pkl + + + diff --git a/src/vai_lab/examples/xml_files/k-mean_clustering_demo.xml b/src/vai_lab/examples/xml_files/k-mean_clustering_demo.xml index 658fc9db..997202a7 100644 --- a/src/vai_lab/examples/xml_files/k-mean_clustering_demo.xml +++ b/src/vai_lab/examples/xml_files/k-mean_clustering_demo.xml @@ -4,7 +4,7 @@ - [(350.0,50),0,{}] + [(350.0, 50), 0, {}] @@ -12,49 +12,86 @@ + + + - [(350,200),2,{0:'d0-u2'}] + [(227, 254), 2, {0: 'd0-u2'}] - - - - - X - + + + (0, 1) + + + + + + X + + + + + + - [(350.0,350.0),3,{2:'d2-u3'}] + [(474, 412), 3, {2: 'd2-u3'}] - - - - + + + + 4 + + + 500 + + + + + + X + + + + + + + X + + + + + + + - [(350.0,650),1,{3:'d3-u1'}] + [(350.0, 650), 1, {3: 'd3-u1'}] - - Modelling - - - .\examples\results\output.pkl - + + + Modelling + + + .\examples\results\output.pkl + + diff --git a/src/vai_lab/examples/xml_files/pybullet_env_example.xml b/src/vai_lab/examples/xml_files/pybullet_env_example.xml index 53629908..41e8a999 100644 --- a/src/vai_lab/examples/xml_files/pybullet_env_example.xml +++ b/src/vai_lab/examples/xml_files/pybullet_env_example.xml @@ -15,25 +15,27 @@ [(350.0,350.0),2,{0:'d0-u2'}] - - - plane.urdf - ./Environment/resources/models/half_cheetah_with_mass.xml - - - False - - - 0.0 - 0.0 - -9.81 - - - 0.01 - - - 10 - + + + + plane.urdf + ./Environment/resources/models/half_cheetah_with_mass.xml + + + False + + + 0.0 + 0.0 + -9.81 + + + 0.01 + + + 10 + + @@ -41,12 +43,14 @@ - - MyEnv - - - .\examples\results\output.pkl - + + + MyEnv + + + .\examples\results\output.pkl + + diff --git a/src/vai_lab/examples/xml_files/random_forest_class_demo.xml b/src/vai_lab/examples/xml_files/random_forest_class_demo.xml index 42d4e203..f3f6aee6 100644 --- a/src/vai_lab/examples/xml_files/random_forest_class_demo.xml +++ b/src/vai_lab/examples/xml_files/random_forest_class_demo.xml @@ -1,62 +1,114 @@ - + - - - - - + + [(350.0, 50), 0, {}] + + + + + + [(350.0,50),0,{}] - - + + + + - + + + [(349, 207), 2, {0: 'd0-u2'}] + - - X - + + + + + X + + + + + + + X + + + - - + + + + - + - - - - - [(350.0,350.0),2,{0:'d0-u2'}] + [(349, 419), 3, {2: 'd2-u3'}] + + + + 50 + + + 500 + + + + + + X + + + Y + + + + + + + X + + + Y + + + + - - - [(350.0,650),1,{2:'d2-u1'}] - + + + - + + + [(350.0, 650), 1, {3: 'd3-u1'}] + - - My Modelling Module - - - .\examples\results\output.pkl - + + + Modelling + + + .\examples\results\output.pkl + + diff --git a/src/vai_lab/examples/xml_files/regression_demo.xml b/src/vai_lab/examples/xml_files/regression_demo.xml deleted file mode 100644 index 38faeca5..00000000 --- a/src/vai_lab/examples/xml_files/regression_demo.xml +++ /dev/null @@ -1,44 +0,0 @@ - - - - - - - [(350.0,50),0,{}] - - - - - - - - - - - - - [(350.0,350.0),2,{0:'d0-u2'}] - - - - - - - - - - - Modelling - - - .\examples\results\output.pkl - - - - - - - [(350.0,650),1,{2:'d2-u1'}] - - - diff --git a/src/vai_lab/examples/xml_files/ridge-scalar-ridge_demo.xml b/src/vai_lab/examples/xml_files/ridge-scalar-ridge_demo.xml index c98d1191..56aa8de1 100644 --- a/src/vai_lab/examples/xml_files/ridge-scalar-ridge_demo.xml +++ b/src/vai_lab/examples/xml_files/ridge-scalar-ridge_demo.xml @@ -4,7 +4,7 @@ - [(350.0,50),0,{}] + [(350.0, 50), 0, {}] @@ -14,71 +14,137 @@ + + + - [(350,180),2,{0:'d0-u2'}] + [(178, 218), 2, {0: 'd0-u2'}] - - 0.01 - + + + 1e-3 + + + + + + X + + + Y + + + + + + + X + + + Y + + + + + + - [(350.0,350.0),3,{2:'d2-u3'}] + [(450, 219), 3, {2: 'r2-l3'}] - - X - + + + + + X + + + + + + + X + + + + + + - [(350,408),4,{3:'d3-u4'}] + [(350.0, 350.0), 4, {3: 'd3-u4'}] - - 0.01 - + + + + + X + + + Y + + + + + + + X + + + Y + + + + + + - [(350.0,650),1,{4:'d4-u1'}] + [(350.0, 650), 1, {4: 'd4-u1'}] - - Modelling - Modelling-1 - - - .\examples\results\output.pkl - + + + Modelling + Modelling-1 + + + .\examples\results\output.pkl + + diff --git a/src/vai_lab/examples/xml_files/ridge_regression_demo.xml b/src/vai_lab/examples/xml_files/ridge_regression_demo.xml deleted file mode 100644 index 58e182f7..00000000 --- a/src/vai_lab/examples/xml_files/ridge_regression_demo.xml +++ /dev/null @@ -1,49 +0,0 @@ - - - - - - - [(350.0,50),0,{}] - - - - - - - - - - - - - - - [(350.0,350.0),2,{0:'d0-u2'}] - - - - 0.01 - - - - - - - - - - - - [(350.0,650),1,{2:'d2-u1'}] - - - - Modelling - - - .\examples\results\output.pkl - - - - diff --git a/src/vai_lab/examples/xml_files/scalar_demo.xml b/src/vai_lab/examples/xml_files/scalar_demo.xml deleted file mode 100644 index 5b2b1d9a..00000000 --- a/src/vai_lab/examples/xml_files/scalar_demo.xml +++ /dev/null @@ -1,52 +0,0 @@ - - - - - - - [(350.0, 50), 0, {}] - - - - - - - - - - - - - - - - - - [(350.0, 350.0), 2, {0: 'd0-u2'}] - - - - X - - - - - - - - - - - - [(350.0, 650), 1, {2: 'd2-u1'}] - - - - Data Processing - - - .\examples\results\output.pkl - - - - diff --git a/src/vai_lab/examples/xml_files/scaler-lasso_demo.xml b/src/vai_lab/examples/xml_files/scaler-lasso_demo.xml deleted file mode 100644 index a05d1bfa..00000000 --- a/src/vai_lab/examples/xml_files/scaler-lasso_demo.xml +++ /dev/null @@ -1,67 +0,0 @@ - - - - - - - [(350.0,50),0,{}] - - - - - - - - - - - - - - - [(350,183),2,{0:'d0-u2'}] - - - - - - - X - - - - - - - - - - [(350,328),3,{2:'d2-u3'}] - - - - - - - 0.2 - - - - - - - - - [(350.0,650),1,{3:'d3-u1'}] - - - - Data Processing - Modelling - - - ./examples/results/test.pkl - - - - diff --git a/src/vai_lab/examples/xml_files/user_feedback_demo2.xml b/src/vai_lab/examples/xml_files/user_feedback_demo.xml similarity index 79% rename from src/vai_lab/examples/xml_files/user_feedback_demo2.xml rename to src/vai_lab/examples/xml_files/user_feedback_demo.xml index bfe76cc7..8fe8f048 100644 --- a/src/vai_lab/examples/xml_files/user_feedback_demo2.xml +++ b/src/vai_lab/examples/xml_files/user_feedback_demo.xml @@ -1,15 +1,15 @@ - - - - [(350.0, 50), 0, {}] + + + + @@ -22,7 +22,8 @@ [(350.0, 350.0), 2, {0: 'd0-u2'}] - + + @@ -35,12 +36,14 @@ [(350.0, 650), 1, {2: 'd2-u1'}] - - User Interaction - - - .\examples\results\output.pkl - + + + User Interaction + + + .\examples\results\output.pkl + + diff --git a/src/vai_lab/run_pipeline.py b/src/vai_lab/run_pipeline.py index a11c9620..a2c0537e 100644 --- a/src/vai_lab/run_pipeline.py +++ b/src/vai_lab/run_pipeline.py @@ -9,7 +9,6 @@ import vai_lab as ai - def parse_args(): """ Parse command line arguments @@ -57,10 +56,8 @@ def main(): for i in range(0,len(args.file)): args.file[i] = abspath(args.file[i]) core.load_config_file(args.file) - # Run pipeline core.run() if __name__=='__main__': - main() \ No newline at end of file diff --git a/src/vai_lab/utils/plugins/MainPage.py b/src/vai_lab/utils/plugins/MainPage.py index e9b5dabe..ca78db04 100644 --- a/src/vai_lab/utils/plugins/MainPage.py +++ b/src/vai_lab/utils/plugins/MainPage.py @@ -5,9 +5,8 @@ from tkinter.filedialog import askopenfilename, askdirectory import pandas as pd -from vai_lab.Data.xml_handler import XML_handler from vai_lab.utils.plugins.dataLoader import dataLoader -from vai_lab._import_helper import get_lib_parent_dir, rel_to_abs +from vai_lab._import_helper import get_lib_parent_dir, abs_to_rel _PLUGIN_READABLE_NAMES = {"main": "default", "main page": "alias", @@ -217,6 +216,9 @@ def upload_xml(self): ('All Files', '*.*')]) if filename is not None and len(filename) > 0: self.controller._append_to_output("xml_filename", filename) + self.controller.xml_handler.filename = filename + self.controller.xml_handler.load_XML(filename) + self.controller.xml_handler.write_to_XML() self.controller.XML.set(True) def upload_data_file(self): @@ -299,17 +301,6 @@ def start_dataloader(self): """ Reads all the selected files, loads the data and passes it to dataLoader. """ - - # s = XML_handler() - # s.new_config_file(self.save_path.name) - # s.filename = self.save_path.name - - # self.s = XML_handler() - # self.s.new_config_file() - - # self.s._print_xml_config() - # self.s.load_XML(self.controller.output["xml_filename"]) - data = {} isVar = [0] * len(self.var) if len(self.label_list[0].cget("text")) > 0: @@ -320,7 +311,7 @@ def start_dataloader(self): # Infers by default, should it be None? data[variable] = pd.read_csv(filename) isVar[i] = 1 - self.controller.xml_handler.append_input_data(variable, rel_to_abs(filename)) + self.controller.xml_handler.append_input_data(variable, abs_to_rel(filename)) if i == 0: self.controller.Data.set(True) if any(isVar[1::2]) and ( @@ -330,6 +321,9 @@ def start_dataloader(self): for i in data[variable].to_numpy().flatten()]) if not any(isVar[1::2]): self.controller.output_type = 'unsupervised' + if hasattr(self.controller.xml_handler, 'filename'): + self.controller.xml_handler._parse_XML() + self.controller.xml_handler.write_to_XML() self.newWindow.destroy() dataLoader(self.controller, data) else: @@ -343,7 +337,7 @@ def upload_data_path(self): title='Select a folder', mustexist=True) if folder is not None and len(folder) > 0: - self.controller.xml_handler.append_input_data('X', rel_to_abs(folder)) + self.controller.xml_handler.append_input_data('X', abs_to_rel(folder)) def upload_data_folder(self): """ Stores the directory containing the data that will be later loaded diff --git a/src/vai_lab/utils/plugins/aidCanvas.py b/src/vai_lab/utils/plugins/aidCanvas.py index 04171f0d..8e4d03a9 100644 --- a/src/vai_lab/utils/plugins/aidCanvas.py +++ b/src/vai_lab/utils/plugins/aidCanvas.py @@ -1,4 +1,3 @@ -from vai_lab.Data.xml_handler import XML_handler from vai_lab._types import DictT import os @@ -762,10 +761,8 @@ def upload(self): if filename is not None and len(filename) > 0: self.reset() - s = XML_handler() - s.load_XML(filename) - # s._print_pretty(s.loaded_modules) - modules = s.loaded_modules + self.controller.xml_handler.load_XML(filename) + modules = self.controller.xml_handler.loaded_modules modout = modules['Output'] del modules['Initialiser'], modules['Output'] # They are generated when resetting disp_mod = ['Initialiser', 'Output'] @@ -897,13 +894,14 @@ def reset(self): if hasattr(self, 'entry2'): self.entry2.destroy() + self.controller.xml_handler.new_config_file() + self.canvas_startxy = [] self.out_data = pd.DataFrame() self.connections = {} self.modules = 0 self.module_list = [] self.module_names = [] - self.add_module('Initialiser', self.width/2, self.h, ini=True) self.add_module('Output', self.width/2, self.height - self.h, out=True) diff --git a/src/vai_lab/utils/plugins/pluginCanvas.py b/src/vai_lab/utils/plugins/pluginCanvas.py index 5c1e56c3..62dd8e54 100644 --- a/src/vai_lab/utils/plugins/pluginCanvas.py +++ b/src/vai_lab/utils/plugins/pluginCanvas.py @@ -1,10 +1,12 @@ from vai_lab.Data.xml_handler import XML_handler from vai_lab._plugin_helpers import PluginSpecs -from vai_lab._import_helper import get_lib_parent_dir +from vai_lab._import_helper import get_lib_parent_dir, import_plugin_absolute import os import numpy as np import pandas as pd +from inspect import getmembers, isfunction, ismethod, getfullargspec +from functools import reduce from typing import Dict, List from PIL import Image, ImageTk @@ -31,7 +33,6 @@ def __init__(self, parent, controller, config: dict): self.bg = parent['bg'] self.parent = parent self.controller = controller - self.s = XML_handler() self.m: int self.w, self.h = 100, 50 self.cr = 4 @@ -300,7 +301,8 @@ def check_updated(self): self.id_done.append(self.m) self.xml_handler.append_plugin_to_module( self.plugin[self.m].get(), - {**self.req_settings, **self.opt_settings}, + self.merge_dicts(self.req_settings, self.opt_settings), + self.meths_sort, self.plugin_inputData.get(), np.array(self.module_names)[self.m == np.array(self.id_mod)][0], True) @@ -311,19 +313,35 @@ def check_updated(self): self.id_done.append(self.m) if os.path.normpath(get_lib_parent_dir()) == os.path.normpath(os.path.commonpath([self.path_out, get_lib_parent_dir()])): - rel_path = os.path.join('.', os.path.relpath(self.path_out, + rel_path = os.path.join('.', os.path.relpath(self.path_out, os.path.commonpath([self.path_out, get_lib_parent_dir()]))) else: rel_path = self.path_out self.xml_handler.append_plugin_to_module( 'Output', - {'outdata': self.out_data_xml, - 'outpath': rel_path}, + {'__init__': {'outdata': self.out_data_xml, + 'outpath': rel_path}}, + [], None, np.array(self.module_names)[self.m == np.array(self.id_mod)][0], True) + def merge_dicts(self, a, b, path = None): + "merges b into a" + if path is None: path = [] + for key in b: + if key in a: + if isinstance(a[key], dict) and isinstance(b[key], dict): + self.merge_dicts(a[key], b[key], path + [str(key)]) + elif a[key] == b[key]: + pass # same leaf value + else: + raise Exception('Conflict at %s' % '.'.join(path + [str(key)])) + else: + a[key] = b[key] + return a + def display_buttons(self): """ Updates the displayed radiobuttons and the description windows. It loads the information corresponding to the selected module (self.m) @@ -477,85 +495,194 @@ def optionsWindow(self): module = np.array(self.module_list)[self.m == np.array(self.id_mod)][0] ps = PluginSpecs() - self.opt_settings = ps.optional_settings[module][self.plugin[self.m].get( - )] - self.req_settings = ps.required_settings[module][self.plugin[self.m].get( - )] - if (len(self.opt_settings) != 0) or (len(self.req_settings) != 0): - if hasattr(self, 'newWindow') and (self.newWindow != None): - self.newWindow.destroy() - self.newWindow = tk.Toplevel(self.controller) - # Window options - self.newWindow.title(self.plugin[self.m].get()+' plugin options') - script_dir = get_lib_parent_dir() - self.tk.call('wm', 'iconphoto', self.newWindow, ImageTk.PhotoImage( - file=os.path.join(os.path.join( - script_dir, - 'utils', - 'resources', - 'Assets', - 'VAILabsIcon.ico')))) - # self.newWindow.geometry("350x400") - - frame1 = tk.Frame(self.newWindow) - frame4 = tk.Frame(self.newWindow) - - # Print settings - tk.Label(frame1, - text="Please indicate your desired options for the "+self.plugin[self.m].get()+" plugin.", anchor=tk.N, justify=tk.LEFT).pack(expand=True) - - style = ttk.Style() - style.configure( - "Treeview", background='white', foreground='white', - rowheight=25, fieldbackground='white', - # font=self.controller.pages_font - ) - style.configure("Treeview.Heading", + file_name = os.path.split(ps.find_from_class_name(self.plugin[self.m].get())['_PLUGIN_DIR'])[-1] + avail_plugins = ps.available_plugins[module][file_name] + plugin = import_plugin_absolute(globals(), + avail_plugins["_PLUGIN_PACKAGE"], + avail_plugins["_PLUGIN_CLASS_NAME"]) + # Update required and optional settings for the plugin + self.req_settings = {'__init__': ps.required_settings[module][self.plugin[self.m].get()]} + self.opt_settings = {'__init__': ps.optional_settings[module][self.plugin[self.m].get()]} + # Tries to upload the settings from the actual library + try: + self.model = plugin().model + meth_req, meth_opt = self.getArgs(self.model.__init__) + if meth_req is not None: + self.req_settings['__init__'] = {**self.req_settings['__init__'], **meth_req} + if meth_opt is not None: + self.opt_settings['__init__'] = {**self.opt_settings['__init__'], **meth_opt} + # Find functions defined for the module + plugin_meth_list = [meth[0] for meth in getmembers(plugin, isfunction) if meth[0][0] != '_'] + # Find available methods for the model + model_meth_list = [meth[0] for meth in getmembers(self.model, ismethod) if meth[0][0] != '_'] + # List intersection + # TODO: use only methods from the model + set_2 = frozenset(model_meth_list) + meth_list = [x for x in plugin_meth_list if x in set_2] + except Exception as exc: + meth_list = [] + + + self.meths_sort = [] + self.method_inputData = {} + self.default_inputData = {} + + if hasattr(self, 'newWindow') and (self.newWindow != None): + self.newWindow.destroy() + self.newWindow = tk.Toplevel(self.controller) + # Window options + self.newWindow.title(self.plugin[self.m].get()+' plugin options') + script_dir = get_lib_parent_dir() + self.tk.call('wm', 'iconphoto', self.newWindow, ImageTk.PhotoImage( + file=os.path.join(os.path.join( + script_dir, + 'utils', + 'resources', + 'Assets', + 'VAILabsIcon.ico')))) + # self.newWindow.geometry("350x400") + + + frame1 = tk.Frame(self.newWindow) + frame2 = tk.Frame(self.newWindow) + frame3 = tk.Frame(self.newWindow) + frame4 = tk.Frame(self.newWindow) + frame5 = tk.Frame(self.newWindow) + frame6 = tk.Frame(self.newWindow, highlightbackground="black", highlightthickness=1) + + # Print settings + tk.Label(frame1, + text="Please indicate your desired options for the plugin.", anchor=tk.N, justify=tk.LEFT).pack(expand=True) + + style = ttk.Style() + style.configure( + "Treeview", background='white', foreground='white', + rowheight=25, fieldbackground='white', # font=self.controller.pages_font ) - style.map('Treeview', background=[('selected', 'grey')]) + style.configure("Treeview.Heading", + # font=self.controller.pages_font + ) + style.map('Treeview', background=[('selected', 'grey')]) - frame2 = tk.Frame(self.newWindow, bg='green') - self.r = 1 - self.create_treeView(frame2) - self.fill_treeview(self.req_settings, self.opt_settings) - frame2.grid(column=0, row=1, sticky="nswe", pady=10, padx=10) - - frame2.grid_rowconfigure(tuple(range(self.r)), weight=1) - frame2.grid_columnconfigure(tuple(range(2)), weight=1) + # Show method selection if there is any + if len(meth_list) > 0: + frame21 = tk.Frame(self.newWindow) + frameButt = tk.Frame(frame3) + frameDrop = tk.Frame(frame3, highlightbackground="black", highlightthickness=1) - frame5 = tk.Frame(self.newWindow) - tk.Label(frame5, - text="Indicate which plugin's output data should be used as input", anchor=tk.N, justify=tk.LEFT).pack(expand=True) - - frame6 = tk.Frame(self.newWindow, highlightbackground="black", highlightthickness=1) + tk.Label(frame21, + text="Add your desired methods in your required order.", anchor=tk.N, justify=tk.LEFT).pack(expand=True) - current = np.where(self.m == np.array(self.id_mod))[0][0] - dataSources = [i for j, i in enumerate(self.module_names) if j not in [1,current]] - - self.plugin_inputData = tk.StringVar(frame6) - dropDown = tk.ttk.OptionMenu(frame6, self.plugin_inputData, dataSources[current-2], *dataSources) + self.meth2add = tk.StringVar(frameDrop) + dropDown = tk.ttk.OptionMenu(frameDrop, self.meth2add, meth_list[0], *meth_list) style.configure("TMenubutton", background="white") dropDown["menu"].configure(bg="white") - dropDown.pack() + dropDown.grid(row=0,column=0) + + tk.Button(frameButt, text='Add', command=self.addMeth).grid(row=0,column=0) + tk.Button(frameButt, text='Delete', command=self.deleteMeth).grid(row=0,column=1) + tk.Button(frameButt, text='Up', command=lambda: self.moveMeth(-1)).grid(row=0,column=2) + tk.Button(frameButt, text='Down', command=lambda: self.moveMeth(+1)).grid(row=0,column=3) + + frame21.grid(column=0, row=2, sticky="ew") + frameButt.grid(column=1, row=0, sticky="w") + frameDrop.grid(column=0, row=0) - self.finishButton = tk.Button( - frame4, text='Finish', command=self.removewindow) - self.finishButton.grid( - column=1, row=0, sticky="es", pady=(0, 10), padx=(0, 10)) - self.finishButton.bind( - "", lambda event: self.removewindow()) - self.newWindow.protocol('WM_DELETE_WINDOW', self.removewindow) + self.r = 1 + self.tree = self.create_treeView(frame2, ['Name', 'Value']) + self.tree.insert(parent='', index='end', iid='__init__', text='', values=tuple(['__init__', '']), + tags=('meth','__init__')) + self.fill_treeview(self.req_settings['__init__'], self.opt_settings['__init__'], '__init__') - frame1.grid(column=0, row=0, sticky="ew") - frame4.grid(column=0, row=20, sticky="se") - frame5.grid(column=0, row=2, sticky="ew") - frame6.grid(column=0, row=3) + tk.Label(frame5, + text="Indicate which plugin's output data should be used as input", anchor=tk.N, justify=tk.LEFT).pack(expand=True) + + current = np.where(self.m == np.array(self.id_mod))[0][0] + dataSources = [i for j, i in enumerate(self.module_names) if j not in [1,current]] + + self.plugin_inputData = tk.StringVar(frame6) + dropDown = tk.ttk.OptionMenu(frame6, self.plugin_inputData, dataSources[current-2], *dataSources) + style.configure("TMenubutton", background="white") + dropDown["menu"].configure(bg="white") + dropDown.pack() + + self.finishButton = tk.Button( + frame4, text='Finish', command=self.removewindow) + self.finishButton.grid( + column=1, row=0, sticky="es", pady=(0, 10), padx=(0, 10)) + self.finishButton.bind( + "", lambda event: self.removewindow()) + self.newWindow.protocol('WM_DELETE_WINDOW', self.removewindow) + + frame1.grid(column=0, row=0, sticky="ew") + frame2.grid(column=0, row=1, sticky="nswe", pady=10, padx=10) + frame3.grid(column=0, row=3, pady=10, padx=10) + frame4.grid(column=0, row=20, sticky="se") + frame5.grid(column=0, row=4, sticky="ew") + frame6.grid(column=0, row=5) + + frame2.grid_rowconfigure(tuple(range(self.r)), weight=1) + frame2.grid_columnconfigure(tuple(range(2)), weight=1) + self.newWindow.grid_rowconfigure(1, weight=2) + self.newWindow.grid_columnconfigure(tuple(range(2)), weight=1) + + def getArgs(self, f): + """ Get required and optional arguments from method. - self.newWindow.grid_rowconfigure(1, weight=2) - self.newWindow.grid_columnconfigure(0, weight=1) + Parameters + ---------- + f : method + method to extract arguments from - def create_treeView(self, tree_frame): + :returns out: two dictionaries with arguments and default value (if optional) + """ + + meth_args = getfullargspec(f).args + if meth_args is not None: + meth_args.remove('self') + meth_def = getfullargspec(f).defaults + if meth_def is None: + meth_def = [] + meth_req = {p: '' for p in meth_args[:(len(meth_args)-len(meth_def))]} + meth_r_opt = {p: meth_def[i] for i,p in enumerate(meth_args[(len(meth_args)-len(meth_def)):])} + + meth_opt = getfullargspec(f).kwonlydefaults + if meth_opt is not None: + meth_opt = {p: meth_opt[p] for p in meth_opt} + if meth_r_opt is not None: + meth_opt = {**meth_r_opt, **meth_opt} + return meth_req, meth_opt + else: + return meth_req, meth_r_opt + + def addMeth(self): + """ Adds selected method in dropdown menu to the plugin tree """ + meth = self.meth2add.get() + self.meths_sort.append(meth) + self.tree.insert(parent='', index='end', iid=meth, text='', values=tuple([meth, '']), + tags=('meth',meth)) + # TODO: Remove X and y? + self.req_settings[meth], self.opt_settings[meth] = self.getArgs(getattr(self.model, meth)) + self.fill_treeview(self.req_settings[meth], self.opt_settings[meth], meth) + + def deleteMeth(self): + """ Deletes selected method in dropdown menu from the plugin tree """ + meth = self.meth2add.get() + if meth in self.meths_sort: + self.meths_sort.remove(meth) + del self.req_settings[meth] + del self.opt_settings[meth] + self.tree.delete(meth) + + def moveMeth(self, m): + meth = self.meth2add.get() + if meth in self.meths_sort and self.tree.index(meth)+m > 0: + idx = self.meths_sort.index(meth) + self.meths_sort.insert(idx+m, self.meths_sort.pop(idx)) + self.tree.move(meth, self.tree.parent(meth), self.tree.index(meth)+m) + + def create_treeView(self, tree_frame, columns_names): """ Function to create a new tree view in the given frame Parameters @@ -571,37 +698,37 @@ def create_treeView(self, tree_frame): tree_scrolly = tk.Scrollbar(tree_frame) tree_scrolly.pack(side=tk.RIGHT, fill=tk.Y) - self.tree = ttk.Treeview(tree_frame, + tree = ttk.Treeview(tree_frame, yscrollcommand=tree_scrolly.set, xscrollcommand=tree_scrollx.set) - self.tree.pack(fill='both', expand=True) + tree.pack(fill='both', expand=True) - tree_scrollx.config(command=self.tree.xview) - tree_scrolly.config(command=self.tree.yview) + tree_scrollx.config(command=tree.xview) + tree_scrolly.config(command=tree.yview) - columns_names = ['Name', 'Type', 'Value'] - self.tree['columns'] = columns_names + tree['columns'] = columns_names # Format columns - self.tree.column("#0", width=20, + tree.column("#0", width=40, minwidth=0, stretch=tk.NO) for n, cl in enumerate(columns_names): - self.tree.column( + tree.column( cl, width=int(self.controller.pages_font.measure(str(cl)))+20, minwidth=50, anchor=tk.CENTER) # Headings for cl in columns_names: - self.tree.heading(cl, text=cl, anchor=tk.CENTER) - self.tree.tag_configure('req', foreground='black', + tree.heading(cl, text=cl, anchor=tk.CENTER) + tree.tag_configure('req', foreground='black', background='#9fc5e8') - self.tree.tag_configure('opt', foreground='black', + tree.tag_configure('opt', foreground='black', background='#cfe2f3') - self.tree.tag_configure('type', foreground='black', + tree.tag_configure('type', foreground='black', background='#E8E8E8') - self.tree.tag_configure('func', foreground='black', + tree.tag_configure('meth', foreground='black', background='#DFDFDF') # Define double-click on row action - self.tree.bind("", self.OnDoubleClick) + tree.bind("", self.OnDoubleClick) + return tree def OnDoubleClick(self, event): """ Executed when a row of the treeview is double clicked. @@ -614,31 +741,54 @@ def OnDoubleClick(self, event): if len(tags) > 0 and tags[0] in ['opt', 'req']: # get column position info x, y, width, height = self.tree.bbox(self.treerow, self.treecol) - # y-axis offset pady = height // 2 - # pady = 0 - - if hasattr(self, 'entry'): - self.entry.destroy() - - self.entry = tk.Entry(self.tree, justify='center') - - if int(self.treecol[1:]) > 0: - value = self.tree.item(self.treerow)['values'][int(str(self.treecol[1:]))-1] - value = str(value) if str(value) not in ['default', 'Choose X or Y'] else '' - self.entry.insert(0, value) - # self.entry['selectbackground'] = '#123456' - self.entry['exportselection'] = False - - self.entry.focus_force() - self.entry.bind("", self.on_return) - self.entry.bind("", lambda *ignore: self.entry.destroy()) - - self.entry.place(x=x, + if tags[-1] != 'data': + if hasattr(self, 'entry'): + self.entry.destroy() + self.entry = tk.Entry(self.tree, justify='center') + if int(self.treecol[1:]) > 0: + value = self.tree.item(self.treerow)['values'][int(str(self.treecol[1:]))-1] + value = str(value) if str(value) not in ['default', 'Choose X or Y'] else '' + self.entry.insert(0, value) + # self.entry['selectbackground'] = '#123456' + self.entry['exportselection'] = False + + self.entry.focus_force() + self.entry.bind("", self.on_return) + self.entry.bind("", lambda *ignore: self.entry.destroy()) + + self.entry.place(x=x, + y=y + pady, + anchor=tk.W, width=width) + else: + data_list = ['X','Y','X_test','Y_test'] # TODO: Substitute with loaded data + data_list.insert(0,self.default_inputData['_'.join(tags[:-1])]) + data_list = list(np.unique(data_list)) + self.dropDown = tk.ttk.OptionMenu(self.tree, self.method_inputData['_'.join(tags[:-1])], + self.method_inputData['_'.join(tags[:-1])].get(), *data_list) + bg = '#9fc5e8' if tags[0] == 'req' else '#cfe2f3' + self.dropDown["menu"].configure(bg=bg) + style = ttk.Style() + style.configure("new.TMenubutton", background=bg, highlightbackground="black", highlightthickness=1) + self.dropDown.configure(style="new.TMenubutton") + self.dropDown.place(x=x, y=y + pady, anchor=tk.W, width=width) - + + def on_changeOption(self, *args): + """ Executed when the optionmenu is selected and pressed enter. + Saves the value""" + if hasattr(self, 'dropDown'): + value = self.tree.item(self.treerow)['values'][int(str(self.treecol[1:]))-2] + tags = self.tree.item(self.treerow)["tags"] + val = self.tree.item(self.treerow)['values'] + new_val = self.method_inputData['_'.join(tags[:-1])].get() + val[int(self.treecol[1:])-1] = new_val + self.tree.item(self.treerow, values=tuple([val[0], new_val])) + self.dropDown.destroy() + self.saved = False + def on_return(self, event): """ Executed when the entry is edited and pressed enter. Saves the edited value""" @@ -646,9 +796,9 @@ def on_return(self, event): val = self.tree.item(self.treerow)['values'] val[int(self.treecol[1:])-1] = self.entry.get() if self.entry.get() != '': - self.tree.item(self.treerow, values=tuple([val[0], val[1], self.entry.get()])) - elif val[2] == '': - self.tree.item(self.treerow, values=tuple([val[0], val[1], 'default'])) + self.tree.item(self.treerow, values=tuple([val[0], self.entry.get()])) + elif val[1] == '': + self.tree.item(self.treerow, values=tuple([val[0], 'default'])) else: self.tree.item(self.treerow, values=val) self.entry.destroy() @@ -661,33 +811,62 @@ def fill_treeview(self, req_settings, opt_settings, parent = ''): :param parent: string type of parent name """ self.tree.insert(parent=parent, index='end', iid=parent+'_req', text='', - values=tuple(['Required settings', '', '']), tags=('type',)) + values=tuple(['Required settings', '']), tags=('type',parent), open=True) self.r+=1 for arg, val in req_settings.items(): - if arg == 'Data': + if arg.lower() in ['x', 'y']: + value = np.array(['X', 'Y'])[arg.lower() == np.array(['x', 'y'])][0] self.tree.insert(parent=parent+'_req', index='end', iid=str(self.r), text='', - values=tuple([arg, val, 'Choose X or Y']), tags=('req',)) + values=tuple([arg, value]), tags=('req',parent,arg,'data')) + self.method_inputData['req_'+parent+'_'+str(arg)] = tk.StringVar(self.tree) + self.method_inputData['req_'+parent+'_'+str(arg)].set(value) + self.method_inputData['req_'+parent+'_'+str(arg)].trace("w", self.on_changeOption) + self.default_inputData['req_'+parent+'_'+str(arg)] = value else: self.tree.insert(parent=parent+'_req', index='end', iid=str(self.r), text='', - values=tuple([arg, val, '']), tags=('req',)) + values=tuple([arg, str(val)]), tags=('req',parent)) self.r+=1 self.tree.insert(parent=parent, index='end', iid=parent+'_opt', text='', - values=tuple(['Optional settings', '', '']), tags=('type',)) + values=tuple(['Optional settings', '']), tags=('type',parent), open=True) self.r+=1 for arg, val in opt_settings.items(): - self.tree.insert(parent=parent+'_opt', index='end', iid=str(self.r), text='', - values=tuple([arg, val, 'default']), tags=('opt',)) + if arg.lower() in ['x', 'y']: + self.tree.insert(parent=parent+'_opt', index='end', iid=str(self.r), text='', + values=tuple([arg, val]), tags=('opt',parent,arg,'data')) + self.method_inputData['opt_'+parent+'_'+str(arg)] = tk.StringVar(self.tree) + self.method_inputData['opt_'+parent+'_'+str(arg)].set(val) + self.method_inputData['opt_'+parent+'_'+str(arg)].trace("w", self.on_changeOption) + self.default_inputData['opt_'+parent+'_'+str(arg)] = str(val) + else: + self.tree.insert(parent=parent+'_opt', index='end', iid=str(self.r), text='', + values=tuple([arg, str(val)]), tags=('opt',parent)) self.r+=1 def removewindow(self): """ Stores settings options and closes window """ - self.req_settings.pop("Data", None) - children = self.get_all_children() - for child in children: - tag = self.tree.item(child)["tags"][0] - if tag in ['req', 'opt']: - val = self.tree.item(child)["values"] - self.settingOptions(tag, val) + # Updates the tree with any unclosed dropDown menu + if hasattr(self, 'dropDown'): + for data in self.method_inputData.keys(): + tags = data.split('_') + el = self.get_element_from_tags(*tags) + val = self.tree.item(el)['values'] + new_val = self.method_inputData[data].get() + val[int(self.treecol[1:])-1] = new_val + self.tree.item(el, values=tuple([val[0], new_val])) + self.dropDown.destroy() + # Updates the modified options and removes the ones that are not + for f in self.tree.get_children(): + for c in self.tree.get_children(f): + for child in self.tree.get_children(c): + tags = self.tree.item(child)["tags"] + if tags[0] in ['req', 'opt']: + if tags[-1] == 'data': + self.updateSettings(tags[0], tags[1], tags[2], self.method_inputData['_'.join(tags[:-1])].get()) + else: + val = self.tree.item(child)["values"] + self.settingOptions(tags[0], f, val) + if hasattr(self, 'model'): + del self.model self.newWindow.destroy() self.newWindow = None self.focus() @@ -699,20 +878,21 @@ def get_all_children(self, item=""): children += self.get_all_children(child) return children - def settingOptions(self, tag, val): + def get_element_from_tags(self, *args): + """ Finds item in tree with specified tags """ + el = set(self.tree.tag_has(args[0])) + for arg in args[1:]: + el = set.intersection(el, set(self.tree.tag_has(arg))) + return list(el)[0] + + def settingOptions(self, tag, f, val): """ Identifies how the data should be stored """ - if val[0] == 'Data': - if val[2] == 'Choose X or Y' or len(val[2]) == 0: - self.updateSettings(tag, val[0], 'X') - else: - self.updateSettings(tag, val[0], val[2]) + if val[1] == 'default' or len(str(val[1])) == 0: + self.updateSettings(tag, f, val[0]) else: - if val[2] == 'default' or len(str(val[2])) == 0: - self.updateSettings(tag, val[0]) - else: - self.updateSettings(tag, val[0], val[2]) + self.updateSettings(tag, f, val[0], val[1]) - def updateSettings(self, tag, key, value = None): + def updateSettings(self, tag, f, key, value = None): """ Return the selected settings Parameters @@ -720,17 +900,40 @@ def updateSettings(self, tag, key, value = None): tag : str tag for the settings """ + + value = self.str_to_bool(value) if tag == 'req': - if value is not None or self.req_settings[key] != value: - self.req_settings[key] = value + if value is not None or self.isNotClose(self.req_settings[f][key], value): + self.req_settings[f][key] = value else: - self.req_settings.pop(key, None) + self.req_settings[f].pop(key, None) elif tag == 'opt': - if value is not None or self.opt_settings[key] != value: - self.opt_settings[key] = value + if self.isNotClose(self.opt_settings[f][key], value): + self.opt_settings[f][key] = value else: - self.opt_settings.pop(key, None) + self.opt_settings[f].pop(key, None) + def isNotClose(self, a, b, rel_tol=1e-09, abs_tol=0.0): + a = self.xml_handler._str_to_num(a) if isinstance(a, (str)) else a + b = self.xml_handler._str_to_num(b) if isinstance(b, (str)) else b + if isinstance(a, (int, float)) and isinstance(b, (int, float)): + return abs(a-b) > max(rel_tol * max(abs(a), abs(b)), abs_tol) + else: + return a != b + + def str_to_bool(self, s): + if type(s) is str: + if s == 'True': + return True + elif s == 'False': + return False + elif s == 'None': + return None + else: + return s + else: + return s + def on_return_entry(self, r): """ Changes focus to the next available entry. When no more, focuses on the finish button. @@ -867,7 +1070,6 @@ def upload(self): self.xml_handler = XML_handler() self.xml_handler.load_XML(filename) - # self.xml_handler._print_pretty(self.xml_handler.loaded_modules) modules = self.xml_handler.loaded_modules modout = modules['Output'] # They are generated when resetting @@ -1053,7 +1255,7 @@ def check_quit(self): self.controller.Plugin.set(True) self.controller._show_frame("MainPage") # TODO: Check if loaded - elif len(self.xml_handler.loaded_modules) == 0: + elif not hasattr(self,"xml_handler") or len(self.xml_handler.loaded_modules) == 0: self.controller._show_frame("MainPage") self.controller.Plugin.set(False) else: diff --git a/src/vai_lab/utils/plugins/progressTracker.py b/src/vai_lab/utils/plugins/progressTracker.py index fbb02e27..22fab3fe 100644 --- a/src/vai_lab/utils/plugins/progressTracker.py +++ b/src/vai_lab/utils/plugins/progressTracker.py @@ -3,11 +3,14 @@ from typing import Dict from tkinter import ttk +from inspect import getmembers, isfunction, ismethod, getfullargspec +from vai_lab._plugin_helpers import PluginSpecs + import numpy as np import pandas as pd from PIL import Image, ImageTk from vai_lab.Data.xml_handler import XML_handler -from vai_lab._import_helper import get_lib_parent_dir +from vai_lab._import_helper import get_lib_parent_dir, import_plugin_absolute _PLUGIN_READABLE_NAMES = {"progress_tracker":"default", "progressTracker":"alias", @@ -266,20 +269,43 @@ def add_module(self,boxName: str,x: float,y:float,ini = False,out = False): def optionsWindow(self): """ Function to create a new window displaying the available options of the selected plugin.""" - - mod_idx = np.where(self.m == np.array(self.id_mod))[0][0] - self.module = np.array(self.module_list)[mod_idx] - self.plugin = np.array(self.plugin_list)[mod_idx] - module_type = np.array(self.type_list)[mod_idx] - - self.opt_settings = self.controller._avail_plugins.optional_settings[module_type][self.plugin] - self.req_settings = self.controller._avail_plugins.required_settings[module_type][self.plugin] - if (len(self.opt_settings) != 0) or (len(self.req_settings) != 0): - if hasattr(self, 'newWindow') and (self.newWindow!= None): + + self.mt = self.m -1 if self.m < len(self.module_list)-1 else 1 + + module = np.array(self.module_list)[self.mt == np.array(self.id_mod)][0] + ps = PluginSpecs() + file_name = os.path.split(ps.find_from_class_name(self.plugin_list[self.m])['_PLUGIN_DIR'])[-1] + avail_plugins = ps.available_plugins[module][file_name] + plugin = import_plugin_absolute(globals(), + avail_plugins["_PLUGIN_PACKAGE"], + avail_plugins["_PLUGIN_CLASS_NAME"]) + # Update required and optional settings for the plugin + self.req_settings = {'__init__': ps.required_settings[module][self.plugin_list[self.m]]} + self.opt_settings = {'__init__': ps.optional_settings[module][self.plugin_list[self.m]]} + self.model = plugin().model + meth_req, meth_opt = self.getArgs(self.model.__init__) + if meth_req is not None: + self.req_settings['__init__'] = {**self.req_settings['__init__'], **meth_req} + if meth_opt is not None: + self.opt_settings['__init__'] = {**self.opt_settings['__init__'], **meth_opt} + + # Find functions defined for the module + plugin_meth_list = [meth[0] for meth in getmembers(plugin, isfunction) if meth[0][0] != '_'] + # Find available methods for the model + model_meth_list = [meth[0] for meth in getmembers(self.model, ismethod) if meth[0][0] != '_'] + # List intersection + # TODO: use only methods from the model + set_2 = frozenset(model_meth_list) + meth_list = [x for x in plugin_meth_list if x in set_2] + self.meths_sort = [] + self.method_inputData = {} + self.default_inputData = {} + if (len(self.opt_settings['__init__']) != 0) or (len(self.req_settings['__init__']) != 0): + if hasattr(self, 'newWindow') and (self.newWindow != None): self.newWindow.destroy() self.newWindow = tk.Toplevel(self.controller) # Window options - self.newWindow.title(self.plugin+' plugin options') + self.newWindow.title(self.plugin_list[self.m]+' plugin options') script_dir = get_lib_parent_dir() self.tk.call('wm', 'iconphoto', self.newWindow, ImageTk.PhotoImage( file=os.path.join(os.path.join( @@ -288,15 +314,21 @@ def optionsWindow(self): 'resources', 'Assets', 'VAILabsIcon.ico')))) - self.newWindow.geometry("350x400") - self.raise_above_all(self.newWindow) + # self.newWindow.geometry("350x400") frame1 = tk.Frame(self.newWindow) + frame2 = tk.Frame(self.newWindow) + frame21 = tk.Frame(self.newWindow) + frame3 = tk.Frame(self.newWindow) frame4 = tk.Frame(self.newWindow) - + frame5 = tk.Frame(self.newWindow) + frame6 = tk.Frame(self.newWindow, highlightbackground="black", highlightthickness=1) + frameDrop = tk.Frame(frame3, highlightbackground="black", highlightthickness=1) + frameButt = tk.Frame(frame3) + # Print settings tk.Label(frame1, - text ="Please indicate your desired options for the "+self.plugin+" plugin.", anchor = tk.N, justify=tk.LEFT).pack(expand = True) - + text="Please indicate your desired options for the plugin.", anchor=tk.N, justify=tk.LEFT).pack(expand=True) + style = ttk.Style() style.configure( "Treeview", background='white', foreground='white', @@ -308,14 +340,44 @@ def optionsWindow(self): ) style.map('Treeview', background=[('selected', 'grey')]) - frame2 = tk.Frame(self.newWindow, bg='green') + tk.Label(frame21, + text="Add your desired methods in your required order.", anchor=tk.N, justify=tk.LEFT).pack(expand=True) + + self.meth2add = tk.StringVar(frameDrop) + dropDown = tk.ttk.OptionMenu(frameDrop, self.meth2add, meth_list[0], *meth_list) + style.configure("TMenubutton", background="white") + dropDown["menu"].configure(bg="white") + dropDown.grid(row=0,column=0) + + tk.Button(frameButt, text='Add', command=self.addMeth).grid(row=0,column=0) + tk.Button(frameButt, text='Delete', command=self.deleteMeth).grid(row=0,column=1) + tk.Button(frameButt, text='Up', command=lambda: self.moveMeth(-1)).grid(row=0,column=2) + tk.Button(frameButt, text='Down', command=lambda: self.moveMeth(+1)).grid(row=0,column=3) + self.r = 1 - self.create_treeView(frame2, ['Name', 'Type', 'Value']) - self.fill_treeview(frame2, self.req_settings, self.opt_settings) - frame2.grid(column=0, row=1, sticky="nswe", pady=10, padx=10) + self.tree = self.create_treeView(frame2, ['Name', 'Value']) + self.tree.insert(parent='', index='end', iid='__init__', text='', values=tuple(['__init__', '']), + tags=('meth','__init__')) + self.fill_treeview(self.update_options(self.req_settings['__init__'], self.p_list[self.mt]['options']), + self.update_options(self.opt_settings['__init__'], self.p_list[self.mt]['options']), + '__init__') + meth2add = self.meth2add.get() + for meth in self.p_list[self.mt]['methods']['_order']: + self.meth2add.set(meth) + self.addMeth() + self.meth2add.set(meth2add) + + tk.Label(frame5, + text="Indicate which plugin's output data should be used as input", anchor=tk.N, justify=tk.LEFT).pack(expand=True) + + current = np.where(self.m == np.array(self.id_mod))[0][0] + dataSources = [i for j, i in enumerate(self.module_names) if j not in [1,current]] - frame2.grid_rowconfigure(tuple(range(self.r)), weight=1) - frame2.grid_columnconfigure(tuple(range(2)), weight=1) + self.plugin_inputData = tk.StringVar(frame6) + dropDown = tk.ttk.OptionMenu(frame6, self.plugin_inputData, dataSources[current-2], *dataSources) + style.configure("TMenubutton", background="white") + dropDown["menu"].configure(bg="white") + dropDown.pack() self.finishButton = tk.Button( frame4, text='Finish', command=self.removewindow) @@ -325,14 +387,101 @@ def optionsWindow(self): "", lambda event: self.removewindow()) self.newWindow.protocol('WM_DELETE_WINDOW', self.removewindow) + frameDrop.grid(column=0, row=0) + frameButt.grid(column=1, row=0, sticky="w") frame1.grid(column=0, row=0, sticky="ew") - frame4.grid(column=0, row=2, sticky="se") + frame2.grid(column=0, row=1, sticky="nswe", pady=10, padx=10) + frame21.grid(column=0, row=2, sticky="ew") + frame3.grid(column=0, row=3, pady=10, padx=10) + frame4.grid(column=0, row=20, sticky="se") + frame5.grid(column=0, row=4, sticky="ew") + frame6.grid(column=0, row=5) + + frame2.grid_rowconfigure(tuple(range(self.r)), weight=1) + frame2.grid_columnconfigure(tuple(range(2)), weight=1) self.newWindow.grid_rowconfigure(1, weight=2) - self.newWindow.grid_columnconfigure(0, weight=1) + self.newWindow.grid_columnconfigure(tuple(range(2)), weight=1) - def raise_above_all(self, window): - window.attributes('-topmost', 1) - window.attributes('-topmost', 0) + def update_options(self, base, new): + """ Updates the values in a dictionary with the keys in the new dictionary. + + Parameters + ---------- + base : dict + dictionary that needs to be updated + new : dict + dictionary with keys that will be updated if they exist in base + + :returns out: the updated dictionary + """ + for key in base: + if key in new: + base[key] = new[key] + return base + + def getArgs(self, f): + """ Get required and optional arguments from method. + + Parameters + ---------- + f : method + method to extract arguments from + + :returns out: two dictionaries with arguments and default value (if optional) + """ + + meth_args = getfullargspec(f).args + if meth_args is not None: + meth_args.remove('self') + meth_def = getfullargspec(f).defaults + if meth_def is None: + meth_def = [] + meth_req = {p: '' for p in meth_args[:(len(meth_args)-len(meth_def))]} + meth_r_opt = {p: meth_def[i] for i,p in enumerate(meth_args[(len(meth_args)-len(meth_def)):])} + + meth_opt = getfullargspec(f).kwonlydefaults + if meth_opt is not None: + meth_opt = {p: meth_opt[p] for p in meth_opt} + if meth_r_opt is not None: + meth_opt = {**meth_r_opt, **meth_opt} + return meth_req, meth_opt + else: + return meth_req, meth_r_opt + + def addMeth(self): + """ Adds selected method in dropdown menu to the plugin tree """ + meth = self.meth2add.get() + self.meths_sort.append(meth) + self.tree.insert(parent='', index='end', iid=meth, text='', values=tuple([meth, '']), + tags=('meth',meth)) + # TODO: Remove X and y? + self.req_settings[meth], self.opt_settings[meth] = self.getArgs(getattr(self.model, meth)) + if meth in self.p_list[self.mt]['methods']['_order']: + self.fill_treeview( + self.update_options(self.req_settings[meth], self.p_list[self.mt]['methods'][meth]['options']), + self.update_options(self.opt_settings[meth], self.p_list[self.mt]['methods'][meth]['options']), + meth) + else: + self.fill_treeview( + self.req_settings[meth], + self.opt_settings[meth], + meth) + + def deleteMeth(self): + """ Deletes selected method in dropdown menu from the plugin tree """ + meth = self.meth2add.get() + if meth in self.meths_sort: + self.meths_sort.remove(meth) + del self.req_settings[meth] + del self.opt_settings[meth] + self.tree.delete(meth) + + def moveMeth(self, m): + meth = self.meth2add.get() + if meth in self.meths_sort and self.tree.index(meth)+m > 0: + idx = self.meths_sort.index(meth) + self.meths_sort.insert(idx+m, self.meths_sort.pop(idx)) + self.tree.move(meth, self.tree.parent(meth), self.tree.index(meth)+m) def create_treeView(self, tree_frame, columns_names): """ Function to create a new tree view in the given frame @@ -350,72 +499,97 @@ def create_treeView(self, tree_frame, columns_names): tree_scrolly = tk.Scrollbar(tree_frame) tree_scrolly.pack(side=tk.RIGHT, fill=tk.Y) - self.tree = ttk.Treeview(tree_frame, + tree = ttk.Treeview(tree_frame, yscrollcommand=tree_scrolly.set, xscrollcommand=tree_scrollx.set) - self.tree.pack(fill='both', expand=True) + tree.pack(fill='both', expand=True) - tree_scrollx.config(command=self.tree.xview) - tree_scrolly.config(command=self.tree.yview) + tree_scrollx.config(command=tree.xview) + tree_scrolly.config(command=tree.yview) - self.tree['columns'] = columns_names + tree['columns'] = columns_names # Format columns - self.tree.column("#0", width=20, + tree.column("#0", width=40, minwidth=0, stretch=tk.NO) for n, cl in enumerate(columns_names): - self.tree.column( + tree.column( cl, width=int(self.controller.pages_font.measure(str(cl)))+20, minwidth=50, anchor=tk.CENTER) # Headings for cl in columns_names: - self.tree.heading(cl, text=cl, anchor=tk.CENTER) - self.tree.tag_configure('req', foreground='black', + tree.heading(cl, text=cl, anchor=tk.CENTER) + tree.tag_configure('req', foreground='black', background='#9fc5e8') - self.tree.tag_configure('opt', foreground='black', + tree.tag_configure('opt', foreground='black', background='#cfe2f3') - self.tree.tag_configure('type', foreground='black', + tree.tag_configure('type', foreground='black', background='#E8E8E8') - self.tree.tag_configure('func', foreground='black', + tree.tag_configure('meth', foreground='black', background='#DFDFDF') # Define double-click on row action - self.tree.bind("", self.OnDoubleClick) + tree.bind("", self.OnDoubleClick) + return tree def OnDoubleClick(self, event): """ Executed when a row of the treeview is double clicked. Opens an entry box to edit a cell. """ + # ii = self.notebook.index(self.notebook.select()) self.treerow = self.tree.identify_row(event.y) self.treecol = self.tree.identify_column(event.x) tags = self.tree.item(self.treerow)["tags"] if len(tags) > 0 and tags[0] in ['opt', 'req']: # get column position info x, y, width, height = self.tree.bbox(self.treerow, self.treecol) - # y-axis offset pady = height // 2 - # pady = 0 - - if hasattr(self, 'entry'): - self.entry.destroy() - - self.entry = tk.Entry(self.tree, justify='center') - - if int(self.treecol[1:]) > 0: - value = self.tree.item(self.treerow)['values'][int(str(self.treecol[1:]))-1] - value = str(value) if str(value) not in ['default', 'Choose X or Y'] else '' - self.entry.insert(0, value) - # self.entry['selectbackground'] = '#123456' - self.entry['exportselection'] = False - - self.entry.focus_force() - self.entry.bind("", self.on_return) - self.entry.bind("", lambda *ignore: self.entry.destroy()) - - self.entry.place(x=x, + if tags[-1] != 'data': + if hasattr(self, 'entry'): + self.entry.destroy() + self.entry = tk.Entry(self.tree, justify='center') + if int(self.treecol[1:]) > 0: + value = self.tree.item(self.treerow)['values'][int(str(self.treecol[1:]))-1] + value = str(value) if str(value) not in ['default', 'Choose X or Y'] else '' + self.entry.insert(0, value) + # self.entry['selectbackground'] = '#123456' + self.entry['exportselection'] = False + + self.entry.focus_force() + self.entry.bind("", self.on_return) + self.entry.bind("", lambda *ignore: self.entry.destroy()) + + self.entry.place(x=x, + y=y + pady, + anchor=tk.W, width=width) + else: + data_list = ['X','Y','X_test','Y_test'] # TODO: Substitute with loaded data + data_list.insert(0,self.default_inputData['_'.join(tags[:-1])]) + data_list = list(np.unique(data_list)) + self.dropDown = tk.ttk.OptionMenu(self.tree, self.method_inputData['_'.join(tags[:-1])], + self.method_inputData['_'.join(tags[:-1])].get(), *data_list) + bg = '#9fc5e8' if tags[0] == 'req' else '#cfe2f3' + self.dropDown["menu"].configure(bg=bg) + style = ttk.Style() + style.configure("new.TMenubutton", background=bg, highlightbackground="black", highlightthickness=1) + self.dropDown.configure(style="new.TMenubutton") + self.dropDown.place(x=x, y=y + pady, anchor=tk.W, width=width) + def on_changeOption(self, *args): + """ Executed when the optionmenu is selected and pressed enter. + Saves the value""" + if hasattr(self, 'dropDown'): + value = self.tree.item(self.treerow)['values'][int(str(self.treecol[1:]))-2] + tags = self.tree.item(self.treerow)["tags"] + val = self.tree.item(self.treerow)['values'] + new_val = self.method_inputData['_'.join(tags[:-1])].get() + val[int(self.treecol[1:])-1] = new_val + self.tree.item(self.treerow, values=tuple([val[0], new_val])) + self.dropDown.destroy() + self.saved = False + def on_return(self, event): """ Executed when the entry is edited and pressed enter. Saves the edited value""" @@ -423,73 +597,106 @@ def on_return(self, event): val = self.tree.item(self.treerow)['values'] val[int(self.treecol[1:])-1] = self.entry.get() if self.entry.get() != '': - self.tree.item(self.treerow, values=tuple([val[0], val[1], self.entry.get()])) - elif val[2] == '': - self.tree.item(self.treerow, values=tuple([val[0], val[1], 'default'])) + self.tree.item(self.treerow, values=tuple([val[0], self.entry.get()])) + elif val[1] == '': + self.tree.item(self.treerow, values=tuple([val[0], 'default'])) else: self.tree.item(self.treerow, values=val) self.entry.destroy() self.saved = False - def fill_treeview(self, frame, req_settings, opt_settings, parent = ''): + def fill_treeview(self, req_settings, opt_settings, parent = ''): """ Adds an entry for each setting. Displays it in the specified row. :param req_settings: dict type of plugin required setting options :param opt_settings: dict type of plugin optional setting options :param parent: string type of parent name """ self.tree.insert(parent=parent, index='end', iid=parent+'_req', text='', - values=tuple(['Required settings', '', '']), tags=('type',)) + values=tuple(['Required settings', '']), tags=('type',parent), open=True) self.r+=1 for arg, val in req_settings.items(): - if arg == 'Data': + if arg.lower() in ['x', 'y']: + value = np.array(['X', 'Y'])[arg.lower() == np.array(['x', 'y'])][0] self.tree.insert(parent=parent+'_req', index='end', iid=str(self.r), text='', - values=tuple([arg, val, 'Choose X or Y']), tags=('req',)) + values=tuple([arg, value]), tags=('req',parent,arg,'data')) + self.method_inputData['req_'+parent+'_'+str(arg)] = tk.StringVar(self.tree) + self.method_inputData['req_'+parent+'_'+str(arg)].set(value) + self.method_inputData['req_'+parent+'_'+str(arg)].trace("w", self.on_changeOption) + self.default_inputData['req_'+parent+'_'+str(arg)] = value else: self.tree.insert(parent=parent+'_req', index='end', iid=str(self.r), text='', - values=tuple([arg, val, '']), tags=('req',)) + values=tuple([arg, val]), tags=('req',parent)) self.r+=1 self.tree.insert(parent=parent, index='end', iid=parent+'_opt', text='', - values=tuple(['Optional settings', '', '']), tags=('type',)) + values=tuple(['Optional settings', '']), tags=('type',parent), open=True) self.r+=1 for arg, val in opt_settings.items(): - self.tree.insert(parent=parent+'_opt', index='end', iid=str(self.r), text='', - values=tuple([arg, val, 'default']), tags=('opt',)) + if arg.lower() in ['x', 'y']: + self.tree.insert(parent=parent+'_opt', index='end', iid=str(self.r), text='', + values=tuple([arg, val]), tags=('opt',parent,arg,'data')) + self.method_inputData['opt_'+parent+'_'+str(arg)] = tk.StringVar(self.tree) + self.method_inputData['opt_'+parent+'_'+str(arg)].set(val) + self.method_inputData['opt_'+parent+'_'+str(arg)].trace("w", self.on_changeOption) + self.default_inputData['opt_'+parent+'_'+str(arg)] = str(val) + else: + self.tree.insert(parent=parent+'_opt', index='end', iid=str(self.r), text='', + values=tuple([arg, val]), tags=('opt',parent)) self.r+=1 + def raise_above_all(self, window): + window.attributes('-topmost', 1) + window.attributes('-topmost', 0) + def removewindow(self): """ Stores settings options and closes window """ - self.req_settings.pop("Data", None) - children = self.get_all_children() - for child in children: - tag = self.tree.item(child)["tags"][0] - if tag in ['req', 'opt']: - val = self.tree.item(child)["values"] - self.settingOptions(tag, val) + # Updates the tree with any unclosed dropDown menu + if hasattr(self, 'dropDown'): + for data in self.method_inputData.keys(): + tags = data.split('_') + el = self.get_element_from_tags(*tags) + val = self.tree.item(el)['values'] + new_val = self.method_inputData[data].get() + val[int(self.treecol[1:])-1] = new_val + self.tree.item(el, values=tuple([val[0], new_val])) + self.dropDown.destroy() + # Updates the modified options and removes the ones that are not + for f in self.tree.get_children(): + for c in self.tree.get_children(f): + for child in self.tree.get_children(c): + tags = self.tree.item(child)["tags"] + if tags[0] in ['req', 'opt']: + if tags[-1] == 'data': + self.updateSettings(tags[0], tags[1], tags[2], self.method_inputData['_'.join(tags[:-1])].get()) + else: + val = self.tree.item(child)["values"] + self.settingOptions(tags[0], f, val) + del self.model self.newWindow.destroy() self.newWindow = None self.focus() def get_all_children(self, item=""): - """ Iterates over the treeview to get all childer """ + """ Iterates over the treeview to get all children """ children = self.tree.get_children(item) for child in children: children += self.get_all_children(child) return children - def settingOptions(self, tag, val): + def get_element_from_tags(self, *args): + """ Finds item in tree with specified tags """ + el = set(self.tree.tag_has(args[0])) + for arg in args[1:]: + el = set.intersection(el, set(self.tree.tag_has(arg))) + return list(el)[0] + + def settingOptions(self, tag, f, val): """ Identifies how the data should be stored """ - if val[0] == 'Data': - if val[2] == 'Choose X or Y' or len(val[2]) == 0: - self.updateSettings(tag, val[0], 'X') - else: - self.updateSettings(tag, val[0], val[2]) + if val[1] == 'default' or len(str(val[1])) == 0: + self.updateSettings(tag, f, val[0]) else: - if val[2] == 'default' or len(str(val[2])) == 0: - self.updateSettings(tag, val[0]) - else: - self.updateSettings(tag, val[0], val[2]) + self.updateSettings(tag, f, val[0], val[1]) - def updateSettings(self, tag, key, value = None): + def updateSettings(self, tag, f, key, value = None): """ Return the selected settings Parameters @@ -497,17 +704,40 @@ def updateSettings(self, tag, key, value = None): tag : str tag for the settings """ + + value = self.str_to_bool(value) if tag == 'req': - if value is not None or self.req_settings[key] != value: - self.req_settings[key] = value + if value is not None or self.isNotClose(self.req_settings[f][key], value): + self.req_settings[f][key] = value else: - self.req_settings.pop(key, None) + self.req_settings[f].pop(key, None) elif tag == 'opt': - if value is not None or self.opt_settings[key] != value: - self.opt_settings[key] = value + if self.isNotClose(self.opt_settings[f][key], value): + self.opt_settings[f][key] = value else: - self.opt_settings.pop(key, None) + self.opt_settings[f].pop(key, None) + def isNotClose(self, a, b, rel_tol=1e-09, abs_tol=0.0): + a = self.xml_handler._str_to_num(a) if isinstance(a, (str)) else a + b = self.xml_handler._str_to_num(b) if isinstance(b, (str)) else b + if isinstance(a, (int, float)) and isinstance(b, (int, float)): + return abs(a-b) > max(rel_tol * max(abs(a), abs(b)), abs_tol) + else: + return a != b + + def str_to_bool(self, s): + if type(s) is str: + if s == 'True': + return True + elif s == 'False': + return False + elif s == 'None': + return None + else: + return s + else: + return s + def on_return_entry(self, r): """ Changes focus to the next available entry. When no more, focuses on the finish button. @@ -790,7 +1020,7 @@ def reset(self): def check_quit(self): self.controller.destroy() - + class CanvasTooltip: ''' It creates a tooltip for a given canvas tag or id as the mouse is