forked from openvla/openvla
-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathdeploy.py
145 lines (115 loc) · 5.42 KB
/
deploy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""
deploy.py
Provide a lightweight server/client implementation for deploying OpenVLA models (through the HF AutoClass API) over a
REST API. This script implements *just* the server, with specific dependencies and instructions below.
Note that for the *client*, usage just requires numpy/json-numpy, and requests; example usage below!
Dependencies:
=> Server (runs OpenVLA model on GPU): `pip install uvicorn fastapi json-numpy`
=> Client: `pip install requests json-numpy`
Client (Standalone) Usage (assuming a server running on 0.0.0.0:8000):
```
import requests
import json_numpy
json_numpy.patch()
import numpy as np
action = requests.post(
"http://0.0.0.0:8000/act",
json={"image": np.zeros((256, 256, 3), dtype=np.uint8), "instruction": "do something"}
).json()
Note that if your server is not accessible on the open web, you can use ngrok, or forward ports to your client via ssh:
=> `ssh -L 8000:localhost:8000 ssh USER@<SERVER_IP>`
"""
import os.path
# ruff: noqa: E402
import json_numpy
json_numpy.patch()
import json
import logging
import traceback
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Optional, Union
import draccus
import torch
import uvicorn
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from PIL import Image
from transformers import AutoModelForVision2Seq, AutoProcessor
# === Utilities ===
SYSTEM_PROMPT = (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
)
def get_openvla_prompt(instruction: str, openvla_path: Union[str, Path]) -> str:
if "v01" in openvla_path:
return f"{SYSTEM_PROMPT} USER: What action should the robot take to {instruction.lower()}? ASSISTANT:"
else:
return f"In: What action should the robot take to {instruction.lower()}?\nOut:"
# === Server Interface ===
class OpenVLAServer:
def __init__(self, openvla_path: Union[str, Path], attn_implementation: Optional[str] = "flash_attention_2") -> Path:
"""
A simple server for OpenVLA models; exposes `/act` to predict an action for a given image + instruction.
=> Takes in {"image": np.ndarray, "instruction": str, "unnorm_key": Optional[str]}
=> Returns {"action": np.ndarray}
"""
self.openvla_path, self.attn_implementation = openvla_path, attn_implementation
self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
# Load VLA Model using HF AutoClasses
self.processor = AutoProcessor.from_pretrained(self.openvla_path, trust_remote_code=True)
self.vla = AutoModelForVision2Seq.from_pretrained(
self.openvla_path,
attn_implementation=attn_implementation,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True,
).to(self.device)
# [Hacky] Load Dataset Statistics from Disk (if passing a path to a fine-tuned model)
if os.path.isdir(self.openvla_path):
with open(Path(self.openvla_path) / "dataset_statistics.json", "r") as f:
self.vla.norm_stats = json.load(f)
def predict_action(self, payload: Dict[str, Any]) -> str:
try:
if double_encode := "encoded" in payload:
# Support cases where `json_numpy` is hard to install, and numpy arrays are "double-encoded" as strings
assert len(payload.keys()) == 1, "Only uses encoded payload!"
payload = json.loads(payload["encoded"])
# Parse payload components
image, instruction = payload["image"], payload["instruction"]
unnorm_key = payload.get("unnorm_key", None)
# Run VLA Inference
prompt = get_openvla_prompt(instruction, self.openvla_path)
inputs = self.processor(prompt, Image.fromarray(image).convert("RGB")).to(self.device, dtype=torch.bfloat16)
action = self.vla.predict_action(**inputs, unnorm_key=unnorm_key, do_sample=False)
if double_encode:
return JSONResponse(json_numpy.dumps(action))
else:
return JSONResponse(action)
except: # noqa: E722
logging.error(traceback.format_exc())
logging.warning(
"Your request threw an error; make sure your request complies with the expected format:\n"
"{'image': np.ndarray, 'instruction': str}\n"
"You can optionally an `unnorm_key: str` to specific the dataset statistics you want to use for "
"de-normalizing the output actions."
)
return "error"
def run(self, host: str = "0.0.0.0", port: int = 8000) -> None:
self.app = FastAPI()
self.app.post("/act")(self.predict_action)
uvicorn.run(self.app, host=host, port=port)
@dataclass
class DeployConfig:
# fmt: off
openvla_path: Union[str, Path] = "openvla/openvla-7b" # HF Hub Path (or path to local run directory)
# Server Configuration
host: str = "0.0.0.0" # Host IP Address
port: int = 8000 # Host Port
# fmt: on
@draccus.wrap()
def deploy(cfg: DeployConfig) -> None:
server = OpenVLAServer(cfg.openvla_path)
server.run(cfg.host, port=cfg.port)
if __name__ == "__main__":
deploy()