forked from openvla/openvla
-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathvla.py
319 lines (226 loc) · 11.6 KB
/
vla.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
"""
vla.py
Draccus Dataclass Definition for a VLAConfig object, with various registered subclasses for each VLA experiment and
model configuration thereof. A given VLA model (`policy`) configures the following attributes:
- Data Mixture (e.g., Bridge, OXE_MAGIC_SOUP, etc.)
- Base VLM from Prismatic Registry (e.g., `prism-dinosiglip+7b`)
- VLA Model Architecture / Parameters (e.g., freeze vision encoder, last layer finetuning)
- Training / Optimization Hyperparameters
"""
from dataclasses import dataclass
from enum import Enum, unique
from pathlib import Path
from typing import Optional, Union
from draccus import ChoiceRegistry
@dataclass
class VLAConfig(ChoiceRegistry):
# fmt: off
vla_id: str # Unique VLA Policy ID that fully specifies a configuration variant
base_vlm: Union[str, Path] # Base VLM as ID/Path to Run Directory (e.g., `prism-dinosiglip+7b`)
freeze_vision_backbone: bool # Freeze Vision Backbone Parameters (akin to pretraining)
freeze_llm_backbone: bool # Freeze LLM Backbone parameters
unfreeze_last_llm_layer: bool # Unfreeze final layer of LLM (only takes effect if LLM is frozen)
# Data Mixture Parameters
data_mix: str # Open-X Embodiment Dataset =>> Unique Mixture ID (e.g., `bridge`)
shuffle_buffer_size: int # Size of Shuffle Buffer (100K for Bridge, 1M for OXE)
# Optimization Parameters
epochs: int # Epochs to Run (in case `max_steps` is not specified)
max_steps: Optional[int] # [Optional] Max Gradient Steps to Run (overrides `epochs`)
save_every_n_steps: Optional[int]
expected_world_size: int # Expected # of GPUs =>> allows us to gate training on hardware
global_batch_size: int # Global Batch Size (divided across processes / world size)
per_device_batch_size: int # Per-Device Batch Size (per-process / individual GPU)
# =>> # of accumulation steps is auto-computed
learning_rate: float # Peak Learning Rate (`lr_scheduler_type` sets warmup/decay)
weight_decay: float # Weight Decay for AdamW Optimizer
max_grad_norm: float # Max Grad Norm (for global gradient clipping)
lr_scheduler_type: str # LR Scheduler (usually: "constant" | "linear-warmup+cosine-decay")
warmup_ratio: float # Fraction of Steps to Warmup (for warmup LR schedulers)
train_strategy: str # Train Strategy (default "fsdp-full-shard")
action_tokenizer: str
image_sequence_len: int
use_wrist_image: bool
# Enable Gradient/Activation Checkpointing (for the LLM Backbone)
enable_gradient_checkpointing: bool = True # Enable Gradient/Activation Checkpointing during Training
# Mixed Precision Training via Torch Native AMP (`autocast`)
enable_mixed_precision_training: bool = True # Enable Traditional BF16 Mixed Precision
reduce_in_full_precision: bool = True # Accumulate/Reduce All-Gather Gradients in FP32 Full Precision
# fmt: on
# === OpenVLA Training Configurations ===
# = [8 GPU] Fast Iteration =>> SigLIP 224px + Bridge =
@dataclass
class Exp_SigLIP_224px_Bridge(VLAConfig):
vla_id: str = "siglip-224px+mx-bridge"
base_vlm: Union[str, Path] = "siglip-224px+7b"
image_sequence_len: int = 1
use_wrist_image: bool = False
freeze_vision_backbone: bool = False
freeze_llm_backbone: bool = False
unfreeze_last_llm_layer: bool = False
# Data Mixture Parameters
data_mix: str = "bridge"
shuffle_buffer_size: int = 256_000
# Optimization Parameters
epochs: int = 1000
max_steps: Optional[int] = None
save_every_n_steps: Optional[int] = 25000
expected_world_size: int = 8
global_batch_size: int = 256
per_device_batch_size: int = 32
learning_rate: float = 2e-5
weight_decay: float = 0.0
max_grad_norm: float = 1.0
lr_scheduler_type: str = "constant"
warmup_ratio: float = 0.0
train_strategy: str = "fsdp-full-shard"
action_tokenizer: str = "action_tokenizer"
# = [8 GPU] SigLIP 224px Frozen Vision Backbone + Bridge =
@dataclass
class Exp_FreezeVIT_SigLIP_224px_Bridge(Exp_SigLIP_224px_Bridge):
vla_id: str = "siglip-224px-icy+mx-bridge"
base_vlm: Union[str, Path] = "siglip-224px+7b"
freeze_vision_backbone: bool = True
# = [8 GPU] Fast Iteration =>> DINO-SigLIP 224px + Bridge =
@dataclass
class Exp_DinoSigLIP_224px_Bridge(Exp_SigLIP_224px_Bridge):
vla_id: str = "prism-dinosiglip-224px+mx-bridge"
base_vlm: Union[str, Path] = "prism-dinosiglip-224px+7b"
data_mix: str = "bridge"
# = [64 GPU] SigLIP 224px + OXE Magic Soup =
@dataclass
class Exp_SigLIP_224px_OXE_Magic_Soup(Exp_SigLIP_224px_Bridge):
vla_id: str = "siglip-224px+mx-oxe-magic-soup"
base_vlm: Union[str, Path] = "siglip-224px+7b"
data_mix: str = "oxe_magic_soup"
expected_world_size: int = 64
global_batch_size: int = 2048
per_device_batch_size: int = 32
# = [8 GPU] Qwen2.5 0.5B SigLIP 224px + OXE Magic Soup =
@dataclass
class Exp_Qwen25_DinoSigLIP_224px_0_5B_OXE_Magic_Soup(Exp_SigLIP_224px_Bridge):
vla_id: str = "prism-qwen25-dinosiglip-224px+0_5b+mx-oxe-magic-soup"
base_vlm: Union[str, Path] = "prism-qwen25-extra-dinosiglip-224px+0_5b"
data_mix: str = "oxe_magic_soup"
action_tokenizer: str = "extra_action_tokenizer"
expected_world_size: int = 8
global_batch_size: int = 256
per_device_batch_size: int = 32
@dataclass
class Exp_Qwen25_DinoSigLIP_224px_0_5B_LIBERO_90(Exp_Qwen25_DinoSigLIP_224px_0_5B_OXE_Magic_Soup):
vla_id: str = "prism-qwen25-dinosiglip-224px+0_5b+mx-libero-90"
data_mix: str = "libero_90"
expected_world_size: int = 8
global_batch_size: int = 256
per_device_batch_size: int = 32
@dataclass
class Exp_Qwen25_DinoSigLIP_224px_T2_0_5B_LIBERO_90(Exp_Qwen25_DinoSigLIP_224px_0_5B_LIBERO_90):
vla_id: str = "prism-qwen25-dinosiglip-224px-t2+0_5b+mx-libero-90"
image_sequence_len: int = 2
@dataclass
class Exp_Qwen25_DinoSigLIP_224px_wrist_0_5B_LIBERO_90(Exp_Qwen25_DinoSigLIP_224px_0_5B_LIBERO_90):
vla_id: str = "prism-qwen25-dinosiglip-224px-wrist+0_5b+mx-libero-90"
image_sequence_len: int = 2
use_wrist_image: bool = True
## bridge Qwen
@dataclass
class Exp_Qwen25_DinoSigLIP_224px_0_5B_Bridge(Exp_SigLIP_224px_Bridge):
vla_id: str = "prism-qwen25-dinosiglip-224px+0_5b+mx-bridge"
base_vlm: Union[str, Path] = "prism-qwen25-extra-dinosiglip-224px+0_5b"
data_mix: str = "bridge_dataset" # direct dataset
action_tokenizer: str = "extra_action_tokenizer"
expected_world_size: int = 8
global_batch_size: int = 256
per_device_batch_size: int = 32
@dataclass
class Exp_DinoSigLIP_224px_LIBERO_90(Exp_DinoSigLIP_224px_Bridge):
vla_id: str = "prism-dinosiglip-224px+mx-libero-90"
data_mix: str = "libero_90"
expected_world_size: int = 8
global_batch_size: int = 256
per_device_batch_size: int = 32
# = [64 GPU] DINO-SigLIP 224px + OXE Magic Soup++ =
@dataclass
class Exp_DinoSigLIP_224px_OXE_Magic_Soup_Plus(Exp_SigLIP_224px_Bridge):
vla_id: str = "prism-dinosiglip-224px+mx-oxe-magic-soup-plus"
base_vlm: Union[str, Path] = "prism-dinosiglip-224px+7b"
# Note =>> We adopt two stages, training on a mixture including DROID for 70% of training, before resampling!
# data_mix: str = "oxe_magic_soup_plus"
data_mix: str = "oxe_magic_soup_plus_minus"
expected_world_size: int = 64
global_batch_size: int = 2048
per_device_batch_size: int = 32
# === OpenVLA Fine-tuning Configurations ===
# = [8 GPU] SigLIP 224px + T-DROID =
@dataclass
class Exp_SigLIP_224px_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge):
vla_id: str = "siglip-224px+mx-tdroid_carrot_in_bowl"
base_vlm: Union[str, Path] = "siglip-224px+7b"
data_mix: str = "tdroid_carrot_in_bowl"
@dataclass
class Exp_SigLIP_224px_TDROID_PourCornInPot(Exp_SigLIP_224px_Bridge):
vla_id: str = "siglip-224px+mx-tdroid_pour_corn_in_pot"
base_vlm: Union[str, Path] = "siglip-224px+7b"
data_mix: str = "tdroid_pour_corn_in_pot"
# = [8 GPU] SigLIP 224px + T-DROID -- Partial Finetuning =
@dataclass
class Exp_SigLIP_224px_Icy_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge):
vla_id: str = "siglip-224px-icy+mx-tdroid_carrot_in_bowl"
base_vlm: Union[str, Path] = "siglip-224px+7b"
freeze_vision_backbone: bool = True
freeze_llm_backbone: bool = False
data_mix: str = "tdroid_carrot_in_bowl"
@dataclass
class Exp_SigLIP_224px_LastLayer_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge):
vla_id: str = "siglip-224px-last_layer+mx-tdroid_carrot_in_bowl"
base_vlm: Union[str, Path] = "siglip-224px+7b"
freeze_vision_backbone: bool = True
freeze_llm_backbone: bool = True
unfreeze_last_llm_layer: bool = True
data_mix: str = "tdroid_carrot_in_bowl"
@dataclass
class Exp_SigLIP_224px_Sandwich_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge):
vla_id: str = "siglip-224px-sandwich+mx-tdroid_carrot_in_bowl"
base_vlm: Union[str, Path] = "siglip-224px+7b"
freeze_vision_backbone: bool = False
freeze_llm_backbone: bool = True
unfreeze_last_llm_layer: bool = True
data_mix: str = "tdroid_carrot_in_bowl"
# === [8 GPU] SigLIP 224px + FrankaWipe ===
@dataclass
class Exp_SigLIP_224px_Droid_Wipe(Exp_SigLIP_224px_Bridge):
vla_id: str = "siglip-224px+mx-droid_wipe"
base_vlm: Union[str, Path] = "siglip-224px+7b"
data_mix: str = "droid_wipe"
# === Define a VLA Registry Enum for Reference & Validation ===
@unique
class VLARegistry(Enum):
# Sanity Check Configurations =>> BridgeV2
SIGLIP_224PX_MX_BRIDGE = Exp_SigLIP_224px_Bridge
DINOSIGLIP_224PX_MX_BRIDGE = Exp_DinoSigLIP_224px_Bridge
DINOSIGLIP_224PX_MX_LIBERO_90 = Exp_DinoSigLIP_224px_LIBERO_90
# SigLIP Frozen Backbone Experiment
FREEZE_SIGLIP_224PX_MX_BRIDGE = Exp_FreezeVIT_SigLIP_224px_Bridge
# [OpenVLA v0.1 7B] SigLIP 224px + OXE Magic Soup
SIGLIP_224PX_MX_OXE_MAGIC_SOUP = Exp_SigLIP_224px_OXE_Magic_Soup
# [OpenVLA 7B] DINO + SigLIP 224px + OXE Magic Soup++
DINOSIGLIP_224PX_MX_OXE_MAGIC_SOUP_PLUS = Exp_DinoSigLIP_224px_OXE_Magic_Soup_Plus
# [OpenVLA 0.5B] Qwen backbones
QWEN25_DINOSIGLIP_224PX_0_5B_MX_OXE_MAGIC_SOUP = Exp_Qwen25_DinoSigLIP_224px_0_5B_OXE_Magic_Soup
QWEN25_DINOSIGLIP_224PX_0_5B_LIBERO_90 = Exp_Qwen25_DinoSigLIP_224px_0_5B_LIBERO_90
QWEN25_DINOSIGLIP_224PX_T2_0_5B_LIBERO_90 = Exp_Qwen25_DinoSigLIP_224px_T2_0_5B_LIBERO_90
QWEN25_DINOSIGLIP_224PX_WRIST_0_5B_LIBERO_90 = Exp_Qwen25_DinoSigLIP_224px_wrist_0_5B_LIBERO_90
QWEN25_DINOSIGLIP_224PX_0_5B_BRIDGE = Exp_Qwen25_DinoSigLIP_224px_0_5B_Bridge
# === TDROID Fine-tuning Configs ===
SIGLIP_224PX_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_TDROID_CarrotInBowl
SIGLIP_224PX_MX_TDROID_POUR_CORN_IN_POT = Exp_SigLIP_224px_TDROID_PourCornInPot
SIGLIP_224PX_ICY_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_Icy_TDROID_CarrotInBowl
SIGLIP_224PX_LASTLAYER_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_LastLayer_TDROID_CarrotInBowl
SIGLIP_224PX_SANDWICH_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_Sandwich_TDROID_CarrotInBowl
# === DROID Fine-tuning Configs ===
SIGLIP_224PX_MX_DROID_WIPE = Exp_SigLIP_224px_Droid_Wipe
@property
def vla_id(self) -> str:
return self.value.vla_id
# Register VLAs in Choice Registry
for vla_variant in VLARegistry:
VLAConfig.register_subclass(vla_variant.vla_id, vla_variant.value)