We need to divide the array nums
into k
subsets such that the sum of each subset is same.
$$
k \times sum\ of\ each\ subset = total\ sum\
sum\ of\ each\ subset = total\ sum\ \div k\
target = total\ sum\ \div k
$$
Now our task is to find all the k
subsets in nums whose sum is target
.
My idea is to structure this as multi-level recursion. We first try to find the $$k^{th}$$ subset, then $$(k-1)^{th}$$, then $$(k-2)^{th} …$$ until there is only one subset left. The last subset will naturally sum to target
. You can only go to $$(k-1)^{th}$$ level when you are able to successfully find $$k^{th}$$ level subset,
For example,
1
2
3
4
5
6
7
| ..
..
find the 5th subset
find the 4th subset
find the 3rd subset
find the 2nd subset
return true
|
Seeing the recursive structure.
1
2
3
4
5
6
7
8
9
| nums = [4,3,2,3,5,2,1] k=4
target = 20 / 4 = 5
suppose you find the 4th subset as: {3,2} from [4,3,2,3,5,2,1]
Now your task reduces to finding 3 subsets in [4, ,3,4,2,1]
that sum to target=5
recursive call:
nums=[4, ,3,4,2,1] k=3
|
When finding the kth subset. You can simply use pick/not pick backtracking approach. It is important to understant the distinction between same level recursion (at any given k) which operates by deciding weather to pick/drop nums@i.
And the k levels recursion which succeeds/fails by checking if we can form the $$k^{th}$$ subset.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
| class Solution:
def __init__(self):
self.nums, self.target, self.used = None, None, None
def canPartitionKSubsets(self, nums: List[int], k: int) -> bool:
self.nums = nums
total = sum(self.nums)
self.target = total // k
self.used = [False] * len(nums)
if total % k != 0: return False
self.nums.sort(reverse=True)
return self.bt_search(0, k, 0)
def bt_search(self, i: int, k: int, cur_sum: int) -> bool:
if k == 1: # last subset is sure to sum upto target
return True
if i >= len(self.nums) or cur_sum > self.target:
return False
if cur_sum == self.target: # current subset found,
return self.bt_search(0, k-1, 0) # explore the next level of k
# try picking nums@i in current subset
if not self.used[i] and cur_sum + self.nums[i] <= self.target:
self.used[i] = True
# explore in same level
if self.bt_search(i+1, k, cur_sum + self.nums[i]):
return True
self.used[i] = False
# on failure, skip duplicates of nums@i
while i+1 < len(self.nums) and self.nums[i] == self.nums[i+1]:
i += 1
# do not pick nums@i -> explore in same level
return self.bt_search(i+1, k, cur_sum)
|