You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, thank you very much for providing the code! I've installed jax with CUDA, so now solving the kernel_fn is faster. However, the build_with_representer_proxy_batch is still quite slow, I assume it is due to the solve_bilevel_opt_representer_proxy, which requires calculating the implicit gradient, or is it because of something else? Is there a way to make the computation faster? Thank you!
The text was updated successfully, but these errors were encountered:
For speeding up build_with_representer_proxy_batch, one could set max_outer_it and max_inner_it to smaller values. However, the representer proxy is only practical for small coresets (< 500) due to its complexity; for larger coresets, I would recommend the Nystrom proxy (build_with_nystrom_proxy and setting nr_presampled_transforms=1 if no data augmentation is used).
Hi, thank you very much for providing the code! I've installed jax with CUDA, so now solving the kernel_fn is faster. However, the build_with_representer_proxy_batch is still quite slow, I assume it is due to the solve_bilevel_opt_representer_proxy, which requires calculating the implicit gradient, or is it because of something else? Is there a way to make the computation faster? Thank you!
The text was updated successfully, but these errors were encountered: