# Python Efficient O(k) Insert and Sum using Trie

• A standard `Trie`-based solution where each node keeps track of the total count of its children.

For inserting, we first determine if the string already exists in the Trie. If it does, we calculate the difference in the previous and new value, and update the nodes with the difference as we traverse down the Trie nodes.

Sum is simple because each node already holds the sum of its children and we simply have to traverse to the node and obtain its count.

This results in both operations being O(k), where k is the length of the string/prefix.

- Yangshun

``````class TrieNode():
def __init__(self, count = 0):
self.count = count
self.children = {}

class MapSum(object):

def __init__(self):
"""
"""
self.root = TrieNode()
self.keys = {}

def insert(self, key, val):
"""
:type key: str
:type val: int
:rtype: void
"""
# Time: O(k)
curr = self.root
delta = val - self.keys.get(key, 0)
self.keys[key] = val

curr = self.root
curr.count += delta
for char in key:
if char not in curr.children:
curr.children[char] = TrieNode()
curr = curr.children[char]
curr.count += delta

def sum(self, prefix):
"""
:type prefix: str
:rtype: int
"""
# Time: O(k)
curr = self.root
for char in prefix:
if char not in curr.children:
return 0
curr = curr.children[char]
return curr.count
``````

• @yangshun thanks ðŸ˜€
In the Trienode structure the Val can be removed right?

• @jerom.chai-gmail.com Yep good catch, have fixed it. Thanks!

• Directly use the dictionary Structure

``````class MapSum(object):

def __init__(self):
self.Dict={}

def insert(self, key, val):
print("Null")
self.Dict[key]=val

def sum(self, prefix):
sum=0
for key, val in self.Dict.items():
if key.find(prefix)==0:
sum=sum+val
return sum
``````

• Here is a trick that I recently learned: using defaultdict to build Trie with a few lines of code.

``````COUNT = None
def _trie():
ret = collections.defaultdict(_trie)
ret[COUNT] = 0
return ret

class MapSum(object):

def __init__(self):
"""
"""
self.key_val = {}
self.root = _trie()

def insert(self, key, val):
"""
:type key: str
:type val: int
:rtype: void
"""
diff = val - self.key_val.get(key, 0)
self.key_val[key] = val
root = self.root
for c in key:
root[COUNT] += diff
root = root[c]
root[COUNT] += diff

def sum(self, prefix):
"""
:type prefix: str
:rtype: int
"""
root = self.root
for c in prefix:
root = root[c]
return root[COUNT]
``````

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.