JAX in Cell 차세대 차별 가능한 입자셀 시뮬레이터
초록
JAX‑in‑Cell은 1차원 3속도(1D3V) 전자기 입자‑셀(PIC) 코드를 JAX로 완전 구현한 프레임워크이다. JIT 컴파일과 자동 벡터화를 활용해 CPU·GPU·TPU에서 전통적인 C/Fortran 기반 PIC 코드와 동등한 성능을 달성한다. 명시적 보리스와 암시적 크랭크‑니콜슨(피카드 반복) 두 가지 시간 적분 스킴을 제공하며, 자동 미분을 통해 물리 파라미터 최적화와 베이지안 추론 등 차별 가능한 시뮬레이션을 가능하게 한다.
상세 분석
JAX‑in‑Cell은 Vlasov‑Maxwell 방정식을 Yee 격자 위에 이산화하고, 입자‑그리드 상호작용을 삼각형(3‑point) 스플라인 커널로 일관되게 처리한다. 전하와 전류는 연속 방정식에 기반한 전하 보존 스키마를 사용해 격자에 deposit되며, 디지털 필터링을 통해 고주파 노이즈와 그리드 히팅을 억제한다. 시간 적분은 두 가지 경로를 제공한다. 첫 번째는 전통적인 보리스 알고리즘을 이용한 명시적 스킴으로, 입자 속도 회전에 로렌츠 힘을 정확히 반영한다. 두 번째는 전자기장과 입자 움직임을 동시에 해결하는 암시적 크랭크‑니콜슨 스킴으로, 피카드 반복을 통해 수렴한다. 이 암시적 방법은 에너지 보존 특성이 뛰어나며, 큰 시간 스텝에서도 수치적 안정성을 유지한다. JAX의 jax.lax.scan과 jax.vmap을 활용해 전체 타임스텝을 함수형으로 구현함으로써 상태를 불변 튜플로 전달하고, 자동 미분이 가능한 단일 함수 형태를 만든다. 이는 파라미터에 대한 그래디언트를 직접 계산할 수 있게 해, 두‑스트림 불안정성의 성장률을 드리프트 속도에 대해 최적화하는 예시와 같이 고차원 최적화 문제에 바로 적용할 수 있다. 또한, 다중 종(species) 입자를 하나의 전역 배열에 통합해 SIMD 패러다임에 최적화했으며, 전통적인 객체 지향 PIC 구현과 달리 GPU 워프 발산을 최소화한다. 성능 평가에서는 AMD EPYC CPU 대비 NVIDIA A100 GPU에서 약 100배 가량 가속을 보였으며, 32‑bit 연산 모드가 메모리 사용량을 절감하지만 64‑bit와 비교해 미세한 물리적 차이를 야기할 수 있음을 보고한다. 전체적으로 JAX‑in‑Cell은 현대 머신러닝 인프라와 자연스럽게 결합되는 차별 가능한 전자기 PIC 플랫폼으로, 교육용 스크립트와 대규모 연구 코드 사이의 격차를 메우는 역할을 수행한다.
댓글 및 학술 토론
Loading comments...
의견 남기기