Framework Thinking - Neetcode 150 - Merge K Sorted Linked Lists
Here is solutions for .
1. Understand the problem
-
You are given an array of k linked lists, each list is sorted in ascending order.
-
Your task: merge all the linked lists into one sorted linked list and return its head.
2. Clarify constraints, asks 4 - 5 questions including edge cases.
-
What is the maximum number of lists k and maximum nodes per list?
-
Can some lists be empty (None)?
-
Can all lists be empty?
-
Are list values bounded (for integer overflow concerns)?
-
Do we modify existing lists or create a new one?
3. Explore examples.
- Example 1:
lists = [
1 -> 4 -> 5,
1 -> 3 -> 4,
2 -> 6
]
1 -> 1 -> 2 -> 3 -> 4 -> 4 -> 5 -> 6
- Example 2:
lists = []
Output: []
- Example 3:
lists = [ [], [1] ]
Output: 1
4. Brainstorm 2 - 3 solutions
4.1. Naive Solution: Merge One by One - Time O(k * n), Space: O(1) extra space
-
Merge list1 & list2 => result
-
Merge result with list3 => result
-
Time: O(k * n)
-
Space: O(1) (no extra data structure)
4.2. Divide & Conquer (Optimized): Merge sorted parallel - Time O(N * logK), Space O(1)
-
Merge lists in pairs (like merge sort)
-
Time: O(n log k)
-
Space: O(1) (no extra data structure)
4.3. Min-Heap (Priority Queue) - Time O(N * logK), Space O(K)
-
Push the head of each list into a min-heap (value, index, node)
-
Pop smallest, append to result, push its next. N items, but each pop is logK.
-
Time: O(n log k)
-
Space: O(k)
5. Implement solutions.
5.1. Very naive solutions - Build array first sort later - Time O(NlogN), Space O(N)
# Definition for singly-linked list.
# class ListNode:
# def __init__(self, val=0, next=None):
# self.val = val
# self.next = next
class Solution:
def mergeKLists(self, lists: List[Optional[ListNode]]) -> Optional[ListNode]:
nodes = []
for lst in lists:
while lst:
nodes.append(lst.val)
lst = lst.next
nodes.sort()
res = ListNode(0)
cur = res
for node in nodes:
cur.next = ListNode(node)
cur = cur.next
return res.next
-
Time: O(NlogN)
-
Space: O(N)
5.2. Immediate naive solutions - Build array to find min from the head k lists - Time O(N * K), Space O(1)
# Definition for singly-linked list.
# class ListNode:
# def __init__(self, val=0, next=None):
# self.val = val
# self.next = next
class Solution:
def mergeKLists(self, lists: List[Optional[ListNode]]) -> Optional[ListNode]:
res = ListNode(0)
cur = res
while True:
minNode = -1
for i in range(len(lists)):
if not lists[i]:
continue
if minNode == -1 or lists[minNode].val > lists[i].val:
minNode = i
if minNode == -1:
break
cur.next = lists[minNode]
lists[minNode] = lists[minNode].next
cur = cur.next
return res.next
-
Time: O(N * K)
-
Space: O(1)
5.3. Naive Solution - Merge Lists One By One - Time O(N * K), Space O(1)
# Definition for singly-linked list.
# class ListNode:
# def __init__(self, val=0, next=None):
# self.val = val
# self.next = next
class Solution:
def mergeKLists(self, lists: List[Optional[ListNode]]) -> Optional[ListNode]:
if len(lists) == 0:
return None
for i in range(1, len(lists)):
lists[i] = self.mergeList(lists[i - 1], lists[i])
return lists[-1]
def mergeList(self, l1, l2):
dummy = ListNode()
tail = dummy
while l1 and l2:
if l1.val < l2.val:
tail.next = l1
l1 = l1.next
else:
tail.next = l2
l2 = l2.next
tail = tail.next
if l1:
tail.next = l1
if l2:
tail.next = l2
return dummy.next
-
Merge 2 lists: O(K + K) = O(2 * K).
-
N pairs => Time: O(N * 2K) ~ O(N * K)
-
Space: O(1) due to merge to final result
5.4. Min Heap - Time O(N * logK), Space O(K) for Heap
- Idea:
- Use a heap with size k, first init the k smallest value from head.
- After that, we pop the smallest value from heap, and add the next smallest value to heap.
- We always keep the smallest value.
from typing import List, Optional
import heapq
# Definition for singly-linked list.
class ListNode:
def __init__(self, val=0, next=None):
self.val = val
self.next = next
class Solution:
def mergeKLists(self, lists: List[Optional[ListNode]]) -> Optional[ListNode]:
# Priority Queue: (value, index, node)
min_heap = []
# Initialize heap with the head of each list
for i, node in enumerate(lists):
if node:
heapq.heappush(min_heap, (node.val, i, node))
dummy = ListNode()
current = dummy
while min_heap:
val, i, node = heapq.heappop(min_heap)
current.next = node
current = current.next
if node.next:
heapq.heappush(min_heap, (node.next.val, i, node.next))
return dummy.next
-
Time: O(N * logK)
-
Space: O(K)
5.5. Divide And Conquer (Recursion) - Idea same with parallel merge sort - Time O(N * logK), Space O(logK)
# Definition for singly-linked list.
# class ListNode:
# def __init__(self, val=0, next=None):
# self.val = val
# self.next = next
class Solution:
def mergeKLists(self, lists):
if not lists or len(lists) == 0:
return None
return self.divide(lists, 0, len(lists) - 1)
def divide(self, lists, l, r):
if l > r:
return None
if l == r:
return lists[l]
mid = l + (r - l) // 2
left = self.divide(lists, l, mid)
right = self.divide(lists, mid + 1, r)
return self.conquer(left, right)
def conquer(self, l1, l2):
dummy = ListNode(0)
curr = dummy
while l1 and l2:
if l1.val <= l2.val:
curr.next = l1
l1 = l1.next
else:
curr.next = l2
l2 = l2.next
curr = curr.next
if l1:
curr.next = l1
else:
curr.next = l2
return dummy.next
-
Time: O(N * logK)
-
Space: O(logK)
5.6. Divide And Conquer (Iteration) - Instead of waiting it done synchronously, add it to recursion stack to sort - Time O(N * logK), Space O(logK)
Idea:
- Instead of waiting it done synchronously, add it to recursion stack to sort
# Definition for singly-linked list.
# class ListNode:
# def __init__(self, val=0, next=None):
# self.val = val
# self.next = next
class Solution:
def mergeKLists(self, lists: List[Optional[ListNode]]) -> Optional[ListNode]:
if not lists or len(lists) == 0:
return None
while len(lists) > 1:
mergedLists = []
for i in range(0, len(lists), 2):
l1 = lists[i]
l2 = lists[i + 1] if (i + 1) < len(lists) else None
mergedLists.append(self.mergeList(l1, l2))
lists = mergedLists
return lists[0]
def mergeList(self, l1, l2):
dummy = ListNode()
tail = dummy
while l1 and l2:
if l1.val < l2.val:
tail.next = l1
l1 = l1.next
else:
tail.next = l2
l2 = l2.next
tail = tail.next
if l1:
tail.next = l1
if l2:
tail.next = l2
return dummy.next
-
Time: O(N * logK)
-
Space: O(logK)
6. Dry run testcases.
lists = [
1 -> 4 -> 5,
1 -> 3 -> 4,
2 -> 6
]
| Step | Heap | Result Linked List |
|---|---|---|
| Init | [1,1,2] | - |
| Pop 1 | [1,2,4] | 1 |
| Pop 1 | [2,3,4,4] | 1 → 1 |
| Pop 2 | [3,4,4,6] | 1 → 1 → 2 |
| Pop 3 | [4,4,5,6] | 1 → 1 → 2 → 3 |
| Pop 4 | [4,5,6] | 1 → 1 → 2 → 3 → 4 |
| Pop 4 | [5,6] | 1 → 1 → 2 → 3 → 4 → 4 |
| Pop 5 | [6] | 1 → 1 → 2 → 3 → 4 → 4 → 5 |
| Pop 6 | [] | 1 → 1 → 2 → 3 → 4 → 4 → 5 → 6 |