```
# Definition for a binary tree node.
# class TreeNode(object):
# def __init__(self, x):
# self.val = x
# self.left = None
# self.right = None
class Solution(object):
def countUnivalSubtrees(self, root):
self.count = 0
"""
:type root: TreeNode
:rtype: int
"""
def f(root):
if root is None:
return True
# post order
left = f(root.left)
right = f(root.right)
if left and right:
if root.left and root.left.val != root.val:
return False
if root.right and root.right.val != root.val:
return False
self.count += 1
return True
return False
f(root)
return self.count
```