JAX에서 값을 쓰다가 어려운 부분이 있으면 대체로 numpy에 있는 함수들과 비슷한 것이 많아서 찾기 쉽다.
1. NaN 값 찾기
import jax.numpy as jnp
a = jnp.array([jnp.nan,1,0,jnp.nan])
x = jnp.isnan(a)
print(x)
>> [ True False False True]
2. +-Inf 값 찾기
import jax.numpy as jnp
a = jnp.array([jnp.inf,1,0,-jnp.inf])
x = jnp.isinf(a)
print(x)
>> [ True False False True]
1,2번 항목을 보면 boolean array으로 나오기 때문에 위 코드에서 'a'라는 array에 대해 indexing하면 NaN값을 추출할 수 있고 역으로 not 연산자를 통해서 nan(또는 inf)가 아닌 값을 추출할 수 있다.
a[x]
>> [nan nan]
a[~x]
>> [1.,0.]
3. 조건에 맞는 인덱스 찾기
위에서는 값만 추출한다면 인덱스를 찾는 방법은 기존의 numpy 등에서 쓰는 방식과 거의 차이가 없다.
a = jnp.array([jnp.nan,1,0,-jnp.nan])
x = jnp.isnan(a)
print(jnp.where(x))
>> (Array([0, 3], dtype=int64),)
print(x.nonzero())
>> (Array([0, 3], dtype=int64),)
다음과 같이 where 또는 .nonzero()를 통해 구할 수 있는데
이 때 x라는 array는 True, False로 구성된 값들이기 때문에 이를 응용하면,
특정 조건에 해당하는 array를 구할 수 있다.
a = jnp.array([1.0, 2.0, 3.0])
x = a < 1.5
print(jnp.where(x))
>> (Array([0], dtype=int64),)
print(a[x])
>> [1.]
동작이 numpy와 동일해서 굳이 헷갈릴 이유는 없는 듯하다.