Union Find

Definition

Union find is an algorithm to deal with the concept of connected graph to find the root of a connected component. That's particularly useful to solve Number Of Connected Component in a graph problem.

For example if we have the following graph.

Pasted image 20230727211902.png

The following operation hold true:

  • find(5) -> 0: since 0 is the root of 5
  • union(2, 5), find(2) -> 0 since we combine group 2 with group 5 which leads to the following new graph: Pasted image 20230727212127.png

Basic Algorithm

The algorithm requires 2 datastructures:

  • parents: List[int] — list to contain the parents.
    • For example the above graph can be represented as [0, 0, 0, 0, 2, 0]:
      • parent of everything is 0
      • parent of 4 is 2
    • Initially given a non-connected graph, a node will be the parent of itself. So [0,1,2,3,4,5] for this one: Pasted image 20230727212711.png
  • size (ranking): List[int] — the size of the current group Pasted image 20230727215112.png
    • In this case should be [4,0,2,0,0,0]
      • 0 has 4 elements
      • 2 has 2 elements
      • the rest has 0 elements

Union

To union 2 groups, we simply union the parents of the 2 group together.

So for example, given the following case:

Pasted image 20230727212908.png

If we're doing a union(4, 3), we are essentially merging the root parent of 4 (2) with the parent of 3 (0)

Now we have 2 choices, either we merge 2 into 0 or 0 into 2:

Pasted image 20230727213255.png

Both ways of merging should work, however, as a rule of thumbs, we should merge the smaller group into the larger group to reduce the size of the tree.

[!important]
This is the reason why we keep track of the size array

Implementation

class UnionFind:
    def __init__(self, n) -> None:
        self.parents = [i for i in range(0, n)]
        self.componentSize = [1] * n
    
    def union(self, a: int, b: int) -> bool:
        rootA = self.find(a)
        rootB = self.find(b)

        # If they have the same root, nothing needs to be done
        if rootA == rootB: 
            return False
        
        if self.componentSize[rootA] < self.componentSize[rootB]:
            self.parents[rootA] = self.parents[rootB]
            self.componentSize[rootB] += self.componentSize[rootA]
            self.componentSize[rootA] = 0
        else:
            self.parents[rootB] = self.parents[rootA]
            self.componentSize[rootA] += self.componentSize[rootB]
            self.componentSize[rootB] = 0
        
        return True

Time complexity: $O(1 + O(find))$

  • $O(1)$ for assigning the parent

Find

The find method will find the root of a node.

Pasted image 20230727215800.png

Given the above graph, we have:

  • find(5) -> 1
  • find(10) -> 6

Implementation

Iteratively

def find(self, node: int):
	currNode = node
	while currNode != self.parents[currNode]:
		parent = self.parents[currNode]
		currNode = parent
	
	return currNode

Time complexity: $O(n)$ — worst case

Recursively

def find(self, node: int):
	if node != self.parents[node]:
		return self.find(self.parents[node])	
	return node

Time complexity: $O(n)$ — worst case
Space complexity: $O(n)$ — Recursion space

Path compression

As we can see from the above, the path are not optimal in find — to travel from 5 to 1 we need to go through 5 -> 4 -> 3 -> 2 -> 1

As the result, we can compress the path to be as following:

Pasted image 20230728100957.png

So in this case, travel from 5 -> 1 takes 5 -> 1.

Intuition

The idea is we connect the root of one node to be the root of other node. For example

Pasted image 20230728103148.png

If we now perform union(4, 5) we will:

  • Find the root of 4 : find(4) -> 2
  • Find the root of 5 : find(5) -> 0

Pasted image 20230728103359.png

We connect 2 -> 0.

If we continue to do union(4, 1). We first:

  • find(4) -> 0
  • find(1) -> 0

Since they're the same root, we're not doing union. However, we will compress the path:

  1. Parent of 4: 2
    • Since 4 != 2.
      • Assign parent of 4 to parent of 2: 0
    • We now check 2
  2. Parent of 2 is 0
    • Since 2 != 0
      • Assign parent of 2 to parent of 0: 0
    • We now check 0
  3. Parent of 0 is 0: We break

Pasted image 20230728103546.png

Pseudo

The algorithm works as following:

Assign as we go

while node != parentNode:
	assign its parentNode to be the next parent
	node -> parentNode

Since algorithm seems like it doesn't work if the tree has more than 3 elements — since assigning to the nextParent might not be sufficient enough.

However if we runs this algorithm from the start, it becomes correct as it always make sure that the nextParent's parent would be the root.

Time complexity: $O(logn)$

  • The reason why it's $O(logn)$ is because we traverse the tree upwards.

Simpler way: find root and then assign

root = node

# Find the root
while (root != self.parents[root]):
	root = self.parents[root]

# Assign everything to its root
while (node != root):
	parent = self.parents[node]
	self.parents[node] = root
	node = parent

Time complexity: $O(logn)$

Implementation

class UnionFind:
    def __init__(self, n) -> None:
        self.parents = [i for i in range(0, n)]
        self.componentSize = [1] * n
    
    def union(self, a: int, b: int) -> bool:
        rootA = self.find(a)
        rootB = self.find(b)
        
        if rootA == rootB:
            return False
        
        if self.componentSize[rootA] < self.componentSize[rootB]:
            self.parents[rootA] = self.parents[rootB]
            self.componentSize[rootB] += self.componentSize[rootA]
            self.componentSize[rootA] = 0
        else:
            self.parents[rootB] = self.parents[rootA]
            self.componentSize[rootA] += self.componentSize[rootB]
            self.componentSize[rootB] = 0
        
        return True

    def find(self, node: int):
        root = node

        while (root != self.parents[root]):
            root = self.parents[root]

		# Path compression
        while (node != root):
            parent = self.parents[node]
            self.parents[node] = root
            node = parent
        
        return root

Note that the path compression can also be implemented as follows:

def find(self, node: int):
	currNode = node
	while (currNode != self.parents[currNode]):
		parent = self.parents[currNode]
		self.parents[currNode] = self.parents[parent]
		currNode = parent

	return currNode

Or recursively:

def find(self, node: int):
	if node != self.parents[node]:
		self.parents[node] = self.find(self.parents[node])

	return self.parents[node]

Deal with custom data type

In the example above, we can see that our UnionFind only works with integer as the input. However they might give us a Node or something.

In that case, we might wanna change our datastructures from array to dictionaries

class Node:
    def __init__(self, val: str) -> None:
        self.val = val
    
    def __repr__(self) -> str:
        return self.val

class UnionFindNode():
    def __init__(self) -> None:
        self.parents = {}
        self.componentSize = defaultdict(lambda: 1)
    
    def union(self, a: int, b: int) -> bool:
        rootA = self.find(a)
        rootB = self.find(b)
        
        if rootA == rootB:
            return False
        
        if self.componentSize[rootA] < self.componentSize[rootB]:
            self.parents[rootA] = self.parents.get(rootB, rootB)
            self.componentSize[rootB] += self.componentSize[rootA]
            self.componentSize[rootA] = 0
        else:
            self.parents[rootB] = self.parents.get(rootA, rootA)
            self.componentSize[rootA] += self.componentSize[rootB]
            self.componentSize[rootB] = 0
        
        return True

    def find(self, node: Node):
        currNode = node
        while (currNode != self.parents.get(currNode, currNode)):
            parent = self.parents.get(currNode, currNode)
            self.parents[currNode] = self.parents.get(parent, parent)
            currNode = parent
		return currNode

if __name__ == "__main__":
    unionFind = UnionFindNode()
    n1 = Node("1")
    n2 = Node("2")
    n3 = Node("3")
    n4 = Node("4")
    n5 = Node("5")
    
    unionFind.union(n1, n3)
    unionFind.union(n4, n1)
    unionFind.union(n5, n2)
    unionFind.union(n2, n3)
    
    print(unionFind.parents, unionFind.componentSize)

Output:

{3: 1, 4: 1, 2: 5, 5: 1} {1: 5, 3: 0, 4: 0, 5: 0, 2: 0})