Rust의 병렬 행렬 곱셈

5892 단어 rustmachinelearning
대부분의 선형 대수학은 행렬 곱셈을 중심으로 이루어집니다. 이것은 O(n^3) 작업이므로 성능을 개선하려면 다른 기술이 필요합니다. 오늘은 행렬 곱셈을 병렬화하여 12스레드 머신에서 속도를 4배 향상시킨 방법을 보여드리겠습니다.

내 예제에서는 nalgebra 크레이트를 사용하여 행렬 저장소를 처리합니다(동적으로 크기가 조정된 행렬을 열 주요 순서로 Vec로 저장). 이것은 그다지 중요하지 않습니다. 정말 중요한 유일한 것은 행 및 열 인덱스로 행렬을 인덱싱하는 기능입니다.

단일 스레드 구현



행렬을 곱하려면 rhs의 열을 반복한 다음 각 열에 대해 lhs의 행을 반복해야 합니다. 그런 다음 lhs 행과 rhs 열을 압축하여 이 두 배열의 내적을 취합니다.

let l_shape = lhs.shape();
let r_shape = rhs.shape();

// check for shape compatibility here...

// the multiplication
let result: Vec<f64> = (0..r_shape.1).flat_map(move |rj| {
    (0..l_shape.0).flat_map(move |li| {
        (0..r_shape.0)
            .zip(0..l_shape.1)
            .map(move |(ri, lj)| {
                lhs.index((li, lj)) * rhs.index((ri, rj))
            })
            .sum::<f64>()
    })
})
.collect();

// result is a vec in column-major order


병렬화



이제 rayon를 사용하여 이 작업을 병렬화하겠습니다. 생각보다 훨씬 간단합니다!

let result: Vec<f64> = (0..r_shape.1).into_par_iter().flat_map(move |rj| {
    (0..l_shape.0).into_par_iter().flat_map(move |li| {
        (0..r_shape.0)
            .zip(0..l_shape.1)
            .map(move |(ri, lj)| {
                lhs.index((li, lj)) * rhs.index((ri, rj))
            })
            .sum::<f64>()
    })
})
.collect();



그게 다야! 범위 다음에 into_par_iter를 추가하기만 하면 됩니다. Intel i5-10600K(12) @ 4.800GHz에서 1000x1000 매트릭스에 1000x1 벡터를 곱한 결과 평균 실행 시간이 8ms에서 2ms로 줄었습니다. 속도가 75% 향상되었습니다.

좋은 웹페이지 즐겨찾기