[JAX] JAX vmap에 대한 설명
·
연구 Research/인공지능 Artificial Intelligent
아직 jax가 한글화가 많이 안 되어있어서 기본적인 기능은 내가 적어놓으려고 한다. jax.vmap(fun, in_axes=0, out_axes=0, axis_name=None, axis_size=None, spmd_axis_name=None) function을 argument axes에 대해서 mapping해주는 기능. fun : mapping할 function in_axes : function에 들어가는 input을 의미한다. 정수, None, Python container(tuple/list/dict 모두 가능)을 지원한다. 이는 모두 mapping할 input array 축을 의미한다. 만약 fun의 argument가 array이면 in_axes에는 정수, None, 튜플(Integer, None..