-
Notifications
You must be signed in to change notification settings - Fork 1
/
extend_function.py
73 lines (60 loc) · 2.65 KB
/
extend_function.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""
Has a method to extend a function in the direction of the derivative at the endpoints given.
"""
def get_extensions(extension_args):
if extension_args.extra_t is None:
extra_start_t = extension_args.extra_start_t
extra_end_t = extension_args.extra_end_t
else:
if (extension_args.extra_start_t is not None and
extension_args.extra_t != extension_args.extra_start_t):
raise ValueError("extra_start_t and extra_t both set, but do not agree")
if (extension_args.extra_end_t is not None and
extension_args.extra_t != extension_args.extra_end_t):
raise ValueError("extra_end_t and extra_t both set, but do not agree")
extra_start_t = extension_args.extra_t
extra_end_t = extension_args.extra_t
if extra_start_t is None:
extra_start_t = 0.0
if extra_end_t is None:
extra_end_t = 0.0
return extra_start_t, extra_end_t
def build_time_t(start_t, end_t, num_time_steps, extension_args):
extra_start_t, extra_end_t = get_extensions(extension_args)
t0 = start_t
if extra_start_t:
t0 = t0 - extra_start_t
tn = end_t
if extra_end_t:
tn = tn + extra_end_t
def time_t(time_step):
return t0 + (tn - t0) * time_step / num_time_steps
return time_t
def build_extension(base_f_t, t0):
epsilon = 0.001
f0 = base_f_t(t0)
derivative = (base_f_t(t0 + epsilon) - base_f_t(t0 - epsilon)) / (epsilon * 2)
print("Extenstion at %.4f. Derivative %.4f f0 %.4f" % (t0, derivative, f0))
def extension_t(t):
return f0 + derivative * (t - t0)
return extension_t
def extend_f_t(time_t, base_f_t, start_t, end_t, extension_args):
begin_f_t = build_extension(base_f_t, start_t)
end_f_t = build_extension(base_f_t, end_t)
extra_start_t, extra_end_t = get_extensions(extension_args)
def f_t(time_step):
t = time_t(time_step)
if extra_start_t and t < start_t:
return begin_f_t(t)
elif extra_end_t and t > end_t:
return end_f_t(t)
else:
return base_f_t(t)
return f_t
def add_extend_args(parser, default_extra_t=None):
parser.add_argument('--extra_t', default=default_extra_t, type=float,
help='Extra time to build the model as a straight line before & after the domain')
parser.add_argument('--extra_start_t', default=None, type=float,
help='Extra time to build the model as a straight line before the domain')
parser.add_argument('--extra_end_t', default=None, type=float,
help='Extra time to build the model as a straight line after the domain')