Skip to content

Commit

Permalink
add filter of dirname
Browse files Browse the repository at this point in the history
  • Loading branch information
gxywy committed Jun 8, 2021
1 parent 7fbc479 commit acbfb3f
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 12 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ optional arguments:
--xformat x-axis format
--xlim x-axis limitation (default: None)
--log_dir log dir (default: ./)
--filter filter of dirname
--filename csv filename
--show show figure
--save save figure
Expand All @@ -125,6 +126,7 @@ finally, the learning curves looks like this:
## Features

- [x] custom logger, style, key, label, interval, and so on ...
- [x] filter of directory name
- [x] multi-experiment plotter
- [x] x-axis formatter features
- [x] compatible with [OpenAI-baseline](https://github.com/openai/baselines) monitor data style
Expand Down
15 changes: 8 additions & 7 deletions rl_plotter/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def load_csv_results(dir, filename="monitor"):
#df.headers = headers # HACK to preserve backwards compatibility
return df

def load_results(root_dir_or_dirs="./", filename="monitor"):
def load_results(root_dir_or_dirs="./", filename="monitor", filter=""):

if isinstance(root_dir_or_dirs, str):
rootdirs = [osp.expanduser(root_dir_or_dirs)]
Expand All @@ -149,14 +149,15 @@ def load_results(root_dir_or_dirs="./", filename="monitor"):
for rootdir in rootdirs:
assert osp.exists(rootdir), "%s doesn't exist"%rootdir
for dirname, dirs, files in os.walk(rootdir):
result = {'dirname' : dirname, "data": None}
if filter in dirname:
result = {'dirname' : dirname, "data": None}

file_re = re.compile(r'(\d+\.)?(\d+\.)?' + filename + r'\.csv')
if any([f for f in files if file_re.match(f)]):
result['data'] = pandas.DataFrame(load_csv_results(dirname, filename))
file_re = re.compile(r'(\d+\.)?(\d+\.)?' + filename + r'\.csv')
if any([f for f in files if file_re.match(f)]):
result['data'] = pandas.DataFrame(load_csv_results(dirname, filename))

if result['data'] is not None:
allresults.append(result)
if result['data'] is not None:
allresults.append(result)
return allresults


Expand Down
13 changes: 9 additions & 4 deletions rl_plotter/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,16 @@ def main():
parser.add_argument('--time_interval', type=float, default=1,
help='parameters about time, x axis time interval (default: 1)')

parser.add_argument('--xformat', default='eng',
parser.add_argument('--xformat', default='',
help='x-axis format')
parser.add_argument('--xlim', type=int, default=None,
help='x-axis limitation (default: None)')

parser.add_argument('--log_dir', default='./',
help='log dir (default: ./)')
parser.add_argument('--filename', default='monitor',
parser.add_argument('--filter', default='',
help='filter of dirname')
parser.add_argument('--filename', default='evaluator',
help='csv filename')
parser.add_argument('--show', action='store_true',
help='show figure')
Expand Down Expand Up @@ -92,7 +94,7 @@ def main():
args.xkey = 'total_steps'
args.ykey = 'mean_score'

allresults = pu.load_results(args.log_dir, filename=args.filename)
allresults = pu.load_results(args.log_dir, filename=args.filename, filter=args.filter)
pu.plot_results(allresults,
fig_length=args.fig_length,
fig_width=args.fig_width,
Expand Down Expand Up @@ -125,7 +127,10 @@ def main():
elif args.xformat == 'log':
ax.xaxis.set_major_formatter(mticker.LogFormatter())
elif args.xformat == 'sci':
ax.xaxis.set_major_formatter(mticker.LogFormatterSciNotation())
#ax.xaxis.set_major_formatter(mticker.LogFormatterSciNotation())
plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0), useMathText=True)
else:
plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0), useMathText=False)

if args.xlim is not None:
plt.xlim((0, args.xlim))
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="rl_plotter",
version="2.2.3",
version="2.2.4",
author="Gong Xiaoyu",
author_email="gxywy@hotmail.com",
description="A plotter for reinforcement learning (RL)",
Expand Down

0 comments on commit acbfb3f

Please sign in to comment.