Clean C++ O(1) solution with explanation


  • 0
    L

    To implement LRU strategy we can represent our timeline as a doubly linked list with the most recently used elements at the beginning of the list. Recently added/updated elements are inserted in the beginning of the list moving less recently used elements towards the end. We will have our LRU cache entry as the last element of the list, which makes eviction O(1)

    To implement LFU strategy we can store our cache entries in the doubly linked list, sorted by frequency with the most frequently used entries at the beginning of the list. Note, that when increasing frequency of a particular cache entry it can move only up to one position towards the beginning, Such kind of swap is O(1) as well. The last element of the list will be the one for eviction, which is also O(1).

    To implement LFU strategy as described in the task we can store cache entries by frequency grouped in buckets. Buckets themselves are linked lists sorted by recentness of use. Updating frequency of the cache entry requires moving from one bucket to another applying the rules described for LRU strategy. This way the element to evict will be the last element of the last bucket. This makes eviction O(1).

    To achieve get operation in O(1) we can use hash map storing references to the actual cache entries.

    #include <unordered_map>
    #include <set>
    
    class LFUCache {
    private:
        struct Record;
        struct Bucket;
        struct RecordInfo;
    
        struct Record
        {
            int key;
            int value;
    
            Record(int k, int v) : key(k), value(v)
            { }
        };
    
        class Bucket
        {
        private:
            int _frequency;
            std::list<Record> _records;
    
        public:
            Bucket(int frequency) : _frequency(frequency), _records()
            { }
    
            std::list<Record>::iterator Add(const Record& record)
            {
                _records.push_front(record);
                return _records.begin();
            }
    
            void Remove(std::list<Record>::iterator record)
            {
                _records.erase(record);
            }
    
            std::list<Record>::iterator GetLruRecord()
            {
                return std::prev(_records.end());
            }
    
            int GetFrequency() const
            {
                return _frequency;
            }
    
            int Size() const
            {
                return _records.size();
            }
        };
    
        class RecordInfo
        {
        private:
            std::list<Record>::iterator _recordIterator;
            std::list<Bucket>::iterator _bucketIterator;
    
        public:
            RecordInfo() : _recordIterator(), _bucketIterator()
            { }
    
            RecordInfo(std::list<Record>::iterator recordIterator, std::list<Bucket>::iterator bucketIterator)
                : _recordIterator(recordIterator), _bucketIterator(bucketIterator)
            { }
    
            std::list<Record>::iterator GetRecord() const
            {
                return _recordIterator;
            }
    
            std::list<Bucket>::iterator GetBucket() const
            {
                return _bucketIterator;
            }
    
            int GetValue() const
            {
                return _recordIterator->value;
            }
    
            void SetValue(int value)
            {
                _recordIterator->value = value;
            }
        };
    
        std::list<Bucket>::iterator FindBucket(std::list<Bucket>::iterator currentBucket, int targetFrequency)
        {
            std::list<Bucket>::iterator targetBucket;
            if (currentBucket == _buckets.begin())
            {
                targetBucket = _buckets.insert(_buckets.begin(), Bucket(targetFrequency));
            }
            else
            {
                targetBucket = std::prev(currentBucket);
    
                if (targetBucket->GetFrequency() > targetFrequency)
                    targetBucket = _buckets.insert(std::next(targetBucket), Bucket(targetFrequency));
            }
    
            return targetBucket;
        }
    
        RecordInfo UpdateFrequency(const RecordInfo& recordInfo)
        {
            auto record = recordInfo.GetRecord();
            auto bucket = recordInfo.GetBucket();
    
            auto targetBucket = FindBucket(bucket, bucket->GetFrequency() + 1);
            auto newRecord = targetBucket->Add(*record);
    
            Remove(record, bucket);
    
            return RecordInfo(newRecord, targetBucket);
        }
    
        void EnsureCapacity()
        {
            if (_records.size() == _capacity)
            {
                auto leastFrequencyBucket = std::prev(_buckets.end());
                auto recordToEvict = leastFrequencyBucket->GetLruRecord();
    
                _records.erase(recordToEvict->key);
                Remove(recordToEvict, leastFrequencyBucket);
            }
        }
    
        void Remove(std::list<Record>::iterator record, std::list<Bucket>::iterator bucket)
        {
            bucket->Remove(record);
    
            if (bucket->Size() == 0)
                _buckets.erase(bucket);
        }
    
        std::list<Bucket>::iterator GetZeroFrequencyBucket()
        {
            if (_buckets.size() == 0 || _buckets.back().GetFrequency() != 0)
                _buckets.push_back(Bucket(0));
    
            return std::prev(_buckets.end());
        }
    
        int _capacity;
        std::list<Bucket> _buckets;
        std::unordered_map<int, RecordInfo> _records;
    
    public:
        LFUCache(int capacity) : _capacity(capacity), _records(), _buckets()
        { }
    
        int get(int key)
        {
            auto kv = _records.find(key);
            if (kv == _records.end())
                return -1;
    
            auto recordInfo = UpdateFrequency(kv->second);
            _records[key] = recordInfo;
    
            return recordInfo.GetValue();
        }
    
        void set(int key, int value)
        {
            auto kv = _records.find(key);
            if (kv != _records.end())
            {
                RecordInfo& recordInfo = kv->second;
    
                recordInfo.SetValue(value);
                _records[key] = UpdateFrequency(recordInfo);
            }
            else if (_capacity > 0)
            {
                EnsureCapacity();
    
                auto bucket = GetZeroFrequencyBucket();
                auto record = bucket->Add(Record(key, value));
    
                _records[key] = RecordInfo(record, bucket);
            }
        }
    };
    

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.