[JAX] 기본 Neural Networks 모델
·
연구 Research/인공지능 Artificial Intelligent
가장 simple하게 신경망을 구성하는 방법에 대해서 저장해놓은 글이다.차츰 업데이트 할 예정 1. 기본 학습 코드import jaximport jax.numpy as jnpfrom jax import grad, jit, vmapfrom jax import random# Define a simple neural network modeldef init_params(layer_sizes, key): params = [] for i in range(1, len(layer_sizes)): key, subkey = random.split(key) w = random.normal(subkey, (layer_sizes[i-1], layer_sizes[i])) b = j..