目录:

  1. 暴力
  2. 折半搜索 + map 或是哈希表
  3. 折半搜索 + 双指针 + 和 dfs 不同奇怪的状态枚举方式
    1. 第一种双指针
    2. 第二种双指针
  4. 完整代码

题目链接

博客中观看体验更佳

1. 题意

给你 nn 个二维的向量,对于任意一个 k(1kn)k (1 \le k \le n),求出有多少选取的方案能满足在这 nn 个向量中选 kk 个,并且他们的和为 (xg,yg)(x_g, y_g)

2. 分析

2.1 暴力算法

看到这个题我们可以比较快的想到拿部分分的做法,就是暴力枚举所有的选取方案,然后看他们加起来是否等于目标向量,再把符合要求的方案累加到答案中。但是我们发现 n=40n = 40,并且这个算法的复杂度是 2n2^n 的,所以一定会超时。

2.2 折半搜索 + map 或者是哈希表

2.2.1 简要思路

这道题中的折半搜索指的是把 nn 种向量分成两部分,对这两个部分分别用暴力的方法求出所有可能的选取方案,再把这些选取的方案,以及这个方案的结果(他们的和)按照某种方案储存下来,最后匹配这两部分的方案,把符合题目要求的(和等于 (xg,yg(x_g, y_g)的)累加进答案里。

2.2.2 细节

要达到前面提到的效果,我们可以使用 STL 的 map 或者是 unordered_map (手写哈希表也可以,但是可能要花更多时间来实现)来储存每个选取的方案。这里推荐使用 unordered_map,因为 map 的时间复杂度是 O(logn)O(\log n) 的,在这道题中会被卡,而 unodered_map 和手写哈希表的理想时间复杂度是 O(1)O(1)(当然使用 unodered_map 的话也要确保哈希函数写得好才不会被卡,比如我现在写的就过不了)。

对于暴力枚举状态的部分,一般的方法是 dfs ,也比较好写,这篇题解的双指针部分讲了一个比较奇怪的方法,想看的可以跳到下面。

注:下文的 map 指 unodered_map 或是 map 或是 手写的哈希表

我们首先创建两个 map ,firsec,分别储存前半部分向量的选取方案和后半部分的选取方案。对于这两个 map 的键值,我们可以设成包含 (sumx,sumy,k)(\text{sum}_x, \text{sum}_y, k) 三个整数的结构体。其中 (sumx,sumy)(\text{sum}_x, \text{sum}_y) 表示当前这个方案下,选取的所有向量的和。kk 表示当前的这个方案一共选取了多少个向量。因为可能有很多方案的 x,y,kx,y,k 都完全一样,所以对于 map 的值,我们设置成 (sumx,sumy,k)(\text{sum}_x, \text{sum}_y, k) 这三个值都相同的方案数。

对于答案的储存,我们开一个 ans[n] 的数组,ans[i] 表示 k=ik = i 时的选取方案数。

找到两部分的方案后我们要把符合要求的方案组合累加到答案中。具体来说,每个在 map 中储存的方案都包含一个这个方案的向量和 (sumx,sumy)(\text{sum}_x, \text{sum}_y) 。 我们设一个从前半部分向量得到的方案的向量和为 v1\vec{v_1} , 一个后半部分方案的向量和为 v2\vec{v_2} ,那么如果 v1+v2=(xg,yg)\vec{v_1} + \vec{v_2} = (x_g, y_g) ,我们就把这个方案记录进答案。

因为我们使用了 map ,所以并不需要真的用双层循环把每个方案都遍历一遍。我们知道对于一个可能的匹配方案 v1+v2=(xg,yg)\vec{v_1} + \vec{v_2} = (x_g, y_g) 。那么我们可以使用一个双层循环,一层枚举 fir 这个 map,一维枚举当前处理的 kk 。然后在循环里,我们可以写:

ans[当前k] += it_fir.值 * sec[{x_g - it_fir.键.x, y_g - it_fir.键.y, 当前k - it_fir.键.k}]

其中 it_fir 指的是 fir 的迭代器。而 sec[{x_g - it_fir.键.x, y_g - it_fir.键.y, 当前k - it_fir.键.k}] 中的方案和 it_fir 遍历到的方案符合 v1+v2=(xg,yg)\vec{v_1} + \vec{v_2} = (x_g, y_g) ,并且这两种方案选取的向量数的和也等于当前 k 。因为这两个 map 的值是所有符合这些条件的方案的数量,所以我们把这两个值相乘以求出所有符合要求的匹配数量。

2.3 折半搜索 + 双指针

双指针的理论复杂度似乎是和哈希表一样的,但是如果哈希函数有问题的话,哈希表的速度就会慢很多。而双指针就没有这个问题。

2.3.1 双指针a

我们可以首先开两个 vector ,fir 和 sec ,其数据类型和之前 map 的一样,也是 (sumx,sumy,k)(\text{sum}_x, \text{sum}_y, k) 。fir 储存由前半部分的向量得来的方案,sec储存后半部分的。

枚举完所有方案后我们对这两个 vector 进行排序,排序规则如下:

if(a.x != b.x) return x < b.x;
if(a.y != b.y) return y < b.y;
return a.k < b.k;

然后我们创建两个指针,p1p_1 的初值设成 1, p2p_2 的初值设成 sec.size() - 1 ,代表 sec 的最后一个元素。此时 p1p_1 指向的是 fir 中最小的元素,而 p2p_2 指向 sec 中最大的元素。我们可以想一下,如果想要让当前的这两个指针所指的 v1,v2\vec{v_1}, \vec{v_2}
的和等于 (xg,yg)(x_g, y_g) ,需要怎么做。如果我们想要增加 v1+v2\vec{v_1} + \vec{v_2} 的值,只能使 p1p_1 增加,因为 p2p_2 已经指向这个数组里最大的元素了。反过来也是一样的,如果我们想要减少 v1+v2\vec{v_1} + \vec{v_2} 的值,也只能使 p2p_2 减少。写成程序就是下面这样:

int p1 = 0, p2 = sec_half.size() - 1;
while(p1 < fir_half.size() && p2 >= 0){
   Instruct &f = fir_half[p1], &s = sec_half[p2];
   if(f.x + s.x < tar_x ||(f.x + s.x == tar_x && f.y + s.y < tar_y)){
      //如果两个向量相加小于目标值,我们只能加 p1 的值,
      //因为 p2 指向的元素最开始就是最大的。
      p1++;
   }
   else if(f.x + s.x > tar_x ||(f.x + s.x == tar_x && f.y + s.y > tar_y)){
      //如果两个向量相加大于目标值,我们只能减 p2 的值,
      //因为 p1 指向的元素最开始就是最小的。
      p2--;
   }
   //下面后半段代码插入的位置
}

注:Instruct 是包含 {x, y, k} 三种整数的结构体。

通过这样的方法,我们最终就一定能找到 v1+v2=(xg,yg)\vec{v_1} + \vec{v_2} = (x_g, y_g) 的情况。不过呢,这两个数组中都可能有连续的一段是完全一样的值,也就是有多个 p1p_1 , p2p_2 满足 v1+v2=(xg,yg)\vec{v_1} + \vec{v_2} = (x_g, y_g) 。因此我们需要找出符合条件的这个连续段具体是什么。通过上面的代码,我们已经知道了,满足条件的最小 p1p_1 和最大的 p2p_2,因为我们希望找出连续段的具体范围,所以还需要找出最大的 p1p_1 和最小的 p1p_1 。那么如何找呢?很简单,因为连续的这一段值一定都完全相等,所以我们只需要判断当前元素是否和最开始的元素相等就可以了。

当然,因为我们还需要把符合的匹配统计进答案,而答案是按照 kk 来输出的。所以我们可以开两个数组 fir_same_ksec_same_kfir_same_k[i] 就表示,对于第一个数组,在符合条件的这一长段中, k=ik = i 的有多少。而 sec_same_k 是对于第二个数组的。

然后我们就可以得到下面的代码了:

注意这段代码是插入前面那段代码的 else if 后面的

else{
   int p1t, p2t;
   
   memset(fir_same_k, 0, sizeof(fir_same_k));
   memset(sec_same_k, 0, sizeof(sec_same_k));
   
   //因为每次找到的符合条件的段都是不重合的,所以每次都清空一下数组

   for(p1t = p1; p1t < fir_half.size() && fir_half[p1t] == f; p1t++){
      //p1t 代表能满足 v_1 + v_2 == (x_g, y_g) 的最大 p1
      fir_same_k[fir_half[p1t].k]++;
   }
   for(p2t = p2; p2t >= 0 && sec_half[p2t] == s; p2t--){
      //p2t 代表满足 v_1 + v_2 == (x_g, y_g) 的最小 p2
      sec_same_k[sec_half[p2t].k]++;
   }

   //统计答案,对于前半段和后半段都枚举可能的
   for(int i = 0; i <= 20; i++){
      for(int j = 0; j <= 20; j++){//这个20其实是可以改成 n / 2 + 1 的
         ans[i + j] += 1LL * fir_same_k[i] * sec_same_k[j];
         //相乘是因为同一个 fir_same_k[i] 和 sec_same_k[j] 
         //中代表的任意一种选取方案都是完全相同的,(x,y,k) 都相同
      }
   }
   p1 = p1t, p2 = p2t;//不加这个会一直在相同的一段死循环
}

这个方法还是跑的相对比较快的,可以看下提交记录

2.3.2 双指针 b

我们发现双指针 a 的方法会需要在统计答案时开 fir_same_ksec_same_k 这两个数组来统计 kk 相同的情况。我们其实可以改进一下这个方法,直接在枚举状态的时候把 kk 相同的方案放到一起。

具体来说,我们把前面的 firsec 这两个 vector 改成 vector<Instruct> fir[20], sec[20]fir[i] 就储存前半部分 k=ik = i 时的所有方案,sec[i] 是后半部分的。既然把储存方案的方法改了,后面的双指针部分自然也要改。

这一次我们需要用一个双重循环来分别枚举不同的 fir[i]sec[j]。在循环内部我们再做和双指针 a 相似的事情。也就是说,我们已经知道了当前前半部分的方案和后半部分的方案的 kk,现在只需要通过双指针找出能满足 v1+v2=(xg,yg)\vec{v_1} + \vec{v_2} = (x_g, y_g)p1p_1p2p_2 范围。

for(int fir_k = 0; fir_k <= n / 2 + 1; fir_k++){   //枚举当前前半部分的 k
   for(int sec_k = 0; sec_k <= n / 2 + 1; sec_k++){//枚举后半部分的 k
      int p1 = 0, p2 = sec_half[sec_k].size() - 1;
      while(p1 < fir_half[fir_k].size() && p2 >= 0){
         Instruct &f = fir_half[fir_k][p1], &s = sec_half[sec_k][p2];
         if(f.x + s.x < tar_x ||(f.x + s.x == tar_x && f.y + s.y < tar_y)){
            p1++;
            //找出当前 fir_k 时满足 v_1 + v_2 == (x_g, y_g) 的最小的 p1
         }
         else if(f.x + s.x > tar_x ||(f.x + s.x == tar_x && f.y + s.y > tar_y)){
            p2--;
            //找出当前 sec_k 时满足 v_1 + v_2 == (x_g, y_g) 的最大 p2
         }
         else{
            int p1t = p1, p2t = p2;
            while(p1t < fir_half[fir_k].size() && fir_half[fir_k][p1t] == f){
               p1t++;
            }
            while(p2t >= 0 && sec_half[sec_k][p2t] == s){
               p2t--;
            }
            
            ans[fir_k + sec_k] += 1LL * (p1t - p1) * (p2 - p2t); 
            //把 p1 范围长度乘上 p2 范围的长度
            //小细节:本来要表示长度的话应该是 p1t - p1 + 1的,但是我们可以观察前面两个while
            //        在跳出之后 p1t 会比正确的 p1t 多 1,而 p2t 会比正确的 p2t 少1,因为如果
            //        它们还是正确的话会又回到循环中。因我们计算长度的时候就不需要 + 1 了。
            p1 = p1t, p2 = p2t;
         }
      }
   }
}

那么双指针b的写法比a有什么好处呢?答案就是节省空间。如果我们使用的是双指针a,那储存方案的结构体必须包含 x,y,kx, y, k 三种整数。注意其中的 kk 最大只有 2020,而我们却必须开一个 int 或是 short 来储存这个值,考虑到 2020 这个值非常小,不管哪种数据类型都会浪费大量的空间。而采用双指针b后,我们的结构体中只会包含 x,yx, y 两种整数, kk 这个值储存在数组的下标中,只要你开的数组大小为 kk 的最大值,就不会有任何的浪费。

具体的对比可以参考这个提交记录,可以发现相比双指针 a 的做法,双指针 b 的内存占用大约少了 17MB 左右

当然,代价也是有的,双指针 b 会稍微慢一些。我估计这主要出在双指针的环节。排序的部分甚至还会快一点。当然,不管哪种双指针,他们的理论复杂度都是一样的,因为每一种选取方案最多会被遍历到一次。

2.4 奇怪的状态枚举方法

这道题中常见的状态枚举方法就是 dfs。这里提供一种比较奇怪的枚举方法。在一个选取方案中,对于每个向量,都有两种状态,选或者是不选。因为只有选或不选两种状态,我们可以想到通过二进制数字表示这个状态。数字的第 ii 位表示向量 ii 是选还是不选,例如:(101)2(101)_2 就表示选择第 1 ,3 个向量,不选第 2 个。要枚举所有的状态,我们只需要把一个数从 00 一直累加到 2n2^n 就可以了,并且每次累加的时候检查他每一位是 0 还是 1 。当然,因为我们这里采用的是折半搜索,所以只需要累加到 2n22^\frac{n}{2}

至于复杂度的话,可能比 dfs 还慢?而且码量更大?毕竟每次累加还要写一个循环把这个数字从第一位检查到第二十位。不过,因为不需要递归,所以不需要一直给递归的函数开栈,内存占用可能会少一些。

真在比赛的时候还是不建议这样写的,毕竟 dfs 写起来是真的方便,这里只是提供一种好玩的做法。

3. 完整代码

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define rg register
const int MAXN = 45;

struct Instruct{
    ll x, y;
    int k;
    const bool operator < (Instruct b) const{
        if(x != b.x) return x < b.x;
        if(y != b.y) return y < b.y;
        return k < b.k;
    }
    const bool operator == (Instruct b) const{
        return x == b.x && y == b.y;
    }
}ins[MAXN];

vector<Instruct> fir_half, sec_half;
ll ans[MAXN];
int mx_state;
int n;
int tar_x, tar_y;

void vec_sum(int st, int ed, int cur_state, vector<Instruct> *cur_half){
    //根据当前提供的状态 cur_state, 把选中的向量累加起来
    //因为整个搜索的过程分成了两部分,所以需要参数表示是哪个部分,st, ed 表示的就是参与搜索的第一份向量,和最后一个。
    ll tot_x = 0, tot_y = 0;
    int k = 0;
    int len = ed - st + 1;
    for(int i = 1; i <= len; i++){
        if(cur_state & (1 << (i - 1))){
            tot_x += ins[st + i - 1].x, tot_y += ins[st + i - 1].y;
            k++;
        }
    }
    (*cur_half).push_back({tot_x, tot_y, k});
}

void state_generator(bool mode){
    //mode 表示当前处理的是前半部分还是后半部分,0是前半部分,1是后半部分
    rg int cur_state = 0;//初始的状态,对应的就是什么都不选
    int st, ed;
    vector<Instruct> *cur_half;//fir_half和sec_half储存这前半部分的方案和后半部分的方案
                               //cur_half就是当前这次搜索要把方案存到哪里
    if(mode){
        cur_half = &sec_half;
        st = n / 2 + 1, ed = n;
        if(n & 1) mx_state = mx_state * 2 + 1;
        //mx_state就是表示状态的数字最大能达到多少,它的初始值是 2^(n/2) 但是如果 n 是奇数
        //后半部分会比前半部分多包含一个向量,所以还要把原来的 mx_state * 2 + 1
    }
    else{
        cur_half = &fir_half;
        st = 1, ed = n / 2;
    }

    while(cur_state <= mx_state){
        vec_sum(st, ed, cur_state, cur_half);
        cur_state++;
    }
}

int main(){
    scanf("%d%d%d", &n, &tar_x, &tar_y);
    for(int i = 1; i <= n; i++){
        scanf("%lld%lld",&ins[i].x, &ins[i].y); 
    }
    for(int i = 0; i < n / 2; i++){
        mx_state = mx_state | (1 << i);
        //最大的状态是 n/2 位都是 1
    }
    state_generator(0); state_generator(1);
    
    sort(fir_half.begin(), fir_half.end());
    sort(sec_half.begin(), sec_half.end());

    rg int p1 = 0, p2 = sec_half.size() - 1;
    int fir_same_k[21], sec_same_k[21];

   while(p1 < fir_half.size() && p2 >= 0){
        Instruct &f = fir_half[p1], &s = sec_half[p2];
        if(f.x + s.x < tar_x ||(f.x + s.x == tar_x && f.y + s.y < tar_y)){
            //如果两个向量相加小于目标值,我们只能加 p1 的值,
            //因为 p2 指向的元素最开始就是最大的。
            p1++;
        }
        else if(f.x + s.x > tar_x ||(f.x + s.x == tar_x && f.y + s.y > tar_y)){
            //如果两个向量相加大于目标值,我们只能减 p2 的值,
            //因为 p1 指向的元素最开始就是最小的。
            p2--;
        }
        else{
            int p1t, p2t;
            memset(fir_same_k, 0, sizeof(fir_same_k));
            memset(sec_same_k, 0, sizeof(sec_same_k));
            //因为每次找到的符合条件的段都是不重合的,所以每次都清空一下数组
            for(p1t = p1; p1t < fir_half.size() && fir_half[p1t] == f; p1t++){
                //p1t 代表能满足 v_1 + v_2 == (x_g, y_g) 的最大 p1
                fir_same_k[fir_half[p1t].k]++;
            }
            for(p2t = p2; p2t >= 0 && sec_half[p2t] == s; p2t--){
                //p2t 代表满足 v_1 + v_2 == (x_g, y_g) 的最小 p2
                sec_same_k[sec_half[p2t].k]++;
            }
            //统计答案,对于前半段和后半段都枚举可能的
            for(int i = 0; i <= 20; i++){
                for(int j = 0; j <= 20; j++){//这个20其实是可以改成 n / 2 + 1 的
                    ans[i + j] += 1LL * fir_same_k[i] * sec_same_k[j];
                    //相乘是因为同一个 fir_same_k[i] 和 sec_same_k[j] 
                    //中代表的任意一种选取方案都是完全相同的,(x,y,k) 都相同
                }
            }
            p1 = p1t, p2 = p2t;//不加这个会一直在相同的一段死循环
        }        
    }

    for(int i = 1; i <= n; i++){
        printf("%lld\n", ans[i]);
    }
    system("pause");
}
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define rg register
const int MAXN = 45;

struct Instruct{
    ll x, y;
    const bool operator < (Instruct b) const{
        if(x != b.x) return x < b.x;
        return y < b.y;
    }
    const bool operator == (Instruct b) const{
        return x == b.x && y == b.y;
    }
}ins[MAXN];
vector<Instruct> fir_half[MAXN], sec_half[MAXN];
ll ans[MAXN];
int mx_state;
int n;
int tar_x, tar_y;

void vec_sum(int st, int ed, int cur_state, vector<Instruct> *cur_half){
    //根据当前提供的状态 cur_state, 把选中的向量累加起来
    //因为整个搜索的过程分成了两部分,所以需要参数表示是哪个部分,st, ed 表示的就是参与搜索的第一份向量,和最后一个。
    ll tot_x = 0, tot_y = 0;
    int k = 0;
    int len = ed - st + 1;
    for(int i = 1; i <= len; i++){
        if(cur_state & (1 << (i - 1))){
            tot_x += ins[st + i - 1].x, tot_y += ins[st + i - 1].y;
            k++;
        }
    }
    cur_half[k].push_back({tot_x, tot_y});
}

void state_generator(bool mode){
    //mode 表示当前处理的是前半部分还是后半部分,0是前半部分,1是后半部分
    rg int cur_state = 0;//初始的状态,对应的就是什么都不选
    int st, ed;
    vector<Instruct> *cur_half;
    if(mode){ 
        cur_half = sec_half;
        st = n / 2 + 1, ed = n;
        if(n & 1) mx_state = mx_state * 2 + 1;
        //mx_state就是表示状态的数字最大能达到多少,它的初始值是 2^(n/2) 但是如果 n 是奇数
        //后半部分会比前半部分多包含一个向量,所以还要把原来的 mx_state * 2 + 1
    }
    else{
        cur_half = fir_half;
        st = 1, ed = n / 2;
    }

    while(cur_state <= mx_state){
        vec_sum(st, ed, cur_state, cur_half);
        cur_state++;
    }
}

int main(){
    scanf("%d%d%d", &n, &tar_x, &tar_y);
    for(int i = 1; i <= n; i++){
        scanf("%lld%lld",&ins[i].x, &ins[i].y); 
    }
    for(int i = 0; i < n / 2; i++){
        mx_state = mx_state | (1 << i);
        //最大的状态是 n/2 位都是 1
    }
    state_generator(0); state_generator(1);
    
    for(int i = 0; i <= n / 2 + 1; i ++){
        sort(fir_half[i].begin(), fir_half[i].end());
        sort(sec_half[i].begin(), sec_half[i].end());
    }

    for(int fir_k = 0; fir_k <= n / 2 + 1; fir_k++){   //枚举当前前半部分的 k
        for(int sec_k = 0; sec_k <= n / 2 + 1; sec_k++){//枚举后半部分的 k
            int p1 = 0, p2 = sec_half[sec_k].size() - 1;
            while(p1 < fir_half[fir_k].size() && p2 >= 0){
                Instruct &f = fir_half[fir_k][p1], &s = sec_half[sec_k][p2];
                if(f.x + s.x < tar_x ||(f.x + s.x == tar_x && f.y + s.y < tar_y)){
                    p1++;
                    //找出当前 fir_k 时满足 v_1 + v_2 == (x_g, y_g) 的最小的 p1
                }
                else if(f.x + s.x > tar_x ||(f.x + s.x == tar_x && f.y + s.y > tar_y)){
                    p2--;
                    //找出当前 sec_k 时满足 v_1 + v_2 == (x_g, y_g) 的最大 p2
                }
                else{
                    int p1t = p1, p2t = p2;
                    while(p1t < fir_half[fir_k].size() && fir_half[fir_k][p1t] == f){
                    p1t++;
                    }
                    while(p2t >= 0 && sec_half[sec_k][p2t] == s){
                    p2t--;
                    }
                    
                    ans[fir_k + sec_k] += 1LL * (p1t - p1) * (p2 - p2t); 
                    //把 p1 范围长度乘上 p2 范围的长度
                    //小细节:本来要表示长度的话应该是 p1t - p1 + 1的,但是我们可以观察前面两个while
                    //        在跳出之后 p1t 会比正确的 p1t 多 1,而 p2t 会比正确的 p2t 少1,因为如果
                    //        它们还是正确的话会又回到循环中。因我们计算长度的时候就不需要 + 1 了。
                    p1 = p1t, p2 = p2t;
                }
            }
        }
    }
    for(int i = 1; i <= n; i++){
        printf("%lld\n", ans[i]);
    }
    system("pause");
}

这个做法一直在调,现在也没卡过时间和空间限制,应该是我哈希函数写挂了,但是我也不知道正确的写法是怎么样的,如果想用这个做法,可以看其它大佬的题解。如果你会这个方法,而且能帮我调一下的话我表示非常感谢,这是提交记录,代码我放在这里了。

3.1 折半搜索 + 哈希表

最后,希望这篇题解对你有帮助。有任何问题都可以在私信和评论区提出,我会尽量解决问题。