[Package] 시험 최적 운송 문제의 포장 | POT, OTT-jax, Optimal Transport.jl

최상의 운송은 최근 몇 년 동안 기계 학습 분야에서 주목받은 화제다.
https://speakerdeck.com/eumesy/how-to-leverage-optimal-transport
https://sites.google.com/view/uda-0x-seminar/home/0x03
전생에도 Optimal Transport에 관한 PDF 파일(또는 책이라고 할까요)에 대한 기록을 쓴 적이 있다(아래의 Qita 검색을 보십시오).
https://qiita.com/search?q=takilog+computational+optimal+transport
그동안 파이톤의 대표적인 포장은 POT (Python Optimal Transport), 줄리아의 포장OptimalTransport.jl 등이었다.최근 온라인 서핑을 하다 구글리서치가 삼가 만든 JAX 패키지Optimal Transport Tools (OTT)를 발견해 비교해봤다.

전제 조건


원래 JAX의 포장을 생각해 봤는데 익숙하지 않아서 유감입니다
POT와 OTT-jax의 비교는 공식 문서에 공개된 겁니다.
이런 기사를 참고하면 GPU와 TPU 등을 지원하는 고속 Numpy라고 보면 된다. OT의 계산에서는 Sinkhorn과 Entropy-regularized Sinkhorn이 매트릭스 계산을 많이 하는데,빠른 속도로 실현될 조짐이 보입니다. 우선 다음과 같은 OT의 실례로 (\mathbf{a},\mathbf{b},\mathbf{C}). 다음은 Juria의 코드입니다.
using Distances

n = 10 # data size
dims = 3 # random data dimension
a = rand(n)
b = rand(n)
x = rand(n, dim) # data は 1次元目
y = rand(n, dim) # ↑

# 1次元目をデータとしたpair-wise distance
# 距離はSqEuclidean()
C = pairwise(Distances.SqEuclidean(), x, y, dims=1) 
이렇게 주변 분포\mathbf{a},\mathbf{b}와 운송비용\mathbf{C}을 만든 후 나머지는 Pythn/Juria 코드만 호출합니다.

Python


설치에 대해서는 https://zenn.dev/koshian2/articles/af6758a5f3efc2를 참조하십시오.

CPU


먼저 CPU 환경에서 실험을 실시했는데 반복 횟수는 1000이고 오차의 끝 조건은 공식과 같이 1e-2이다.
실행 중인 로컬 환경은 Ryzen7 Pro4750G+32GB main memory의 로컬 PC입니다. 데이터의 차원은 공식적으로는 크지만 로컬 환경이기 때문에 nkin[32641281285651211024]으로 설정합니다. 엔트로피의 정규화 계수는 공식과 같습니다\epsilon\in[0.10.01]. 결과는 아래 도표와 같습니다.
계산 시간은 행렬 계산의 무게에 따라 증가하는 모습을 볼 수 있다.

GPU


이어 GPU 환경에서 실험을 진행했는데 반복 횟수는 1000이고 오차의 종료 조건은 공식과 같이 1e-2이다.
실행 환경은 Google Colaboratory의 Free가 사용하는 노트북 환경으로, 실행 시간을 GPU로 설정한 후 같은 노트북에서 실험했다.
결과는 그림과 같다. POT에 관해서는 CPU의 집행이 거의 같은 행동을 보였다. 한편, OTT-jax에 관해서는 n=1024의 환경까지 계산 시간이 거의 고정되어 있다. 배경에 따라 사용하는 JAX의 성질(일본어는 이상하다).고속 연산을 하고 있는 모습을 볼 수 있다. 공식 노드북에서 n=4096의 범위를 상당한 속도로 계산할 수 있음을 암시하고 있다.

Julia


그런데 요즘은 근처에서 간혹 이름이 들리는 주리아라는 언어가 있다. 여기에는 Optimal Transport용 프로그램 라이브러리도 있다. 비슷한 예제가 주리아에서 시행되고 있으니 행동을 살펴보자. 이것에 관해서도 CUDA에 있다.jl 등의 지원이 있었지만, 이번에는 쥬리아의 GPU 환경이 검증되지 않아 CPU 버전과만 비교하기로 했다.
파이톤은 *%timeit%*를 이용해 시간 측정을 했고, 주리아 버전은 게으름을 피우며 적당한 5번의 평균치만 얻었다. 실험 부분의 실시는 이렇다.
using Statistics
using Distances
using OptimalTransport

dim = 3
for n in [32, 64, 128, 256, 512, 1024]
    # 人工データ
    x = rand(n, dim)
    y = rand(n, dim)
    a = rand(n)
    b = rand(n)
    a ./= sum(a)
    b ./= sum(b)
    C = pairwise(Distances.SqEuclidean(), x, y, dims=1)

    # ここはクソ測定コードです
    sol = Float64[]
    for eps in [0.01, 0.1]
        for _ in 1:5
            t1 = time()
            P = sinkhorn(a, b, C, eps; maxiter=1000, atol=1e-2)
            tsolve = time() - t1
            push!(sol, tsolve)
        end

        # store mean(sol) to plot
    end
end
알고리즘의 구현이지만 일반 싱커온 알고리즘(Sinkhorn Gibbs),log-domain이 구현한(Sinkhorn Stabilized) 등으로 구현한 것 같아 두 가지를 비교했다.POT와 OTT-jax에서threshold=1e-2 등 지정한 것이 아톨이라는 매개 변수에 해당한다는 것을 읽었기 때문에 아톨로 지정했습니다.
결과적으로.SS는 SinkhornStabilized를 사용하고 SO는 일반적인 Sinkhorn을 사용합니다.

파이톤과 주리아의 랜덤 수 데이터가 맞지 않아 촉각만 봤을 뿐, 가볍게 본 느낌도 주리아의 실현이 상당히 좋은 느낌이었기 때문이다.jl을 사용하는 환경에서 가속화된 Optimal Transport.제이엘도 비교해보고 싶은데.

좋은 웹페이지 즐겨찾기