Skip to content

Commit

Permalink
Merge pull request #82 from inverted-ai/fix-compound-kinematic-model-…
Browse files Browse the repository at this point in the history
…argument

Update copy and extend for compound km
  • Loading branch information
Ruishenl authored Jan 14, 2025
2 parents 2c9d8c3 + f01e78f commit 5cfbd5d
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions torchdrivesim/kinematic.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,16 +218,17 @@ def fit_action(self, future_state: Tensor, current_state: Optional[Tensor] = Non

def copy(self, other=None):
if other is None:
other = self.__class__(models=[m.copy() for m in self.models], batch_assignments=self.batch_assignments, dt=self.dt)
other = self.__class__(models=[m.copy() for m in self.models], model_assignments=self.model_assignments, dt=self.dt)
other.set_params(**self.get_params())
other.set_state(self.get_state())
return other

def extend(self, n: int):
enlarge = lambda x: x.unsqueeze(1).expand((x.shape[0], n) + x.shape[1:]).reshape((n * x.shape[0],) + x.shape[1:])
state = self.get_state()
self.model_assignments = enlarge(self.model_assignments)
self.map_param(enlarge)
self.set_state(enlarge(self.get_state()))
self.set_state(enlarge(state))

def select_batch_elements(self, idx):
self.model_assignments = self.model_assignments[idx]
Expand Down

0 comments on commit 5cfbd5d

Please sign in to comment.