-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9b4006f
commit 3d2cd90
Showing
55 changed files
with
31,505 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
pip-wheel-metadata/ | ||
share/python-wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.nox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
*.py,cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
db.sqlite3 | ||
db.sqlite3-journal | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# IPython | ||
profile_default/ | ||
ipython_config.py | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# pipenv | ||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||
# However, in case of collaboration, if having platform-specific dependencies or dependencies | ||
# having no cross-platform support, pipenv may install dependencies that don't work, or not | ||
# install all needed dependencies. | ||
#Pipfile.lock | ||
|
||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow | ||
__pypackages__/ | ||
|
||
# Celery stuff | ||
celerybeat-schedule | ||
celerybeat.pid | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
# env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
.dmypy.json | ||
dmypy.json | ||
|
||
# Pyre type checker | ||
.pyre/ | ||
|
||
*.pkl | ||
*.npz | ||
logs/ | ||
datasets/ | ||
pretrained-models/ | ||
.vscode/ | ||
outputs | ||
experiments | ||
caches/ | ||
*.eps | ||
*.csv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
Transparenz und offene Kommunikation sind innerhalb und außerhalb der Volkswagen AG für uns selbstverständlich. Wichtige Informationen über die Volkswagen AG, die Möglichkeiten der Kontaktaufnahme sowie zum Datenschutz finden Sie hier. | ||
|
||
I. Informationen über die Volkswagen AG: siehe | ||
|
||
https://www.volkswagenag.com/de/meta/provider-identification.html | ||
|
||
II. Information über die Erhebung personenbezogener Daten | ||
|
||
(1) Wir stellen Ihnen auf der Plattform der Firma GitHub, Inc., 88 Colin P Kelly Jr Street, San Francisco, CA 94107, USA, ein Repository zur Verfügung, über das u.a. Softwarebibliotheken der VOLKSWAGEN AG dokumentiert und zum Download angeboten werden. Im Folgenden informieren wir über die Erhebung personenbezogener Daten bei Nutzung unseres Repository. | ||
|
||
(2) Personenbezogene Daten sind alle Daten, die auf Sie persönlich beziehbar sind, z.B. E-Mail-Adresse, Nutzerprofil u.ä.. Die Erhebung personenbezogener Daten erfolgt bei Nutzung unseres Repository allein durch die Firma GitHub. | ||
|
||
(3) Wir erhalten von der Firma GitHub lediglich anonymisierte Informationen, also Daten ohne Personenbezug. Diese Informationen betreffen die Aktivitäten der Nutzer von GitHub bzw. unseres Repository. So übermittelt uns GitHub beispielsweise Statistiken über die Nutzung unserer Open-Source-Projekte, etwa über die Anzahl der „Stars“ (Lesezeichen) in Bezug auf unser Repository. | ||
|
||
(4) Verantwortlicher für die Datenerhebung und damit Ansprechpartner für Datenschutzanfragen und Betroffenenrechte ist allein GitHub. Wir verweisen insoweit auf die GitHub Privacy Notice unter https://help.github.com/articles/github-privacy-statement/#how-we-share-the-information-we-collect |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
The MIT License | ||
=============== | ||
|
||
Copyright (C) 2022-2023 Volkswagen Aktiengesellschaft, | ||
Berliner Ring 2, 38440 Wolfsburg, Germany | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
from copy import deepcopy | ||
|
||
import numpy as np | ||
import torch | ||
from einops import rearrange | ||
|
||
from aime.data import ArrayDict | ||
|
||
|
||
class RandomActor: | ||
"""Actor that random samples from the action space""" | ||
|
||
def __init__(self, action_space) -> None: | ||
self.action_space = action_space | ||
|
||
def __call__(self, obs): | ||
return self.action_space.sample() | ||
|
||
def reset(self): | ||
pass | ||
|
||
|
||
class PolicyActor: | ||
"""Model-based policy for taking actions""" | ||
|
||
def __init__(self, ssm, policy, eval=True) -> None: | ||
""" | ||
ssm : a state space model | ||
policy : a policy take a hidden state and output the distribution of actions | ||
""" # noqa: E501 | ||
self.ssm = ssm | ||
self.policy = policy | ||
self.eval = eval | ||
|
||
def reset(self): | ||
self.state = self.ssm.reset(1) | ||
self.model_parameter = list(self.ssm.parameters())[0] | ||
|
||
def __call__(self, obs): | ||
obs = ArrayDict(deepcopy(obs)) | ||
obs.to_torch() | ||
obs.expand_dim_equal_() | ||
obs.to(self.model_parameter) | ||
obs.vmap_(lambda v: v.unsqueeze(dim=0)) | ||
|
||
self.state, _ = self.ssm.posterior_step(obs, obs["pre_action"], self.state) | ||
state_feature = self.ssm.get_state_feature(self.state) | ||
action_dist = self.policy(state_feature) | ||
action = action_dist.mode if self.eval else action_dist.sample() | ||
action = action.detach().cpu().numpy()[0] | ||
|
||
return action | ||
|
||
|
||
class StackPolicyActor: | ||
"""Actor for the BCO policy, who needs a stack of observation to operate""" | ||
|
||
def __init__(self, encoder, policy, stack: int) -> None: | ||
self.encoder = encoder | ||
self.policy = policy | ||
self.stack = stack | ||
self.embs = [] | ||
|
||
def reset(self): | ||
self.embs = [] | ||
self.model_parameter = list(self.policy.parameters())[0] | ||
|
||
def __call__(self, obs): | ||
obs = ArrayDict(deepcopy(obs)) | ||
obs.to_torch() | ||
obs.expand_dim_equal_() | ||
obs.to(self.model_parameter) | ||
obs.vmap_(lambda v: v.unsqueeze(dim=0)) | ||
emb = self.encoder(obs) | ||
|
||
if len(self.embs) == 0: | ||
for _ in range(self.stack): | ||
self.embs.append(emb) | ||
else: | ||
self.embs.pop(0) | ||
self.embs.append(emb) | ||
|
||
emb = torch.stack(self.embs) | ||
|
||
emb = rearrange(emb, "t b f -> b (t f)") | ||
|
||
action = self.policy(emb) | ||
action = action.detach().cpu().numpy()[0] | ||
|
||
return action | ||
|
||
|
||
class GuassianNoiseActorWrapper: | ||
def __init__(self, actor, noise_level, action_space) -> None: | ||
self._actor = actor | ||
self.noise_level = noise_level | ||
self.action_space = action_space | ||
|
||
def reset(self): | ||
return self._actor.reset() | ||
|
||
def __call__(self, obs): | ||
action = self._actor(obs) | ||
action = action + self.noise_level * np.random.randn(*action.shape) | ||
action = np.clip(action, self.action_space.low, self.action_space.high) | ||
return action |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
defaults: | ||
- _self_ | ||
- env: ??? | ||
- world_model: rssmo | ||
|
||
model_name: ??? | ||
demonstration_dataset_name: ??? | ||
|
||
freeze_model: True | ||
random_policy: True | ||
num_expert_trajectories: 100 | ||
algo_name: aime | ||
seed: 42 | ||
log_name: "${env.name}/${environment_setup}/${algo_name}/${demonstration_dataset_name}/${model_name}/${num_expert_trajectories}/use_idm=${use_idm}/${world_model.idm_mode}/kl_only=${kl_only}" | ||
horizon: 50 | ||
batch_size: 50 | ||
batch_per_epoch: 100 | ||
epoch: 500 | ||
test_period: 10 | ||
num_test_trajectories: 10 | ||
final_num_test_trajectories: 100 | ||
|
||
use_fp16: false | ||
model_lr: 3e-4 | ||
grad_clip: 100.0 | ||
|
||
environment_setup: visual | ||
|
||
use_idm: False | ||
kl_only: False | ||
|
||
env: | ||
action_repeat: 2 | ||
render: False | ||
|
||
world_model: | ||
nll_reweight: dim_wise | ||
idm_mode: detach | ||
|
||
min_std: null | ||
|
||
kl_scale: 1.0 | ||
free_nats: 0.0 | ||
kl_rebalance: null | ||
|
||
encoders: | ||
tabular: | ||
name: identity | ||
visual: | ||
name: cnn_ha | ||
|
||
decoders: | ||
tabular: | ||
name: smlp | ||
hidden_size: 128 | ||
hidden_layers: 2 | ||
visual: | ||
name: cnn_ha | ||
|
||
probes: | ||
tabular: | ||
name: dmlp | ||
hidden_size: 128 | ||
hidden_layers: 2 | ||
visual: | ||
name: cnn_ha | ||
|
||
policy: | ||
hidden_size: 128 | ||
hidden_layers: 2 | ||
|
||
vnet: | ||
hidden_size: 128 | ||
hidden_layers: 2 |
Oops, something went wrong.