-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathagent_handler.py
210 lines (178 loc) · 6.86 KB
/
agent_handler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
# Echo client program
import argparse
import glob
# from sample_agents import MyAgent
# from sample_agents import RandomAgent
import os
import random
import socket
import subprocess
import sys
from time import sleep
import matplotlib.pyplot as plt
import numpy as np
from sample_agents import RLAgent
class AgentHandler(object):
STAGE_STEPS_NUM = 1000
def __init__(self, host, port, num_episodes, agent_params):
self.agent_params = agent_params
self.num_episodes = num_episodes
self.ChosenAgent = RLAgent
self.LINE_SEPARATOR = '\n'
self.BUF_SIZE = 4096 # in bytes
self.connected = False
self.host = host
self.sock = None
self.port = port
self.agent = None
def connect(self, host, port):
if not self.connected:
try:
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.connect((host, port))
except socket.error:
print('Unable to contact environment at the given host/port.')
self.connected = False
sys.exit(1)
self.connected = True
def disconnect(self):
if self.sock:
self.sock.close()
self.sock = None
self.connected = False
def send_str(self, s):
self.sock.send(bytes(s + self.LINE_SEPARATOR, encoding='utf8'))
def receive(self, numTokens):
data = ['']
while len(data) <= numTokens:
rawData = data[-1] + str(self.sock.recv(self.BUF_SIZE), encoding='utf8')
del data[-1]
data = data + rawData.split(self.LINE_SEPARATOR)
del data[-1]
return data
def send_action(self, action):
# sends all the components of the action one by one
for a in action:
self.send_str(str(a).replace('.', ','))
def run(self):
self.connect(self.host, self.port)
self.send_str('GET_TASK')
data = self.receive(2)
stateDim = int(data[0])
actionDim = int(data[1])
# instantiate agent
if self.agent is None:
self.agent = self.ChosenAgent.Agent(stateDim, actionDim, agent_params)
self.send_str('START_LOG')
self.send_str(self.agent.getName())
stage = 0
while True:
self.send_str('START')
step = 0
data = self.receive(2 + stateDim)
terminalFlag = int(data[0])
state = list(map(float, data[2:]))
action = self.agent.start(state)
while not terminalFlag:
self.send_str('STEP')
self.send_str(str(actionDim))
self.send_action(action)
data = self.receive(3 + stateDim)
if not (len(data) == stateDim + 3):
print('Communication error: calling agent.cleanup()')
self.safe_exit()
sys.exit(1)
reward = float(data[0])
terminalFlag = int(data[1])
state = list(map(float, data[3:]))
action = self.agent.step(reward, state)
if self.num_episodes is not None:
step += 1
if step >= self.STAGE_STEPS_NUM or reward == 10:
return
stage += 1
if stage > self.num_episodes:
break
if random.random() < 0.01:
self.agent.reset_epsilon()
def safe_exit(self):
self.agent.safe_exit()
class EnvironmentHandler(object):
def __init__(self, port, test_dir, gui, plot_score):
self.plot_score = plot_score
self.gui = gui
self.port = port
self.test_dir = test_dir
self.tests = []
self.load_tests()
self.process = None
self.score = []
self.tests_angle = []
self.current_test = 0
def start_next(self):
if self.plot_score:
test = self.tests[self.current_test]
else:
test = random.choice(self.tests)
print('New test', test)
for log_file in glob.glob('*.log'):
os.remove(log_file)
self.process = subprocess.Popen(
['java', '-jar', 'octopus-environment.jar', self.gui, test, str(self.port)])
sleep(1)
def stop(self):
if self.process is not None:
self.process.terminate()
sleep(0.5)
log = glob.glob('*.log')
if log:
with open(log[0], 'r') as log_file:
score = np.mean([sum(float(r) for r in line.split()) for line in log_file])
print('Score:', score)
if self.plot_score:
self.score.append(score)
self.tests_angle.append(float(self.tests[self.current_test].split('\\')[1][:-4][5:]))
self.current_test += 1
if self.current_test >= len(self.tests):
self.current_test = 0
if self.plot_score:
plt.plot(self.tests_angle, self.score)
plt.xlabel('Początkowy kąt')
plt.ylabel('Wynik')
plt.show()
with open("score.csv", 'w', encoding='utf8') as file:
file.write("angle;score\n")
for ang, sc in zip(self.tests_angle, self.score):
file.write(f"{ang};{sc}\n")
def load_tests(self):
for test_path in glob.glob(os.path.join(self.test_dir, '*.xml')):
self.tests.append(test_path)
if self.plot_score:
self.tests = sorted(self.tests, key=lambda x: float(x.split('\\')[1][:-4][5:]))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('gui', choices=['internal', 'external', 'external_gui'])
parser.add_argument('--p', dest='port', default=1410, required=False)
parser.add_argument('--h', dest='host', default='localhost', required=False)
parser.add_argument('--t', dest='test_dir', default='tests', required=False)
parser.add_argument('--e', dest='episodes', default=5, required=False, type=int)
parser.add_argument('--no_learning', action='store_false', default=True)
parser.add_argument('--plot_score', action='store_true', default=False)
args = parser.parse_args()
agent_params = {
RLAgent.Agent.IS_LEARNING: args.no_learning
}
agent_handler = AgentHandler(args.host, args.port, args.episodes, agent_params)
env_handler = EnvironmentHandler(args.port, args.test_dir, args.gui, args.plot_score)
try:
while True:
env_handler.start_next()
agent_handler.run()
agent_handler.disconnect()
env_handler.stop()
except (ConnectionResetError, KeyboardInterrupt, Exception) as e:
print(e)
agent_handler.safe_exit()
agent_handler.disconnect()
env_handler.stop()
print('Agent has stopped safely')