forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
control_ops_grad_test.py
49 lines (40 loc) · 1.85 KB
/
control_ops_grad_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import unittest
from caffe2.python import core, test_util, workspace
from caffe2.python.control_ops_grad import disambiguate_grad_if_op_output
from caffe2.python.model_helper import ModelHelper
import numpy as np
class TestControl(test_util.TestCase):
def test_disambiguate_grad_if_op_output(self):
workspace.FeedBlob("cond", np.array(True))
workspace.FeedBlob("then_grad", np.array(1))
workspace.FeedBlob("else_grad", np.array(2))
then_model = ModelHelper(name="then_test_model")
then_model.net.Copy("then_grad", "input_grad")
else_model = ModelHelper(name="else_test_model")
else_model.net.Copy("else_grad", "else_temp_grad")
else_model.net.Copy("else_temp", "input_grad")
# to BuildGradientGenerators, in forward pass, we need else temp
# as one of the output. Which later on results in a grad op like this:
grad_op = core.CreateOperator(
"If",
["cond", "then_grad", "else_grad"],
["input_grad", "else_temp_grad"],
then_net=then_model.net.Proto(),
else_net=else_model.net.Proto(),
)
# in certain cases, another branch of the net also generates input_grad
# and we call _DisambiguateGradOpOutput in core.py
new_grad_output = "input_grad" + "_autosplit_" + "0"
disambiguate_grad_if_op_output(grad_op, 0, new_grad_output)
self.assertEqual(grad_op.output[0], new_grad_output)
for arg in grad_op.arg:
if arg.name == "else_net":
self.assertEqual(arg.n.op[1].output[0], new_grad_output)
else:
self.assertEqual(arg.name, "then_net")
if __name__ == '__main__':
unittest.main()