Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup unused path. #282

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8]
python-version: ['3.10']
steps:
- name: Cancel previous
uses: styfle/cancel-workflow-action@0.8.0
Expand Down
12 changes: 6 additions & 6 deletions vit_jax/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ def init_model():
filename = best.filename
logging.info('Selected fillename="%s" for "%s" with final_val=%.3f',
filename, model_or_filename, best.final_val)
pretrained_path = os.path.join(config.pretrained_dir,
f'{config.model.model_name}.npz')
else:
# ViT / Mixer papers
filename = config.model.model_name
Expand All @@ -140,7 +138,7 @@ def init_model():
optax.sgd(
learning_rate=lr_fn,
momentum=0.9,
accumulator_dtype='bfloat16',
accumulator_dtype=config.optim_dtype,
),
)

Expand Down Expand Up @@ -212,7 +210,7 @@ def init_model():
(step == total_steps)):

accuracies = []
lt0 = time.time()
tt0 = time.time()
for test_batch in input_pipeline.prefetch(ds_test, config.prefetch):
logits = infer_fn_repl(
dict(params=params_repl), test_batch['image'])
Expand All @@ -223,8 +221,7 @@ def init_model():
accuracy_test = np.mean(accuracies)
img_sec_core_test = (
config.batch_eval * ds_test.cardinality().numpy() /
(time.time() - lt0) / jax.device_count())
lt0 = time.time()
(time.time() - tt0) / jax.device_count())

lr = float(lr_fn(step))
logging.info(f'Step: {step} ' # pylint: disable=logging-fstring-interpolation
Expand All @@ -237,14 +234,17 @@ def init_model():
accuracy_test=accuracy_test,
lr=lr,
img_sec_core_test=img_sec_core_test))
lt0 += time.time() - tt0

# Store checkpoint.
if ((config.checkpoint_every and step % config.eval_every == 0) or
step == total_steps):
tt0 = time.time()
checkpoint_path = flax_checkpoints.save_checkpoint(
workdir, (flax.jax_utils.unreplicate(params_repl),
flax.jax_utils.unreplicate(opt_state_repl), step), step)
logging.info('Stored checkpoint at step %d to "%s"', step,
checkpoint_path)
lt0 += time.time() - tt0

return flax.jax_utils.unreplicate(params_repl)
Loading