forked from tensorflow/tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlinalg_grad.py
More file actions
25 lines (23 loc) · 842 Bytes
/
linalg_grad.py
File metadata and controls
25 lines (23 loc) · 842 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
"""Gradients for operators defined in linalg_ops.py."""
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import constant_op
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
@ops.RegisterGradient("MatrixInverse")
def _MatrixInverseGrad(op, grad):
"""Gradient for MatrixInverse."""
ainv = op.outputs[0]
return -math_ops.matmul(
ainv,
math_ops.matmul(grad, ainv, transpose_b=True),
transpose_a=True)
@ops.RegisterGradient("BatchMatrixInverse")
def _BatchMatrixInverseGrad(op, grad):
"""Gradient for BatchMatrixInverse."""
ainv = op.outputs[0]
return -math_ops.batch_matmul(
ainv,
math_ops.batch_matmul(grad, ainv, adj_y=True),
adj_x=True)