We build a simple Tri where each node’s value is the sum of all words in the Tri that have same prefix as current node.
Let me explain with a simple example. Suppose, we have an empty Tri. Insert "apple" with value=3. While inserting a key, we shall increase values of all intermediate nodes.
Here’s how to Tri looks like after insertion:

Now let’s insert "app" with value=2. Please note that we shall increase the value of intermediate nodes while inserting app. Therefore, the first 3 nodes will have value=5.

Now if you observe carefully. The query for the prefix ap will end up at the first p node from start, which will have value=5. This is the desired behaviour as both "app" and "apple" share prefix "ap", hence we return sum of their values.
So far so good. But what happens, when we want to update the value for a word. For example, I want to change the value for "apple" form 3 to 1.
To solve this problem, first we need to figure out how can we delete the value of "apple". Fortunately, this can be simply done by inserting negative of old value. For example, we can delete the value of apple by inserting -1 * old_value = -3.

Notice that now queries like: "appl" will give result=0 since there is no longer any word in the Tri with prefix "appl".
Now we can just reinsert the new value for "apple" i.e. 1.

Note that queries like "ap" will return 3 which is correct since, "apple" -> 1 and "app" -> 2.
While update the value of keys, we can just combine the steps of deleting and reinserting into one by just inserting -(old value) + (new value).
Complexity#
Code#
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
| class MapSum {
public map = new Map<string, number>();
public tri = new Tri();
insert(key: string, val: number): void {
let oldVal = this.map.get(key)??0;
this.map.set(key, val);
this.tri.insert(key, -oldVal + val);
}
sum(prefix: string): number {
return this.tri.query(prefix);
}
}
class TriNode {
public children = new Map<string,TriNode>();
public val: number = 0;
}
class Tri {
public root = new TriNode();
insert(word: string, value: number) {
let node = this.root;
for ( let i=0; i < word.length; i++ ) {
if ( !node.children.has(word[i]) )
node.children.set(word[i], new TriNode());
node = node.children.get(word[i]);
node.val += value;
}
}
query(prefix: string) {
let node = this.root;
for ( let i=0; i < prefix.length; i++ ) {
if ( !node.children.has(prefix[i]) )
return 0;
node = node.children.get(prefix[i]);
}
return node.val;
}
}
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
| class TrieNode:
def __init__(self):
self.children = dict()
self.val = None
self.is_leaf = False
def set_val(self, val):
self.val = val
def set_leaf(self):
self.is_leaf = True
class MapSum:
def __init__(self):
self.trie = TrieNode()
def insert(self, key: str, val: int) -> None:
curr = self.trie
for ch in key:
if ch not in curr.children:
curr.children[ch] = TrieNode()
curr = curr.children[ch]
curr.set_val(val)
curr.set_leaf()
def sum(self, prefix: str) -> int:
curr = self.trie
sum_prefix = 0
for ch in prefix:
if ch not in curr.children:
return 0
curr = curr.children[ch]
if curr.is_leaf:
sum_prefix += curr.val
queue = collections.deque([curr])
while queue:
curr = queue.popleft()
for child, node in curr.children.items():
queue.append(node)
if node.is_leaf:
sum_prefix += node.val
return sum_prefix
# Your MapSum object will be instantiated and called as such:
# obj = MapSum()
# obj.insert(key,val)
# param_2 = obj.sum(prefix)
|