[JAX] NaN, Inf 값 처리 및 조건에 맞는 요소 찾기
·
프로그래밍 Programming
JAX에서 값을 쓰다가 어려운 부분이 있으면 대체로 numpy에 있는 함수들과 비슷한 것이 많아서 찾기 쉽다. 1. NaN 값 찾기 import jax.numpy as jnp a = jnp.array([jnp.nan,1,0,jnp.nan]) x = jnp.isnan(a) print(x) >> [ True False False True] 2. +-Inf 값 찾기 import jax.numpy as jnp a = jnp.array([jnp.inf,1,0,-jnp.inf]) x = jnp.isinf(a) print(x) >> [ True False False True] 1,2번 항목을 보면 boolean array으로 나오기 때문에 위 코드에서 'a'라는 array에 대해 indexing하면 NaN값을 추출할..