Customized Operator¶
Sometimes it is useful to define a customized operator with its own derivatives. Here is a comprehensive example of customized operator: Example.
The following code is taken from our example.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 | @customop('numpy')
def my_softmax(x, y):
probs = numpy.exp(x - numpy.max(x, axis=1, keepdims=True))
probs /= numpy.sum(probs, axis=1, keepdims=True)
N = x.shape[0]
loss = -numpy.sum(numpy.log(probs[numpy.arange(N), y])) / N
return loss
def my_softmax_grad(ans, x, y):
def grad(g):
N = x.shape[0]
probs = numpy.exp(x - numpy.max(x, axis=1, keepdims=True))
probs /= numpy.sum(probs, axis=1, keepdims=True)
probs[numpy.arange(N), y] -= 1
probs /= N
return probs
return grad
my_softmax.def_grad(my_softmax_grad)
|
As in the example, the forward pass of the operator is defined in a normal python function.
The only discrepancy is the decorator @customop('numpy')
.
The decorator will change the function into a class instance with the same name as the function.
The decorator customop
has two options:
@customop('numpy')
: It assumes the arrays in the input and the output of the user-defined function are both NumPy arrays.@customop('mxnet')
: It assumes the arrays in the input and the output of the user-defined function are both MXNet NDArrays.
Register derivatives for customized operator¶
To register a derivative, you first need to define a function that takes output and inputs as parameters and returns a function, just as the example above. The returned function takes upstream gradient as input, and outputs downstream gradient. Basically, the returned function describes how to modify the gradient in the backpropagation process on this specific customized operator w.r.t. a certain variable.
After derivative function is defined, simply register the function by def_grad
as shown in
the example above. In fact, my_softmax.def_grad(my_softmax_grad)
is the shorthand of
my_softmax.def_grad(my_softmax_grad, argnum=0)
. Use argnum
to specify which variable to bind
with the given derivative.