[JAX] L-BFGS optimizer로 학습하는 예제 코드
·
연구 Research/인공지능 Artificial Intelligent
Quasi Newton method를 이용해서 JAX에서 최적화를 시키고 싶었는데 JAX 자체가 Deep learning에 포커스가 있고 대부분 딥러닝이 Gradient descent method로 최적화를 하다보니 라이브러리를 찾게 되었다. 대부분 jaxopt라는 라이브러리를 추천했기 때문에 이걸로 수행해보았다. jaxopt에는 jaxopt.ScipyMinimize와 jaxopt.LBFGS가 있는데 다른 분들의 시도를 보니 ScipyMinimize가 더 성능이 괜찮은 것 같다. ScipyMinimize는 scipy에 있는 최적화를 사용한 것이고 LBFGS는 직접 만든 것 같은데 line search 방법 등이 다르다고 한다. import jax import jax.numpy as np from jax ..