
13 May
2021
13 May
'21
9:59 a.m.
gobot1234ytļ¼ gmail.com wrote:
I really like this idea, however I was wondering if there would be any way to add support for classes that wrap a dataclass as I don't think this is a particularly uncommon use case
Thank you for bringing this up. I had to extend dataclass to support speicifying "JAX PyTrees" whose elements can be "static" or "nonstatic". I ended up extending the MyPy plugin, which is a maintenance burden.
https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/dataclasses/datacl... https://github.com/NeilGirdhar/tjax/blob/master/tjax/mypy_plugin.py