D001单修区查,单查区修,区查区修模板,如何为?
摘要:树状数组(Binary Indexed Tree, BIT,也称 Fenwick Tree)是算法竞赛中极其常用的数据结构。它的核心优势在于:代码短小精悍、常数极小、内存占用少。 虽然它的功能是线段树的子集,但在处理“动态前缀和”相关问题时
树状数组(Binary Indexed Tree, BIT,也称 Fenwick Tree)是算法竞赛中极其常用的数据结构。它的核心优势在于:代码短小精悍、常数极小、内存占用少。
虽然它的功能是线段树的子集,但在处理“动态前缀和”相关问题时,BIT 通常是首选。
来自董晓算法,对两个关键函数原理的图解
BIT 结构模板为
class BIT:
def __init__(self, n):
self.tree = [0] * (n + 1) # 初始化为一颗空树
def add(self, i, val):
while i < len(self.tree):
self.tree[i] += val
i += i & -i
def sum(self, i):
s = 0
while i > 0:
s += self.tree[i]
i -= i & -i
return s
def query(self, l, r): # 查询区间为 [l, r]
return self.sum(r) - self.sum(l - 1)
已知大小为 \(n\) 的数组,有两种初始化方式一种为 \(O(n\log n)\) 一种为 \(O(n)\) 。
前者只需要调用 \(n\) 次 add 函数即可,具体为
bit = BIT(n)
for i, x in enumerate(a):
bit.add(i + 1, x) # 注意这里应该要 1-based 初始化
后者的话在初始化时传入的是数组而不是数组大小,具体为
def __init__(self, arr): # 传入数组
n = len(arr)
self.tree = [0] + arr[:] # 列表加法可能会炸内存
for i in range(1, n + 1):
j = i + (i & -i)
if j <= n:
self.tree[j] += self.tree[i]
P3374 【模板】树状数组 1 点修区查
import sys
if 1:
inp = lambda: sys.stdin.readline().strip()
II = lambda: int(inp())
MII = lambda: map(int, inp().split())
LII = lambda: list(MII())
Max = lambda x, y: x if x > y else y
Min = lambda x, y: x if x < y else y
class BIT:
def __init__(self, n):
self.tree = [0] * (n + 1)
def add(self, i, val):
while i < len(self.tree):
self.tree[i] += val
i += i & -i
def sum(self, i):
s = 0
while i > 0:
s += self.tree[i]
i -= i & -i
return s
def query(self, l, r):
return self.sum(r) - self.sum(l - 1)
def main():
n, q = MII()
a = LII()
bit = BIT(n)
for i, x in enumerate(a):
bit.add(i + 1, x) # 1-based 初始化
outs = []
for _ in range(q):
o = LII()
if o[0] == 1:
bit.add(o[1], o[2])
else:
outs.append(bit.query(o[1], o[2]))
print(*outs, sep='\n')
if __name__ == "__main__":
main()
这是树状数组最基本的演变形式,其它功能主要通过差分思想实现切换。比如可以使用一颗空树状数组当作差分数组,区修就变成了对这可空树进行两次点修,点查就变成了原始值加修改值了,即 a[i] + sum(i) 。
P3368 【模板】树状数组 2 区修点查
# 与上述代码相同的模板不再展示,直接复制即可
def main():
n, q = MII()
a = LII()
bit = BIT(n)
outs = []
for _ in range(q):
o = LII()
if o[0] == 1:
l, r, v = o[1:]
bit.add(l, v)
bit.add(r + 1, -v) # 1-based 下的差分操作
else:
idx = o[1]
outs.append(a[idx - 1] + bit.sum(idx))
print(*outs, sep='\n')
if __name__ == "__main__":
main()
