taichi.ad
#
- taichi.ad.grad_for(primal)#
Generates a decorator to decorate primal’s customized gradient function. See
grad_replaced()
for examples.- Parameters
primal (Callable) – The primal function, must be decorated by
grad_replaced()
.- Returns
The decorator used to decorate customized gradient function.
- Return type
Callable
- taichi.ad.grad_replaced(func)#
A decorator for python function to customize gradient with Taichi’s autodiff system, e.g. ti.Tape() and kernel.grad(). This decorator forces Taichi’s autodiff system to use a user-defined gradient function for the decorated function. Its customized gradient must be decorated by
grad_for()
.- Parameters
fn (Callable) – The python function to be decorated.
- Returns
The decorated function.
- Return type
Callable
Example:
>>> @ti.kernel
>>> def multiply(a: ti.float32):
>>> for I in ti.grouped(x):
>>> y[I] = x[I] * a
>>>
>>> @ti.kernel
>>> def multiply_grad(a: ti.float32):
>>> for I in ti.grouped(x):
>>> x.grad[I] = y.grad[I] / a
>>>
>>> @ti.grad_replaced
>>> def foo(a):
>>> multiply(a)
>>>
>>> @ti.grad_for(foo)
>>> def foo_grad(a):
>>> multiply_grad(a)