-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathutils.py
23 lines (21 loc) · 882 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
def save(args, save_name, model, wandb, ep=None):
import os
save_dir = './trained_models/'
if not os.path.exists(save_dir):
os.makedirs(save_dir)
if not ep == None:
torch.save(model.state_dict(), save_dir + args.run_name + save_name + str(ep) + ".pth")
wandb.save(save_dir + args.run_name + save_name + str(ep) + ".pth")
else:
torch.save(model.state_dict(), save_dir + args.run_name + save_name + ".pth")
wandb.save(save_dir + args.run_name + save_name + ".pth")
def collect_random(env, dataset, num_samples=200):
state = env.reset()
for _ in range(num_samples):
action = env.action_space.sample()
next_state, reward, done, _ = env.step(action)
dataset.add(state, action, reward, next_state, done)
state = next_state
if done:
state = env.reset()