JAX에서 함수 연산 속도를 빠르게 하기 위해서 @jit을 사용해서 jit compile을 한다.
보통 입력이 jax array나 튜플이면 크게 문제가 되지 않는데 나의 경우에는 사용하고 싶은 함수 내용이 복잡해지면서 class를 쓰고 싶었다. 특정 변수들을 모아놓은 집합을 계속 옮겨야 했기 때문이다.(structure와 비슷하게)
그런데 class를 argument로 쓰면 다음과 같은 오류를 보게 된다.
TypeError: Cannot interpret value of type <class 'example'> as an abstract array; it does not have a dtype attribute
이 때 사용해야 하는 것이 pytree이다.
1. Pytree의 정의
In JAX, we use the term pytree to refer to a tree-like structure built out of container-like Python objects
pytree에 대한 설명은 다음 글에 자세히 나와있다. 본 글에서는 pytree 개념이 무엇인지, 그리고 클래스와 같이 내가 원하는 형태를 어떻게 pytree로 만들 수 있는지에 대해 다루고자 한다.
기본적으로 생각할 것은 JAX에서 다룰 수 있는 형태의 자료형이라고 생각하면 된다.
2. 기본적인 pytree 등록
기본 방법과 class를 등록하는 방법 2가지가 있다.
일단 공통적으로 tree를 unflatten, flatten하는 메서드를 포함해야 한다.
class c_temp():
def __init__(self) -> None:
pass
def _tree_flatten(self):
attributes = [attr for attr in dir(self) if not attr.startswith('__') and not callable(getattr(self, attr))]
data = [getattr(self, name, None) for name in attributes] # children
# attributes are used for keys
return (tuple(data), attributes)
@classmethod
def _tree_unflatten(cls, keys, data):
""" Args:
keys : the opaque data that was specified during flattening of the
current treedef.
data : the unflattened children
"""
bufmanage = cls()
for name, data_tmp in zip(keys, data):
setattr(bufmanage, name, data_tmp)
return bufmanage
pass
from jax import tree_util
tree_util.register_pytree_node(c_temp,
c_temp._tree_flatten,
c_temp._tree_unflatten)
위에서는 class이기 때문에 class 안의 메서드로 tree_flatten과 tree_unflatten을 정의했지만 따로 함수로 정의해도 무방하다.
flatten은 class 내부의 데이터를 튜플 형태로 내보내는 로직이고, unflatten은 그 정보를 이용해 다시 내가 정의한 tree를 복원하는 로직이다. class이기 때문에 setattr, getatter를 사용했지만 다른 자료형이라면 다르게 접근할 수 있을 것이다.
class를 사용하는 경우에는 tree_util.register_pytree_node 대신 register_pytree_node_class를 이용해도 된다.
이 매서드는 class를 바로 pytree로 사용할 수 있도록 해준다.
@jax.tree_util.register_pytree_node_class
class c_inpParams():
def __init__(self):
pass
def tree_flatten(self):
attributes = [attr for attr in dir(self) if not attr.startswith('__') and not callable(getattr(self, attr))]
data = [getattr(self, name, None) for name in attributes] # children
# attributes are used for keys
return (tuple(data), attributes)
@classmethod
def tree_unflatten(cls, keys, data):
inpParams = cls()
for name, data_tmp in zip(keys, data):
setattr(inpParams, name, data_tmp)
return inpParams
pass
마찬가지로 tree_flatten과 tree_unflatten을 정의한다.
https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.register_pytree_node_class.html
https://www.kaggle.com/code/aakashnain/tf-jax-tutorials-part-10-pytrees-in-jax