-
Notifications
You must be signed in to change notification settings - Fork 1
/
enjoy.py
120 lines (96 loc) · 3.88 KB
/
enjoy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import argparse
import os
import types
import numpy as np
import torch
from torch.autograd import Variable
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
from baselines.common.vec_env.vec_normalize import VecNormalize
from envs import make_env
parser = argparse.ArgumentParser(description='RL')
parser.add_argument('--seed', type=int, default=1,
help='random seed (default: 1)')
parser.add_argument('--num-stack', type=int, default=4,
help='number of frames to stack (default: 4)')
parser.add_argument('--log-interval', type=int, default=10,
help='log interval, one log per n updates (default: 10)')
parser.add_argument('--env-name', default='PongNoFrameskip-v4',
help='environment to train on (default: PongNoFrameskip-v4)')
parser.add_argument('--load-dir', default='./trained_models/',
help='directory to save agent logs (default: ./trained_models/)')
parser.add_argument('--gen', action='store_true', default=False,
help='Use the generic policy')
parser.add_argument('--att', action='store_true', default=False,
help='Get with attention')
args = parser.parse_args()
env = make_env(args.env_name, args.seed, 0, './tmp/test/'+args.env_name+'/')
env = DummyVecEnv([env])
if args.gen:
actor_critic, ob_rms = \
torch.load(args.load_dir)
else:
actor_critic, ob_rms = \
torch.load(os.path.join(args.load_dir, args.env_name + ".pt"))
if len(env.observation_space.shape) == 1:
env = VecNormalize(env, ret=False)
env.ob_rms = ob_rms
# An ugly hack to remove updates
def _obfilt(self, obs):
if self.ob_rms:
obs = np.clip((obs - self.ob_rms.mean) / np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob, self.clipob)
return obs
else:
return obs
env._obfilt = types.MethodType(_obfilt, env)
render_func = env.venv.envs[0].render
else:
render_func = env.envs[0].render
obs_shape = env.observation_space.shape
obs_shape = (obs_shape[0] * args.num_stack, *obs_shape[1:])
current_obs = torch.zeros(1, *obs_shape)
states = torch.zeros(1, actor_critic.state_size)
masks = torch.zeros(1, 1)
def update_current_obs(obs):
shape_dim0 = env.observation_space.shape[0]
obs = torch.from_numpy(obs).float()
if args.num_stack > 1:
current_obs[:, :-shape_dim0] = current_obs[:, shape_dim0:]
current_obs[:, -shape_dim0:] = obs
#render_func('human')
obs = env.reset()
update_current_obs(obs)
if args.env_name.find('Bullet') > -1:
import pybullet as p
torsoId = -1
for i in range(p.getNumBodies()):
if (p.getBodyInfo(i)[0].decode() == "torso"):
torsoId = i
for i in range(50000):
value, action, _, states = actor_critic.act(Variable(current_obs, volatile=True),
Variable(states, volatile=True),
Variable(masks, volatile=True),
deterministic=True)
if i % 1000 == 0:
print("STEP: ", i)
states = states.data
cpu_actions = action.data.squeeze(1).cpu().numpy()
# Obser reward and next obs
obs, reward, done, _ = env.step(cpu_actions)
masks.fill_(0.0 if done else 1.0)
if done:
sd = np.random.randint(1000, size=1)
print("New Random", sd[0])
env.envs[0].seed(int(sd[0]))
if current_obs.dim() == 4:
current_obs *= masks.unsqueeze(2).unsqueeze(2)
else:
current_obs *= masks
update_current_obs(obs)
if args.env_name.find('Bullet') > -1:
if torsoId > -1:
distance = 5
yaw = 0
humanPos, humanOrn = p.getBasePositionAndOrientation(torsoId)
p.resetDebugVisualizerCamera(distance, yaw, -20, humanPos)
render_func('human')
print(dir(env.envs[0]))