# Print and remove leaf nodes until root node left

• Given a imbalance binary tree, print the leaf nodes then remove those leaf node, print the new leaf nodes until only root node left.

For example:

``````          1
/ \
2   3
/ \
4   5
``````

Print: [4 5 3], [2], [1]

• @evanl1208 what is meant by drop? Shall we remove leaves first?

• @elmirap yes, I mean remove those leaf nodes.

• @evanl1208 thank you

• ``````int height(TreeNode * root)
{
if(!root) return 0;
return 1+max(height(root->left),height(root->right));
}

int dfs(vector<vector<int>> &res,TreeNode * root)
{
if(!root) return 0;
int n=1+max(dfs(res,root->left),dfs(res,root->right));
res[n-1].push_back(root->val);
return n;
}

vector<vector<int>> func(TreeNode * root)
{
int n=height(root);
vector<vector<int>> res(n);
dfs(res,root);
return res;
}

/*
TreeNode* a1=new TreeNode(3);
a1->left =new TreeNode(1);
a1->left->right =new TreeNode(2);
a1->right =new TreeNode(5);
a1->right->left =new TreeNode(4);
a1->right->right =new TreeNode(7);
a1->right->right->left =new TreeNode(6);
a1->left->left =new TreeNode(0);
a1->left->left->left =new TreeNode(-1);

func(a1);
*/
``````

O(n) time and space, bottom up, hope it is correct!

• Very good idea GoGoDong! My idea is the same but I optimize a little as avoiding calculation of getHeight in advance.

``````int dropByLevel(TreeNode root,Map<Integer,ArrayList<TreeNode>> map) {
if (root != null) {
int l = dropByLevel(root.left, map);
int r = dropByLevel(root.right, map);
int h = Math.max(l, r) + 1;
ArrayList<TreeNode> arr = map.containsKey(h) ? map.get(h)  : new ArrayList<>():;
map.put(h, arr);
return h;
}
return 0;
}

void test (TreeNode root)  {
Map<Integer,ArrayList<TreeNode>> map  =  new HashMap<>();
dropByLevel(root,  map) ;
for (Entry<Integer, ArrayList<TreeNode>> entry  :  map.entrySet())  {
System.out.println(entry.getValue());
}
}
``````

• @GoGoDong but you are not removing them? you are only printing.

• @evanl1208 Does the order of printing leaves matter?

• Recursive approach in Java:

It is similar to the ones above. Keeps track of the `level` and deletes the leaf from the `parent` (set `parent.left` or `parent.right` as `null` and let garbage collector collect the `child` node). Add the child node's value to the list at index `level`

``````public class PrintAndRemoveLeafNodes {

private static List<List<Integer>> leaves;

public static List<List<Integer>> printAndRemoveLeafNodes(TreeNode root) {
leaves = new ArrayList<List<Integer>>();

if (root == null)
return leaves;

// use a dummy node as parent for root
TreeNode dummy = new TreeNode(-1);
dummy.left = root;
printAndRemoveLeafNodes(dummy, root, true);

return leaves;
}

private static int printAndRemoveLeafNodes(TreeNode parent, TreeNode child, boolean left) {
int level = 0;

if (child.left != null)
level = printAndRemoveLeafNodes(child, child.left, true);
if (child.right != null)
level = Math.max(level, printAndRemoveLeafNodes(child, child.right, false));
// know the level at this point
// remove the leaf
if (left)
parent.left = null;
else
parent.right = null;

if (leaves.size() < level + 1)
// add leaf to list of that level
return level + 1;
}

}
``````

• would a queue work? a bfs till you get to a leaf node. delete it then at the end just recur the root node. stop when root is nullptr??????

• First, let's know the parent of each node, and the number of children they have. We can do this with a simple DFS.

Now perform a bottom up traversal of leaves using a BFS. A leaf has no children. When we process a node (at this point it is a leaf), its parent now has one less child, and if it is a leaf, enqueue. We will visit the nodes in the correct order, and variable `depth` will tell us where to put the answer.

``````class Node:
def __init__(self, val):
self.val = val
self.left = None
self.right = None

def solve(root):
num_children = collections.Counter()
parent = {}
def dfs(node, par = None):
if node:
num_children[node] = bool(node.left) + bool(node.right)
parent[node] = par
dfs(node.left, node)
dfs(node.right, node)

dfs(root)

leaves = [(node, 0) for node, num in num_children.iteritems() if num == 0]
ans = []

while leaves:
node, depth = leaves.pop()
while len(ans) <= depth:
ans.append([])
ans[depth].append(node.val)

par = parent[node]
num_children[par] -= 1
if num_children[par] == 0:
leaves.append((par, depth+1))

return ans
``````

• Simple traversal

Complexity -

Worst Case - assume tree is slanted either to right or left, each time leaf is popped, a size of tree reduces by 1, so

T(n) = n + n-1 + n-2 .... + 1 = n(n+1)/2 = O(n^2)

Best Case - tree is complete balanced binary tree, each time leaves are popped size reduces by half

T(n) = n + n/2 + n/4 + ..... = O ( n*log(n) )

``````class TreeNode(object):
def __init__(self,val):
self.val  = val
self.left = self.right = None

def popLeaves(node,leaves):
if node:
if node.left is None and node.right is None:
leaves.append(node.val) # node is a leaf, save it in a list
return None # unlinking it from tree
else:
# node is internal node
node.left = popLeaves(node.left, leaves)
node.right = popLeaves(node.right, leaves)
return node

def purge(root):
result = []
while root:
leaves = []
root  = popLeaves(root,leaves) # pop leaves in current state of a tree
result.append(leaves[:]) # put all collected leaves in a result
print(result)
return result``````

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.