내가 필요해서 정리하는 JAX(1)(JAX의 기초부터 XLA까지)

2023. 4. 18. 16:54AI 기술

JAX를 이해하기전 꼭 필요한 프레임워크 지식

1. Google

Google은 2015년에 기계학습 라이브러리 Tensorflow를 오픈소스로 공개한다.

 

2. Facebook(+ Microsoft)

Facebook은 2016년에 딥러닝 라이브러리 Pytorch를 오픈소스로 공개한다.

 

JAX는 Google에서 개발되었고, 이 말은 XLA가 자유롭게 가능하다는 것이다. 물론 최근(2019년)에 pytorch도 XLA가 가능하게끔 google과 facebook간의 협약이 있었다고 하지만, pytorch를 주로 쓰는 나인데도 XLA라는 용어를 이번에 처음듣게 되었다. 그래서 JAX를 알아보기 전에 XLA를 간단히 알고가려 한다.

 

XLA(=Accelerated Linear Algebra)는 Tensorflow의 서브 프로젝트로 그래프 연산의 최적화를 목적으로 하는 complier 이다.
이로인해 얻는 이득은 실행속도의 비약적인 상승과 메모리 사용량을 낮추어 사실상, 자원의 부담을 크게 덜어준다. 

 

 

 

JAX를 사용하는 이유

JAX의 진실

여러분들이 JAX를 처음 들었거나, 전에 들어본적이 있다면 어떤 종류의 것이라고 생각할 것인가?

대부분은, 차세대 딥러닝 프레임워크라고 생각할 것이다.

하지만, 그것은 JAX의 일부분일 뿐이다.

JAX는 딥러닝 프레임워크나 라이브러리가 아니다.  (딥러닝은 JAX가 할 수 있는것의 하위 집합일 뿐이다.)

 

그래서 JAX가 뭔데?
JAX is a high performacne, numerical computing libary

 

높은 성능의 기계학습을 위한 자동미분(Auto grad) + XLA(Tensorflow model compiler)

 

JAX는 XLA를 사용하여 Numpy 프로그램을 GPU(or TPU)에 올려서 coplie 할 수 있게한다.

 

JAX는 그럼에도 불고하고 

 

JAX vs Pytorch

 

                                                                                             JAX                                                        Pytorch

자동 미분 자동미분지원, 역전파 알고리즘을 자동으로 계산가능 자동미분지원, 역전파 알고리즘을 자동으로 계산가능
컴파일러 JIT 컴파일러를 내정하여 모델을 최적화 마찬가지로 JIT 컴파일러 사용
호환성 Pytorch와 Numpy 호환가능, 쉽게 변환하여 사용. Numpy 계산을 빠르게 가능 Numpy를 Tensor 변환 과정 필요
분산 훈련 사용가능 사용가능

 

JAX vs Numpy
import jax.numpy
import numpy
import math
import jax
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
def test_cal(x):
    return x*x*x
num = numpy.random.randn(10000, 10000).astype(dtype='float32')
print(num)

##result##
[[-1.8918115  -0.3903762  -1.179873   ...  0.5916424   0.15500794
   0.22397971]
 [ 0.8290976   0.42299742 -1.2149614  ...  0.38769847  0.66221833
  -1.0417027 ]
 [ 0.5024024   0.42093113  0.20726153 ... -0.48331222 -0.3716934
  -0.7382673 ]
 ...
 [ 1.784616   -0.8545765   0.25831875 ...  0.10218267  1.3564131
  -0.24079399]
 [-1.5958557  -0.54271454 -0.6889319  ...  0.58296067  0.2260301
   0.27385828]
 [ 1.2249873   0.12241207 -0.2876963  ...  0.4413192  -0.88594097
  -0.15549691]]
%timeit -n5 test_cal(num)


##result##
98.4 ms ± 1.18 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)
jax_fn = jax.jit(test_cal)
num = jax.numpy.array(num)  ## << GPU에 array가 올라가서 연산을 준비함
%timeit jax_fn(num)

##result##
810 µs ± 235 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

 

 

jax.numpy에 올리는 순간 GPU를 사용하는 것을 볼 수있다.

 

* 생각보다 GPU 메모리를 많이 잡아먹는다.. 왜그러지? --> 물론 A100 8G짜리라 부담은 가지 않지만 더적은 GPU를 사용했을때 어떨지 궁금하긴하다...

 

 

Trend

최근 나온 논문들, 및 github의 code를 살펴보면 jax로 구성된 것들이 꽤 보이기 시작한다.

심지어, 원래는 Tensorflow / Pytorch로 짜여졌던 코드들이 JAX version으로도 출시되고 있다.