From f01e78f08c734c77b1d399d92778aa14b975855f Mon Sep 17 00:00:00 2001 From: rlyu Date: Mon, 13 Jan 2025 15:49:34 -0800 Subject: [PATCH] Update copy and extend --- torchdrivesim/kinematic.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchdrivesim/kinematic.py b/torchdrivesim/kinematic.py index 4e29093..44e5a54 100644 --- a/torchdrivesim/kinematic.py +++ b/torchdrivesim/kinematic.py @@ -218,16 +218,17 @@ def fit_action(self, future_state: Tensor, current_state: Optional[Tensor] = Non def copy(self, other=None): if other is None: - other = self.__class__(models=[m.copy() for m in self.models], batch_assignments=self.batch_assignments, dt=self.dt) + other = self.__class__(models=[m.copy() for m in self.models], model_assignments=self.model_assignments, dt=self.dt) other.set_params(**self.get_params()) other.set_state(self.get_state()) return other def extend(self, n: int): enlarge = lambda x: x.unsqueeze(1).expand((x.shape[0], n) + x.shape[1:]).reshape((n * x.shape[0],) + x.shape[1:]) + state = self.get_state() self.model_assignments = enlarge(self.model_assignments) self.map_param(enlarge) - self.set_state(enlarge(self.get_state())) + self.set_state(enlarge(state)) def select_batch_elements(self, idx): self.model_assignments = self.model_assignments[idx]