Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multirom #249

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions ezyrb/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,14 @@ def __init__(self, parameters=None, snapshots=None):
if parameters is None and snapshots is None:
return

if parameters is None:
parameters = [None] * len(snapshots)
elif snapshots is None:
snapshots = [None] * len(parameters)

if len(parameters) != len(snapshots):
raise ValueError

for param, snap in zip(parameters, snapshots):
self.add(Parameter(param), Snapshot(snap))

Expand Down Expand Up @@ -74,7 +79,9 @@ def __len__(self):

def __str__(self):
""" Print minimal info about the Database """
return str(self.parameters_matrix)
s = 'Database with {} snapshots and {} parameters'.format(
self.snapshots_matrix.shape[1], self.parameters_matrix.shape[1])
return s

def add(self, parameter, snapshot):
"""
Expand Down
8 changes: 4 additions & 4 deletions ezyrb/plugin/automatic_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ def _train_shift_network(self, db):

n_epoch += 1

def fom_preprocessing(self, rom):
db = rom._full_database
def fit_preprocessing(self, rom):
db = rom.database

reference_snapshot = db._pairs[self.reference_index][1]
self.reference_snapshot = reference_snapshot
Expand All @@ -154,11 +154,11 @@ def fom_preprocessing(self, rom):
snap.values = self.interpolator.predict(
reference_snapshot.space.reshape(-1, 1)).flatten()

def fom_postprocessing(self, rom):
def predict_postprocessing(self, rom):

ref_space = self.reference_snapshot.space

for param, snap in rom._full_database._pairs:
for param, snap in rom.predict_full_database._pairs:
input_shift = np.hstack([
ref_space.reshape(-1, 1),
np.ones(shape=(ref_space.shape[0], 1))*param.values])
Expand Down
43 changes: 39 additions & 4 deletions ezyrb/plugin/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,53 @@ class Plugin(ABC):
All the classes that implement the input-output mapping should be inherited
from this class.
"""
def fom_preprocessing(self, rom):
def fit_preprocessing(self, rom):
""" Void """
pass

def rom_preprocessing(self, rom):
def fit_before_reduction(self, rom):
""" Void """
pass

def rom_postprocessing(self, rom):
def fit_after_reduction(self, rom):
""" Void """
pass

def fit_before_approximation(self, rom):
""" Void """
pass

def fom_postprocessing(self, rom):
def fit_after_approximation(self, rom):
""" Void """
pass

def fit_postprocessing(self, rom):
""" Void """
pass

def predict_preprocessing(self, rom):
""" Void """
pass

def predict_before_approximation(self, rom):
""" Void """
pass

def predict_after_approximation(self, rom):
""" Void """
pass

def predict_before_expansion(self, rom):
""" Void """
pass

def predict_after_expansion(self, rom):
""" Void """
pass

def predict_postprocessing(self, rom):
""" Void """
pass



10 changes: 5 additions & 5 deletions ezyrb/plugin/shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def __init__(self, shift_function, interpolator, parameter_index=0,
self.parameter_index = parameter_index
self.reference_index = reference_index

def fom_preprocessing(self, rom):
db = rom._full_database
def fit_preprocessing(self, rom):
db = rom.database

reference_snapshot = db._pairs[self.reference_index][1]

Expand All @@ -68,10 +68,10 @@ def fom_preprocessing(self, rom):
snap.values = self.interpolator.predict(
reference_snapshot.space.reshape(-1, 1)).flatten()

rom._full_database = db
rom.database = db

def fom_postprocessing(self, rom):
for param, snap in rom._full_database._pairs:
def predict_postprocessing(self, rom):
for param, snap in rom.predict_full_database._pairs:
snap.space = (
rom.database._pairs[self.reference_index][1].space +
self.__shift_function(param.values)
Expand Down
Loading
Loading