1451 - 직사각형으로 나누기

📚 1451 - 직사각형으로 나누기

직사각형으로 나누기

이해

✔️ (1, 1)부터 (n, m)까지 각 자리의 총합
각 자리 총합

(1,1) ~ (2, 2)까지 총합이라고 했을 때
(2, 2) = (1, 2) + (2, 1) - (1, 1)

→ 왜 (1, 1)을 빼는가? (1, 2)와 (2, 1)에 (1, 1)이 2번 더해진다.

 

(2, 2) ~ (3, 3)까지 총합이라고 했을 때
(3, 3) = (2, 3) + (3, 2) - (2, 2)

→ 왜 (2, 2)를 빼는가? (2, 3)와 (3, 2)에 (2, 2)가 2번 더해진다.

ex)
기존 행렬

136
578
1078

 

(1, 1) 행열 부터 (n, m) 행렬까지 총합이다.
|1|4|10|
|-|-|-|
|6|16|30|
|16|33|55|

 

✔️ 각 자리수 총합을 이용해 (a, b)부터 (n, m)까지 합
현재 (1, 1) 행열 부터 (3, 3) 행열까지 모두 합으로 이루어져 있다.

14(1+3)10(1+3+6)
6(1+5)16(7+5+3+1)30
16(1+5+10)3355

만약 (2, 2) 부터 (3, 3)까지 총합은 어떻게 될까?
(1, 1)부터 시작 보았을 때
(2, 2) ~ (3, 3)이외 나머지는 빼줘야 한다.

xxx
x
x

근데 생각해보면

  • (1, 1) ~ (3, 1) 의 총합은 (3, 1)에 저장되어 있다.
  • (1, 1) ~ (1, 3) 의 총합은 (1, 3)에 저장되어 있다.

그래서 55(3, 3) - 16(3, 1) - 10(1, 3) + 1(1, 1) 이다.

→ 왜 1(1, 1)을 더했을까?
→ 보면, (3, 1) 총합에서 (1, 1)이 추가되고, (1, 3) 총합에서 (1, 1)이 추가된다. 그러므로 (1, 1)이 2번 빠지게 되므로 +1(1, 1)이 필요하다.

 

테스트 소스

arr = [[0, 1, 3, 6], [0, 5, 7, 8],[0, 10,7,8]]

print(sum(arr[0]) + sum(arr[1]))

result = [[0] * 5 for _ in range(5)]

arr = [[0]] + arr

print("arr : ", arr)

for i in range(1, 4):
    for j in range(1, 4):
        result[i][j] = result[i - 1][j] + result[i][j - 1] - result[i - 1][j - 1] + arr[i][j]


def sum_cal(x1, y1, x2, y2):
    print(result[x2][y2], result[x2][y1 - 1], result[x1 - 1][y2], result[x1 - 1][y1 - 1])
    return result[x2][y2] - result[x2][y1 - 1] - result[x1 - 1][y2] + result[x1 - 1][y1 - 1]


print("result : ", result)
print(sum_cal(2, 2, 3, 3))

 

결과

30
arr :  [[0], [0, 1, 3, 6], [0, 5, 7, 8], [0, 10, 7, 8]]
result :  [[0, 0, 0, 0, 0], [0, 1, 4, 10, 0], [0, 6, 16, 30, 0], [0, 16, 33, 55, 0], [0, 0, 0, 0, 0]]
55 16 10 1
30

 

✔️ 직사각형 세 방면으로 나눌 때

(1)
|o|x|t|
|-|-|-|-|
|o|x|t|

  • (1, 1) ~ (n, i)
  • (1, i+1) ~ (n, j)
  • (1, j+1) ~ (n, m)

(2)
|o|o|
|-|-|-|-|
|x|x|
|t|t|

  • (1, 1) ~ (i, m)
  • (i+1, 1) ~ (j, m)
  • (j+1, 1) ~ (n, m)

(3)
|o|x|x|
|-|-|-|-|
|o|t|t|

  • (1, 1) ~ (n, i)
  • (1, i+1) ~ (j, m)
  • (j+1, i+1) ~ (n, m)

(4)
|o|o|x|
|-|-|-|-|
|t|t|x|

  • (1, 1) ~ (i, j)
  • (i+1, 1) ~ (n, j)
  • (1, j+1) ~ (n, m)

(5)
|o|o|
|-|-|-|-|
|x|t|
|x|t|

  • (1, 1) ~ (i, m)
  • (i+1, 1) ~ (n, j)
  • (i+1, j+1) ~ (n, m)

(6)
|o|x|
|-|-|-|-|
|o|x|
|t|t|

  • (1, 1) ~ (i, j)
  • (1, j+1) ~ (i, m)
  • (i+1, 1) ~ (n, m)

 

소스

import sys

read = sys.stdin.readline

n, m = map(int, read().split())

arr = [[0] * (m + 1)]

for i in range(n):
    arr.append([0] + list(map(int, read().strip())))

sum_rec = [[0] * (m + 1) for _ in range(n + 1)]

# (1, 1)부터 (n, m)까지 각 자리의 총합
for i in range(1, n + 1):
    for j in range(1, m + 1):
        sum_rec[i][j] = arr[i][j] + sum_rec[i - 1][j] + sum_rec[i][j - 1] - sum_rec[i - 1][j - 1]


# 각 자리수 총합을 이용해 (a, b)부터 (n, m)까지 합
def sum_of_rec(x1, y1, x2, y2):
    # print(sum_rec[x2][y2], sum_rec[x1 - 1][y2], sum_rec[x2][y1 - 1], y2)
    return sum_rec[x2][y2] - sum_rec[x1 - 1][y2] - sum_rec[x2][y1 - 1] + sum_rec[x1 - 1][y1 - 1]


result = 0
ans = 0

# (1) 전체 직사각형을 세로로만 분할한 경우
for i in range(1, m - 1):
    for j in range(i + 1, m):
        num1 = sum_of_rec(1, 1, n, i)
        num2 = sum_of_rec(1, i + 1, n, j)
        num3 = sum_of_rec(1, j + 1, n, m)

        if result < num1 * num2 * num3:
            result = num1 * num2 * num3

# (2) 전체 직사각형을 가로로만 분할한 경우
for i in range(1, n - 1):
    for j in range(i + 1, n):
        num1 = sum_of_rec(1, 1, i, m)
        num2 = sum_of_rec(i + 1, 1, j, m)
        num3 = sum_of_rec(j + 1, 1, n, m)

        if result < num1 * num2 * num3:
            result = num1 * num2 * num3

# (3) 전체 세로 분할 후, 우측 가로 분할한 경우
for i in range(1, m):
    for j in range(1, n):
        num1 = sum_of_rec(1, 1, n, i)
        num2 = sum_of_rec(1, i + 1, j, m)
        num3 = sum_of_rec(j + 1, i + 1, n, m)

        if result < num1 * num2 * num3:
            result = num1 * num2 * num3

# (4) 전체 세로 분할 후, 좌측 가로 분할한 경우
for i in range(1, n):
    for j in range(1, m):
        num1 = sum_of_rec(1, 1, i, j)
        num2 = sum_of_rec(i + 1, 1, n, j)
        num3 = sum_of_rec(1, j + 1, n, m)

        if result < num1 * num2 * num3:
            result = num1 * num2 * num3

# (5) 전체 가로 분할 후, 하단 세로 분할한 경우
for i in range(1, n):
    for j in range(1, m):
        num1 = sum_of_rec(1, 1, i, m)
        num2 = sum_of_rec(i + 1, 1, n, j)
        num3 = sum_of_rec(i + 1, j + 1, n, m)

        if result < num1 * num2 * num3:
            result = num1 * num2 * num3

# (6) 전체 가로 분할 후, 상단 세로 분할한 경우
for i in range(1, n):
    for j in range(1, m):
        num1 = sum_of_rec(1, 1, i, j)
        num2 = sum_of_rec(1, j + 1, i, m)
        num3 = sum_of_rec(i + 1, 1, n, m)

        if result < num1 * num2 * num3:
            result = num1 * num2 * num3

print(result)

 

채점 결과

 

문제를 풀다 변수 하나 잘못 입력해서 계속 틀렸었는데(1위치에 i 입력), 찾는데 오래 걸림, 변수 많이 사용시 노트나 따로 정리하고 소스 작성하자!!!

좋은 웹페이지 즐겨찾기