Python 實作 Disjoint Set 與 Union Find

使用情境

在 Leetcode 寫到一題:

現在有 n 台電腦以及一些 cables 將電腦點對點連接,問需要移動至少幾條 cable 才能讓在所有電腦連成單一網路。
以 graph 的角度來看,電腦就是 nodes,cables 就是 edges。

要將整張 graph 連接起來,至少需要 n-1 個 edges。若一個 graph 裡面有超過 n-1 個 edges,剩下的就是多出來的 edges,可以供我們拿來移動的 edges。
所以第一件事就是要檢查 edges 數量 >= n-1

當檢查完畢之後,我們有至少 n-1 條 edges,一定可以用這些 edges 將所有 nodes 連接起來。
因為題目只問需要移動幾條 edges,我們可以假設我們移動的都是多出來的 edges,不必去動原本的 n-1 個 edges。

假設原本的 graph 被切分成分離的 m 塊 connected components,則我們需要移動 m-1 個 edges 去連接,因此問題變成了找出目前有幾塊 connected components

這個問題的一個標準做法就是使用 disjoint set:

def makeConnected(self, n: int, connections: List[List[int]]) -> int:
    # 檢查 edges 數量 >= n-1
    if len(connections) < n-1:
        return -1
    
    # 找出目前有幾塊 connected components
    ds = DisjointSet(range(n))
    for u, v in connections:
        ds.union(u, v)
    
    return ds.count_sets() - 1

實作 Disjoint Set

Disjoint set 的特性,是將一個大集合裡面分為 n 個子集合,這些子集合本身是 disjoint,無交集的。並提供兩個方法:

  • find/1:查找某個元素在哪個集合,實務上會選擇其中一個「家長」當作代表
  • union/2:將兩個元素所在的集合合併

在上個部分,我假設 DisjointSet 已經寫好了,而我們要實作的則是 initialization 和 union 方法,而 union 會需要查找輸入的元素所在的集合,因此 find 當然也必須實作。

教科書做法是使用 set forest 實作,也就是每個子集合都是一個 tree,每個 node 只需要一個 pointer 指向其 parent,root 為「家長」。這樣上面的方法所做的事情就是:

  • find/1:一路網上查找「家長」
  • union/2:將兩棵樹合併,找到家長之後,將其中一個的 parent 指向另一個

如果是這樣的話我們可以這樣實作:

class DisjointSet:
    def __init__(self, elements):
        self.parents = [n for n in elements]
        self.count = len(self.parents)
        
    def find(self, element):
        n = self.parents[element]
        while self.parents[n] != n:
            n = self.parents[n]
        return n
        
    def union(self, u, v):
        u = self.find(u)
        v = self.find(v)
        if u != v:
            self.parents[u] = v
            self.count -= 1

    def count_sets(self):
        return self.count

進一步優化

但其實這個物件還有可以優化的地方:find 會重複執行,如果 tree 很深,find 的時間就會越來越長。由於 node 在 tree 內部的位置並不是重點,我們希望能夠讓向上查找家長這件事變快,也就是保持 tree 越淺越好。為此可以做兩件事:

  • 加入 rank 概念,代表該子集合的最大可能深度,保持每個 set 深度的平衡,避免某個 tree 的深度太高。把 rank 記錄在家長的 node 即可
  • Path compression,有點像是把每次 find 的結果存起來,做法是當 find 做完時,把沿路找到的 node 都掛載家長下方,這樣下次在這個子集合裡跑 find 時就會加速許多。

所以 DisjointSet 改寫為:

class DisjointSet:  # with rank and path compression
    def __init__(self, elements):
        self.sets = [Node(n) for n in elements]
        self.count = len(self.sets)
        
    def find(self, element):
        n = self.sets[element]
        path_node = []
        while n.parent != n:
            n = n.parent
            path_node.append(n) # 記錄路上的 nodes

        # path compression
        for v in path_node:
            v.parent = n
        return n
        
    def union(self, u, v):
        u = self.find(u)
        v = self.find(v)
        if u != v:
            # 把 rank 小的掛到 rank 大的下方
            if u.rank < v.rank:
                u.parent = v
            else:
                v.parent = u
                if v.rank == u.rank:
                    u.rank += 1
            self.count -= 1

    def count_sets(self):
        return self.count

這邊用到了 Node 物件,其實就是 C 的 struct 的概念。其實也可以只使用 tuple,但我選擇寫成物件比較清楚:

class Node:
    def __init__(self, n):
        self.parent = self
        self.rank = 0

以 leetcode 實測,前後執行時間分別為 1748 ms 及 560 ms,相差了三倍。

 
comments powered by Disqus