[코드] JAX-FEM 설명

2026. 6. 2. 03:30·연구 Research/최적화 Optimization

2023년에 배포된 라이브러리인 JAX-FEM에서 Sensitivity analysis를 어떻게 하는지 설명하고자 한다.

 

참고문헌 :

Xue, T. et al. JAX-FEM: A differentiable GPU-accelerated 3D finite element solver for automatic inverse design and mechanistic data science. Computer Physics Communications 291, 108802 (2023).

 

 

Inverse design 문제를 푸는 것은 주로 topology optimization (TO)이라고도 표현하는데, 이 때 목적함수에 대한 gradient 값을 계산해야한다. 일반적으로 TO 문제는 Adjoint method를 통해 forward simulation & backward simulation 두 번으로 gradient 값을 효율적으로 계산하게 된다.

 

1. Problem formulation of topology optimization

 

$$\min \quad J(U,\theta) \text{ s.t. } C(U,\theta)=0$$

 

문제 정의는 다음과 같이 $U\in R^N$가 finite element solution vector (the physical response)이고 $\theta \in R^M$가 design parameters일 때 PDE constraint $C(U,\theta)=0$ 하에서 $J(U,\theta)$를 최대화 (또는 최소화)하는 문제이다.

 

2. Adjoint method + JAX-FEM의 방식

 

다음 문제에서 $U(\theta)$는 PDE를 풀면 나오는 implicit function이라고 볼 수 있다. 따라서,

 

$$\min \hat{J}(\theta) \quad \text{s.t.} \hat{J}(\theta)=(U(\theta),\theta)$$ 문제를 푸는 것이라고 할 수 있다. 우리가 원하는 gradient를 알기 위해서 다음과 같이 total derivative를 구한다.

 

 

$\frac{dU}{d\theta}$는 implicit function thoerem으로 존재한다고 가정한다.

또한 앞서 보인 constraint $C(U,\theta)=0$를 양변에 미분을 취하면 다음 식을 얻을 수 있다.

위 두 식을 종합하면

 

 

다음과 같이 식이 유도된다. 일반적인 adjoint method는 위 식에서 adjoint PDE라고 표시된 부분을 어떤 변수 $\lambda$라고 두고 adjoint PDE를 시뮬레이션 하면 $\lambda$를 구할 수 있고, $\partial C/\partial \theta$는 원래 의도했던 시뮬레이션을 통해 구할 수 있다. 따라서 둘을 결합해서 최종적으로 좌변 결과($d\hat{J}/d\theta$)를 구하게 된다. (마지막 term인 $\partial J/\partial \theta$는 $J$의 definition을 통해 구할 수 있다.)

 

이 논문에서는 tangent linear PDE로 언급을 하는데, adjoint PDE를 풀어서 구하나, tangent linear PDE를 풀어서 구하나 결과는 동일하지만 computation 관점에서 뭐가 더 유리한지는 다르다. Design parameters의 수가 클 수록 adjoint PDE 방법을 쓰는 게 낫다.

 

다음과 같은 adjoint variable $\lambda \in R^N$를 정의한다.

이 PDE는 linear PDE이다.

 

 

일반적으로는 $\partial C/\partial \theta$에 대한 expression을 직접 구하는데 이 과정이 지루하고 실수할 가능성이 많아 JAX-FEM에서는 JAX의 VJP 기능으로 자동으로 계산하도록 했다. 

 

다음과 같은 알고리즘을 적용할 때 result로 나오는 것이 $\lambda^* \frac{\partial C}{\partial \theta}$(위 식에서 Hermitian 적용한 것)

 

위 코드에서 v가 adjoint 값이고 vec_jac_prod에 adjoint 값을 넣으면 JAX 내에서 효율적으로 $\lambda^* \frac{\partial C}{\partial \theta}$를 계산한다.누군가는 이렇게 물어볼 수 있다. 어차피 adjoint 값($\lambda$)이 있고 $\frac{\partial C}{\partial \theta}$가 있으면 따로따로 계산해서 곱하면 되지 않을까? 이는 컴퓨터 계산과 관련이 있다. $\lambda^* \frac{\partial C}{\partial \theta}$는 사실 상 $R^M$ 벡터이다. 개별적으로 두 값을 곱하면 Explicit하게 $R^N$ 크기의 $\lambda$와 Matrix인 $\frac{\partial C}{\partial \theta}$를 곱해야한다. 이 때 matrix를 구하기 위한 컴퓨터 계산에서 비효율이 발생한다. 모든 계산에서 matrix 계산을 반복해야하기 때문이다. 따라서 벡터와 행렬 계산으로 이루어질 수 있도록 $\lambda^* \frac{\partial C}{\partial \theta}$를 다이렉트로 계산하는 게 빠르다.

 

하여튼 이러한 workflow로 $\partial \hat{J}/\partial \theta$를 구해 최적화 문제에 활용할 수 있다.

 

 

 

3. 정리

 

JAX-FEM은 Adjoint method를 통해 TO를 하는 게 맞으나, 이 Adjoint 방정식을 풀 때 JAX의 기능을 활용해 자동으로 푼다고 할 수 있다. 이는 직접적인 수식 유도에 비해 효율적이라는 장점이 있다.

 

(이 코드를 직접 써본 분 말로는 JIT을 활용할 수 없는 구조라서 좀 느리다고 하던데 내가 직접 해보지는 않아서 모르겠다.)

 

저작자표시 비영리 변경금지 (새창열림)

'연구 Research > 최적화 Optimization' 카테고리의 다른 글

[최적화] Gradient를 계산하는 방법  (0) 2025.01.07
[Optimization] Linear subspace, Affine subspace  (0) 2024.10.11
[최적화] Introduction to Optimization - Introduction  (0) 2020.09.25
'연구 Research/최적화 Optimization' 카테고리의 다른 글
  • [최적화] Gradient를 계산하는 방법
  • [Optimization] Linear subspace, Affine subspace
  • [최적화] Introduction to Optimization - Introduction
보통의공대생
보통의공대생
수학,프로그래밍,기계항공우주 등 공부하는 기록들을 남깁니다.
  • 보통의공대생
    뛰는 놈 위에 나는 공대생
    보통의공대생
  • 전체
    오늘
    어제
    • 분류 전체보기 (482)
      • 공지 (1)
      • 영어 공부 English Study (40)
        • 텝스 TEPS (7)
        • 글 Article (21)
        • 영상 Video (10)
      • 연구 Research (102)
        • 최적화 Optimization (4)
        • 데이터과학 Data Science (8)
        • 인공지능 Artificial Intelligent (41)
        • 제어 Control (45)
      • 프로그래밍 Programming (36)
        • 매트랩 MATLAB (25)
        • 파이썬 Python (33)
        • 줄리아 Julia (2)
        • C++ (3)
        • 리눅스 우분투 Ubuntu (6)
      • 항공우주 Aeronautical engineeri.. (21)
        • 항법 Navigation (0)
        • 유도 Guidance (0)
      • 기계공학 Mechanical engineering (3)
        • 열역학 Thermodynamics (0)
        • 고체역학 Statics & Solid mechan.. (10)
        • 동역학 Dynamics (1)
        • 유체역학 Fluid Dynamics (0)
      • 수학 Mathematics (35)
        • 선형대수학 Linear Algebra (18)
        • 미분방정식 Differential Equation (3)
        • 확률및통계 Probability & Sta.. (2)
        • 미적분학 Calculus (1)
        • 복소해석학 Complex Analysis (5)
        • 실해석학 Real Analysis (1)
      • 수치해석 Numerical Analysis (29)
      • 확률 및 랜덤프로세스 Random process (2)
      • 추론 & 추정 이론 Estimation (3)
      • 기타 (97)
        • 설계 프로젝트 System Design (8)
        • 논문작성 Writing (58)
        • 세미나 Seminar (2)
        • 생산성 Productivity (3)
      • 실험 Experiment (1)
      • 유학 생활 Daily (8)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 링크

  • 공지사항

  • 인기 글

  • 태그

    JAX
    Numerical Analysis
    LaTeX
    teps
    Python
    우분투
    Zotero
    에러기록
    생산성
    ChatGPT
    텝스
    pytorch
    수치해석
    WOX
    고체역학
    딥러닝
    Linear algebra
    obsidian
    텝스공부
    Julia
    matplotlib
    논문작성
    Dear abby
    논문작성법
    옵시디언
    IEEE
    Statics
    MATLAB
    머신러닝
    인공지능
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.3
보통의공대생
[코드] JAX-FEM 설명
상단으로

티스토리툴바