제발 RL하면 JAX 합시다
이 PPT 의 필자는 딱히 JAX 와 관련해서 Google 과 접점이 있거나, 행사에 참여하거나, 집필에 참여하거나, 회사에서 이걸로 일을 하거나 하지 않습니다.
또한 RL 만으로 박사를 받거나, 그걸로 학회를 가보거나, 공신력 있는 논문을 퍼블리시 했거나, 체계적인 교육을 받거나 지도를 받은적도 없는 사람입니다.
이 모든 내용은 딱히 공신력도, 증거도, 연구도 없는 개인적 견해이니 흘러들어도 됩니다.
저는 RL을 사랑하는데, 걔도 절 사랑했으면 RL을 ‘짝사랑’ 하는 사람이라고 소개하진 않았겠죠. 지금 한국에 있지도 않았을테고.
면책조항
이 PPT 의 필자는 JAX로 4년 넘게 개인프로젝트를 해왔으며, 석사 졸업논문을 JAX로 완료하였고, 어떠한 공동작업이나 논문의 공식구현, LLM 붐에 탑승하지 않고 순수 JAX 구현 이라는 폐관수련에 가까운 레포들으로 총합 116 스타를 수집했습니다.
동일하게 RL을 ‘공부’ 한지 7년 넘게 일편단심 이었으며, 교수님과 싸우고 1학기 휴학 하면서도 석사 졸업논문을 RL으로 하였고, 어떠한 공동작업이나 논문의 공식구현, LLM 붐에 탑승하지 않고 순수 RL 구현 이라는 폐관수련에 가까운 레포들으로 총합 116 스타를 수집했습니다.
소위 말하는 스타 개발자나 연구자도 아닌데 JAX + RL 이라는 주제로 개인 레포에서 스타를 이만큼 모은거면 인정해 줘야 한다고 생각합니다.
변명조항
그나마
JAX 란?
JAX 란? : Numpy
JAX 란? : Numpy
JAX 란? : AutoGrad
JAX 란? : AutoGrad
JAX 란? : AutoGrad
JAX 란? : XLA
JAX 란? : XLA
JAX 란? : XLA
JAX 란? : XLA
JAX 란? : Flax, Haiku, Equinox, Optax
Flax
JAX 란? : Flax, Haiku, Equinox, Optax
이것이 JAX 다!!: 희망편
GPU
CPU
TPU
함수형으로�잘 짜여진 코드
빠른 학습�빠른 인퍼런스
유연한 모듈 적용
효율적인 ML
이것이 JAX 다!!: 파멸편
평생 구경해볼
일도 없는 TPU
뜨거운 GPU
할일 없는 CPU
처음보면 이해조차 �안가는 코드구조
타입 or 형태 변경으로�JIT 안되는 코드
디버그�(힘들다)
구현 안된 일반적인�데이터구조
컴파일�시간
이것이 JAX 다!!: 성능편
이것이 JAX 다!!: 성능편
평균 : 3배 학습 속도
평균 : 4배 학습 속도
평균 : 2배 학습 속도
평균 : 2배 학습 속도
극적인 학습 속도 상승 X
이것이 JAX 다!!: 성능편
이것이 JAX 다!!: 회사편
First Party
Second Party
DeepMind
Anthropic
Cohere
Apple
xAI
All Claude models
Apple Intelligence�Engine
All Grok models
HuggingFace
JAX supports
Whisper JAX
JAX 를 내가 써야 하나? - RL 이 아니면 : 하면 좋다
VS
JAX 를 내가 써야 하나? - RL 이라면 : 무조건 좋다
VS
RL 에서 JAX 가 더욱 강한 이유: RL 의 특수성?
RL 의 특수성: 복잡한 파이프라인
RL 의 특수성: 복잡한 파이프라인
RL 의 특수성: 복잡한 파이프라인
RL 의 특수성: 복잡한 파이프라인
개인적인 구현에서는 같은 CPU 환경을 해결하는 DQN 에서 2 ~ 3 배의 프레임을 달성하였음
Stable Baselines Jax 에서는 SB3(pytorch) 와 비교하여, 최대 20배의 학습 속도가 나온다고 ‘주장’
RL 의 특수성: 환경 탐색 루프
RL 의 특수성: 환경 탐색 루프
RL 의 특수성: 환경 탐색 루프
RL 의 특수성: 환경 탐색 루프
RL 의 특수성: 환경 탐색 루프
RL 의 특수성: 환경 탐색 루프
RL 의 특수성: 병렬화에 대한 이득
RL 의 특수성: 병렬화에 대한 이득
RL 의 특수성: 병렬화에 대한 이득
RL 의 특수성: 인퍼런스의 복잡성
RL 의 특수성: 인퍼런스의 복잡성
RL 의 특수성: 인퍼런스의 복잡성
RL 의 특수성: 인퍼런스의 복잡성
RL 의 특수성: 인퍼런스의 복잡성
RL 의 특수성: 인퍼런스의 복잡성
RL 의 특수성: 인퍼런스의 복잡성
저희 속도 말고 다른 생각을 해봅시다
저희 속도 말고 다른 생각을 해봅시다
저희 속도 말고 다른 생각을 해봅시다
제발 RL하면 JAX 합시다
저희 속도 말고 다른 생각을 해봅시다
저희 속도 말고 다른 생각을 해봅시다
저희 속도 말고 다른 생각을 해봅시다 - 갈!!!!!!
Ex) 분자 예측 시뮬레이션, 등등…
Ex) 바둑, 조합최적화, 등등…
저희 속도 말고 다른 생각을 해봅시다
아니 그… 빠른게 좋긴 한데… 근본적인 해결책은 아니라는거죠…
샘플 많이 쓸 생각 하기 전에 샘플을 최대한 적게 쓰고 잘푸는법을 고민해야 하는거 아님?
저희 속도 말고 다른 생각을 해봅시다
그럼 이대로만 갑시다!
그래서 이렇게 하자고?
저희 속도 말고 다른 생각을 해봅시다
근데 JAX 랑 RL 발표인데�이런 내용 들어가는게 맞음?
샘플 효율적 이면 다 된다는 이야기 잖아?
그 샘플 효율적인 알고리즘의 복잡한 연산 최적화는 조상님이 해주시나?
샘플 효율적 이려면 그만큼 몸비틀어서 알고리즘이 복잡해질텐데
복잡한 워크플로우와 거기서 오는 오버헤드를 JAX 로 최적화 할수 있다니까?
MCTS에 A*도 짜고 저만큼 빨라지는데 안하는게 이상하지 않음?
저희 속도 말고 다른 생각을 해봅시다
둘다 다 하자는 이야기구나!
생각해보면 샘플 효율적 알고리즘은 환경의 속도가 아니라 학습 알고리즘의 최적화 정도 쪽에 주도권이 오니까 JAX로 최적화 하면 그 효과가 더 크겠네?
마치며
다시한번 말씀드리자면�우리 모두 제발 RL하면 JAX 합시다�아니어도 JAX 해보면 좋을 겁니다
감사합니다