Skip to content

Commit

Permalink
better verbose output for agent.ask()
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Morris committed Oct 12, 2023
1 parent 8db416a commit 57e71f3
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions bloptools/bayesian/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def tell(self, new_table=None, append=True, train=True, **kwargs):
"""
Inform the agent about new inputs and targets for the model.
If run with no arguments, it will just reconstruct all the models.
If run with no arguments, it will just reconstruct all the models.
"""

new_table = pd.DataFrame() if new_table is None else new_table
Expand Down Expand Up @@ -193,6 +193,9 @@ def ask(self, acq_func_identifier="qei", n=1, route=True, sequential=True, **acq

start_time = ttime.monotonic()

if self.verbose:
print(f'finding points with acquisition function "{acq_func_name}" ...')

if acq_func_type in ["analytic", "monte_carlo"]:
if not self.initialized:
raise RuntimeError(
Expand Down Expand Up @@ -228,9 +231,6 @@ def ask(self, acq_func_identifier="qei", n=1, route=True, sequential=True, **acq
acq_func_meta["read_only_values"] = read_only_X

else:

acqf_objective = None

if acq_func_name == "random":
acquisition_X = torch.rand()
acq_func_meta = {"name": "random", "args": {}}
Expand All @@ -247,10 +247,16 @@ def ask(self, acq_func_identifier="qei", n=1, route=True, sequential=True, **acq
else:
raise ValueError()

# define dummy acqf objective
acqf_objective = None

acq_func_meta["duration"] = duration = ttime.monotonic() - start_time

if self.verbose:
print(f"found points {acquisition_X} with acqf {acq_func_meta['name']} in {duration:.01f} seconds (obj = {acqf_objective})")
summary = pd.DataFrame(acquisition_X, columns=self.dofs.subset(active=True, read_only=False).names)
summary.insert(0, "acqf", acqf_objective)

print(f"found points in {duration:.03f} seconds:\n" + summary.__repr__())

if route and n > 1:
routing_index = utils.route(self.dofs.subset(active=True, read_only=False).readback, acquisition_X)
Expand All @@ -266,7 +272,7 @@ def acquire(self, acquisition_inputs):
"""
try:
acquisition_devices = self.dofs.subset(active=True, read_only=False).devices
#read_only_devices = self.dofs.subset(active=True, read_only=True).devices
# read_only_devices = self.dofs.subset(active=True, read_only=True).devices

# the acquisition plan always takes as arguments:
# (things to move, where to move them, things to trigger once you get there)
Expand Down Expand Up @@ -350,12 +356,12 @@ def reset(self):
def benchmark(
self, output_dir="./", runs=16, n_init=64, learning_kwargs_list=[{"acq_func": "qei", "n": 4, "iterations": 16}]
):
cache_limits = {dof.name:dof.limits for dof in self.dofs}
cache_limits = {dof.name: dof.limits for dof in self.dofs}

for run in range(runs):
for dof in self.dofs:
offset = 0.25 * np.ptp(dof.limits) * np.random.uniform(low=-1, high=1)
dof.limits = (dof.limits[0] + offset, dof.limits[1] + offset)
dof.limits = (cache_limits[dof.name][0] + offset, cache_limits[dof.name][1] + offset)

self.reset()

Expand Down

0 comments on commit 57e71f3

Please sign in to comment.