[JAX] Pytree 사용법 + 실용적인 방법들
·
카테고리 없음
JAX에서 함수 연산 속도를 빠르게 하기 위해서 @jit을 사용해서 jit compile을 한다.보통 입력이 jax array나 튜플이면 크게 문제가 되지 않는데 나의 경우에는 사용하고 싶은 함수 내용이 복잡해지면서 class를 쓰고 싶었다. 특정 변수들을 모아놓은 집합을 계속 옮겨야 했기 때문이다.(structure와 비슷하게) 그런데 class를 argument로 쓰면 다음과 같은 오류를 보게 된다. TypeError: Cannot interpret value of type as an abstract array; it does not have a dtype attribute 이 때 사용해야 하는 것이 pytree이다. 1. Pytree의 정의 In JAX, we use the term pyt..