Framework Thinking - Neetcode 150 - Top K Frequent Elements
1. Understand the problem
-
Given an integer array nums and an integer k, return the k most frequent elements.
-
You must do it better than O(n log n) time if possible.
2. Ask clarify questions
-
Are there negative numbers or just positives?
-
Can there be ties in frequency? (If yes, any order is fine?)
-
What’s the expected output — list of numbers or list of (num, frequency)?
-
Do we need stable order in output?
-
What’s the expected time complexity — can we use a heap or bucket sort?
3. Walk through examples
-
Example 1: nums = [1,1,1,2,2,3], k = 2 → Output: [1,2]
-
Example 2: nums = [1], k = 1 → Output: [1]
-
Example 3: nums = [4,1,-1,2,-1,2,3], k = 2 → Output: [-1,2]
4. Brainstorm 2–3 solutions
- Naive (O(n log n)):
- Count frequencies, sort by frequency, take top k.
- Optimized 1 (O(n log k)):
- Use a min-heap of size k → push (freq, num); if size > k, pop smallest.
- Optimized 2 (O(n)):
-
Use bucket sort by frequency:
-
Count each element’s frequency.
-
Create buckets where index = frequency, store numbers in that bucket.
-
Traverse buckets from high → low to collect top k.
5. Implement solutions
5.1. Sorting
class Solution:
def topKFrequent(self, nums: List[int], k: int) -> List[int]:
count = {}
for num in nums:
count[num] = 1 + count.get(num, 0)
arr = []
for num, cnt in count.items():
arr.append([cnt, num])
arr.sort()
res = []
while len(res) < k:
res.append(arr.pop()[1])
return res
-
Time complexity: O(nlogn).
-
Space complexity: O(n).
5.2. Min-Heap
class Solution:
def topKFrequent(self, nums: List[int], k: int) -> List[int]:
count = {}
for num in nums:
count[num] = 1 + count.get(num, 0)
heap = []
for num in count.keys():
heapq.heappush(heap, (count[num], num))
if len(heap) > k:
heapq.heappop(heap)
res = []
for i in range(k):
res.append(heapq.heappop(heap)[1])
return res
-
Time complexity: O(nlogk).
-
Space complexity: O(n + k).
5.3. Bucket Sort
class Solution:
def topKFrequent(self, nums: List[int], k: int) -> List[int]:
count = {}
freq = [[] for i in range(len(nums) + 1)]
for num in nums:
count[num] = 1 + count.get(num, 0)
for num, cnt in count.items():
freq[cnt].append(num)
res = []
for i in range(len(freq) - 1, 0, -1):
for num in freq[i]:
res.append(num)
if len(res) == k:
return res
-
Time complexity: O(n).
-
Space complexity: O(n).