# Use divide and conquer to solve this problem.

• //The complexity should be O(log(n))

class Solution {
public:

``````int getNodeNum(TreeNode* root)
{
if(root == nullptr)
{
return 0;
}

int res = 1;
if(root->left)
{
res += getNodeNum(root->left);
}

if(root->right)
{
res += getNodeNum(root->right);
}

return res;
}

int kthSmallest(TreeNode* root, int k) {

int left = getNodeNum(root->left);
if(1 + left == k)
{
return root->val;
}
else if(k > left + 1)
{
return kthSmallest(root->right, k-left-1);
}
else
{
return kthSmallest(root->left, k);
}
}
``````

};

• The complexity should be O(log(n))

Not even close. Already a single `getNodeNum` call is O(n). Your solution is actually O(n^2).

• You have to be more careful about the recursive calls. Think about how you can return as early as possible and only count the minimum number of nodes possible. Ideally you would have something like a node count stored at each node, but since you don't you want to avoid computing it so many times.

• Here's a memoized version of your solution, based off of your solution and the answer here on stackoverflow. Also this will perform better when you need to keep calling the method for different k input.

``````class Solution:
# @param {TreeNode} root
# @param {integer} k
# @return {integer}
h = {}
def kthSmallest(self, root, k):
left = self.countNodes(root.left)
if (1+left == k):
return root.val
if (k > left + 1):
return self.kthSmallest(root.right, k-left-1)
else:
return self.kthSmallest(root.left, k)

def countNodes(self,root):
if root is None:
return 0
if root in self.h:
return self.h[root]
else:
result = 1
if root.left:
if root.left in self.h:
result += self.h[root.left]
else:
self.h[root.left] = self.countNodes(root.left)
result += self.h[root.left]
if root.right:
if root.right in self.h:
result += self.h[root.right]
else:
self.h[root.right] = self.countNodes(root.right)
result += self.h[root.right]
self.h[root] = result

return result
``````

• Whoa, that's a lot of code. Why do you have both the caller and the callee memoize? It's enough when the callee does it:

``````def countNodes(self,root):
if root is None:
return 0
if root in self.h:
return self.h[root]

result = 1
if root.left:
result += self.countNodes(root.left)
if root.right:
result += self.countNodes(root.right)

self.h[root] = result
return result``````

• I thin the code below is better

• //I have revised my code shown below

void dfs(TreeNode *root, int k, vector<int> &ascending)
{
if(root == nullptr || k == ascending.size())
{
return;
}

``````dfs(root->left, k, ascending);
ascending.push_back(root->val);
dfs(root->right, k, ascending);
``````

}

int kthSmallest(TreeNode *root, int k)
{

``````vector<int> ascending;
dfs(root, k, ascending);
return ascending[k - 1];
``````

}

• class Solution {
public:

``````void dfs(TreeNode* root, int k, vector<int> &ascending)
{
if(root == nullptr || k == ascending.size())
{
return;
}

dfs(root->left, k, ascending);
ascending.push_back(root->val);
dfs(root->right, k, ascending);
}

int kthSmallest(TreeNode* root, int k) {

vector<int> ascending;
dfs(root, k, ascending);
return ascending[k-1];
}
``````

};

• The time complexity of getNodeNum is O(n). So this algorithm at least has O(n) complexity. T(n)=O(n)+T(n/2), so the complexity is O(n).

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.