Java O(1) AC Solution Using Two HashMaps And Doubly Linked List, With Explanation


  • 0
    J

    We first create a ListNode class which encapsulates the key, value, previous and next pointers, and the frequency. Then we create a hash map (frequencies) to map each frequency count to a doubly linked list of nodes. We create a second hash map (store) to map keys to nodes. We also keep track of the current smallest frequency (smallestFrequency).

    When we want to get a value, we check if the key exists in the store. If it exists, then we "promote" the node by bumping up its frequency. The promote method works like this:
    (1) Remove the node from the existing frequency list of nodes. If the list becomes empty, then we remove this frequency from the frequencies map. We'll also update the smallestFrequency to be the new frequency of the node (which is current frequency + 1);
    (2) We then add the node to the list of nodes corresponding to its new frequency. If the new frequency list does not exist, create one. Otherwise append to the end of the new frequency list. By appending we ensure that the most recently accessed node is at the end of the list.

    When we want to put a new node, we first check existence in the store.
    (1) If it exists, then we promote it, and return;
    (2) If it does not exist, and size reaches the capacity, we then remove the head node of the smallest frequency list. The head node is the least recently used node. Also remove the key and node from store, and reduce the size;
    (3) Then add the node to the store and the frequency list. Note that if the current smallestFrequency is 0 (which means we removed the last node of the prior smallest frequency, or the cache has never been put a value into it), or is greater than 1, we have to update it to 1. At the end, increase the size by 1.

    Both get() and put() are O(1) b/c HashMap lookup, addition, removal, doubly linked list addition, removal, and updating the smallestFrequency are all O(1) operations.

    Please let me know if you have any questions. Thanks guys!

    public class LFUCache {
    
        private Map<Integer, ListNode> store;
        private Map<Integer, DList> frequencies;
        private int capacity;
        private int size;
        private int smallestFrequency;
    
        public LFUCache(int capacity) {
            store = new HashMap<>();
            frequencies = new HashMap<>();
            this.capacity = capacity;
            size = 0;
            smallestFrequency = 0;
        }
        
        public int get(int key) {
            ListNode node = store.get(key);
            if (node == null) {
                return -1;
            }
            promoteNode(node);
            return node.value;
        }
        
        public void put(int key, int value) {
            if (capacity == 0) {
                return;
            }
            ListNode node = store.get(key);
            if (node != null) {
                promoteNode(node);
                node.value = value;
                return;
            } else if (size == capacity) {
                DList leastFrequentList = frequencies.get(smallestFrequency);
                ListNode leastFrequentNode = leastFrequentList.head;
                remove(leastFrequentNode);
                store.remove(leastFrequentNode.key);
                --size;
            }
            node = new ListNode(key, value);
            if (smallestFrequency  == 0 || smallestFrequency > node.frequency) {
                smallestFrequency = node.frequency;
            }
            add(node);
            store.put(key, node);
            ++size;
        }
        
        private void promoteNode(ListNode node) {
            remove(node);
            ++node.frequency;
            if (smallestFrequency == 0) {
                smallestFrequency = node.frequency;
            }
            add(node);
        }
        
        private void remove(ListNode node) {
            int currFrequency = node.frequency;
            DList currList = frequencies.get(currFrequency);
            if (currList.head == currList.tail) {
                frequencies.remove(currFrequency);
                if (currFrequency == smallestFrequency) {
                    smallestFrequency = 0;
                }
                return;
            }
            if (node != currList.head) {
                node.prev.next = node.next;
            } else {
                ListNode next = node.next;
                node.next.prev = null;
                node.next = null;
                currList.head = next;
                return;
            }
            if (node != currList.tail) {
                node.next.prev = node.prev;
            } else {
                ListNode prev = node.prev;
                node.prev.next = null;
                node.prev = null;
                currList.tail = prev;
            }
        }
        
        private void add(ListNode node) {
            DList currList = frequencies.computeIfAbsent(node.frequency, key -> new DList());
            if (currList.head == null) {
                currList.head = node;
                currList.tail = node;
                return;
            }
            node.prev = currList.tail;
            currList.tail.next = node;
            currList.tail = node;
        }
        
        private static class ListNode {
            int key;
            int value;
            int frequency;
            ListNode next;
            ListNode prev;
            
            public ListNode(int key, int value) {
                this.key = key;
                this.value = value;
                frequency = 1;
            }
        }
        
        private static class DList {
            ListNode head;
            ListNode tail;
            
            public DList() {}
        }
    }
    

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.