Codeforces Round #510 (Div. 2) D. Petya and Array 좌압과 세그먼트 트리

생각보다 구현에 시간이 걸렸다.

주제


  • $n\leq 10^5$개의 각 요소 $- 10^9\leq a_i\leq 10^9$로 구성된 배열 $a$가 주어집니다.
  • 이 연속하는 부분 배열의 합이, $- 10^{12}\leq t\leq 10^{12} $ 미만의 것은 모두로 몇인가? (배열이 중복되고 정렬되지 않음)

  • 이렇게 생각했다



    문제는 바꿔 말하면, 구간 $[l, r)$의 부분합이 $t$ 미만의 부분합이 되는, $l,r$의 수를 요구하고 싶다. 라고 말할 수 있습니다.

    누적 합을 취하여 세그먼트 트리에 넣습니다. 즉, $a_0$에서 있는 $a_i$까지의 합을 index로, 출현 횟수를 값으로 하는 세그먼트 트리 $st$를 가집니다. 다만, 값은 부를 취할 수 있는 것과, 값의 범위가 매우 크기 때문에, 좌압을 실시합니다. 예를 들어,
    예제를 예로 들겠습니다. $[5,-1,3,4,-1]$이면 누적 합은 $[5,4,7,11,10]$이므로 좌압을 위한 테이블 $[4,5 ,7,10,11]$를 가지고 $1,0,2,4,3$로 변환합니다.
    이때 st는 $[1,1,1,1,1]$라는 테이블이 됩니다.

    여기서 $l=0$로 고정했다고 가정합니다. 이 때, $ t = 4 $이면, 누적 합의 테이블 중에서 t 미만이되는 수를 열거하면 되므로, 상기에서 계산했다 (누적 합의 카운트를 좌압한 수치를 세그먼트 트리로 하여 있다)st에 쿼리하면 됩니다. 쿼리 대상은 $t=4$ 미만이지만 좌압 후 값이어야 합니다. $t=0$는 좌압하면 index가 $0$입니다. $st.query[0,0) = 0$가 대답이 됩니다.

    다음으로 $l=1$로 고정했다고 하자. 우선, 이 처리에 들어가기 전에, $i=0$에 대응하는 누적합의 요소를 지우고, 세그먼트 트리의 해당 요소를 $-1$ 합니다.
    이때 다시 누적합을 계산하고 싶습니다만, 이것을 반복하면 $O(N^2)$의 시간이 걸립니다. 이 때문에, 반대로, $t$를 누적합에 맞추어 변경합니다. 그런데, 최초로 계산한 누적합 $[5,4,7,11,10]$입니다만, $l=1$로부터 누적합을 실시하면 $[-1, 2, 6, 5]$입니다. 이제 이것을 보면 첫 번째 누적 합의 두 번째 요소부터 첫 번째 요소의 $ 5 $를 뺀 것을 볼 수 있습니다. 누적 합의 성질을 생각하면 첫 번째 요소가 빠졌기 때문에 분명합니다.
    라는 것은, $l=1$일 때 최초로 구한 각 요소로부터 a_0의 요소의 값(이 경우는 5)을 뺀, t=4 미만의 요소이면 좋기 때문에, 바꾸어 말하면, 최초로 구한 각 요소로부터 a_0의 요소의 값이 t=9 미만이면 됩니다.
    여기서 좌압 테이블 $[4,5,7,10,11]$에서 $9$ 미만의 index를 생각하면 $3$입니다. 이때 lower bound(Python이라면 bisect left)로 생각합니다. 이대로 [l, index)를 쿼리하면 마치 $l=0$일 때의 쿼리와 같이 $l=1$일 때의 조합이 요구됩니다.

    그림과 같이 다음과 같이 됩니다.


    구현


    def do():
        st = segmentTreeSum()
        n, t = map(int, input().split())
        dat = list(map(int, input().split()))
        dattotal = []
        total = 0
        segtreeList = [0] * (200000 + 10)
        zatsu = set()
        for x in dat:
            total += x
            dattotal.append(total)
            zatsu.add(total)
        zatsu = list(zatsu)
        zatsu.sort()
        zatsuTable = dict()
        zatsuTableRev = dict()
        for ind, val in enumerate(zatsu):
            zatsuTable[val] = ind
            zatsuTableRev[ind] = val
        buf = []
        for x in dattotal:
            buf.append(zatsuTable[x])
            segtreeList[zatsuTable[x]] += 1
        st.load(segtreeList)
    
        from bisect import bisect_left, bisect_right
    
        offset = 0
        res = 0
        for i in range(n): # x = total from 0 to curren
            x = dattotal[i]
            curvalind = zatsuTable[x]
            targetval = t + offset# target val
            targetind = bisect_left(zatsu, targetval)
            cnt = st.query(0, targetind )
            res += cnt
            st.addValue(curvalind, -1)
            offset += dat[i] # for next offset
        print(res)
    do()
    

    좋은 웹페이지 즐겨찾기