1 of 56

제발 RL하면 JAX 합시다

  1. JAX 란?
  2. 이것이 JAX 다!!
  3. JAX 를 내가 써야 하나?
  4. RL 에서 JAX 가 더욱 강한 이유
  5. 저희 속도 말고 다른 생각을 해봅시다

2 of 56

이 PPT 의 필자는 딱히 JAX 와 관련해서 Google 과 접점이 있거나, 행사에 참여하거나, 집필에 참여하거나, 회사에서 이걸로 일을 하거나 하지 않습니다.

또한 RL 만으로 박사를 받거나, 그걸로 학회를 가보거나, 공신력 있는 논문을 퍼블리시 했거나, 체계적인 교육을 받거나 지도를 받은적도 없는 사람입니다.

이 모든 내용은 딱히 공신력도, 증거도, 연구도 없는 개인적 견해이니 흘러들어도 됩니다.

저는 RL을 사랑하는데, 걔도 절 사랑했으면 RL을 ‘짝사랑’ 하는 사람이라고 소개하진 않았겠죠. 지금 한국에 있지도 않았을테고.

면책조항

3 of 56

이 PPT 의 필자는 JAX로 4년 넘게 개인프로젝트를 해왔으며, 석사 졸업논문을 JAX로 완료하였고, 어떠한 공동작업이나 논문의 공식구현, LLM 붐에 탑승하지 않고 순수 JAX 구현 이라는 폐관수련에 가까운 레포들으로 총합 116 스타를 수집했습니다.

동일하게 RL을 ‘공부’ 한지 7년 넘게 일편단심 이었으며, 교수님과 싸우고 1학기 휴학 하면서도 석사 졸업논문을 RL으로 하였고, 어떠한 공동작업이나 논문의 공식구현, LLM 붐에 탑승하지 않고 순수 RL 구현 이라는 폐관수련에 가까운 레포들으로 총합 116 스타를 수집했습니다.

소위 말하는 스타 개발자나 연구자도 아닌데 JAX + RL 이라는 주제로 개인 레포에서 스타를 이만큼 모은거면 인정해 줘야 한다고 생각합니다.

변명조항

그나마

4 of 56

JAX 란?

  • JAX = Numpy + AutoGrad + XLA
    • Numpy : 수치계산 라이브러리
    • AutoGrad : 모든 함수에 대한 자동 도함수 변환
    • XLA : 가속화된 대수계산 컴파일러(CPU/GPU/TPU)�

5 of 56

JAX 란? : Numpy

  • Numpy
    • 원래의 대부분 numpy 함수들이 그대로 이식되어 있으며, 이를 통해 대부분의 연산을 해 낼수 있고 gpu 같은 하드웨어 가속이 자연스럽게 적용되므로. 강력한 성능을 쉽게 사용할 수 있음
    • JAX 자체에서는 대부분 numpy 연산만이 구현되어 있으며, 뉴럴 네트워크의 구현은 numpy 연산으로 우회적으로 구현 가능

6 of 56

JAX 란? : Numpy

  • Numpy
    • 원래의 대부분 numpy 함수들이 그대로 이식되어 있으며, 이를 통해 대부분의 연산을 해 낼수 있고 gpu 같은 하드웨어 가속이 자연스럽게 적용되므로. 강력한 성능을 쉽게 사용할 수 있음
    • JAX 자체에서는 대부분 numpy 연산만이 구현되어 있으며, 뉴럴 네트워크의 구현은 numpy 연산으로 우회적으로 구현 가능

7 of 56

JAX 란? : AutoGrad

  • AutoGrad
    • jax.grad : 입력된 ’함수’ 의 인수에 대하여 출력의 그래디언트를 반환하는 함수를 출력
    • 인수라면 뉴럴넷의 파라미터 또한 가능하므로, 이를 통해 backprop 된 그래디언트 값을 구할수 있고, 이를 통해 뉴럴넷을 학습 가능함
    • 그 외에 2차 3차 도함수도 아주 쉽게 변환이 가능하므로, 이러한 도함수들을 이용한 복잡한 구조의 설계 가능

8 of 56

JAX 란? : AutoGrad

  • AutoGrad
    • jax.grad : 입력된 ’함수’ 의 인수에 대하여 출력의 그래디언트를 반환하는 함수를 출력
    • 인수라면 뉴럴넷의 파라미터 또한 가능하므로, 이를 통해 backprop 된 그래디언트 값을 구할수 있고, 이를 통해 뉴럴넷을 학습 가능함
    • 그 외에 2차 3차 도함수도 아주 쉽게 변환이 가능하므로, 이러한 도함수들을 이용한 복잡한 구조의 설계 가능

9 of 56

JAX 란? : AutoGrad

  • AutoGrad
    • jax.grad : 입력된 ’함수’ 의 인수에 대하여 출력의 그래디언트를 반환하는 함수를 출력
    • 인수라면 뉴럴넷의 파라미터 또한 가능하므로, 이를 통해 backprop 된 그래디언트 값을 구할수 있고, 이를 통해 �뉴럴넷을 학습 가능함
    • 그 외에 2차 3차 도함수도 아주 쉽게 변환이 가능하므로, 이러한 도함수들을 이용한 복잡한 구조의 설계 가능

10 of 56

JAX 란? : XLA

  • XLA
    • XLA 적용 조건 : 입력과 출력의 형태, 길이, 타입 전부가 변경되지 않는 ‘함수’ 로 작성 되어야 함
    • Just In Time(JIT) Compilation : ‘함수’ 를 지정된 입력 형식에 맞춰 컴파일해 실행
    • vmap : 동일한 하드웨어 내에서 함수의 병렬화를 진행
    • pmap : 동시에 다른 하드웨어들 내에서 함수의 병렬화를 진행 (Multi GPU Support)

11 of 56

JAX 란? : XLA

  • XLA
    • XLA 적용 조건 : 입력과 출력의 형태, 길이, 타입 전부가 변경되지 않는 ‘함수’ 로 작성 되어야 함
    • Just In Time(JIT) Compilation : ‘함수’ 를 지정된 입력 형식에 맞춰 컴파일해 실행
    • vmap : 동일한 하드웨어 내에서 함수의 병렬화를 진행
    • pmap : 동시에 다른 하드웨어들 내에서 함수의 병렬화를 진행 (Multi GPU Support)

12 of 56

JAX 란? : XLA

  • XLA
    • XLA 적용 조건 : 입력과 출력의 형태, 길이, 타입 전부가 변경되지 않는 ‘함수’ 로 작성 되어야 함
    • Just In Time(JIT) Compilation : ‘함수’ 를 지정된 입력 형식에 맞춰 컴파일해 실행
    • vmap : 동일한 하드웨어 내에서 함수의 병렬화를 진행
    • pmap : 동시에 다른 하드웨어들 내에서 함수의 병렬화를 진행 (Multi GPU Support)

13 of 56

JAX 란? : XLA

  • XLA
    • XLA 적용 조건 : 입력과 출력의 형태, 길이, 타입 전부가 변경되지 않는 ‘함수’ 로 작성 되어야 함
    • Just In Time(JIT) Compilation : ‘함수’ 를 지정된 입력 형식에 맞춰 컴파일해 실행
    • vmap : 동일한 하드웨어 내에서 함수의 병렬화를 진행
    • pmap : 동시에 다른 하드웨어들 내에서 함수의 병렬화를 진행 (Multi GPU Support)

14 of 56

JAX 란? : Flax, Haiku, Equinox, Optax

  • 순수 JAX 는 위에서 설명한 Numpy, AutoGrad, XLA 이 3가지만 구현되어 있으므로, ML 연구(뉴럴넷) 을 위해서는 아래와 같은 라이브러리를 사용하여 뉴럴넷을 구현하고 학습해야함
    • Flax : 구글 브레인의 메인스트림 뉴럴넷 구현 라이브러리
    • Haiku : Deepmind 의 메인스트림 뉴럴넷 구현 라이브러리(였던것). 딥마인드가 구글 브레인과 합쳐지며 버려짐…
    • Equinox : 어느 개인이 관리중인 뉴럴넷 구현 라이브러리, 하지만 Flax 의 대안으로 떠오르고 있음. 인기 많음
    • Optax : JAX 의 유니버설한 Optimizer 구현. Flax 든 Haiku 든 Equinox 든 Optax 를 통해 학습됨

Flax

15 of 56

JAX 란? : Flax, Haiku, Equinox, Optax

  • JAX 로 작성되는 코드는 모두 순수 “함수” 형태를 띄어야 하므로, NN module 은 Function 의 역할을 함
  • NN의 Param 들은 그 함수의 인수로써 입력되는 형태로 작성됨
  • 그 외엔 생각 외로 PyTorch 랑 크게 다르진 않다

16 of 56

이것이 JAX 다!!: 희망편

GPU

CPU

TPU

함수형으로�잘 짜여진 코드

빠른 학습�빠른 인퍼런스

유연한 모듈 적용

효율적인 ML

  • 장점
    • XLA 기능을 기반으로 한 압도적인 최적화와 속도

    • 여러 하드웨어 종류(CPU, GPU, TPU) 에서 단순하고 유니버설 하게 동작하고 최적화된 코드

    • 함수형 프로그래밍을 기반으로 한 유연한 모듈 변경

    • 짜다 보면 역으로 OOP 보다 훨씬 직관적 일수도?

17 of 56

이것이 JAX 다!!: 파멸편

평생 구경해볼

일도 없는 TPU

뜨거운 GPU

할일 없는 CPU

처음보면 이해조차 �안가는 코드구조

타입 or 형태 변경으로�JIT 안되는 코드

디버그�(힘들다)

구현 안된 일반적인�데이터구조

컴파일�시간

  • 단점
    • 함수의 입출력 종류 제한이 엄격해 복잡한 구조의 프로그램의 경우 JIT 이 되는 구조를 만들기 어려움

    • Python 을 주로 개발해온 ML 개발자 들에게 함수형 프로그램은 많이 낫 설 수 있음

    • List, Set, Dictionary 등의 동적 형태 데이터구조 사용 X – 당신에겐 고정 크기 np.array 와 tuple 들이 있을 뿐

    • 디버그가 어려움

18 of 56

이것이 JAX 다!!: 성능편

  • 그럼에도 불구하고 말 그대로 강력하다
    • XLA 는 다른 수치연산 ‘컴파일러’ 들 과도 비교해도 상당히 강력하고
    • 충분히 뉴럴 네트워크를 구현할 수 있지만, 그것만 구현할 수 있는 것도 아니다
    • Numpy + AutoGrad + XLA 오직 그것 만으로도 정말 많은 것이 가능하다

19 of 56

이것이 JAX 다!!: 성능편

  • 다양한 네트워크와 예시에서 학습 시간 측정 비교

  • 일반적인 모델 학습에서도 기본 2배정도의 학습 속도 향상

  • Pytorch 는 이미 PyTorch Lightning 을 기반으로 최적화 된 상태

  • Vision Transformer 에서는 1.1배정도의 적은 속도 상승(아마 Torch Vision 지원으로 인한 최적화가 원인 인듯)

평균 : 3배 학습 속도

평균 : 4배 학습 속도

평균 : 2배 학습 속도

평균 : 2배 학습 속도

극적인 학습 속도 상승 X

20 of 56

이것이 JAX 다!!: 성능편

  • OpenAI 의 Whisper 모델을 JAX 로 최적화한 파이프라인으로 인퍼런스 했을때, 기존의 Pytorch 구현과 비교해 최대 70배가 넘는 속도로 실행할 수 있음

  • Python 위에서 동작하는 프로그램, 동시에 학습용으로 작성된 프로그램을 C++ 나 여타 최적화된 Cuda 프로그램, 또는 TensorRT 에 가깝게 또는 그 이상으로 최적화 할 수 있다는 것은 매우 매력적인 일

21 of 56

이것이 JAX 다!!: 회사편

First Party

Second Party

Google

DeepMind

Anthropic

Cohere

Apple

xAI

  • Gemini…
  • Alpha Fold…
  • All Researches…

All Claude models

Apple Intelligence�Engine

All Grok models

HuggingFace

JAX supports

Whisper JAX

22 of 56

JAX 를 내가 써야 하나? - RL 이 아니면 : 하면 좋다

  • 그냥 지도학습이나 비지도 학습을 하는거라면 ‘엄청나게’ 큰 매리트 까지는 없음
  • PyTorch 는 많은 자료와 추가적인 최적화 라이브러리(Flash Transformer 등등…)도 많으며, 학습 파이프라인이 이미 많이 최적화 되어있음
  • 그럼에도 불구하고, JAX 의 기본적인 속도 향상은 상당히 매력적이고 강력함

VS

23 of 56

JAX 를 내가 써야 하나? - RL 이라면 : 무조건 좋다

VS

  • 하지만 RL 은 주로 구성되는 연산들의 특수성으로 인해 JAX 는 RL 을 매우 강력하게 최적화 할 수 있음!!
  • 현재 딥마인드에서 발표되는 대부분의 RL 논문은 JAX 를 사용하고 있으며, 그 외의 RL 연구 커뮤니티에서도 많이 이용되고 있는 추세

24 of 56

RL 에서 JAX 가 더욱 강한 이유: RL 의 특수성?

  1. 복잡한 학습 파이프라인

  • 환경 탐색 루프

  • 병렬화에 따른 이득

  • 인퍼런스의 복잡성

25 of 56

RL 의 특수성: 복잡한 파이프라인

  • RL 의 ‘학습’ 에서 가장 중요하고, 오래 걸리는 연산은 사실 grad 를 계산하고 optimize 하는 것이 아님

  • 근사함수인 Model 을 인퍼런스 하고, 데이터를 다시 재가공해서 모델이 예측해야 할 값을 재생산 하는 연산이 가장 오래 걸리고 중요

  • 이러한 복잡한 파이프라인을 최적화 하는것이 RL 의 ‘알고리즘’ 성능 최적화 에서 가장 핵심적

26 of 56

RL 의 특수성: 복잡한 파이프라인

  • 물론 Pytorch 에서 저 연산들을 CPU 에서 수행하는 것은 아니지만, 저렇게 많은 파이프라인의 중간중간 연결을 CPU 에서 관리하고 연결하고 실행시키면서 많은 GPU 사용률이 낭비됨

  • 매우 무거운 모델을 실행하는 것뿐만 아닌 간단하게 곱하고, 더하고, 나누는 과정이 전부 python 으로 된 CPU 코드에서 관리 되는것.

27 of 56

RL 의 특수성: 복잡한 파이프라인

  • 하지만 JAX 의 JIT(Just In Time) compilation 을 사용하면 위와 같이 하나의 GPU code 로 변환이 가능하며, 중간 중간 CPU로 관리해가며 생기던 속도 저하를 거의 완전히 배제할 수 있음

  • 모델을 실행하고, 라벨과 그래디언트에 따라 최적화만 하면 되는 지도학습이나 비지도학습은 이렇게 최적화 되는 단계가 적지만, RL 은 이러한 중간 과정이 매우 많기때문에, 같은 환경에서 DQN의 torch 구현과 jax 구현의 동작 속도는 매우 큰 차이가 남

28 of 56

RL 의 특수성: 복잡한 파이프라인

개인적인 구현에서는 같은 CPU 환경을 해결하는 DQN 에서 2 ~ 3 배의 프레임을 달성하였음

Stable Baselines Jax 에서는 SB3(pytorch) 와 비교하여, 최대 20배의 학습 속도가 나온다고 ‘주장’

29 of 56

RL 의 특수성: 환경 탐색 루프

  • RL 은 모델을 학습 하는것 뿐만 아닌 ‘환경’ 과 계속해서 상호작용 하며, 변화된 정책에 따른 결과를 수집해 가며 학습이 필요함

  • 하지만. Env 는 CPU 에서 작성되었고, 결정을 내리는 Model 은 GPU 에 존재하므로 필연적으로 계속해서 CPU와 GPU 메모리간 통신이 이루어져야 함

  • 하지만 state 하나를 model 이 평가하는 데에는 오래 걸리지 않고, 환경이 행동 하나를 진행하는데에도 오래 걸리지 않는다면… 학습 과정중 대부분을 차지 하는건 통신

30 of 56

RL 의 특수성: 환경 탐색 루프

  • GPU 에서 데이터를 처리하는것은 CPU보다 훨씬 빠를테지만, 처리 시간의 감소는 완전히 정비례하지 않음
  • 위 그림처럼, 데이터를 전송하는데 걸리는 시간이 명백히 존재

31 of 56

RL 의 특수성: 환경 탐색 루프

  • 그렇다면 환경을 GPU 에서 실행한다면?

  • GPU 와 CPU 사이에서 발생하는 통신이 필요 없어 매우 높은 GPU 사용률을 달성 가능함

  • 또한, 이 환경과의 ’통신’ 과 환경의 동작도 ‘파이프라인’이므로… jax.jit 으로 모든 학습 루프를 하나의 GPU 코드로써 실행 가능

  • 이로서 이전보다 매우 잘 최적화된 학습 루프를 구성할 수 있음

32 of 56

RL 의 특수성: 환경 탐색 루프

  • 하지만 매우 치명적인 문제가 존재

  • 기존의 Python 이나 그 외, CPU 에서 동작하는 환경을 GPU 에 들어가도록 만들 수는 없음

  • Jax 로 순수하게 작성된 환경 코드가 필요하며, JAX의 특수한 문법을 완벽히 맞추면서 매우 복잡한 환경을 작성하는 것은 거의 고문에 가까움

33 of 56

RL 의 특수성: 환경 탐색 루프

  • 그럼에도 불구하고, 인내의 과실은 달다

  • GPU 에서 모든 환경 탐색 과정을 포함한 학습 루프를 JIT 해 동작시키는 End-to-End RL Training 은 Pytorch 로 작성된 Clean RL 과 비교하여 10배 이상의 속도를 보장

34 of 56

RL 의 특수성: 환경 탐색 루프

  • 이러한 명확한 이점에 힘입어, 많은 환경들이 새롭게 JAX 네이티브로 작성되고 있음.

35 of 56

RL 의 특수성: 병렬화에 대한 이득

  • RL 의 알고리즘에 따라 다르지만, 대부분 한 epoch 동안 수집되는 데이터의 수가 많으면 많을수록 성능과 학습 시간에서 매우 큰 이득이 있음

  • 이 때문에 Offpolicy 에서는 Ape-X 라는 방법론, Policy Optimize 에서는 기본적으로 대부분 병렬 환경을 사용하며. 이어서는 싱글 프로세스에서는 부족한 규모의 병렬화를 위해 Impala, DDPPO 같은 방법론들까지 등장

  • 하지만, CPU에서의 빠른 병렬화된 환경을 위해서는 프로세스들의 분리가 필요하지만… GPU 는?

36 of 56

RL 의 특수성: 병렬화에 대한 이득

  • JAX 로 잘 작성되고, 조건을 만족하는 함수는 jax.vmap 을 사용하여, 해당 동작을 병렬적으로 동작하는 함수로 바로 변환 가능

  • 이를 통해 CPU 환경들을 사용해야하는 알고리즘들과 달리 학습 구조가 복잡해질 필요가 없고 선형적인 샘플 수집, 성능 향상 가능

37 of 56

RL 의 특수성: 병렬화에 대한 이득

  • 이러한 특성으로 인해, JAX로 작성된 환경은 거의 선형적으로 샘플 수집 속도를 늘릴 수 있고

  • JAX 로 End to End 로 작성된 알고리즘은 거의 선형적으로 학습 규모를 늘릴 수 있음.

38 of 56

RL 의 특수성: 인퍼런스의 복잡성

  • 모든 강화학습이 그런 건 아니지만, 일부 매우 복잡한 문제를 풀어야 하는 경우 액션을 도출하는 인퍼런스 자체가 매우 복잡한 경우들이 있음
  • 이 경우 복잡한 로직에 따라 모델의 계산을 여러번 거치는 파이프라인이 필요하며, 이런 로직은 단순 코드 최적화 뿐만 아니라, 하드웨어간 데이터 전송이 빈번 하기 때문에 연산 가속이 어려움
  • 가장 대표적인 경우가 MCTS 를 사용하는 AlphaZero, MuZero 류의 알고리즘

39 of 56

RL 의 특수성: 인퍼런스의 복잡성

  • Deepmind 가 작성한 MCTS 의 JAX 구현

  • 서치 루프 전체를 JAX 로 구현하여 C++ 구현과 비교해도 경쟁력이 있다고 주장

  • AlphaZero 의 경우 환경이 JAX 로 작성되어야 하고, Muzero 는 환경이 월드 모델이므로 기본적으로 JAX 로 구현됨

40 of 56

RL 의 특수성: 인퍼런스의 복잡성

  • 기존의 CPU 와 GPU 로 함께 작성될 경우 AlphaZero 와 Muzero 또한 이전에 설명한 데이터 전송 문제와 직면

  • CPU 에서 작성된 코드가 C++ 로 잘 최적화 되어있어도, 이런 문제로 GPU 의 실제 사용량이 크게 낮을 수 있음

41 of 56

RL 의 특수성: 인퍼런스의 복잡성

  • GPU 에서 전체 루프가 동작할 경우 이런 문제가 배제됨

  • 하지만 GPU의 싱글코어 속도는 CPU 에 비하여 매우 느리므로, 아직까지는 압도적인 성능차이를 얻을 수는 없음

  • 이를 위해선, MCTS 의 구현이 배치 단위로 재구성이 이루어져야 함

42 of 56

RL 의 특수성: 인퍼런스의 복잡성

  • 개인 프로젝트로 진행한 A* 의 JAX 구현

  • 서치 루프 전체를 JAX 로 구현

  • 이전 구현(Python & JAX) 대비 1500 배�C++ 구현(C++ & torch(batch)) 대비 30 배�속도

43 of 56

RL 의 특수성: 인퍼런스의 복잡성

  • 여기서 또한 데이터 전송 문제가 발생

  • 뉴럴넷으로 동작하는 월드 모델을 사용할 경우 거의 모든 작업에 데이터 전송이 필요

44 of 56

RL 의 특수성: 인퍼런스의 복잡성

  • 여기서도 GPU 에서 전체 루프가 동작할 경우 이런 문제가 배제됨

  • 이를 통해 상당한 성능적 이득을 낼수 있음

  • 월드 모델을 사용하는 A* 의 경우 기존 논문의 결과(C++ & Pytorch) 와 비교하여 300 배 이상의 성능 차이를 보임

45 of 56

저희 속도 말고 다른 생각을 해봅시다

46 of 56

저희 속도 말고 다른 생각을 해봅시다

47 of 56

저희 속도 말고 다른 생각을 해봅시다

제발 RL하면 JAX 합시다

48 of 56

저희 속도 말고 다른 생각을 해봅시다

49 of 56

저희 속도 말고 다른 생각을 해봅시다

  • 환경을 빠르게 만들어서 해결하려고 하면 진짜 끝이 없다

  • 왜? 차라리 모든 트라젝토리 다 모아서 DNN 없는 Q러닝으로 풀자고 하지

  • 시뮬레이터의 원본 코드가 있어도 어려운 환경의 JAX 화 작업을 시뮬레이터를 어떻게 만들어야 하는지 막막한 실제 문제에 어떻게 적용할건가?

  • 남은 평생을 새 도메인에 맞춘 시뮬레이션을 만들고 최적화 하는 데에 보낼 건가?

50 of 56

저희 속도 말고 다른 생각을 해봅시다 - 갈!!!!!!

  • 정말 ‘중요한’ 문제는 시뮬레이터를 만들 가치가 있다

Ex) 분자 예측 시뮬레이션, 등등…

  • 시뮬레이터를 만드는건 쉬운데, 푸는건 어려운 문제들도 많이 있다

Ex) 바둑, 조합최적화, 등등…

  • Mujoco JAX, Nvidia Isaac sim, Genesis 등 어려운 3D 와 그런 시뮬레이션을 위한 GPU 시뮬레이터 들 잔뜩 나오는데 뭐가 문제야!

  • 빠를수록 좋은건 당연한건데 왜 트집임? 설마 점마 ‘분탕’ 아님?

51 of 56

저희 속도 말고 다른 생각을 해봅시다

아니 그… 빠른게 좋긴 한데… 근본적인 해결책은 아니라는거죠…

샘플 많이 쓸 생각 하기 전에 샘플을 최대한 적게 쓰고 잘푸는법을 고민해야 하는거 아님?

52 of 56

저희 속도 말고 다른 생각을 해봅시다

그럼 이대로만 갑시다!

그래서 이렇게 하자고?

53 of 56

저희 속도 말고 다른 생각을 해봅시다

근데 JAX 랑 RL 발표인데�이런 내용 들어가는게 맞음?

샘플 효율적 이면 다 된다는 이야기 잖아?

그 샘플 효율적인 알고리즘의 복잡한 연산 최적화는 조상님이 해주시나?

샘플 효율적 이려면 그만큼 몸비틀어서 알고리즘이 복잡해질텐데

복잡한 워크플로우와 거기서 오는 오버헤드를 JAX 로 최적화 할수 있다니까?

MCTS에 A*도 짜고 저만큼 빨라지는데 안하는게 이상하지 않음?

54 of 56

저희 속도 말고 다른 생각을 해봅시다

둘다 다 하자는 이야기구나!

생각해보면 샘플 효율적 알고리즘은 환경의 속도가 아니라 학습 알고리즘의 최적화 정도 쪽에 주도권이 오니까 JAX로 최적화 하면 그 효과가 더 크겠네?

55 of 56

마치며

  • 아직 ML 연구와 구현중 PyTorch 는 무시할 수 없는 거인
  • 하지만 JAX 는 명확한 강점으로 점점 커져가고 있는 대안 중 하나
  • 역으로 최신 RL 에서는 이제 메인스트림에 가까워지고 있는 중

56 of 56

다시한번 말씀드리자면�우리 모두 제발 RL하면 JAX 합시다�아니어도 JAX 해보면 좋을 겁니다

감사합니다