Skip to content

Commit

Permalink
create state once
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Oct 16, 2024
1 parent 399c5d6 commit 6ee75fd
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions test/neuralgcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,11 @@ def sub(initial_state, all_forcings):
self.eval_era5 = eval_era5
self.all_forcings = all_forcings
self.outer_steps = outer_steps

inputs = self.model.inputs_from_xarray(self.eval_era5.isel(time=0))
input_forcings = self.model.forcings_from_xarray(self.eval_era5.isel(time=0))
rng_key = jax.random.key(42) # optional for deterministic models
self.initial_state = self.model.encode(inputs, input_forcings, rng_key)

def test(self):
for name, pipe, _ in pipelines:
Expand All @@ -148,14 +153,10 @@ def test(self):
print("name=", name, res)

def run_on_fn(self, fn, steps=1):
inputs = self.model.inputs_from_xarray(self.eval_era5.isel(time=0))
input_forcings = self.model.forcings_from_xarray(self.eval_era5.isel(time=0))
rng_key = jax.random.key(42) # optional for deterministic models
initial_state = self.model.encode(inputs, input_forcings, rng_key)
map(
lambda x: x.block_until_ready(),
fn(
initial_state,
self.initial_state,
self.all_forcings,
),
)
Expand All @@ -166,7 +167,7 @@ def run_on_fn(self, fn, steps=1):
))""",
globals={
"fn": fn,
"initial_state": initial_state,
"initial_state": self.initial_state,
"all_forcings": self.all_forcings,
},
).timeit(steps)
Expand Down

0 comments on commit 6ee75fd

Please sign in to comment.