-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathbacktracking.py
54 lines (43 loc) · 1.84 KB
/
backtracking.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
def armijo_search(func, q, eta, p, fx, t, tau, max_iter=20):
for j in range(max_iter):
if func(q + eta * p) <= fx - eta * t:
return eta
print(f'Armijo line search took {j} iterations')
eta *= tau
return eta
def two_way_backtracking_line_search(f, grad_f, x, direction, c1=1e-4, c2=0.9, alpha0=1.0):
"""
Two-way backtracking line search
Args:
f (callable): Objective function. Takes a tensor `x` as input and returns a scalar.
grad_f (callable): Gradient of the objective function. Takes a tensor `x` and returns a tensor.
x (torch.Tensor): Current point.
direction (torch.Tensor): Descent direction.
c1 (float): Armijo (sufficient decrease) parameter. Default is 1e-4.
c2 (float): Wolfe (curvature) parameter. Default is 0.9.
alpha0 (float): Initial step size. Default is 1.0.
Returns:
alpha (float): Step size satisfying the two-way backtracking conditions.
"""
alpha = alpha0
phi_0 = f(x)
grad_phi_0 = grad_f(x)
phi_prime_0 = torch.dot(grad_phi_0, direction)
assert phi_prime_0 < 0, "Direction must be a descent direction."
while True:
# Evaluate the candidate step
x_next = x + alpha * direction
phi_next = f(x_next)
# Check Armijo condition
if phi_next > phi_0 + c1 * alpha * phi_prime_0:
alpha /= 2 # Reduce step size
else:
grad_phi_next = grad_f(x_next)
phi_prime_next = torch.dot(grad_phi_next, direction)
# Check Wolfe condition
if phi_prime_next < c2 * phi_prime_0:
alpha *= 2 # Increase step size
else:
break # Both conditions satisfied
return alpha