question on trax
Dennis Lee Bieber
wlfraed at ix.netcom.com
Wed Aug 18 12:42:59 EDT 2021
On Tue, 17 Aug 2021 17:50:59 +0200, joseph pareti <joepareti54 at gmail.com>
declaimed the following:
>In the following code, where does tl.Fn come from? i see it nowhere in the
>documents, i.e I was looking for trax.layers.Fn :
"layers" imports a whole slew of sub modules using
from xxx import *
in order to put all the sub module names at the same level.
https://github.com/google/trax/blob/master/trax/layers/base.py
>From line 748 on...
def Fn(name, f, n_out=1): # pylint: disable=invalid-name
"""Returns a layer with no weights that applies the function `f`.
`f` can take and return any number of arguments, and takes only
positional
arguments -- no default or keyword arguments. It often uses JAX-numpy
(`jnp`).
The following, for example, would create a layer that takes two inputs
and
returns two outputs -- element-wise sums and maxima:
`Fn('SumAndMax', lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)),
n_out=2)`
The layer's number of inputs (`n_in`) is automatically set to number of
positional arguments in `f`, but you must explicitly set the number of
outputs (`n_out`) whenever it's not the default value 1.
Args:
name: Class-like name for the resulting layer; for use in debugging.
f: Pure function from input tensors to output tensors, where each input
tensor is a separate positional arg, e.g., `f(x0, x1) --> x0 + x1`.
Output tensors must be packaged as specified in the `Layer` class
docstring.
n_out: Number of outputs promised by the layer; default value 1.
Returns:
Layer executing the function `f`.
"""
--
Wulfraed Dennis Lee Bieber AF6VN
wlfraed at ix.netcom.com http://wlfraed.microdiversity.freeddns.org/
More information about the Python-list
mailing list