跳转至

线性时间选择第k大数

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
#include <stdio.h>
#include <stdlib.h>

#define N 1000000     //定义输入数组的最大长度 
#define LEN 5         //定义select中每组元素的个数 

int a[N];

void swap(int *a, int *b) { //交换 a 与 b 的值 
    int tmp = *a;
    *a = *b;
    *b = tmp;
}

int partition(int a[], int low, int high, int pivot) { //将数组a[low..high]划分为 <= pivot和 > pivot的两部分 
    int x;
    int i = low - 1;
    int j;
    for (j = low; j < high; j++) { //在数组中找到值等于privot的元素作为主元,交换到数组最右端 
        if (a[j] == pivot) {
            swap(&a[j], &a[high]);
        }
    }
    x = a[high];
    for (j = low; j < high; j++) { //维护低区a[low..i] <= x, 高区a[i+1..j-1] > x  
        if (a[j] <= x) {           //如果发现a[j] <= x,则将a[j]交换到低区 
            i++;
            swap(&a[i], &a[j]);
        }
    }
    swap(&a[i + 1], &a[high]);    //将主元与最左的大于 x 的元素a[i+1]交换,此时主元到了它应在的位置 
    return i + 1;                 //返回分区完成后主元所在的新下标 
}

void insertSort(int a[], int low, int high) { //对a[low..high]进行插入排序 
    int i, j;
    for (i = low + 1; i <= high; i++) {
        int temp = a[i];
        for (j = i - 1; j >= low && temp < a[j]; j--) {
            a[j + 1] = a[j];
        }
        a[j + 1] = temp;
    }
}

int select(int a[], int begin, int end, int k) { //选出数组a[begin..end]的第k小元素 
    int length = end - begin + 1;   //数组长度,即数组中元素的个数 
    if (length <= 140) {            //长度较小,直接用插入排序 
        insertSort(a, begin, end);
        return a[begin + k - 1];
    }
    int groups = (length + LEN) / LEN;  //组数 
    int i;
    for (i = 0; i < groups; i++) {
        int left = begin + LEN * i;  //第i组的左边界 
        int right = (begin + LEN * i + LEN - 1) > end ? end : (begin + LEN * i + LEN - 1);  //第i组的右边界 
        insertSort(a, left, right);  //组内进行插入排序 
        //将第i组中位数与数组a[]的第i个元素互换位置,方便递归select寻找中位数的中位数
        int mid = (left + right) / 2;
        swap(&a[begin + i], &a[mid]); 
    }
    int pivot = select(a, begin, begin + groups - 1, (groups + 1) / 2);  //找出中位数的中位数
    int p = partition(a, begin, end, pivot);  //用中位数的中位数作为划分的主元
    int leftNum =  p - begin;                 //低区元素的数量 
    if (k == leftNum + 1) {
        return a[p];
    }
    else if (k <= leftNum) {
        return select(a, begin, p - 1, k);  //在低区递归调用select来找出第k小的元素 
    }
    else {
        return select(a, p + 1, end, k - leftNum -1);  //在高区递归调用select来找出第(k-leftNum-1)小的元素 
    }   
}


int main() {
    FILE *fp = fopen("data_1022.txt","r");            //打开文件 
    if (fp == NULL) {
        printf("Can not open the file!\n");
        exit(0);
    }
    int i = 0;
    while (fscanf(fp, "%d\n", &a[i]) != EOF) {  //读取文件中的数据到数组a[]中 
        i++;
    }
    fclose(fp);                                 //关闭文件 
    int k;
    while (1) {
        printf("Please enter an integer k, and you will get the k-th largest element in the array!\n");
        printf("(Enter negative or zero to quit): ");
        scanf("%d", &k);
        if (k <= 0) {
            printf("Bye\n"); 
            break;
        }
        printf("The %dth largest element in the array is: %d\n", k, select(a, 0, i - 1, i - k + 1));
        printf("\n==================================================================================\n");
    }
    return 0;
}