二分查找模板

二分查找写错,99% 是因为没把循环不变量 (loop invariant) 想清楚。锁死一个半开区间 [lo, hi) 的模板,>=><=< 四种边界全用同一份代码。

直觉:不是”找元素”,是”切一刀”

教科书把二分讲成”在有序数组里找 target”,所以大家一上来就纠结:< 还是 <=mid 要不要 -1right 初值是 n 还是 n-1?——全都是表象。

真正的二分查找在做的事是:给定一个单调的 boolean 谓词 pred(i)(前面一段全 False、后面一段全 True),找到第一个 True 的位置。 仅此而已。”找 target” 只是”第一个 ≥ target 的位置 + 验证一下”的特例。

把视角切到谓词,所有边界纠结消失。代价是你必须维护好循环不变量

答案永远落在 [lo, hi) 之间。lo 左边全 False,hi 右边(含 hi 自己代表的位置)全 True。

为什么 off-by-one 那么常出错?因为闭区间 [lo, hi] 模板里,hi = mid - 1hi = mid 取决于 mid 是否已被验证,绕一圈很容易写漏。半开区间 [lo, hi) 天然让 hi 表示”还没验证的右端”,每次 hi = mid 永远合法——mid 还没被验证、不能算”已排除”,所以它就是新的”未验证右端”。

下面动画演示标准模板查找”第一个 ≥ 7”,以及旋转数组里同一套循环不变量怎么用。点”下一步”看每次循环 lohi 如何收缩,注意:hi 始终指向”已知是 True 的位置(或越界哨兵 n)”,循环结束时 lo == hi 就是答案。

一个模板搞定四种边界

1
2
3
4
5
6
7
8
9
10
11
def lower_bound(n, pred):
"""返回 [0, n] 中第一个让 pred(i) 为 True 的下标。
若全 False,返回 n。要求 pred 单调:False...False True...True"""
lo, hi = 0, n
while lo < hi:
mid = (lo + hi) // 2
if pred(mid):
hi = mid # mid 是 True,答案 ≤ mid
else:
lo = mid + 1 # mid 是 False,答案 > mid
return lo

记住这一份代码就够。四种边界全部转化为”构造合适的 pred“:

你想找 pred(i) 返回
第一个 arr[i] >= x arr[i] >= x lo
第一个 arr[i] > x arr[i] > x lo
最后一个 arr[i] <= x arr[i] > x lo - 1
最后一个 arr[i] < x arr[i] >= x lo - 1

口诀:找最左满足,直接套;找最右满足,反转谓词然后减一。 第二行的”反转”利用了”满足条件的元素在排好序数组里必然连续”这条性质——第一个不满足的位置,前一个就是最后一个满足的。

为什么这模板不死循环、不越界?三件事撑住:

  1. mid = (lo + hi) // 2 向下取整,所以 mid < hihi = mid 一定让区间变小。
  2. lo = mid + 1 显然让区间变小。
  3. hi 初值取 n 而不是 n - 1——允许”答案不存在”的情况返回 n,省掉一堆边界判断。返回后用 lo < n and arr[lo] == x 之类的语句验证一下即可。

变种:旋转数组与二分答案

旋转数组找最小值

数组 [4,5,6,7,0,1,2] 不是全局有序,但仍然可以二分——关键是找一个单调的谓词

1
2
3
4
5
6
7
8
9
def find_min_rotated(nums):
lo, hi = 0, len(nums) - 1
while lo < hi:
mid = (lo + hi) // 2
if nums[mid] <= nums[hi]:
hi = mid # mid 已在右半段(含最小值)
else:
lo = mid + 1 # 断崖在右边
return nums[lo]

谓词是 nums[i] <= nums[hi_init]?不完全是——这里跟 nums[hi] 比的是动态右端点,本质上利用”右半段所有元素都 ≤ 数组末尾”这条单调性。这是动画第二个场景演示的内容:注意”断崖”标记,并对照看 lohi 怎么把搜索范围压到 0 那个位置。

二分答案 (binary search on answer)

当问题是”最小化最大值”或”最大化最小值”时,答案本身在某个区间里单调可判定——直接对答案二分:

1
2
3
4
5
6
7
8
9
10
11
12
13
def koko_eating(piles, h):
# LC 875:最小的吃速 k,使得能在 h 小时内吃完
def can_finish(k):
return sum((p + k - 1) // k for p in piles) <= h

lo, hi = 1, max(piles) + 1 # [1, max+1) 半开
while lo < hi:
mid = (lo + hi) // 2
if can_finish(mid):
hi = mid
else:
lo = mid + 1
return lo

完全同一个模板,只是 pred 换成了 can_finish。识别二分答案的信号:题目问”最小的 X,使得 Y 成立”,且 X 越大越容易满足 Y(或反过来)。

经典题

记忆法 / 易错点

  • 第一条铁律:写代码前先口头说出 pred(i) 是什么、单调方向是哪边。 说不清楚就别动手,必错。
  • 区间永远写半开 [lo, hi) 闭区间模板有死循环陷阱(lo = mid 不让区间变小),半开没有。
  • hi 初值取 n,不是 n - 1 允许返回 n 表示”答案不存在”,比 -1 哨兵省事得多。
  • mid = (lo + hi) // 2 向下取整,配 hi = mid;不要换成 lo = mid 配向下取整——会死循环。 实在要保留左端,必须改成 mid = (lo + hi + 1) // 2。建议根本不要写这种变体。
  • 返回 lo 后必须验证。 lower_bound 找的是”第一个 True 的位置”,不保证那里就是 target——可能越界、可能不等于 target。多写一行 lo < n and arr[lo] == x 比心算边界靠谱。
  • 溢出:Python 不用管;C++/Java 一定写 mid = lo + (hi - lo) // 2
  • 旋转数组要跟 nums[hi] 比,不能跟 nums[lo] 比。lo 比在 [1,2,3] 这种没旋转的情况会判错——因为 nums[mid] >= nums[lo] 总成立但断崖不在右边。