[NumPy] Axis

41802 단어 pythonnumpynumpy

axis는 축을 의미한다. 그리고 통상적으로 N차원 그래프는 아래와 같은 축을 가진다.

  • 1차원 그래프는 x축
  • 2차원 그래프는 y축, x축
  • 3차원 그래프는 z축, y축, x축


그리고 이 축은 위의 그림처럼 좌표평면에 구현할 수 있다. 1차원 그래프는 직선 상의 모음이고, 2차원 그래프는 수학을 배울 때 접했다.
반면 3차원 그래프는 직관적이지는 않다. 이로 인해 NumPy에서 3차원 배열의 연산을 이해하기 어렵다. 특히 위에 사진처럼 y축은 양수 값을 가지는 것이 아니라, 음수 값을 가진 형태로 반환된다. 따라서 3차원 배열을 시각화한 연산 과정을 서술한다.


우선 NumPy로 3차원 배열을 생성해보자.

import numpy as np
arr_3d = np.array([[[1, 3, 5, 7],
                   [2, 4, 6, 8]],
                 
                  [[10, 30, 50, 70],
                   [20, 40, 60, 80]],
                 
                  [[100, 300, 500, 700],
                   [200, 400, 600, 800]]])
>>> print(arr_3d)
[[[  1   3   5   7]
  [  2   4   6   8]]

 [[ 10  30  50  70]
  [ 20  40  60  80]]

 [[100 300 500 700]
  [200 400 600 800]]]

>>> arr_3d.shape
(3, 2, 4)

(z=3,y=2,x=4)(z = 3, y = 2, x = 4)

그리고 axis는 왼쪽에서 오른쪽으로 진행되는 숫자가 순서대로 가장 멀리서부터 안쪽으로의 tuple의 개수이다. 즉 가장 바깥쪽 tuple에 3개로 구분된 데이터가 들어있고, 중간에 2개로 구분된 데이터가 들어있으며, 가장 안쪽에 4개로 구분된 데이터가 들어있는 것이다. 즉 각 대괄호 안에 들어있는 데이터의 수로 이해할 수 있다.

위에서 언급했듯 y축은 음수 영역에 표현되는 것에 주의하자.


이를 다시 (z, y, z)로 표현하면 다음과 같다.

zz = 0
for z in arr_3d:
    yy = 0
    for y in z:
        xx = 0
        for x in y:
            print("(%d, %d, %d) = %d" %(zz, yy, xx, x))
            xx += 1  
        yy += 1
    zz += 1
(0, 0, 0) = 1
(0, 0, 1) = 3
(0, 0, 2) = 5
(0, 0, 3) = 7
(0, 1, 0) = 2
(0, 1, 1) = 4
(0, 1, 2) = 6
(0, 1, 3) = 8
(1, 0, 0) = 10
(1, 0, 1) = 30
(1, 0, 2) = 50
(1, 0, 3) = 70
(1, 1, 0) = 20
(1, 1, 1) = 40
(1, 1, 2) = 60
(1, 1, 3) = 80
(2, 0, 0) = 100
(2, 0, 1) = 300
(2, 0, 2) = 500
(2, 0, 3) = 700
(2, 1, 0) = 200
(2, 1, 1) = 400
(2, 1, 2) = 600
(2, 1, 3) = 800


axis=0

axis=0일 때 sum을 하면 어떻게 될까?

>>> np.sum(arr_3d, axis=0)
array([[111, 333, 555, 777],
       [222, 444, 666, 888]])

arr[z][y][x]arr[axis=0][axis=1][axis=2]이고 axis=0은 3차원 배열에서 z축을 의미한다. 따라서 np.sum(arr_3d, axis=0)은 z축을 기준으로 합친다는 의미이다.


이를 시각화하면, z축을 제거하며 합쳐진다.