# Java/Python clear solution with UnionFind Class (Weighting and Path compression)

• Union Find
is an abstract data structure supporting `find` and `unite` on disjointed sets of objects, typically used to solve the network connectivity problem.

The two operations are defined like this:

`find(a,b)` : are `a` and `b` belong to the same set?

`unite(a,b)` : if `a` and `b` are not in the same set, unite the sets they belong to.

With this data structure, it is very fast for solving our problem. Every position is an new land, if the new land connect two islands `a` and `b`, we combine them to form a whole. The answer is then the number of the disjointed sets.

The following algorithm is derived from Princeton's lecture note on Union Find in Algorithms and Data Structures It is a well organized note with clear illustration describing from the naive QuickFind to the one with Weighting and Path compression.
With Weighting and Path compression, The algorithm runs in `O((M+N) log* N)` where `M` is the number of operations ( unite and find ), `N` is the number of objects, `log*` is iterated logarithm while the naive runs in `O(MN)`.

For our problem, If there are `N` positions, then there are `O(N)` operations and `N` objects then total is `O(N log*N)`, when we don't consider the `O(mn)` for array initialization.

Note that `log*N` is almost constant (for `N` = 265536, `log*N` = 5) in this universe, so the algorithm is almost linear with `N`.

However, if the map is very big, then the initialization of the arrays can cost a lot of time when `mn` is much larger than `N`. In this case we should consider using a hashmap/dictionary for the underlying data structure to avoid this overhead.

Of course, we can put all the functionality into the Solution class which will make the code a lot shorter. But from a design point of view a separate class dedicated to the data sturcture is more readable and reusable.

I implemented the idea with 2D interface to better fit the problem.

Java

``````public class Solution {

private int[][] dir = {{0, 1}, {0, -1}, {-1, 0}, {1, 0}};

public List<Integer> numIslands2(int m, int n, int[][] positions) {
UnionFind2D islands = new UnionFind2D(m, n);
List<Integer> ans = new ArrayList<>();
for (int[] position : positions) {
int x = position[0], y = position[1];
for (int[] d : dir) {
int q = islands.getID(x + d[0], y + d[1]);
if (q > 0 && !islands.find(p, q))
islands.unite(p, q);
}
}
return ans;
}
}

class UnionFind2D {
private int[] id;
private int[] sz;
private int m, n, count;

public UnionFind2D(int m, int n) {
this.count = 0;
this.n = n;
this.m = m;
this.id = new int[m * n + 1];
this.sz = new int[m * n + 1];
}

public int index(int x, int y) { return x * n + y + 1; }

public int size() { return this.count; }

public int getID(int x, int y) {
if (0 <= x && x < m && 0<= y && y < n)
return id[index(x, y)];
return 0;
}

public int add(int x, int y) {
int i = index(x, y);
id[i] = i; sz[i] = 1;
++count;
return i;
}

public boolean find(int p, int q) {
return root(p) == root(q);
}

public void unite(int p, int q) {
int i = root(p), j = root(q);
if (sz[i] < sz[j]) { //weighted quick union
id[i] = j; sz[j] += sz[i];
} else {
id[j] = i; sz[i] += sz[j];
}
--count;
}

private int root(int i) {
for (;i != id[i]; i = id[i])
id[i] = id[id[i]]; //path compression
return i;
}
}
//Runtime: 20 ms
``````

Python (using dict)

``````class Solution(object):
def numIslands2(self, m, n, positions):
ans = []
islands = Union()
for p in map(tuple, positions):
for dp in (0, 1), (0, -1), (1, 0), (-1, 0):
q = (p[0] + dp[0], p[1] + dp[1])
if q in islands.id:
islands.unite(p, q)
ans += [islands.count]
return ans

class Union(object):
def __init__(self):
self.id = {}
self.sz = {}
self.count = 0

self.id[p] = p
self.sz[p] = 1
self.count += 1

def root(self, i):
while i != self.id[i]:
self.id[i] = self.id[self.id[i]]
i = self.id[i]
return i

def unite(self, p, q):
i, j = self.root(p), self.root(q)
if i == j:
return
if self.sz[i] > self.sz[j]:
i, j = j, i
self.id[i] = j
self.sz[j] += self.sz[i]
self.count -= 1

#Runtime: 300 ms
``````

• Thank @peisi for sharing such nice material on UnionFind. Find someone who may get confused on `log*`: it means the number of times that we need to take `log` (base `2`) on a number to make it become `1`. You may refer to page 31 of the linked note for it :-)

• Yes. I forgot to mention that.

https://en.wikipedia.org/wiki/Iterated_logarithm

• Very nice. Thanks so much for the solution.

• I would like to clarify following line
int q = islands.getID(x + d[0], y + d[1]);

Should it call getID or getIndex ?
When it calls islands.unite(p, q) - here p is index returned in add. So I think q also should be index and not ID.

• same idea, here is my concise solution - 20ms

``````public class Solution {
private static final int[][] dir = {{0, 1},{1, 0},{0, -1},{-1, 0}};

public List<Integer> numIslands2(int n, int m, int[][] positions) {
int[][] map = new int[n + 2][m + 2];
List<Integer> ans = new ArrayList();
int islandN = 0;
UnionSet us = new UnionSet(n, m);

for (int[] p : positions) {
map[p[0] + 1][p[1] + 1] = 1;
islandN++;
for (int[] d : dir)
if (map[p[0] + d[0] + 1][p[1] + d[1] + 1] > 0 && us.union(p[0], p[1], p[0] + d[0], p[1] + d[1]))
islandN--;
}
return ans;
}

private class UnionSet {
int n, m;
int[] p, size;

public UnionSet(int a, int b) {
n = a; m = b;
p = new int[getID(n, m)];
size = new int[getID(n, m)];
}

private int getID(int i, int j) {
return i * m + j + 1; // ensure no id == 0;
}

private int find(int i) {
if (p[i] == 0) { // == 0 means not yet initialized
p[i] = i;
size[i] = 1;
}
p[i] = (p[i] == i) ? i : find(p[i]);
return p[i];
}

private boolean union(int i1, int j1, int i2, int j2) { // true if combines two element of two different sets
int s1 = find(getID(i1, j1)), s2 = find(getID(i2, j2));
if (s1 == s2) return false;
if (size[s1] > size[s2]) {
p[s2] = s1;
size[s1] += size[s2];
} else {
p[s1] = s2;
size[s2] += size[s1];
}
return true;
}
}
}``````

• UnionFind2D islands = new UnionFind2D(m, n); // it uses O(m*n) time

• For weighting why not use the depth of the tree instead of the size of the tree?

• Note that `log*N` is almost constant (for `N` = 265536, `log*N` = 5) in this universe

265536 isn't that big...

• @StefanPochmann It is actually N = 2*65536, logN = 5

• @StefanPochmann It is actually N = 2*65536, logN = 5

That's not what it says there. Also, about what you just wrote: If N is 2*65536 then logN isn't 5 unless you have an extremely unusual base, which you should provide.

Just pointing out markdown accidents in hopes they'll get fixed.

• @StefanPochmann I thought it was my typo but it turns out to be the formatting bug. I format it as follows.

It is N = 2**65536, log*N = 5
log*(n) is the function that counts the times it takes doing log() to get to 1.
Here when N = 2**65536, log(log(log(log(log(N)))))=1, it takes 5 log functions, so log*N = 5.

• very good codes and excellent explanation

I like the post with articles. Reading paper is far more interesting and important than just finding a solution for the problem

• @dietpepsi Thanks for sharing. Nice solution.
I dont understand why we should set the length of id array and size array to be [m * n + 1] so I tried to use m * n as length of array. No syntax error occur but the results are not correct.
For the test case:

``````3
3
[[0,0],[0,1],[1,2],[2,1]]
``````

I got `[1,2,3,4]` while we should get `[1,1,2,3]`

I dont see the difference between m * n + 1 and m * n and I am also confused about the difference between these two results. Can anyone explain? Thanks a lot.

• @Tōsaka-Rin

Pls don't over think it. Notice -

``````    public int index(int x, int y) { return x * n + y + 1; }
``````

and

``````    int i = index(x, y);
id[i] = i; sz[i] = 1;
``````

when x = 0 and y = 0, index `i` got its min value `1`.
when x = m - 1 and y = n - 1, index `i` got its max value`m * n`.

so.. `i` starts from 1, and its range is [1, m * n].
`id[]` array has 0-based numbering, and its size should be m * n + 1, with its value at id[0] unchanged.

• @zzhai Thanks for explaining.
I see where the min index and max value come but I define the index to be `idx = x * n + y` then the range should start from `0` to `m * n - 1`. So..idx starts from 0 and ends up with m * n - 1. I dont quite understand why we need to have another set but unchanged value(id[0] as you said). Anyway, thanks for explanation.

• @zzhai So the difference is you initially fill the id array with -1?

• @dietpepsi is 265336 a typo by any chance? I thought you meant N=2^65536 there instead of N=265336.

• Really love your python version. Neat and clean.

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.