2023년에 배포된 라이브러리인 JAX-FEM에서 Sensitivity analysis를 어떻게 하는지 설명하고자 한다.
참고문헌 :
Inverse design 문제를 푸는 것은 주로 topology optimization (TO)이라고도 표현하는데, 이 때 목적함수에 대한 gradient 값을 계산해야한다. 일반적으로 TO 문제는 Adjoint method를 통해 forward simulation & backward simulation 두 번으로 gradient 값을 효율적으로 계산하게 된다.
1. Problem formulation of topology optimization
$$\min \quad J(U,\theta) \text{ s.t. } C(U,\theta)=0$$
문제 정의는 다음과 같이 $U\in R^N$가 finite element solution vector (the physical response)이고 $\theta \in R^M$가 design parameters일 때 PDE constraint $C(U,\theta)=0$ 하에서 $J(U,\theta)$를 최대화 (또는 최소화)하는 문제이다.
2. Adjoint method + JAX-FEM의 방식
다음 문제에서 $U(\theta)$는 PDE를 풀면 나오는 implicit function이라고 볼 수 있다. 따라서,
$$\min \hat{J}(\theta) \quad \text{s.t.} \hat{J}(\theta)=(U(\theta),\theta)$$ 문제를 푸는 것이라고 할 수 있다. 우리가 원하는 gradient를 알기 위해서 다음과 같이 total derivative를 구한다.

$\frac{dU}{d\theta}$는 implicit function thoerem으로 존재한다고 가정한다.
또한 앞서 보인 constraint $C(U,\theta)=0$를 양변에 미분을 취하면 다음 식을 얻을 수 있다.

위 두 식을 종합하면

다음과 같이 식이 유도된다. 일반적인 adjoint method는 위 식에서 adjoint PDE라고 표시된 부분을 어떤 변수 $\lambda$라고 두고 adjoint PDE를 시뮬레이션 하면 $\lambda$를 구할 수 있고, $\partial C/\partial \theta$는 원래 의도했던 시뮬레이션을 통해 구할 수 있다. 따라서 둘을 결합해서 최종적으로 좌변 결과($d\hat{J}/d\theta$)를 구하게 된다. (마지막 term인 $\partial J/\partial \theta$는 $J$의 definition을 통해 구할 수 있다.)
이 논문에서는 tangent linear PDE로 언급을 하는데, adjoint PDE를 풀어서 구하나, tangent linear PDE를 풀어서 구하나 결과는 동일하지만 computation 관점에서 뭐가 더 유리한지는 다르다. Design parameters의 수가 클 수록 adjoint PDE 방법을 쓰는 게 낫다.
다음과 같은 adjoint variable $\lambda \in R^N$를 정의한다.

이 PDE는 linear PDE이다.
일반적으로는 $\partial C/\partial \theta$에 대한 expression을 직접 구하는데 이 과정이 지루하고 실수할 가능성이 많아 JAX-FEM에서는 JAX의 VJP 기능으로 자동으로 계산하도록 했다.
다음과 같은 알고리즘을 적용할 때 result로 나오는 것이 $\lambda^* \frac{\partial C}{\partial \theta}$(위 식에서 Hermitian 적용한 것)

위 코드에서 v가 adjoint 값이고 vec_jac_prod에 adjoint 값을 넣으면 JAX 내에서 효율적으로 $\lambda^* \frac{\partial C}{\partial \theta}$를 계산한다.누군가는 이렇게 물어볼 수 있다. 어차피 adjoint 값($\lambda$)이 있고 $\frac{\partial C}{\partial \theta}$가 있으면 따로따로 계산해서 곱하면 되지 않을까? 이는 컴퓨터 계산과 관련이 있다. $\lambda^* \frac{\partial C}{\partial \theta}$는 사실 상 $R^M$ 벡터이다. 개별적으로 두 값을 곱하면 Explicit하게 $R^N$ 크기의 $\lambda$와 Matrix인 $\frac{\partial C}{\partial \theta}$를 곱해야한다. 이 때 matrix를 구하기 위한 컴퓨터 계산에서 비효율이 발생한다. 모든 계산에서 matrix 계산을 반복해야하기 때문이다. 따라서 벡터와 행렬 계산으로 이루어질 수 있도록 $\lambda^* \frac{\partial C}{\partial \theta}$를 다이렉트로 계산하는 게 빠르다.
하여튼 이러한 workflow로 $\partial \hat{J}/\partial \theta$를 구해 최적화 문제에 활용할 수 있다.

3. 정리
JAX-FEM은 Adjoint method를 통해 TO를 하는 게 맞으나, 이 Adjoint 방정식을 풀 때 JAX의 기능을 활용해 자동으로 푼다고 할 수 있다. 이는 직접적인 수식 유도에 비해 효율적이라는 장점이 있다.
(이 코드를 직접 써본 분 말로는 JIT을 활용할 수 없는 구조라서 좀 느리다고 하던데 내가 직접 해보지는 않아서 모르겠다.)
'연구 Research > 최적화 Optimization' 카테고리의 다른 글
| [최적화] Gradient를 계산하는 방법 (0) | 2025.01.07 |
|---|---|
| [Optimization] Linear subspace, Affine subspace (0) | 2024.10.11 |
| [최적화] Introduction to Optimization - Introduction (0) | 2020.09.25 |