All in O(1), with detailed explantation


  • 29
    I

    The main idea is to maintain an ordered two-dimensional doubly-linked list (let's call it matrix for convenience), of which each row is corresponding to a value and all of the keys in the same row have the same value.

    Suppose we get the following key-value pairs after some increment operations. ("A": 4 means "A" is increased four times so its value is 4, and so on.)

    "A": 4, "B": 4, "C": 2, "D": 1
    

    Then one possible matrix may look like this:

    row0: val = 4, strs = {"A", "B"}
    row1: val = 2, strs = {"C"}
    row2: val = 1, strs = {"D"}
    

    If we can guarantee the rows are in descending order in terms of value, then GetMaxKey()/GetMinKey() will be easy to implement in O(1) time complexity. Because the first key in the first row will always has the maximal value, and the first key in the last row will always has the minimal value.

    Once a key is increased, we move the key from current row to last row if last_row.val = current_row.val + 1. Otherwise, we insert a new row before current row with vallue current_row.val + 1, and move the key to to the new row. The logic of decrement operation is similar. Obviously, by doing this, the rows will keep its descending order.

    For example, after Inc("D"), the matrix will become

    row0: val = 4, strs = {"A", "B"}
    row1: val = 2, strs = {"C", "D"}
    

    Inc("D") again

    row0: val = 4, strs = {"A", "B"}
    row1: val = 3, strs = {"D"}
    row2: val = 2, strs = {"C"}
    

    Now the key problem is how to maintain the matrix in O(1) runtime when increase/decrease a key by 1.

    The answer is hash map. By using a hash map to track the position of a key in the matrix, we can access a key in the matrix in O(1). And since we use linked list to store the matrix, thus insert/move operations will all be O(1).

    The psudocode of Inc() is as follows(Dec() is similar).

    if the key isn't in the matrix:
        if the matrix is empty or the value of the last row isn't 1:
            insert a new row with value 1 to the end of the matrix, and put the key in the new row;
        else:
            put the key in the last row of the matrix;
    else:
        if the key is at the first row or last_row.value != current_row.value + 1:
            insert a new row before current row, with value current_row.value + 1, and move the key to the new row;
        else:
            move the key from current row to last row;
    

    Here is the code.

    class AllOne {
    public:
        struct Row {
            list<string> strs;
            int val;
            Row(const string &s, int x) : strs({s}), val(x) {}
        };
    
        unordered_map<string, pair<list<Row>::iterator, list<string>::iterator>> strmap;
        list<Row> matrix;
    
        /** Initialize your data structure here. */
        AllOne() {
            
        }
        
        /** Inserts a new key <Key> with value 1. Or increments an existing key by 1. */
        void inc(string key) {
            if (strmap.find(key) == strmap.end()) {
                if (matrix.empty() || matrix.back().val != 1) {
                    auto newrow = matrix.emplace(matrix.end(), key, 1);
                    strmap[key] = make_pair(newrow, newrow->strs.begin());
                }
                else {
                    auto newrow = --matrix.end();
                    newrow->strs.push_front(key);
                    strmap[key] = make_pair(newrow, newrow->strs.begin());
                }
            }
            else {
                auto row = strmap[key].first;
                auto col = strmap[key].second;
                auto lastrow = row;
                --lastrow;
                if (lastrow == matrix.end() || lastrow->val != row->val + 1) {
                    auto newrow = matrix.emplace(row, key, row->val + 1);
                    strmap[key] = make_pair(newrow, newrow->strs.begin());
                }
                else {
                    auto newrow = lastrow;
                    newrow->strs.push_front(key);
                    strmap[key] = make_pair(newrow, newrow->strs.begin());
                }
                row->strs.erase(col);
                if (row->strs.empty()) matrix.erase(row);
            }
        }
        
        /** Decrements an existing key by 1. If Key's value is 1, remove it from the data structure. */
        void dec(string key) {
            if (strmap.find(key) == strmap.end()) {
                return;
            }
            else {
                auto row = strmap[key].first;
                auto col = strmap[key].second;
                if (row->val == 1) {
                    row->strs.erase(col);
                    if (row->strs.empty()) matrix.erase(row);
                    strmap.erase(key);
                    return;
                }
                auto nextrow = row;
                ++nextrow;
                if (nextrow == matrix.end() || nextrow->val != row->val - 1) {
                    auto newrow = matrix.emplace(nextrow, key, row->val - 1);
                    strmap[key] = make_pair(newrow, newrow->strs.begin());
                }
                else {
                    auto newrow = nextrow;
                    newrow->strs.push_front(key);
                    strmap[key] = make_pair(newrow, newrow->strs.begin());
                }
                row->strs.erase(col);
                if (row->strs.empty()) matrix.erase(row);
            }
        }
        
        /** Returns one of the keys with maximal value. */
        string getMaxKey() {
            return matrix.empty() ?  "" : matrix.front().strs.front();
        }
        
        /** Returns one of the keys with Minimal value. */
        string getMinKey() {
            return matrix.empty() ?  "" : matrix.back().strs.front();
        }
    };
    

  • 2

    @ivan.chan
    Great solution to link the values in order, guy!
    You arrange values and keys in the form of matrix like this:

    1 - [key - key - key - ...]
    |
    2 - [key - key - key - ...]
    |
    3 - [key - key - key - ...]
    |
    ...
    ...
    

    The values column (it is row in your code) is in order, thus getMinKey and getMaxKey is no doubt O(1).
    And you maintain a hash map to record <key, <row_iterator, column_iterator> to find the key in O(1), as a result inc and dec can be done in O(1), because only pointers (iterator) moving involved.
    Brilliant idea to link values! I just linked keys and had no idea to get min/max in O(1) :(


  • 0
    L

    I'll note that this is pedantry, as I don't believe that the question is possible without O(1) amortized.

    said in 0ms, all in O(1), with detailed explantation:

    if (strmap.find(key) == strmap.end()) {

    unordered_map::find is O(n), although O(1) in the average case. http://www.cplusplus.com/reference/unordered_map/unordered_map/find/

    Worst case: linear in container size.


  • 2
    I

    @lano1
    Yes, O(1) means average runtime here. I don't think this problem will have a solution in O(1) at worst case, thus an average O(1) is acceptable. I think everyone knows the time complexities of operations of unordered_map, so I don't emphasis O(1) is in average : )


  • 0
    M

    Hey @ivancjw
    Thanks for submitting your code, and for the detailed explanation. I am curious about something. Whenever we have to move key to a new row, eg after an increase makes it the new max, We have to scan the inner list of keys to find it, remove it from that list and create a new one, yes? How is the list scan an O(1) operation?


  • 0
    I

    @mtu_wa_watu
    Well, since we store the iterators of the keys in a hash map, we can locate any key in O(1)

    unordered_map<string, pair<list<Row>::iterator, list<string>::iterator>> strmap;
    

    list.insert/erase are O(1) operations as only one element is inserted/erased
    reference: http://www.cplusplus.com/reference/list/list/insert/


  • 0
    M

    @ivancjw Thanks for your response. It looks like remove is not a O(1) operation:

    http://www.cplusplus.com/reference/list/list/erase/

    Complexity Linear in the number of elements erased (destructions).

  • 0
    I

    @mtu_wa_watu
    Each time only one element(key) will be erased, the number of elements is 1, not n in the above code


  • 0
    M

    @ivancjw just reread my own quote. <facepalm> Thanks!


  • 0
    M

    @ivancjw Question, I am not sure about the syntax of C++ below (taken from your inc function). It would be great if you can help to answer my question. Specifically, let us say we have two keys: A and B need to emplace into list.
    after you insert A, strmap[A] = (row1, Iterator to begin of strs)
    after you insert B, strmap[B] = (row1, Iterator to begin of strs)
    After these two steps, strs should be in like this: [B, A]. But seems like to me, when you store the position of B, B is overwriting the position of A since both of them pointing to the beginning of strs.

            if (strmap.find(key) == strmap.end()) {
                if (matrix.empty() || matrix.back().val != 1) {
                    auto newrow = matrix.emplace(matrix.end(), key, 1);
                    strmap[key] = make_pair(newrow, newrow->strs.begin());
                }
                else {
                    auto newrow = --matrix.end();
                    newrow->strs.push_front(key);
                    strmap[key] = make_pair(newrow, newrow->strs.begin());
                }
            }
    

  • 0
    I

    @microRNA The iterator of A is a pointer of A. Though B is inserted, the pointer of A is still unchanged.


  • 2
    L

    I think this question is really similar with "LFU Cache". My implementation of "LFU Cache" is here. I used similar data structure to realize the time complexity requirements. The structure graph is like the following:
    (The graph is provided by the link)

    head --- ValueNode1 ---- ValueNode2 ---- ... ---- ValueNodeN --- tail 
                  |               |                       |               
                first           first                   first             
                  |               |                       |               
               KeyNodeA        KeyNodeE                KeyNodeG           
                  |               |                       |               
               KeyNodeB        KeyNodeF                KeyNodeH           
                  |                                       |               
               KeyNodeC                                KeyNodeI           
                  |                                                       
               KeyNodeD                                                   
    

    My implementation is the following:

    public class AllOne {
        class KeyNode {
            String key;
            int freq;
            public KeyNode (String key) {
                this.key = key;
                this.freq = 1;
            }
        }
        class FreqNode {
            int freq;
            FreqNode prev;
            FreqNode next;
            Set<KeyNode> set; // keep the insertion order
            public FreqNode (int freq, FreqNode prev, FreqNode next) {
                this.freq = freq;
                this.prev = prev;
                this.next = next;
                set = new LinkedHashSet<>();
            }
        }
        Map<String, KeyNode> keyMap;
        Map<Integer, FreqNode> freqMap;
        FreqNode head, tail;
        /** Initialize your data structure here. */
        public AllOne() {
            head = null;
            tail = null;
            keyMap = new HashMap<>();
            freqMap = new HashMap<>();
        }
    
        /** Inserts a new key <Key> with value 1. Or increments an existing key by 1. */
        public void inc(String key) {
            if (keyMap.containsKey(key)) {
                increase(key);
                return;
            }
            insertKeyNode(key);
        }
    
        /** Decrements an existing key by 1. If Key's value is 1, remove it from the data structure. */
        public void dec(String key) {
            if (!keyMap.containsKey(key))   return;
            decrease(key);
        }
    
        /** Returns one of the keys with maximal value. */
        public String getMaxKey() {
            if (tail == null)   return "";
            return tail.set.iterator().next().key;
        }
    
        /** Returns one of the keys with Minimal value. */
        public String getMinKey() {
            if (head == null)   return "";
            return head.set.iterator().next().key;
        }
        // helper function
        // increase freq of key, update val if necessary
        public void increase(String key) {
            KeyNode keynode = keyMap.get(key);
            FreqNode freqnode = freqMap.get(keynode.freq);
            keynode.freq += 1;
            FreqNode nextFreqNode = freqnode.next;
            if (nextFreqNode == null) {
                nextFreqNode = new FreqNode(keynode.freq, freqnode, null);
                freqnode.next = nextFreqNode;
                tail = nextFreqNode;
                freqMap.put(keynode.freq, nextFreqNode);
            }
            if (nextFreqNode != null && nextFreqNode.freq > keynode.freq) {
                nextFreqNode = insertFreqNodePlus1(keynode.freq, freqnode);
            }
            unlinkKey(keynode, freqnode);
            linkKey(keynode, nextFreqNode);
        }
        public void decrease(String key) {
            KeyNode keynode = keyMap.get(key);
            if (keynode.freq == 1) {
                keyMap.remove(key);
                freqMap.get(1).set.remove(keynode);
                if (freqMap.get(1).set.size() == 0) {
                    deleteFreqNode(freqMap.get(1));
                }
                return;
            }
            FreqNode freqnode = freqMap.get(keynode.freq);
            keynode.freq -= 1;
            FreqNode prevFreqNode = freqnode.prev;
            if (prevFreqNode == null) {
                prevFreqNode = new FreqNode(keynode.freq, null, freqnode);
                freqnode.prev = prevFreqNode;
                head = prevFreqNode;
                freqMap.put(keynode.freq, prevFreqNode);
            }
            if (prevFreqNode != null && prevFreqNode.freq < keynode.freq) {
                prevFreqNode = insertFreqNodeSub1(keynode.freq, freqnode);
            }
            unlinkKey(keynode, freqnode);
            linkKey(keynode, prevFreqNode);
        }
        // Inserts a new KeyNode<key, value> with freq 1.
        public void insertKeyNode(String key) {
            KeyNode keynode = new KeyNode(key);
            keyMap.put(key, keynode);
            if (!freqMap.containsKey(1)) {
                FreqNode freqnode = new FreqNode(1, null, head);
                freqnode.next = head;
                if (head != null)   head.prev = freqnode;
                if (tail == null)   tail = freqnode;
                head = freqnode;
                freqMap.put(1, freqnode);
            }
            linkKey(keynode, freqMap.get(1));
        }
        // insert a new freqnode with new freq after given "freqnode"
        public FreqNode insertFreqNodePlus1(int freq, FreqNode freqnode) {
            FreqNode newfnode = new FreqNode(freq, freqnode, freqnode.next);
            freqMap.put(freq, newfnode);
            if (freqnode.next != null)  freqnode.next.prev = newfnode;
            if (freqnode == tail)   tail = newfnode;
            freqnode.next = newfnode;
            return newfnode;
        }
        // insert a new freqnode with new freq before given "freqnode"
        public FreqNode insertFreqNodeSub1(int freq, FreqNode freqnode) {
            FreqNode newfnode = new FreqNode(freq, freqnode.prev, freqnode);
            freqMap.put(freq, newfnode);
            if (freqnode.prev != null)  freqnode.prev.next = newfnode;
            if (head == freqnode)   head = newfnode;
            freqnode.prev = newfnode;
            return newfnode;
        }
        // Unlink keyNode from freqNode
        public void unlinkKey(KeyNode keynode, FreqNode freqnode) {
            freqnode.set.remove(keynode);
            if (freqnode.set == null || freqnode.set.size() == 0)     deleteFreqNode(freqnode);
        }
        // Link keyNode to freqNode
        public void linkKey(KeyNode keynode, FreqNode freqnode) {
            freqnode.set.add(keynode);
        }
        // delete freqnode if there is no appending keynode under this freq
        public void deleteFreqNode(FreqNode freqnode) {
            FreqNode prev = freqnode.prev, next = freqnode.next;
            if (prev != null)   prev.next = next;
            if (next != null)   next.prev = prev;
            if (head == freqnode)   head = next;
            if (tail == freqnode)   tail = prev;
            freqMap.remove(freqnode.freq);
        }
    }
    

    Please correct me if anything is wrong. Thanks.


  • 0
    J

    auto lastrow = row;
    --lastrow;
    if (lastrow == matrix.end() || lastrow->val != row->val + 1) {

    It seems like we're trying to decrement the begin iterator and checking if it equals to matrix.end()? Since row is the first row and last row is the one before it, how can lastrow be equal to matrix.end()?


  • 0
    I

    @jyang101 Note that lastrow is a iterator of doubly-linked list. If lastrow == list.begin(), then --lastrow == list.end()


  • 0
    V

    @liqingfd
    wondering if it is allowed to use such many data structures to achieve O(1) time inc/dec complexity...Although this is only method I can think of too for now. Btw, what is space complexity here? already optimal? Can we save some more space here?


  • 0
    T
    This post is deleted!

  • 2
    S

    Here is my java solution, similar idea but simpler data structure

    public class AllOne {
        HashMap<String, Integer> map;
        HashMap<Integer, HashSet<String>> vals;
        String maxKey;
        String minKey;
        int max;
        int min;
    
        /** Initialize your data structure here. */
        public AllOne() {
            map = new HashMap<>();
            vals = new HashMap<>();
            maxKey = "";
            minKey = "";
            max = 0;
            min = 0;
        }
        
        /** Inserts a new key <Key> with value 1. Or increments an existing key by 1. */
        public void inc(String key) {
            map.put(key, map.getOrDefault(key, 0) + 1);
            int val = map.get(key);
            if(vals.get(val) == null) vals.put(val, new HashSet<>());
            vals.get(val).add(key);
            if(vals.containsKey(val - 1)){
                vals.get(val - 1).remove(key);
                if(vals.get(val - 1).size() == 0) vals.remove(val - 1);
            }
            if(map.get(key) > max){
                max = map.get(key);
                maxKey = key;
            }
            if(map.get(key) - 1 == min){
                if(vals.get(min) == null || vals.get(min).size() == 0){
                    min++;
                    minKey = key;
                }
                else minKey = vals.get(min).iterator().next();
            }
            if(map.get(key) == 1){
                min = 1;
                minKey = key;
            }
        }
        
        /** Decrements an existing key by 1. If Key's value is 1, remove it from the data structure. */
        public void dec(String key) {
            if(map.containsKey(key)){
                if(map.get(key) == 1){
                    map.remove(key);
                    vals.get(1).remove(key);
                    if(vals.get(1).size() > 0){
                        min = 1;
                        minKey = vals.get(1).iterator().next();
                        if(max == 1) maxKey = minKey;
                    }else{
                        vals.remove(1);
                        if(map.size() > 0){
                            int tempMin = Integer.MAX_VALUE;
                            for(Map.Entry<Integer, HashSet<String>> e : vals.entrySet()){
                                if(e.getValue().size() > 0)
                                    tempMin = Math.min(tempMin, e.getKey());
                            }
                            min = tempMin;
                            minKey = vals.get(min).iterator().next();
                        }else{
                            min = 0;
                            max = 0;
                        }
                    }
                }else{
                    map.put(key, map.get(key) - 1);
                    int val = map.get(key);
                    vals.get(val + 1).remove(key);
                    if(vals.get(val + 1).size() == 0) vals.remove(val + 1);
                    if(vals.get(val) == null) vals.put(val, new HashSet<>());
                    vals.get(val).add(key);
                    if(val + 1 == max){
                        if(vals.get(max) == null || vals.get(max).size() == 0) max--;
                        else maxKey = vals.get(max).iterator().next();
                    }
                    if(val + 1 == min){
                        min--;
                        minKey = key;
                    }
                }
            }
        }
        
        /** Returns one of the keys with maximal value. */
        public String getMaxKey() {
            if(map.size() == 0) return "";
            return maxKey;
        }
        
        /** Returns one of the keys with Minimal value. */
        public String getMinKey() {
            if(map.size() == 0) return "";
            return minKey;
        }
    }
    

  • 1
    R

    @ShawYoungTang the for loop to find next min is not considered to be O(1). That is the reason why we need a list instead of a map.


Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.