forked from keiohta/tf2rl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
setup.py
61 lines (53 loc) · 1.91 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import platform
from setuptools import setup, find_packages
install_requires = [
"cpprb>=8.1.1",
"setuptools>=41.0.0",
"numpy>=1.16.0",
"joblib",
"future",
"scipy",
"scikit-image"]
tf_version = "2.4" # Default Version
try:
import tensorflow as tf
tf_version = tf.version.VERSION.rsplit('.', 1)[0]
except ImportError:
install_requires.append(f"tensorflow=={tf_version}")
pass
compatible_tfp = {"2.4": ["tensorflow-probability~=0.12.0"],
"2.3": ["tensorflow-probability~=0.11.0"],
"2.2": ["tensorflow-probability~=0.10.0"],
"2.1": ["tensorflow-probability~=0.8.0"],
"2.0": ["tensorflow-probability~=0.8.0"]}
compatible_tfa = {"2.4": ["tensorflow_addons~=0.13.0"],
"2.3": ["tensorflow_addons~=0.13.0"],
"2.2": ["tensorflow_addons==0.11.2"],
"2.1": ["tensorflow_addons~=0.9.1"],
"2.0": ["tensorflow_addons~=0.6.0"]}
install_requires.append(*compatible_tfp[tf_version])
if not (platform.system() == 'Windows' and tf_version == "2.0"):
# tensorflow-addons does not support tf2.0 on Windows
install_requires.append(*compatible_tfa[tf_version])
compatible_gym = {
"2.0": "gym[atari]<0.21.0",
"2.1": "gym[atari]<0.21.0"
}
gym_version = compatible_gym.get(tf_version, "gym[atari]>=0.21.0")
extras_require = {
"tf": ["tensorflow>=2.0.0"],
"tf_gpu": ["tensorflow-gpu>=2.0.0"],
"examples": [gym_version, "opencv-python"],
"test": ["coveralls", gym_version, "matplotlib", "opencv-python", "future"]
}
setup(
name="tf2rl",
version="1.1.5",
description="Deep Reinforcement Learning for TensorFlow2",
url="https://github.com/keiohta/tf2rl",
author="Kei Ohta",
author_email="dev.ohtakei@gmail.com",
license="MIT",
packages=find_packages("."),
install_requires=install_requires,
extras_require=extras_require)