使用情境
在 Leetcode 寫到一題:
現在有 n 台電腦以及一些 cables 將電腦點對點連接,問需要移動至少幾條 cable 才能讓在所有電腦連成單一網路。
以 graph 的角度來看,電腦就是 nodes,cables 就是 edges。
要將整張 graph 連接起來,至少需要 n-1 個 edges。若一個 graph 裡面有超過 n-1 個 edges,剩下的就是多出來的 edges,可以供我們拿來移動的 edges。
所以第一件事就是要檢查 edges 數量 >= n-1。
當檢查完畢之後,我們有至少 n-1 條 edges,一定可以用這些 edges 將所有 nodes 連接起來。
因為題目只問需要移動幾條 edges,我們可以假設我們移動的都是多出來的 edges,不必去動原本的 n-1 個 edges。
假設原本的 graph 被切分成分離的 m 塊 connected components,則我們需要移動 m-1 個 edges 去連接,因此問題變成了找出目前有幾塊 connected components。
這個問題的一個標準做法就是使用 disjoint set:
def makeConnected(self, n: int, connections: List[List[int]]) -> int:
# 檢查 edges 數量 >= n-1
if len(connections) < n-1:
return -1
# 找出目前有幾塊 connected components
ds = DisjointSet(range(n))
for u, v in connections:
ds.union(u, v)
return len(ds) - 1
實作 Disjoint Set
Disjoint set 的特性,是將一個大集合裡面分為 n 個子集合,這些子集合本身是 disjoint,無交集的。並提供兩個方法:
find/1
:查找某個元素在哪個集合,實務上會選擇其中一個「家長」當作代表union/2
:將兩個元素所在的集合合併
在上個部分,我假設 DisjointSet
已經寫好了,而我們要實作的則是 initialization 和 union
方法,而 union
會需要查找輸入的元素所在的集合,因此 find
當然也必須實作。
教科書做法是使用 set forest 實作,也就是每個子集合都是一個 tree,每個 node 只需要一個 pointer 指向其 parent,root 為「家長」。這樣上面的方法所做的事情就是:
find/1
:一路網上查找「家長」union/2
:將兩棵樹合併,找到家長之後,將其中一個的 parent 指向另一個
如果是這樣的話我們可以這樣實作:
class DisjointSet:
def __init__(self, elements):
self.parents = [n for n in elements]
self.count = len(self.parents)
def find(self, element):
n = self.parents[element]
while self.parents[n] != n:
n = self.parents[n]
return n
def union(self, u, v):
u = self.find(u)
v = self.find(v)
if u != v:
self.parents[u] = v
self.count -= 1
def __len__(self):
return self.count
進一步優化
但其實這個物件還有可以優化的地方:find
會重複執行,如果 tree 很深,find
的時間就會越來越長。由於 node 在 tree 內部的位置並不是重點,我們希望能夠讓向上查找家長這件事變快,也就是保持 tree 越淺越好。為此可以做兩件事:
- 加入 rank 概念,代表該子集合的最大可能深度,保持每個 set 深度的平衡,避免某個 tree 的深度太高。把 rank 記錄在家長的 node 即可
- Path compression,有點像是把每次
find
的結果存起來,做法是當find
做完時,把沿路找到的 node 都掛載家長下方,這樣下次在這個子集合裡跑 find 時就會加速許多。
所以 DisjointSet
改寫為:
class DisjointSet: # with rank and path compression
def __init__(self, elements):
self.sets = [Node() for n in elements]
self.count = len(self.sets)
def find(self, element):
return self._find(self.sets[element])
def _find(self, n):
if n.parent != n:
# path compression
n.parent = self._find(n.parent)
return n.parent
def union(self, u, v):
u = self.find(u)
v = self.find(v)
if u != v:
# 把 rank 小的掛到 rank 大的下方
if u.rank < v.rank:
u.parent = v
else:
v.parent = u
if v.rank == u.rank:
u.rank += 1
self.count -= 1
def __len__(self):
return self.count
這邊用到了 Node 物件,其實就是 C 的 struct 的概念。其實也可以只使用 tuple,但我選擇寫成物件比較清楚:
class Node:
def __init__(self):
self.parent = self
self.rank = 0
以 leetcode 實測,前後執行時間分別為 1748 ms 及 560 ms,相差了三倍。