Programming Interview Questions 10: Kth Largest Element in Array

Given an array of integers find the kth element in the sorted order (not the kth distinct element). So, if the array is [3, 1, 2, 1, 4] and k is 3 then the result is 2, because it’s the 3rd element in sorted order (but the 3rd distinct element is 3).

The first approach that comes to mind is sorting the array and returning the kth element. The complexity is NlogN where N is size of the array and it’s clearly not optimal. Because this solution does more work than needed, it finds the absolute ordering of all elements but we’re only looking for the kth largest element. We would ideally prefer a linear time solution.

We can use the Selection Algorithm as used in quicksort. It works as follows, select a pivot and partition the array to left and right subarrays such that, the elements that are smaller than the pivot value end up in the left group, and the ones that are and larger than or equal to the pivot are in the right group. Now, only the pivot is in its sorted position. The remaining elements are not sorted but their relative position to the pivot, whether they are on the left or right, is as in sorted order. Let’s say after partitioning the array the position of the pivot in the array is m. If m is equal to k, then the pivot is exactly the kth element that we’re looking for, so we return the pivot value. If m is less than k, then the kth element is in the right subarray. Else if m is greater than k, then the kth element is in the left subarray. So we can recursively call the same procedure and find the kth element. The code will make everything clear:

def partition1(arr, left, right, pivotIndex):
    arr[right], arr[pivotIndex]=arr[pivotIndex], arr[right]
    pivot=arr[right]
    swapIndex=left
    for i in range(left, right):
        if arr[i]<pivot:
            arr[i], arr[swapIndex] = arr[swapIndex], arr[i]
            swapIndex+=1
    arr[right], arr[swapIndex]=arr[swapIndex], arr[right]
    return swapIndex
 
def kthLargest1(arr, left, right, k):
    if not 1<=k<=len(arr):
        return
    if left==right:
        return arr[left]
 
    while True:
        pivotIndex=random.randint(left, right)
        pivotIndex=partition1(arr, left, right, pivotIndex)
        rank=pivotIndex-left+1
        if rank==k:
            return arr[pivotIndex]
        elif k<rank:
            return kthLargest1(arr, left, pivotIndex-1, k)
        else:
            return kthLargest1(arr, pivotIndex+1, right, k-rank)

The partition function divides the array to two subarrays as described above. The pivot used in partition is selected uniformly at random to potentially avoid worst case performance. We search for the kth element within the indexes [left, right]. Left is initially 0 and right is length of array – 1. If the rank of the pivot is equal to k after partitioning, then the pivot itself is the kth element so we return. Otherwise, we recurse by adjusting the bounds. If the rank of the pivot is greater than k, then we should continue our search in the left subarray, so the new array bounds become [left, pivotIndex-1]. Else if the rank of the pivot is less than k, then we continue to search in the right subarray, so the bounds become [pivotIndex+1, right]. The value of k also gets adjusted if we’re searching the right subarray since we change the left index, so the new value of k becomes k-rank, meaning we count out rank number of elements that we’re eliminating at the left portion of the array.

The average time complexity of this approach is O(N). But worst case complexity is unfortunately O(N^2), which occurs if we make poor pivot selections that doesn’t partition the array well, and leaves most of the elements at one side and very few at the other. In the extreme case the partition results in 0 elements at one side and all others at the other side, when smallest or largest element is chosen as pivot. As a result we can only eliminate 1 element at each step, leading to quadratic time complexity. Conversely the best case performance occurs when the pivot divides the array into to equal sized partitions, which results in a linear complexity. Here are the recurrence relations for best and worst case, best case is linear worst case is quadratic. Average case also turns out to be linear (proof omitted).

Best case:

T(N) = T(\dfrac{N}{2}) + O(N) \rightarrow T(N) = \Omega(N)

Worst case:

T(N) = T(N-1) + O(N) \rightarrow T(N) = O(N^2)

There’s a very elegant algorithm that has worst case linear time performance, which is called Median of Medians Algorithm. It’s discovered by 5 great computer scientists, Manuel Blum (Blum speedup theorem), Robert Floyd (Floyd-Warshall shortest path algorithm), Vaughan Pratt (Pratt primality certificate), Ron Rivest (RSA cryptography algorithm), and Robert Tarjan (graph algorithms and data structures). Median of medians is a modified version of selection algorithm where we improve pivot selection to guarantee reasonable good worst case split. The algorithm divides the array to groups of size 5 (the last group can be of any size < 5). Then calculates the median of each group by sorting and selecting the middle element (sorting complexity of 5 elements is negligible). Finds the median of these medians by recursively calling itself, and selects the median of medians as the pivot for partition. Then it continues similar to the previous selection algorithm by recursively calling the left or right subarray depending on the rank of the pivot after partitioning. The partition function is slightly different though, partition1 function above takes the index of the pivot as input, partition2 here takes the value of the pivot as input, which is only a slight modification. Here is the code:

def partition2(arr, left, right, pivot):
    swapIndex=left
    for i in range(left, right+1):
        if arr[i]<pivot:
            arr[i], arr[swapIndex] = arr[swapIndex], arr[i]
            swapIndex+=1
    return swapIndex-1
 
def kthLargest2(arr, left, right, k):
    length=right-left+1
    if not 1<=k<=length:
        return
    if length<=5:
        return sorted(arr[left:right+1])[k-1]
 
    numMedians=length/5
    medians=[kthLargest2(arr, left+5*i, left+5*(i+1)-1, 3) for i in range(numMedians)]
    pivot=kthLargest2(medians, 0, len(medians)-1, len(medians)/2+1)
    pivotIndex=partition2(arr, left, right, pivot)
    rank=pivotIndex-left+1
    if k<=rank:
        return kthLargest2(arr, left, pivotIndex, k)
    else:
        return kthLargest2(arr, pivotIndex+1, right, k-rank)

The worst case complexity of this approach is O(N) because the median of medians chosen as pivot is either greater than or less than at least 30% of the elements. So even in the worst case we can eliminate constant proportion of the elements at each iteration, which is what we wanted but couldn’t achieve with the previous approach. We can also write the recurrence relation for worst case and verify that it’s linear. N/5 term comes from selecting the median of medians as pivot, and 7N/10 is when the pivot produces the worst split.

T(N) \le T(\dfrac{N}{5}) + T(\dfrac{7N}{10}) + O(N) \le c\dfrac{9N}{10} + O(N)

\le cN - (c\dfrac{N}{10} - O(N)) \le cN \in O(N)

This is a great algorithm but admittedly kind of hard to come up with during an interview if you haven’t seen it before. At least it took 5 great minds to figure it out. But it’s of course good to know and definitely worth the extra effort.

VN:F [1.9.22_1171]
Rating: 9.0/10 (36 votes cast)
Programming Interview Questions 10: Kth Largest Element in Array, 9.0 out of 10 based on 36 ratings
This entry was posted in Programming Interview. Bookmark the permalink.
  • George

    What’s up with latex script?

    • Arden

      Thanks for the notice. They were broken for some reason, now it’s fixed.

  • hayro

    Hi Arden,

    Did you consider the duplicate values case? It seems when you check the kth element, the duplicate values cause problems because 7th element may not be the 7th largest element if it contains duplicates before it. For example, [44,23,19,19,13,19, 10(pivot) , 4,5,6,8], although 10 is the 7th element it is the 5th largest element.

    • Arden

      Yes you’re right. The problem definition should be kth element in sorted order, not kth largest element. Updated the problem definition. Thanks a lot for the comment..

  • Alex

    Hey Arden,

    Thank you for such a detailed explanation! I do have a question – doesn’t the first example, kthLargest1, find the kth smallest element in the array?

    • Arden

      Thanks, glad that you liked it. And I think we mean the same thing by saying “kth smallest” and “kth largest”. It’s basically the kth element in the sorted order.

  • Stephen

    What about using a Min Heap? Specifically

    Heapify first k elements in array
    Then, for each element after k in array{
    if element > MinHeap.root{
    delete MinHeap.root
    add element to MinHeap.
    }
    }

    O(N) time, where N is the size of input
    O(k) space, where k is the size of the heap

    Although, not as good as your medians of medians approach, it seems more practical to give this as an answer in an interview (unless I’ve made a mistake some where)

    great approach , and great blog.

    • Stephen

      I forgot to say, MinHeap.root would be the k-th element at the end of the algorithm.

      • Stephen

        Oops, and it should be

        if element >= MinHeap.root{

        • Idan

          This is basically the algorithm that I was thinking about as well.

          Please mind that AFAIK, you need a max heap (as you need the maximum number, and not the minimum).
          In addition, a heap insertion takes O(logK), making the whole algorithm run at O(nlogk).

          Please correct me if I’m wrong.

          • Stephen

            I believe

          • Darius Liddell

            stephen you are right, it is definitely min heap, not max heap

          • Stephen

            You are correct in saying it is O(nlogk). Which makes this algorithm fall short of Arden’s median of median’s solution.

            However, I still think it requires a min heap. It is quite difficult to explain why. But imagine this: You are iterating through the array, adding an element only if it is larger than the smallest element in the heap. At the end of traversing your array, the k biggest elements will be in your min heap. Then you want the root of this heap, which is the lowest element amongst the k elements in your heap, making that the kth largest element.

          • Idan

            Now I see.
            You thought that the Kth element in the array is the Kth largest element, while I thought the Kth element is the Kth smallest element.

            Looking at the example above 2 is both the 3rd smallest number and the 3rd largest number.
            There is a small hint in the example after that that the author was referring to the Kth smallest element (with the distinct), but it would be nice to modify the example, just to keep this clear.

          • Stephen

            Great observation.

            Although I can’t see the hint that the author was referring to the kth smallest element.

        • Idan

          The question says “the 3rd distinct element is 3”
          If it was the 3rd largest element it would be 2 again (4,3,2).

    • Holden

      Time complexity will be O(nlogk), and not O(n)

  • AW

    I don’t understand your quadratic worst case

  • Ben

    Assuming I forget how to implement a sort during the interview and the interviewer won’t let me use the built-in sort method then this can be solved using a dictionary.

    While this might not be the most ideal solution, it still allows me to solve the problem with a different approach.

    def kth_map(elements, k):

    “””

    add each element to a dictionary
    k = element, v = occurence
    counter = 0
    while dictionary not empty:
    get minimum key from dictionary
    add to list number of occurence times
    remove current key, value from dictionary
    if counter == k, return key
    otherwise repeat
    “””
    counter = 0
    d = collections.defaultdict(int)
    keys = []
    for el in elements:
    d[el] += 1
    keys.append(el)

    while d:
    tmp = min(keys)
    counter += d[tmp]
    d.pop(tmp)
    if counter == k:
    return tmp

    • Ben

      I indented everything before posting it, but it’s not showing :(

      The documentation is a little hard to read, the basic algorithm I came up with is

      -add each element in the list to a dictionary where key=element and value=occurrence of the element (to account for duplicates)
      -in the same loop, also save each key to a list of lookup keys
      -exit the loop
      -create a counter and set to 0
      -while the dictionary is not empty
      -get the minimum key from the list of lookup keys
      -increment the counter by however many occurrences there are of that key
      -pop the current key from the list of lookup keys
      -if the current lookup key is equal to the parameter “k” then return the current lookup key because that’s the element we want (it doesn’t matter if there are duplicate elements because either way that is the key we need)
      -otherwise keep looping

    • ben

      just looked at this again. not sure what i was thinking, but it’s obviously very very wrong.

  • abc

    Can you please convert that to Java/C#?

    • abc

      Just the median part in the 2nd method.

    • Ankit Kapadia

      public static int findKLargest(int[] ar, int low, int high, int k){

      int pivotIndex = partition(ar,low,high);

      if(pivotIndex == k){
      return ar[pivotIndex];
      }

      else if(pivotIndex > k){
      return findKLargest(ar, low, pivotIndex-1, k);
      }

      else {
      return findKLargest(ar, pivotIndex+1, high, k);
      }
      }

      private static int partition(int[] ar, int low, int high) {
      int pivot = ar[low]; //first element as pivot

      int i = low + 1;
      for(int j= low+1;j<=high;j++){
      if(ar[j] < pivot){
      swap(ar,i,j);
      i++;
      }
      }

      swap(ar, low, i-1);

      return i-1;
      }

      private static void swap(int[] ar, int i, int j) {
      int temp = ar[i];
      ar[i] = ar[j];
      ar[j] = temp;
      }

      public static void main(String[] args) {
      int ar[] = {3,1,2,1,4};
      int k = 2; //K is from [1…n]
      int kLargest = findKLargest(ar,0,ar.length-1,k-1);
      System.out.println(kLargest);
      }

  • Christian Vielma

    Great explanation!

  • Vijay Ram

    Nice one

  • I[]

    Nore sure why we need to compute “rank” and change the value of k in subsequent calls. The kth element has to be at index k-1 in the end, so it should be enough to compare pivotIndex and k-1 in every call without altering k. Or did I miss something?

  • Ankit Kapadia

    Here is the Java Code. Partition algorithm is

    public static int findKLargest(int[] ar, int low, int high, int k){

    int pivotIndex = partition(ar,low,high);

    if(pivotIndex == k){
    return ar[pivotIndex];
    }

    else if(pivotIndex > k){
    return findKLargest(ar, low, pivotIndex-1, k);
    }

    else {
    return findKLargest(ar, pivotIndex+1, high, k);
    }
    }

    private static int partition(int[] ar, int low, int high) {
    int pivot = ar[low]; //first element as pivot

    int i = low + 1;
    for(int j= low+1;j<=high;j++){
    if(ar[j] < pivot){
    swap(ar,i,j);
    i++;
    }
    }

    swap(ar, low, i-1);

    return i-1;
    }

    private static void swap(int[] ar, int i, int j) {
    int temp = ar[i];
    ar[i] = ar[j];
    ar[j] = temp;
    }

    public static void main(String[] args) {
    int ar[] = {3,1,2,1,4};
    int k = 2; //K is from [1…n]
    int kLargest = findKLargest(ar,0,ar.length-1,k-1);
    System.out.println(kLargest);
    }

  • Robert

    hmmmm what about repeated elements in your partition method???
    youre not handling them correctly

  • Sorrowfull Blinger

    Another Approach : Would be using Min Heap of capacity k , keep adding elements to k until full , after full replace the min with current unexplored element. After all elements are processed the min should be the kth largest . Guaranteed n*Log(n)

    • Holden

      No, time complexity of your proposed solution would: O(k + (n-k)logk) = O(k + nlogk – klogk), which for great values of ‘n’, is equal to O(nlogk).
      Correct me if I am wrong

  • Kamal Chaya

    Why not use a min heap instead?

  • Urvish Mahida

    I miss the posts by you. You explain the problem, approach and solution very fluidly !