Given a imbalance binary tree, print the leaf nodes then remove those leaf node, print the new leaf nodes until only root node left.
For example:
1
/ \
2 3
/ \
4 5
Print: [4 5 3], [2], [1]
int height(TreeNode * root)
{
if(!root) return 0;
return 1+max(height(root->left),height(root->right));
}
int dfs(vector<vector<int>> &res,TreeNode * root)
{
if(!root) return 0;
int n=1+max(dfs(res,root->left),dfs(res,root->right));
res[n-1].push_back(root->val);
return n;
}
vector<vector<int>> func(TreeNode * root)
{
int n=height(root);
vector<vector<int>> res(n);
dfs(res,root);
return res;
}
/*
TreeNode* a1=new TreeNode(3);
a1->left =new TreeNode(1);
a1->left->right =new TreeNode(2);
a1->right =new TreeNode(5);
a1->right->left =new TreeNode(4);
a1->right->right =new TreeNode(7);
a1->right->right->left =new TreeNode(6);
a1->left->left =new TreeNode(0);
a1->left->left->left =new TreeNode(-1);
func(a1);
*/
O(n) time and space, bottom up, hope it is correct!
Very good idea GoGoDong! My idea is the same but I optimize a little as avoiding calculation of getHeight in advance.
int dropByLevel(TreeNode root,Map<Integer,ArrayList<TreeNode>> map) {
if (root != null) {
int l = dropByLevel(root.left, map);
int r = dropByLevel(root.right, map);
int h = Math.max(l, r) + 1;
ArrayList<TreeNode> arr = map.containsKey(h) ? map.get(h) : new ArrayList<>():;
arr.add(root);
map.put(h, arr);
return h;
}
return 0;
}
void test (TreeNode root) {
Map<Integer,ArrayList<TreeNode>> map = new HashMap<>();
dropByLevel(root, map) ;
for (Entry<Integer, ArrayList<TreeNode>> entry : map.entrySet()) {
System.out.println(entry.getValue());
}
}
Recursive approach in Java:
It is similar to the ones above. Keeps track of the level
and deletes the leaf from the parent
(set parent.left
or parent.right
as null
and let garbage collector collect the child
node). Add the child node's value to the list at index level
public class PrintAndRemoveLeafNodes {
private static List<List<Integer>> leaves;
public static List<List<Integer>> printAndRemoveLeafNodes(TreeNode root) {
leaves = new ArrayList<List<Integer>>();
if (root == null)
return leaves;
// use a dummy node as parent for root
TreeNode dummy = new TreeNode(-1);
dummy.left = root;
printAndRemoveLeafNodes(dummy, root, true);
return leaves;
}
private static int printAndRemoveLeafNodes(TreeNode parent, TreeNode child, boolean left) {
int level = 0;
if (child.left != null)
level = printAndRemoveLeafNodes(child, child.left, true);
if (child.right != null)
level = Math.max(level, printAndRemoveLeafNodes(child, child.right, false));
// know the level at this point
// remove the leaf
if (left)
parent.left = null;
else
parent.right = null;
if (leaves.size() < level + 1)
leaves.add(new ArrayList<Integer>());
// add leaf to list of that level
leaves.get(level).add(child.val);
return level + 1;
}
}
First, let's know the parent of each node, and the number of children they have. We can do this with a simple DFS.
Now perform a bottom up traversal of leaves using a BFS. A leaf has no children. When we process a node (at this point it is a leaf), its parent now has one less child, and if it is a leaf, enqueue. We will visit the nodes in the correct order, and variable depth
will tell us where to put the answer.
class Node:
def __init__(self, val):
self.val = val
self.left = None
self.right = None
def solve(root):
num_children = collections.Counter()
parent = {}
def dfs(node, par = None):
if node:
num_children[node] = bool(node.left) + bool(node.right)
parent[node] = par
dfs(node.left, node)
dfs(node.right, node)
dfs(root)
leaves = [(node, 0) for node, num in num_children.iteritems() if num == 0]
ans = []
while leaves:
node, depth = leaves.pop()
while len(ans) <= depth:
ans.append([])
ans[depth].append(node.val)
par = parent[node]
num_children[par] -= 1
if num_children[par] == 0:
leaves.append((par, depth+1))
return ans
Simple traversal
Complexity -
Worst Case - assume tree is slanted either to right or left, each time leaf is popped, a size of tree reduces by 1, so
T(n) = n + n-1 + n-2 .... + 1 = n(n+1)/2 = O(n^2)
Best Case - tree is complete balanced binary tree, each time leaves are popped size reduces by half
T(n) = n + n/2 + n/4 + ..... = O ( n*log(n) )
class TreeNode(object):
def __init__(self,val):
self.val = val
self.left = self.right = None
def popLeaves(node,leaves):
if node:
if node.left is None and node.right is None:
leaves.append(node.val) # node is a leaf, save it in a list
return None # unlinking it from tree
else:
# node is internal node
node.left = popLeaves(node.left, leaves)
node.right = popLeaves(node.right, leaves)
return node
def purge(root):
result = []
while root:
leaves = []
root = popLeaves(root,leaves) # pop leaves in current state of a tree
result.append(leaves[:]) # put all collected leaves in a result
print(result)
return result