Skip to content

Commit

Permalink
Fix issue with stored values and changed test data
Browse files Browse the repository at this point in the history
  • Loading branch information
sevisal committed Sep 15, 2023
1 parent 397e3ac commit 1667f40
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 36 deletions.
2 changes: 1 addition & 1 deletion src/vai_lab/Core/vai_lab_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def _execute_exit_point(self, specs):

with open(rel_to_abs(specs['plugin']['options']['outpath']), 'wb') as handle:
pickle.dump(data_out, handle, protocol=pickle.HIGHEST_PROTOCOL)

def _parse_loop_condition(self, condition):
try:
condition = int(condition)
Expand Down Expand Up @@ -169,7 +170,6 @@ def _execute(self, specs):

if not _tracker['terminate']:
self.load_config_file(self._xml_handler.filename)
# pass
else:
print('Pipeline terminated')
exit()
Expand Down
58 changes: 33 additions & 25 deletions src/vai_lab/_plugin_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def __init__(self, plugin_globals: dict) -> None:
"""
self.X = None
self.Y = None
self.X_tst = None
self.Y_tst = None
self.X_test = None
self.Y_test = None

self._PLUGIN_READABLE_NAMES: dict
self._PLUGIN_MODULE_OPTIONS: dict
Expand Down Expand Up @@ -78,8 +78,8 @@ def _parse_config(self):
"""Parse incoming data and args, sets them as class variables"""
self.X = np.array(self._get_data_if_exist(self._data_in, "X"))
self.Y = np.array(self._get_data_if_exist(self._data_in, "Y")).ravel()
self.X_tst = self._get_data_if_exist(self._data_in, "X_test")
self.Y_tst = np.array(self._get_data_if_exist(self._data_in, "Y_test")).ravel()
self.X_test = self._get_data_if_exist(self._data_in, "X_test")
self.Y_test = np.array(self._get_data_if_exist(self._data_in, "Y_test")).ravel()
self._clean_options()

def _get_data_if_exist(self, data_dict: dict, key: str, default=None):
Expand Down Expand Up @@ -122,10 +122,14 @@ def _parse_options_dict(self,options_dict:Dict):
options_dict[key] = self.X
elif val == 'Y':
options_dict[key] = self.Y
elif val == 'X':
options_dict[key] = self.X_ts
elif val == 'Y_tst':
options_dict[key] = self.Y_tst
elif val == 'X_test':
options_dict[key] = self.X_test
elif val == 'Y_test':
options_dict[key] = self.Y_test
elif key.lower() == 'x':
options_dict[key] = self.X
elif key.lower() == 'y':
options_dict[key] = self.Y
return options_dict

def _clean_options(self):
Expand All @@ -142,25 +146,25 @@ def _test(self, data: DataInterface) -> DataInterface:
if self._PLUGIN_MODULE_OPTIONS['Type'] == 'classification':
print('Training accuracy: %.2f%%' %
(self.score([self.X, self.Y])*100)) # type: ignore
if self.Y_tst is not None:
if self.Y_test is not None:
print('Test accuracy: %.2f%%' %
(self.score([self.X_tst, self.Y_tst])*100))
if self.X_tst is not None:
data.append_data_column("Y_pred", self.predict([self.X_tst]))
(self.score([self.X_test, self.Y_test])*100))
if self.X_test is not None:
data.append_data_column("Y_pred", self.predict([self.X_test]))
return data
elif self._PLUGIN_MODULE_OPTIONS['Type'] == 'regression':
print('Training R2 score: %.3f' %
(self.score([self.X, self.Y]))) # type: ignore
if self.Y_tst is not None:
if self.Y_test is not None:
print('Test R2 score: %.3f' %
(self.score([self.X_tst, self.Y_tst])))
if self.X_tst is not None:
data.append_data_column("Y_pred", self.predict([self.X_tst]))
(self.score([self.X_test, self.Y_test])))
if self.X_test is not None:
data.append_data_column("Y_pred", self.predict([self.X_test]))
return data
elif self._PLUGIN_MODULE_OPTIONS['Type'] == 'clustering':
print('Clustering completed')
if self.X_tst is not None:
data.append_data_column("Y_pred", self.predict([self.X_tst]))
if self.X_test is not None:
data.append_data_column("Y_pred", self.predict([self.X_test]))
return data
else:
return data
Expand Down Expand Up @@ -194,10 +198,14 @@ def _clean_solver_options(self):
_cleaned[key] = self.X
elif val == 'Y':
_cleaned[key] = self.Y
elif val == 'X':
_cleaned[key] = self.X_tst
elif val == 'Y_tst':
_cleaned[key] = self.Y_tst
elif val == 'X_test':
_cleaned[key] = self.X_test
elif val == 'Y_test':
_cleaned[key] = self.Y_test
elif key.lower() == 'x':
_cleaned[key] = self.X
elif key.lower() == 'y':
_cleaned[key] = self.Y
return _cleaned

def init(self):
Expand Down Expand Up @@ -229,12 +237,12 @@ def fit(self, options={}):
+str(list(self._PLUGIN_READABLE_NAMES.keys())[list(self._PLUGIN_READABLE_NAMES.values()).index('default')])+': '+str(exc)+'.')
raise

def transform(self, data: DataInterface, options={}) -> DataInterface:
def transform(self, options={}) -> DataInterface:
try:
if type(options) == list:
data.append_data_column("X", pd.DataFrame(self.model.transform(*options)))
return pd.DataFrame(self.model.transform(*options))
else:
data.append_data_column("X", pd.DataFrame(self.model.transform(**options)))
return pd.DataFrame(self.model.transform(**options))
except Exception as exc:
print('The plugin encountered an error when transforming the data with '
+str(list(self._PLUGIN_READABLE_NAMES.keys())[list(self._PLUGIN_READABLE_NAMES.values()).index('default')])+': '+str(exc)+'.')
Expand Down
19 changes: 9 additions & 10 deletions src/vai_lab/utils/plugins/pluginCanvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,9 +516,11 @@ def optionsWindow(self):
model_meth_list = [meth[0] for meth in getmembers(self.model, ismethod) if meth[0][0] != '_']
# List intersection
# TODO: use only methods from the model
meth_list = list(set(plugin_meth_list) & set(model_meth_list))[::-1]
set_2 = frozenset(model_meth_list)
meth_list = [x for x in plugin_meth_list if x in set_2]
self.meths_sort = []

self.method_inputData = {}
self.default_inputData = {}
if (len(self.opt_settings['__init__']) != 0) or (len(self.req_settings['__init__']) != 0):
if hasattr(self, 'newWindow') and (self.newWindow != None):
self.newWindow.destroy()
Expand Down Expand Up @@ -751,9 +753,9 @@ def OnDoubleClick(self, event):
y=y + pady,
anchor=tk.W, width=width)
else:
data_list = self.default_inputData['_'.join(tags[:-1])] + list(
set(['X','Y','X_tst','Y_tst']) - set([self.method_inputData['_'.join(tags[:-1])].get()])
- set(self.default_inputData['_'.join(tags[:-1])]))
data_list = ['X','Y','X_test','Y_test'] # TODO: Substitute with loaded data
data_list.insert(0,self.default_inputData['_'.join(tags[:-1])])
data_list = list(np.unique(data_list))
self.dropDown = tk.ttk.OptionMenu(self.tree, self.method_inputData['_'.join(tags[:-1])],
self.method_inputData['_'.join(tags[:-1])].get(), *data_list)
bg = '#9fc5e8' if tags[0] == 'req' else '#cfe2f3'
Expand Down Expand Up @@ -799,8 +801,6 @@ def fill_treeview(self, req_settings, opt_settings, parent = ''):
:param opt_settings: dict type of plugin optional setting options
:param parent: string type of parent name
"""
self.method_inputData = {}
self.default_inputData = {}
self.tree.insert(parent=parent, index='end', iid=parent+'_req', text='',
values=tuple(['Required settings', '']), tags=('type',parent), open=True)
self.r+=1
Expand All @@ -812,7 +812,7 @@ def fill_treeview(self, req_settings, opt_settings, parent = ''):
self.method_inputData['req_'+parent+'_'+str(arg)] = tk.StringVar(self.tree)
self.method_inputData['req_'+parent+'_'+str(arg)].set(value)
self.method_inputData['req_'+parent+'_'+str(arg)].trace("w", self.on_changeOption)
self.default_inputData['req_'+parent+'_'+str(arg)] = [value]
self.default_inputData['req_'+parent+'_'+str(arg)] = value
else:
self.tree.insert(parent=parent+'_req', index='end', iid=str(self.r), text='',
values=tuple([arg, val]), tags=('req',parent))
Expand All @@ -827,7 +827,7 @@ def fill_treeview(self, req_settings, opt_settings, parent = ''):
self.method_inputData['opt_'+parent+'_'+str(arg)] = tk.StringVar(self.tree)
self.method_inputData['opt_'+parent+'_'+str(arg)].set(val)
self.method_inputData['opt_'+parent+'_'+str(arg)].trace("w", self.on_changeOption)
self.default_inputData['opt_'+parent+'_'+str(arg)] = [str(val)]
self.default_inputData['opt_'+parent+'_'+str(arg)] = str(val)
else:
self.tree.insert(parent=parent+'_opt', index='end', iid=str(self.r), text='',
values=tuple([arg, val]), tags=('opt',parent))
Expand Down Expand Up @@ -1060,7 +1060,6 @@ def upload(self):

self.xml_handler = XML_handler()
self.xml_handler.load_XML(filename)
# self.xml_handler._print_pretty(self.xml_handler.loaded_modules)
modules = self.xml_handler.loaded_modules
modout = modules['Output']
# They are generated when resetting
Expand Down

0 comments on commit 1667f40

Please sign in to comment.