์๋ฃ๊ตฌ์กฐ, ์ธ๊ทธ๋จผํธ ํธ๋ฆฌ(Segment Tree) [Python]
๐ฆ ์ธ๊ทธ๋จผํธ ํธ๋ฆฌ(Segment Tree)๋?
์ธ๊ทธ๋จผํธ ํธ๋ฆฌ๋ ๋ฐฐ์ด์์ ํน์ ๊ตฌ๊ฐ์ ๋ํ ์ ๋ณด๋ฅผ ๋น ๋ฅด๊ฒ ์ฟผ๋ฆฌํ๊ฑฐ๋ ์์ ํ ์ ์๊ฒ ํด์ฃผ๋ ์ด์ง ํธ๋ฆฌ ๊ธฐ๋ฐ์ ์๋ฃ๊ตฌ์กฐ์
๋๋ค. ์ด ์๋ฃ๊ตฌ์กฐ๋ ์ฃผ๋ก ๋ฒ์ ์ฟผ๋ฆฌ ๋ฌธ์ ์ ํ์ฉ๋๋ฉฐ, ๊ทธ ์ค์์๋ ๊ตฌ๊ฐ ํฉ์ ๋น ๋ฅด๊ฒ ๊ณ์ฐํ๋ ๋ฐ์ ์์ฃผ ์ฌ์ฉ๋ฉ๋๋ค.
์๋ฆฌ
์ธ๊ทธ๋จผํธ ํธ๋ฆฌ์ ๊ฐ ๋
ธ๋๋ ๋ฐฐ์ด์ ํน์ ๊ตฌ๊ฐ์ ๋ํํฉ๋๋ค.
๋ฆฌํ ๋
ธ๋๋ ๋ฐฐ์ด์ ๊ฐ๋ณ ์์๋ฅผ ๋ํ๋
๋๋ค.
๋ถ๋ชจ ๋
ธ๋๋ ์์ ๋
ธ๋๋ค์ ํฉ์ ์ ์ฅํฉ๋๋ค.
์ด๋ฐ ๋ฐฉ์์ผ๋ก ํธ๋ฆฌ์ ๋ฃจํธ ๋
ธ๋๋ ๋ฐฐ์ด ์ ์ฒด์ ํฉ์ ์ ์ฅํ๊ฒ ๋ฉ๋๋ค.
์์
๋ฐฐ์ด A [1, 3, 5, 7, 9, 11]์ ๊ณ ๋ คํด๋ด
์๋ค.
์ด๋ Tree ๋ [0, 36, 9, 27, 4, 5, 16, 11, 1, 3, 7, 9] ๊ฐ ๋ฉ๋๋ค.
๊ฐ ๋
ธ๋์๋ ๋ฐฐ์ด์ ๊ตฌ๊ฐ ํฉ์ด ์ ์ฅ๋์ด ์์ต๋๋ค.
๋ฆฌํ ๋
ธ๋์๋ ์ฃผ์ด์ง ๋ฐฐ์ด(A) ๊ฐ๋ค์ด ์ ์ฅ๋๊ณ ๋ด๋ถ ๋
ธ๋์๋ ์์ ๋
ธ๋์ ํฉ์ด ์ ์ฅ๋ฉ๋๋ค.
๋ฃจํธ ๋ ธ๋(36)๋ ์ ์ฒด ๋ฐฐ์ด์ ํฉ์ ๋ํ๋ ๋๋ค.
์ข์ธก ์์ ๋ ธ๋(9)๋ ๋ฐฐ์ด์ ์ฒซ ๋ฒ์งธ๋ถํฐ ์ธ ๋ฒ์งธ ์์๊น์ง์ ํฉ(A[0:2])์ ๋ํ๋ ๋๋ค.
์ฐ์ธก ์์ ๋ ธ๋(27)๋ ๋ฐฐ์ด์ ๋ค ๋ฒ์งธ๋ถํฐ ์ฌ์ฏ ๋ฒ์งธ ์์๊น์ง์ ํฉ(A[3:5])์ ๋ํ๋ ๋๋ค.
์ด์ ๊ฐ์ด ์ธ๊ทธ๋จผํธ ํธ๋ฆฌ๋ ๊ตฌ๊ฐ์ ํฉ์ ๋น ๋ฅด๊ฒ ๊ณ์ฐํ๊ฑฐ๋, ๋ฐฐ์ด์ ํน์ ์์์ ๊ฐ์ ๋ณ๊ฒฝํ ๋๋ O(logN) ์๊ฐ๋ณต์ก๋๋ฅผ ์ ์งํ๋ฉด์ ๊ตฌ๊ฐ์ ํฉ์ ์ ๋ฐ์ดํธํ ์ ์๊ฒ ํด์ค๋๋ค.
๐๐ฟ ํ์ด์ฌ ์ฝ๋ ๊ตฌํ
1. init ๋ฉ์๋
def __init__(self, arr):
# ๋ฐฐ์ด์ ๊ธธ์ด
self.n = len(arr)
# ์ธ๊ทธ๋จผํธ ํธ๋ฆฌ์ ์ฌ์ด์ฆ๋ ์๋ ๋ฐฐ์ด์ 4๋ฐฐ๋ก ์ค์ (์ถฉ๋ถํ ๊ณต๊ฐ ํ๋ณด)
self.tree = [0] * (4 * self.n)
# ํธ๋ฆฌ ๋น๋
self.build(0, self.n-1, 1, arr)
- self.n = len(arr):
์ฃผ์ด์ง ๋ฐฐ์ด์ ๊ธธ์ด๋ฅผ ์ ์ฅํฉ๋๋ค. ์ด๋ ๋์ค์ ์ธ๊ทธ๋จผํธ ํธ๋ฆฌ์ ๋ฒ์๋ฅผ ์ ํ ๋ ํ์ํฉ๋๋ค. - self.tree = [0] * (4 * self.n):
์ธ๊ทธ๋จผํธ ํธ๋ฆฌ๋ฅผ ์ ์ฅํ ๋ฆฌ์คํธ๋ฅผ ์ด๊ธฐํํฉ๋๋ค. ์ธ๊ทธ๋จผํธ ํธ๋ฆฌ์ ํฌ๊ธฐ๋ ์๋ ๋ฐฐ์ด์ ๊ธธ์ด์ 4๋ฐฐ๊ฐ ๋๋๋ก ์ค์ ํฉ๋๋ค. ์ด๋ ๊ฒ ์ค์ ํ๋ ์ด์ ๋ ์ธ๊ทธ๋จผํธ ํธ๋ฆฌ๊ฐ ์์ ์ด์ง ํธ๋ฆฌ์ ํํ๋ฅผ ๊ฐ์ง๋ฉฐ, ์ด๋ฅผ ๋ฐฐ์ด๋ก ํํํ ๋ ์ถฉ๋ถํ ๊ณต๊ฐ์ ํ๋ณดํ๊ธฐ ์ํจ์ ๋๋ค. - self.build(0, self.n-1, 1, arr):
์ฃผ์ด์ง ๋ฐฐ์ด์ ๊ธฐ๋ฐ์ผ๋ก ์ธ๊ทธ๋จผํธ ํธ๋ฆฌ๋ฅผ ๊ตฌ์ถํ๋ ๋ฉ์๋๋ฅผ ํธ์ถํฉ๋๋ค. ์ด ๋, ์ ์ฒด ๋ฐฐ์ด์ ๋ฒ์์ธ 0๋ถํฐ self.n-1๊น์ง๋ฅผ ๋์์ผ๋ก ํ๋ฉฐ, ์ธ๊ทธ๋จผํธ ํธ๋ฆฌ์ ๋ฃจํธ ๋ ธ๋๋ 1์ด๋ผ๋ ์ธ๋ฑ์ค๋ฅผ ๊ฐ์ง๋๋ค.
2. build ๋ฉ์๋
build ๋ฉ์๋๋ ์ด๊ธฐ ๋ฐฐ์ด์ ๋ฐ์ ์ธ๊ทธ๋จผํธ ํธ๋ฆฌ๋ฅผ ๊ตฌ์ถํฉ๋๋ค.
๋ถํ ์ ๋ณต ์๊ณ ๋ฆฌ์ฆ์ ์ด์ฉํฉ๋๋ค.
def build(self, start, end, node, arr):
# ๋ฆฌํ ๋
ธ๋์ธ ๊ฒฝ์ฐ ๋ฐฐ์ด์ ๊ฐ์ ๊ทธ๋๋ก ๋์
if start == end:
self.tree[node] = arr[start]
return self.tree[node]
# ๊ตฌ๊ฐ์ ์ ๋ฐ์ผ๋ก ๋๋์ด ์ฌ๊ท์ ์ผ๋ก ํธ๋ฆฌ ์์ฑ
mid = (start + end) // 2
self.tree[node] = self.build(start, mid, node*2, arr) + self.build(mid+1, end, node*2+1, arr)
return self.tree[node]
- start, end๋ ํ์ฌ ๊ตฌ๊ฐ์ ๋ํ๋ ๋๋ค. ์ฒ์ ํธ์ถ ์ ์ ์ฒด ๋ฐฐ์ด ๋ฒ์๋ฅผ ์๋ฏธํฉ๋๋ค.
- node๋ ํ์ฌ ์ธ๊ทธ๋จผํธ ํธ๋ฆฌ์์์ ์์น (๋
ธ๋ ๋ฒํธ)๋ฅผ ์๋ฏธํฉ๋๋ค.
๋์๋ฐฉ์
- ๋ง์ฝ start์ end๊ฐ ๋์ผํ๋ฉด (์ฆ, ๋ฆฌํ ๋ ธ๋์ ๋๋ฌํ๋ค๋ฉด) ํด๋น ๋ ธ๋์ ๋ฐฐ์ด์ ํด๋น ์ธ๋ฑ์ค ๊ฐ์ ์ ์ฅํฉ๋๋ค.
- ๊ทธ๋ ์ง ์๋ค๋ฉด, ํ์ฌ ๊ตฌ๊ฐ์ ๋ ๋ถ๋ถ์ผ๋ก ๋๋๊ณ ์ฌ๊ท์ ์ผ๋ก ์ผ์ชฝ๊ณผ ์ค๋ฅธ์ชฝ ์์ ๋ ธ๋๋ฅผ ๋น๋ํฉ๋๋ค.
- ํ์ฌ ๋ ธ๋์ ๊ฐ์, ์ผ์ชฝ ์์ ๋ ธ๋์ ์ค๋ฅธ์ชฝ ์์ ๋ ธ๋์ ํฉ์ผ๋ก ์ค์ ํฉ๋๋ค.
3. query ๋ฉ์๋
query ๋ฉ์๋๋ ์ง์ ๋ ๊ตฌ๊ฐ์ ํฉ์ ๊ณ์ฐํฉ๋๋ค.
def query(self, left, right, node, node_left, node_right):
# ๊ตฌ๊ฐ์ด ๊ฒน์น์ง ์๋ ๊ฒฝ์ฐ
if right < node_left or node_right < left:
return 0
# ํ์ฌ ๋
ธ๋ ๊ตฌ๊ฐ์ด ์ฟผ๋ฆฌ ๊ตฌ๊ฐ์ ์์ ํ ํฌํจ๋ ๊ฒฝ์ฐ
if left <= node_left and node_right <= right:
return self.tree[node]
# ๊ทธ๋ ์ง ์๋ค๋ฉด ๋ ๋ถ๋ถ์ผ๋ก ๋๋์ด ํฉ์ ๊ณ์ฐ
mid = (node_left + node_right) // 2
return self.query(left, right, node*2, node_left, mid) + self.query(left, right, node*2+1, mid+1, node_right)
query ๋ฉ์๋๋ ์ง์ ๋ ๊ตฌ๊ฐ์ ํฉ์ ๊ณ์ฐํฉ๋๋ค.
- left, right๋ ๊ตฌ๊ฐ ํฉ์ ๊ตฌํ๊ณ ์ ํ๋ ๋ฒ์์ ๋๋ค.
- node_left, node_right๋ ํ์ฌ ๋ ธ๋๊ฐ ์ปค๋ฒํ๋ ๋ฒ์์ ๋๋ค.
๋์ ๋ฐฉ์
- ๋ง์ฝ ํ์ฌ ๋ ธ๋์ ๋ฒ์์ ์์ฒญ๋ ๋ฒ์๊ฐ ๊ฒน์น์ง ์์ผ๋ฉด, 0์ ๋ฐํํฉ๋๋ค.
- ์์ฒญ๋ ๋ฒ์๊ฐ ํ์ฌ ๋ ธ๋์ ๋ฒ์์ ์์ ํ ํฌํจ๋๋ค๋ฉด, ํ์ฌ ๋ ธ๋์ ๊ฐ์ ๋ฐํํฉ๋๋ค.
- ๊ทธ๋ ์ง ์๋ค๋ฉด, ์์ฒญ๋ ๋ฒ์๋ฅผ ๋ ๋ถ๋ถ์ผ๋ก ๋๋์ด ์ฌ๊ท์ ์ผ๋ก ํฉ์ ๊ณ์ฐํ๊ณ , ๋ ๊ฒฐ๊ณผ๋ฅผ ํฉ์ณ ๋ฐํํฉ๋๋ค.
4. update ๋ฉ์๋
update ๋ฉ์๋๋ ๋ฐฐ์ด์ ํน์ ์์ ๊ฐ์ ๋ณ๊ฒฝํ๊ณ , ์ด์ ๋ฐ๋ผ ์ธ๊ทธ๋จผํธ ํธ๋ฆฌ๋ฅผ ๊ฐฑ์ ํฉ๋๋ค.
def update(self, index, new_value, node, node_left, node_right):
# ๊ตฌ๊ฐ ๋ฐ์ธ ๊ฒฝ์ฐ
if index < node_left or node_right < index:
return self.tree[node]
# ๋ฆฌํ ๋
ธ๋์ธ ๊ฒฝ์ฐ ๊ฐ์ ๊ฐฑ์
if node_left == node_right:
self.tree[node] = new_value
return self.tree[node]
# ๋ฆฌํ ๋
ธ๋๊ฐ ์๋ ๊ฒฝ์ฐ, ์์ ๋
ธ๋๋ค๋ ๊ฐฑ์ ํด์ผ ํจ
mid = (node_left + node_right) // 2
self.tree[node] = self.update(index, new_value, node*2, node_left, mid) + self.update(index, new_value, node*2+1, mid+1, node_right)
return self.tree[node]
- index๋ ๊ฐฑ์ ํ๋ ค๋ ์์์ ์์น์ ๋๋ค.
- new_value๋ ๊ฐฑ์ ํ๋ ค๋ ๊ฐ์ ๋๋ค.
๋์ ๋ฐฉ์
- ๋ง์ฝ ํ์ฌ ๋ ธ๋์ ๋ฒ์๊ฐ ๊ฐฑ์ ํ๋ ค๋ ์์๋ฅผ ํฌํจํ์ง ์์ผ๋ฉด, ํ์ฌ ๋ ธ๋์ ๊ฐ์ ๋ฐํํฉ๋๋ค.
- ๋ง์ฝ ๋ฆฌํ ๋ ธ๋์ ๋๋ฌํ๋ค๋ฉด, ์์์ ๊ฐ์ ๊ฐฑ์ ํฉ๋๋ค.
- ๊ทธ๋ ์ง ์๋ค๋ฉด, ํด๋น ์์๋ฅผ ํฌํจํ๋ ์์ ๋ ธ๋๋ค์ ์ฌ๊ท์ ์ผ๋ก ๊ฐฑ์ ํ๊ณ , ํ์ฌ ๋ ธ๋์ ๊ฐ์ ์์ ๋ ธ๋๋ค์ ํฉ์ผ๋ก ๊ฐฑ์ ํฉ๋๋ค.
5. ์ ์ฒด ์ฝ๋
class SegmentTree:
def __init__(self, arr):
# ๋ฐฐ์ด์ ๊ธธ์ด
self.n = len(arr)
# ์ธ๊ทธ๋จผํธ ํธ๋ฆฌ์ ์ฌ์ด์ฆ๋ ์๋ ๋ฐฐ์ด์ 4๋ฐฐ๋ก ์ค์ (์ถฉ๋ถํ ๊ณต๊ฐ ํ๋ณด)
self.tree = [0] * (4 * self.n)
# ํธ๋ฆฌ ๋น๋
self.build(0, self.n-1, 1, arr)
# ์ธ๊ทธ๋จผํธ ํธ๋ฆฌ ๋น๋ ํจ์
def build(self, start, end, node, arr):
# ๋ฆฌํ ๋
ธ๋์ธ ๊ฒฝ์ฐ ๋ฐฐ์ด์ ๊ฐ์ ๊ทธ๋๋ก ๋์
if start == end:
self.tree[node] = arr[start]
return self.tree[node]
# ๊ตฌ๊ฐ์ ์ ๋ฐ์ผ๋ก ๋๋์ด ์ฌ๊ท์ ์ผ๋ก ํธ๋ฆฌ ์์ฑ
mid = (start + end) // 2
self.tree[node] = self.build(start, mid, node*2, arr) + self.build(mid+1, end, node*2+1, arr)
return self.tree[node]
# ๊ตฌ๊ฐ ํฉ์ ๊ณ์ฐํ๋ ํจ์
def query(self, left, right, node, node_left, node_right):
# ๊ตฌ๊ฐ์ด ๊ฒน์น์ง ์๋ ๊ฒฝ์ฐ
if right < node_left or node_right < left:
return 0
# ํ์ฌ ๋
ธ๋ ๊ตฌ๊ฐ์ด ์ฟผ๋ฆฌ ๊ตฌ๊ฐ์ ์์ ํ ํฌํจ๋ ๊ฒฝ์ฐ
if left <= node_left and node_right <= right:
return self.tree[node]
# ๊ทธ๋ ์ง ์๋ค๋ฉด ๋ ๋ถ๋ถ์ผ๋ก ๋๋์ด ํฉ์ ๊ณ์ฐ
mid = (node_left + node_right) // 2
return self.query(left, right, node*2, node_left, mid) + self.query(left, right, node*2+1, mid+1, node_right)
# ๋ฐฐ์ด์ ํน์ ์์๋ฅผ ๊ฐฑ์ ํ๋ ํจ์
def update(self, index, new_value, node, node_left, node_right):
# ๊ตฌ๊ฐ ๋ฐ์ธ ๊ฒฝ์ฐ
if index < node_left or node_right < index:
return self.tree[node]
# ๋ฆฌํ ๋
ธ๋์ธ ๊ฒฝ์ฐ ๊ฐ์ ๊ฐฑ์
if node_left == node_right:
self.tree[node] = new_value
return self.tree[node]
# ๋ฆฌํ ๋
ธ๋๊ฐ ์๋ ๊ฒฝ์ฐ, ์์ ๋
ธ๋๋ค๋ ๊ฐฑ์ ํด์ผ ํจ
mid = (node_left + node_right) // 2
self.tree[node] = self.update(index, new_value, node*2, node_left, mid) + self.update(index, new_value, node*2+1, mid+1, node_right)
return self.tree[node]
# arr์ ๊ฐ ์ค์
A = [1, 3, 5, 7, 9, 11]
seg_tree = SegmentTree(A)
# ์์ ์ฟผ๋ฆฌ ๋ฐ ์ถ๋ ฅ
print(seg_tree.query(1, 3, 1, 0, seg_tree.n-1)) # 1~3 ์ธ๋ฑ์ค ๊ตฌ๊ฐ ํฉ
# ๊ฐ ๋ณ๊ฒฝ ์์: arr์ 2๋ฒ์งธ ์์๋ฅผ 10์ผ๋ก ๋ณ๊ฒฝ
seg_tree.update(2, 10, 1, 0, seg_tree.n-1)
print(seg_tree.query(1, 3, 1, 0, seg_tree.n-1)) # 1~3 ์ธ๋ฑ์ค ๊ตฌ๊ฐ ํฉ ์ถ๋ ฅ
A = [1, 3, 5, 7, 9, 11]
15 [3, 5, 7]
A = [1, 3, 10, 7, 9, 11]
20 [3, 10, 7]