之前尝试实现图形学一些论文的时候用到过最小二乘法,不过之前是从微积分的角度理解的。今年学了线性代数的课,发现这个东西也可以用线性代数的概念解释,并且更加直观(有几何角度的理解),所以准备用这篇文章记录一下最小二乘法的两种理解方式。
开始之前先吐槽一下这个东西的名字 “最小二乘”。就感觉很奇怪,最小化误差的平方为什么叫二乘呢?嗯。。。查了一下发现是从日语翻译过来的,那我只能说这个翻译水平是有点高的。。。虽然我很不想用这个迷惑的名字,但是因为都这么用的,那也没办法了。
学习最小二乘法之前先得理解这个算法在尝试解决什么。最小二乘法最常见的用途是用一个函数拟合数据,进而用这个函数来预测数据的趋势。为了拟合数据我们就需要用数学的方法来定义怎样的拟合是好的,然后我们尽量一个函数更好的拟合数据。
这里我们假设有 个数据点 。然后有一个函数 用于拟合这些数据点。我们定义单个点的误差为:
注意这里函数 的含义是使用 来预测 ,你可以把这两个值理解成标量,但实际上最小二乘法也可以用于拟合向量,因为两种情况下公式完全一样,这里简单起见就假设是标量。
那么误差的平方和就是
我们希望能调整这个函数的参数使得这个值最小。
具体来说,可以把函数 表示成如下形式:
其中, 是我们需要调整的参数,而 是一些线性无关的函数。
这部分内容的来源主要是这个视频,讲的内容还是很完整的但是有些快和笔误,我也看了好几遍才搞懂。
我们可以把误差表示成下面这样的形式,然后从微积分的角度最小化这个误差。
写成这样的形式可以更加清楚的表示出我们的目的:通过调整 这些参数使得误差最小化。
注意这里在给每个数据点的误差平方求和的时候多了一个权重 ,通过这个权重可以更加方便的调整每个数据点的重要性。
通过微积分我们知道,函数在达到其最值的时候导数一定为 0。我们通过这个性质可以从微积分的角度来调整参数使得误差最小。不过这个最值可以是最小值也可以是最大值。在计算误差 的时候,一定只存在一个最小值(可以想象把所有的参数设置到 的情况)。
因为 是一个多变量函数,所以要通过上面的方法求最小值,我们需要用偏导数:
在上面的偏导数中,除了 都可以作为常数处理。因为偏导数的定义就是改变参数 ,对误差 造成的影响。
在 的时候,我们可以说在只调整参数 的情况,误差已经达到了最小值,但是因为每个参数都是可以调整的,所以我们希望对于每个 。
这样一来,可以得到一个线性方程组。看到线性方程组第一个想到的肯定是用线性代数的方式来表示这个式子,这样在求解的时候可以极大的提升速度。
经过一定的变形,可以得到下面的公式:
因为我们希望往线性代数的方向上靠,所以可以把下面这个求和表示成点乘的形式:
其中,,,。
同理,式子的另外一部分也可以表示成点乘形式:
那么如果:
就符合误差的偏导为 0,也就是 。
当然,我们最终的目标是把这个式子写成一个矩阵乘法的等式来提升求解速度。仔细观察等式的右边,我们其实可以发现这个求和在本质上也是一个点乘。
为了方便理解,我们把式子中标红的部分记作 ,其是一个 维的向量,那么这个求和就是 。其中
同理可以把 记作 (标量)。
这样只要解决下面的方程组,就能解决最小二乘问题:
这个形式就非常的眼熟了,可以直接表示成矩阵形式。
写的简单点就是 。
待完成
]]>因为学校里的各种事情,以及复习考试,自从上次更新博客已经是四个多月的事情了,距离上次写图形学的博客那就更久了。
最近刚放暑假把之前只看了光线追踪部分的 GAMES101 完整的学习了一遍,还是非常惊喜的:很多之前不太清楚的概念(特别是数学方面的)过了一段时间再看又有了些新的理解。因为课程时间限制的问题,有有部分内容没有比较详细的讲,这里记录一些我自己的理解。
这个东西我已经在之前 RT: The Next Week 的文章中写过了,不过之前的解释比较。。。怪,很啰嗦,并且没有从坐标系变换的角度来解释,这里就重新写下(当然我还没有系统的学过线性代数,所以以下内容可能还是很扯)。
绕三个轴的三维旋转矩阵分别可以写作下列的形式:
不难发现,绕 y 轴旋转的矩阵中, 和 的位置似乎是反的,非常奇怪。原视频中,旋转矩阵是通过选取一些特殊点完成推导的,这里我感觉用坐标系变换的方法更加易于理解(虽然闫老师觉的这个更复杂)。
首先看绕 z 轴的,这个比较简单,基本上就是二维旋转矩阵的情况:
注意这里我虽然没有加入 z 轴,但是通过右手定则,可以发现 z 轴是朝着穿出屏幕方向的。
我们可以分别把新 x 轴 () 和新 y 轴()用向量形式表示出来,注意这里这两个向量都是单位向量:
观察 和原 y 轴的夹角,可以发现是 ,所以可以把 的向量形式“反”一下。还有一点需要注意的是: 的 x 分量是负数:
其实不管在什么坐标系中,任何坐标都是通过单位向量乘以一些长度得到的(可以想象成一个点在某个方向移动一些距离)。比如在平面直角坐标系中, 的坐标就可以理解成,把一个点向着 x 方向移动 1,向着 y 方向移动 2。
所以,在旋转过后的坐标系,对于一个点 ,他的新坐标就是 。相当于是向着 方向移动了 x 个单位,向着 方向移动了 y 个单位,以及向原来的 z 方向(虽然没有变换,但是这里我们还是把它标记成 )移动了 z 个单位。
所以新的坐标就是:
观察之前提供的旋转矩阵公式,能发现 确实能满足上式:
这里有一个比较有意思的地方,已经通过不同的颜色标记出来了:可以发现,旋转矩阵中的三列分别对应着 , 和 ,也就是新坐标系中的三个轴的方向。
这样一来,从坐标系变换的角度,就非常容易理解为什么旋转矩阵是这么写的了。我们可以直接把变换后的三个轴的方向写到矩阵中的三列,从而得到旋转矩阵。
我们可以用相同的方法来分析绕 y 轴的旋转,就是那个看起来“反”了的矩阵。
虽然看起来和刚刚的图很相似,但是可以发现图中的标签已经变过了。同样,通过右手定则,可以发现这时 y 轴是朝着穿出屏幕方向的。
那么有了刚刚的观察,现在我们只需要找到当前情况下的 就可以写出绕 y 轴的旋转矩阵了。
首先,对于新的 z 轴,也就是 ,我们可以类比绕 z 轴情况下的 ,它的向量形式如下:
这里 y 分量为 0 的原因很明显,因为我们考虑的是绕 y 轴旋转,所以 y 肯定没有变化。
然后,类比绕 z 轴情况下的 ,我们可以得到 的向量形式:
对于 ,因为和原来没有变化,所以可以简单的写作:
合并这些表示新 x,y,z 方向的向量,可以得到:
和 GAMES101 课程中讲的一样,重心坐标(特别是三角形的)在图形学中非常有用,可以很方便的把三角形节点上的信息插值到三角形面上。
我最初听课的时候还是有较多疑惑的,比如为什么一定要三个系数和相加等于 1,并且三个系数都不为负数,某个点才是在三角形内部的呢?
后来思考了一段时间,感觉用“顾名思义”的方式理解这个概念是直观的,也就是从重心的角度去理解。
假设有一个三角形,其三个顶点的质量分别为 ,并且除了顶点外,别的区域质量均为 0,那么根据重心的定义,三角形的重心就是:
把这三个项分开,可以得到:
这个形式简直和重心坐标的 太相似了,每项都是一个系数乘以一个顶点的坐标。
观察形如 的项,可以发现,计算重心时的系数完全符合重心坐标的要求,即 ,并且每项不是负数。
现在用物理上重心的思路来思考,我们在转换普通坐标到重心坐标时到底在转换什么?
把重心公式的系数对应到重心坐标的系数,也就是 ,我们能发现,这些系数实际上是三角形每个顶点的质量占总质量的比例。
也就是说,在转换的过程中(假设把 点从笛卡尔坐标系转换到重心坐标系),我们实际上计算的是这样一个问题:三角形的三个顶点如何分配重量,才能使得三角形的重心位于 点?
转换后的坐标, 其实就是在三角形三个顶点上分配的重量的比例。
这样一来,我们就很容易理解这三个数为什么符合相加为 1 并且非负才能使点在三角形内部的要求了。首先,从物理的角度来思考,如果质量非负,那么重心一定在物体的内部。其次,如果要符合重心的定义,那么 相加一定是等于 1 的。因为它们三个代表的是三个顶点的质量占总质量的比例,所以这三个数字相加一定是 1(对应总质量)。
当然,重心坐标是由物理上的定义推广出来的,所以用纯代数的方法来说明可能更有说服力。
我们知道,三角形的重心坐标就是三个顶点的线性组合来标记一个坐标,也就是 。现在我们可以尝试用代数的方式来说明,为什么这三个系数相加等于 1 并且非负才能使点在三角形内部。
首先,我们可以把 点用另一种方式表示出来,这对之后的推导会有帮助:
注意这里的 和 其实和重心坐标的三个系数没有关系,这个式子的含义是:从 A 点出发,沿着 AB 方向走 个单位,然后沿着 AC 方向走 个单位,就可以到达 点。
首先,如果要让点在三角形内部,有一点是可以确定的: 和 是非负的,因为考虑当前在 A 点上,只向着 或者 方向走任何距离,马上就会走出三角形。
对上式稍微做一点变形,可以得到:
可以这么理解这个式子:我们把包括 点在内的所有点都通过 移动了一个距离,这个时候 一定在原点。
这样这个式子就很直观了:先通过 这个向量从原点走到 点,再从 点用 走回 ,也就是原点。
因为这两个部分方向相反并且大小相同(一个从 走到 ,一个从 走到 ),所以这两个向量相加一定是 0。
拆开这个式子,再重新整理,可以得到:
这个形式完全和重心坐标一致,所以我们可以确定:,,。
把这些东西相加,可以发现 ,确实,重心坐标的三个系数相加一定是 1。
前面我们已经说明了, 和 是非负的,但如果 , 不就是负数了吗?
考虑下图, 是 到 的一条垂线( 不是 的中点,这里不用等边三角形可能更清晰,但是我懒):
回到最初的定义中,也就是 , 分别是朝 和 方向走过的距离。
那么 投影在 上的长度就是
观察上图可以发现,如果 为 1,也就是 有完整的长度,这个投影的长度一定等于 。同样的,如果 为 0,投影在 方向的长度也一定是 0。
因为投影,或者说点乘是一个线性的操作,所以如果 为 0.5,投影的长度一定是 的一半。
因此我们可以说, 在 上的投影长度是
相同的, 在 方向的投影长度符合一样的规则。
为了让 点留在三角形内部, 这个向量投影在 上的长度必须小于等于 本身的长度,不然 点就会从 这条边跑出三角形。
在 上的投影长度是 , 在 上的投影长度是 ,自然,这两个向量的和在 上的投影长度就是 。
前面说过,为了使得 点留在三角形内部, 这个向量投影在 上的长度必须小于等于 本身的长度,也就是 ,那么 。
因为 ,并且 ,所以我们可以说明, 也一定是非负的。
至此我们已经能说明,为什么一个在三角形内部的点就要符合 ,并且每个系数非负了。当然要把一个平面直角的坐标转换到重心坐标,还是需要一些相对复杂的计算的。这部分的内容,我觉得这篇博客中介绍的第一种方法相对易于理解并且很巧妙,有兴趣的可以看下。
]]>题意非常简洁,即问你通过一系列的字符替换,最少花多少步能把一个 串变成 串。
拿到题之后,可以先从样例开始分析。
从 这个样例可以发现,不可能同时把某个字符替换成两个字符(),会起冲突。
那直接统计 的个数(给串去重之后,即不存在 这种)就可以作为答案了吗?可以从最后一个样例发现不是这样的。
因为最后一个样例中, 的部分是一样的。我们直接考虑 的变换。如果直接执行 的操作,会得到一个 的串。这个时候就有了和前面一样的问题,不能将其转换成 。执行 也是同理。
解决的办法就是先执行 再处理 。( 是任意别的字符)
是否所有“相互依赖”的情况下,都可以通过这种方式解决呢?我们可以再思考一个大一点的样例 ,用图(创建 的边,并且去掉重边和自环)的方式表示出来会更加清晰:
graph LR A --> B B --> C C --> D D --> A
可以发现,这是一个环。无论我们先执行哪种 的变换,都会需要再执行 的变换。因为 希望能变成别的。这个时候,先前 会跟着一起被变成 。
不过,如果能“化环为链”,就可以解决问题了。比如我们可以先执行 ,这个链就会变成:
graph LR x --> B B --> C C --> D D --> A
这样,就有一个执行 后,不用再执行 的地方了。即 (执行完之后, 也符合这个条件,我们倒着的按照链的顺序就可以把整条串转换为目标)。
从这两个例子可以看出,在一般的情况下,一个操作能把环转化为链,或者把链的长度(边的数量)减少 1。
所以答案的数量就是(环的数量 + 链的长度)了吗?
首先,化环为链的操作需要一个不在环中出现的字符,假设环包含了字符集中所有的字符,我们是不能处理的。
假设我们的字符集只有 这四个字符,那处理下面这个例子时候,就会发现问题。
graph LR A --> B B --> C C --> D D --> A
不管先把 变成什么字符,这个字符之后都会再经历最少一次的变换,导致 不能被转换成目标字符 。
当然,我们处理不了的情况不一定要求整张图中只有一个环,只要符合:
就不能处理了,比如下面这个例子,有两个环还是不行(字符集为 ):
graph LR A --> B B --> A C --> D D --> C
考虑这样一个输入:
graph LR A --> B B --> C C --> D D --> A E --> B F --> E
我们可以在一个操作内即化环为链,又把链的长度减少 1。观察到 和 都希望能被转换成 。从字符转换的角度来说, 和 的最终结果和操作步数都是一样的。但是第二种方法在执行 时,也把环中的一个字符转换成了环外的字符,将环化成了链。
能这么做的前提条件是,有多个环外字符希望变成环内的一个字符。更严谨的说就是环中某个节点的入度大于等于 2。
到此为止,所有的情况都基本分析好了,可以写出以下的总结(括号中的为实际判断方法):
这里第二点的判断方法可以稍微解释一下:
没有选择使用出度是考虑到了环连着树的情况,参考上图。
实现的时候找环的部分需要注意一下,其他部分都比较简单。
我们知道 tarjan 算法就可以判环,不过这道题可以用“简化版”的 tarjan,不用记录访的时间戳。我们把 dfs 的时候把所有访问过的节点从队尾压入一个双向队列。
如果我们开始 dfs 的时候是从一个环上的点进入的,之后一定会访问到一个和队头一样的节点。这个时候把所有在队头和队尾之间的节点都弹出,就得到了环中的所有节点。
如果我们发现某个节点之前访问过,但是并不在队头,就可以确定队列中的节点都不是“绝对环”,因为有树连着他(参考上图,如果从 F 节点开始搜就会出现这种情况)。
#include <bits/stdc++.h>using namespace std;const int CHSZ = 52; // char set sizeint out[CHSZ + 1]; // 出度只能有一个int lpid[CHSZ + 1]; // 环的 id,不知道 -> -1,不是环 -> 0,是环 -> 1,2,3...enum LP_STAT { UNKNOWN = -1, NOT_ABS_LP = 0 };deque<int> vised_dq; // 用于在找环的时候储存信息bool vised[CHSZ + 1]; // 用于在找环的时候储存信息set<int> in_nds[CHSZ + 1]; // in nodes,入度可以有多个int in1_cnt = 0; // 入读为 1 的节点的数量int abs_lp_cnt = 0; // 绝对环数,即环不连树的环数int diff_chs = 0;void init() { memset(out, 0, sizeof(out)); fill(lpid, lpid + CHSZ + 1, UNKNOWN); vised_dq.clear(); memset(vised, 0, sizeof(vised)); for (int i = 0; i <= CHSZ; i++) in_nds[i].clear(); in1_cnt = 0; abs_lp_cnt = 0; diff_chs = 0;}inline int ch2id(char x) { // char to id if (x >= 'a' && x <= 'z') return x - 'a' + 1; if (x >= 'A' && x <= 'Z') return x - 'A' + 27; return -1;}bool check_loop_connect_to_tree() { for (int cur : vised_dq) if (in_nds[cur].size() >= 2) // 有树连这个环 return true; return false;}void fill_lpid_in_vised_dq(int val) { for (int cur : vised_dq) lpid[cur] = val; vised_dq.clear();}void mark_loop(int cur) { if (vised[cur] && vised_dq.front() != cur) { // 从一个树进入的环,不是绝对环 fill_lpid_in_vised_dq(NOT_ABS_LP); return; } vised[cur] = true; if (out[cur] == cur) { // 没有出度,找到一个链 fill_lpid_in_vised_dq(NOT_ABS_LP); return; } if (vised_dq.size() && vised_dq.front() == cur) { // 找到环 if (!check_loop_connect_to_tree()) { // 如果环不连树 abs_lp_cnt++; fill_lpid_in_vised_dq(abs_lp_cnt); } else { fill_lpid_in_vised_dq(NOT_ABS_LP); } return; } vised_dq.push_back(cur); mark_loop(out[cur]);}void solve(const string& origs, const string& tars) { // orig str -> tar str init(); for (int i = 0; i < origs.size(); i++) { int och = ch2id(origs[i]); int tch = ch2id(tars[i]); if (out[och] && out[och] != tch) { // 如果 o 串已经有要转换的字符,但是不是 t // 串的字符,那么会产生多对一 cout << -1 << '\n'; return; } if (!out[och]) { out[och] = tch; in_nds[tch].insert(och); if (och != tch) diff_chs++; } } for (int i = 1; i <= CHSZ; i++) { if (in_nds[i].size() == 1) in1_cnt++; // 统计入度为 1 的节点数量 } for (int i = 1; i <= CHSZ; i++) { if (out[i] && lpid[i] == UNKNOWN) { // 标记环 vised_dq.clear(); memset(vised, 0, sizeof(vised)); mark_loop(i); } } if (origs != tars && in1_cnt == CHSZ) { // 判断是否全部都在环中,用入度为 1 的数量来判断 cout << -1 << '\n'; return; } cout << diff_chs + abs_lp_cnt << '\n';}int main() { int t; cin >> t; while (t--) { string origs, tars; cin >> origs >> tars; solve(origs, tars); }}
]]>Lab4 的主要作用是把前面的 receiver 和 sender 结合起来,形成一个完整的 TCP 协议栈。所以熟悉 TCP 的状态流转图就很重要了。
下面是一个 TCP 的状态流转图:
参考上图,可以看到 TCP 有两种建立连接的方法。第一种是主动连接,即给对方发送一个 SYN 包。第二种是被动连接,即接收到一个 SYN 包后,回复 SYN+ACK 包。
对于主动连接,我们需要实现 connect()
函数:
void TCPConnection::connect() { _shutted = false; _sender.fill_window(); send_sender_segs();}
这里的 shutted
变量表示连接是否已经关闭,之后在 active()
函数中使用。因为我们之前实现过 TCPSender
的 fill_window()
函数,它会记录连接是否已经建立,如果没有会自动发送 SYN 包。
不过 TCPSender
的 fill_window()
仅仅会把要发送的 TCP 包推入它的 _segments_out
队列。我们需要把这里的包放到 TCPConnection
的 _segments_out
中,这样 sponge 才会用 IP 协议把它们发送出去。
所以 fill_window()
后面的 send_sender_segs()
中的一个作用就是把 _segments_out
中的包放到 TCPConnection
的 _segments_out
中。
当然,TCPSender
发送包的时候是不清楚一些报头中的信息的。比如 win
和 ackno
,前者代表 TCPReceiver
还能接收多少数据,后者代表 TCPReceiver
已经收到的数据。所以我们还需要在 send_sender_segs()
中把这些信息填好:
填写过程中有一个比较坑的地方,就是报头中 win
的范围。查看 TCPHeader
中 win
的定义:
uint16_t win = 0; //!< window size
可以发现这是一个 16 位的无符号整数。但是在 TCPReceiver
中,window_size()
返回的是一个 64 位的整数:
size_t TCPReceiver::window_size() const { // 从 ackno 开始,还能接收多少字节 return _capacity - _reassembler.stream_out().buffer_size(); // window_size() + buffer_size() = capacity}
如果强行调用 window_size()
给 win
赋值,可能会造成溢出,所以赋值的时候需要这样写:
seg.header().win = min(_receiver.window_size(), (size_t)numeric_limits<uint16_t>::max());
再参考状态流转图,如果现在在 LISTEN
状态。收到一个 SYN 包并且回复了 SYN+ACK 包后,连接就建立了。
但是如何确定这个 LISTEN
的状态呢?一个很方便的方法是直接使用 Sponge 提供的 TCPState
类。
其不仅可以判断 TCPConnection
整体的状态,也可以分别判断 TCPSender
和 TCPReceiver
的状态。
这里的 LISTEN
是整体的一个状态。
在 segment_received()
函数中这么写就可以判断当前是否要进行被动连接了:
bool passive_connect = (state() == TCPState::State::LISTEN && seg.header().syn);
如果发现是需要被动连接,那么直接这么写就行了:
// 如果是在 listen 状态,被动建立连接bool passive_connect = (state() == TCPState::State::LISTEN && seg.header().syn);// 对于 receiver, LISTEN// 对于 sender, CLOSED_receiver.segment_received(seg); // 先调用 segment_received 才知道要回复什么 acknoif (passive_connect) { connect(); return;}
现在为止,我们已经成功的建立了连接。对于每个新到达的包,只需要在 segment_received()
中调用 _receiver.segment_received()
和 _sender.ack_received()
(这样 sender 知道对方收到了哪些信息,可以重发没有收到的) 来更新信息就可以一直维持连接了。
相比建立连接,关闭连接会显得复杂一些,并且不能保证总是“完美”的关闭。
计算机网络学科中,有一个著名的思想实验来描述 TCP 不能完美关闭连接的问题 – 两军问题。维基百科的描述[1]如下:
两支军队由不同将军领导,准备进攻一座坚固的城市。军队在城市附近的两个山谷扎营。由于有另一个山谷将两山隔开,两名将军只能透过派信使穿越山谷通信,但这山谷由城市护卫占领,有可能俘虏途径山谷传递消息的任何信使。
虽然两军已约定要同时进攻,但尚未约定进攻时间。要顺利攻击,两军必须同时进攻。如果同一时间仅一支军队进攻就会战败,因此两名将军须约定攻击时间,并确保对方知道自己同意了进攻计划。
将军甲首先派信使向将军乙传递消息“在8月4日9时进攻”。然而,派遣信使后,将军甲不知道信使是否成功穿过敌方领土。由于担心自己成为唯一的进攻军队,将军甲可能会犹豫要否按计划进攻。
为了消除不确定性,将军乙可以向将军甲发送确认消息“我收到了您的消息,并会在8月4日9时进攻”,但传递确认消息的使者同样可能会被敌方俘虏。由于担心将军甲没有收到确认消息而退缩,将军乙会犹豫要否按计划进攻。
再次发送确认消息看来可以解决问题——将军甲再让新信使发送确认消息:“我已收到您确认在8月4日9时进攻”。但是,将军甲的新信使也可能被俘虏。显然,无论确认几次都无法满足该问题的条件二,即两方都必须确保对方已同意计划,两名将军总会怀疑他们最后派遣的使者有否顺利穿过敌方领土。
可以发现,TCP 关闭连接的时候,也存在同样的问题。当 A 发送断开连接的消息后,B 可以发送一个 ACK 包表明收到了断开的消息。然而,B 不知道 A 是否收到了 ACK 包,从而担心 A 是否会正常关闭。A 当然可以再回复一个 ACK 包,但这就陷入了两个将军的困境中。
多次的互相发送确认消息看起来可能能减少错误,但是 TCP 协议中是不会对一个 ACK (即不包含实际数据的包,只有 ACK)包回复 ACK 包的,所以我们还需要一些别的解决方案。
和建立连接类似,断开连接时我们也可以分为主动和被动两个方面去讨论。
和主动关闭相比,被动关闭相对比较简单。所以我们可以先讨论。
被动和主动关闭端点的唯一区别就是发送 FIN 包的先后。主动关闭在发送完所有自身出向字节流产生的 TCP 包后,会发送一个带 FIN 的包。
虽然这时连接的一方已经发送 FIN 了,但这并不代表连接就已经关闭了。因为被动的一方可能还有数据没发完。等到发完后,被动端也会发送一个 FIN 进入 LST_ACK 状态。
这个状态唯一的目的就是等待另一端发送对 FIN 的确认信息。如果主动方没确认,被动方还需要一直发送 FIN 来确保对方收到了。
待收到 ACK 后,就可以直接关闭了。
如果出向的字节流已经被完全发送出去了,连接的一方就会发送 FIN 并进入 FIN_WAIT_1 状态。表明 TCP 双向连接的其中一向已经关闭了(即当前端点只接收数据,不会再新发送)。对方确认该 FIN 消息后,当前端点会转换到 FIN_WAIT_2 状态,等待对方完全发送它想传输的数据。
收到对方的 FIN,并且确认后,端点就进入了 TIME_WAIT 状态。这个状态代表代表着:
虽然进入 TIME_WAIT 后,我们无法确定对方是否收到了对于其 FIN 的确认消息,但是如果对方没有收到,大概率是会在一定的时间内重发 FIN 的(TCP 的超时重传机制,TCPSender
有实现)。
虽然网络可能比较拥堵,但如果我们等待(linger)了比较长的一段时间对方都没有重发,那大概率是对方已经收到确认消息并且关闭连接了。
这个等待的时间在实验指导书中有写到:
it has been at least 10 times the initial retransmission timeout (
cfg.rt_timeout
) since the local peer has received any segments from the remote peer.[2]
如果采用的是默认的 cfg.rt_timeout
,那么总的等待时间最少是 10 秒。
前面提到了被动关闭的一方不需要等待,也就是 linger,用如下的代码就可以实现出来:
// 在 segment_received 中 // 后发 fin(先收到 fin)的端点不需要 linger // 这里是 ESTABLISHED 向 CLOSE_WAIT 的转换 if (TCPState::state_summary(_receiver) == TCPReceiverStateSummary::FIN_RECV && TCPState::state_summary(_sender) == TCPSenderStateSummary::SYN_ACKED) { // 不能直接用 state() == CLOSE_WAIT 是因为 CLOSE_WAIT 要求 linger_after 也是 false // 但是我们假设先 linger _linger_after_streams_finish = false; } // 这里是 LAST_ACK 向 CLOSED 的转换 if (TCPState::state_summary(_receiver) == TCPReceiverStateSummary::FIN_RECV && TCPState::state_summary(_sender) == TCPSenderStateSummary::FIN_ACKED && !_linger_after_streams_finish) { // 不能用 state() == LAST_ACK 是因为其代表 sender 发送了 FIN。并不是 FIN 被确认,即 FIN_ACKED _shutted = true; return; }
因为 _linger_after_streams_finish
这个变量是默认设为 true 的,所以只要在之前的判断中,这个变量没有被设置成 false,那么我们就是主动关闭的一方。
TCPConnection
类中。唯一一个能够获取当前时间的函数就是 tick()
了,为了实现超时直接断开连接的功能,我们可以在 tick()
中加入如下代码:
if (state() == TCPState::State::TIME_WAIT && _since_lst_rx_ms >= 10 * _cfg.rt_timeout) { _shutted = true; _linger_after_streams_finish = false; }
完成这些后再加上亿点点细节,就可以通过测试了(因为实验指导书上的合作政策写了不能公开代码,所以这里只放部分的代码片段),测试结果如下:
./tcp_benchmarkCPU-limited throughput : 0.37 Gbit/sCPU-limited throughput with reordering: 0.36 Gbit/s
说实话速度还是比较慢的,主要原因也能从之前的火焰图看出来,是字符串拷贝和处理的问题。我在优化完后应该还会再写一篇博客来介绍优化的过程和内容。
]]>注:因为实验指导书和课程文件[1]里都明确的写了不能公开代码,所以博客上的实验记录就主要记录思路以及一些核心代码片段,不会公开完整的仓库。
Lab 要求实现一个在内存层面上可靠的字节流(ByteStream
),感觉和 unix 中的管道挺像的。其实这样先进先出的结构完全可以直接使用 STL 的 queue<char>
实现,会非常简单。但是考虑到 lab 的要求是一个固定大小(capacity)的字节流,个人认为直接开个数组模拟更合适,速度也应该会更快。
具体来说,就是开一个 string
(没有直接使用字符数组是因为实验指导书提到了最好使用现代 C++ 的风格,避免使用 new
来手动分配内存)来储存数据,以及一个头指针和尾指针指向字节流的开始和结尾。这样就实现了一个环形队列,peek_output()
函数的实现大概是下面这样的:
string ByteStream::peek_output(const size_t len) const { size_t peek_size = min(buffer_size(), len); size_t i = 0; string ret = ""; ret.resize(peek_size); while (i < peek_size) { ret[i] = _data[(_head + i) % _capa]; i++; } return ret;}
不过这样的实现虽然看起来比较直观,其性能是比较差的。这主要是因为环形队列中大量的使用了取模运算,造成速度大幅下降。因为我现在还没开始做 Lab4,所以暂时没有太过考虑性能问题,Lab0 的测试结果如下(release 模式):
[100%] Testing Lab 0...Test project /mnt/e/ocourses/st_cs144/sponge/build Start 26: t_byte_stream_construction1/9 Test #26: t_byte_stream_construction ....... Passed 0.01 sec Start 27: t_byte_stream_one_write2/9 Test #27: t_byte_stream_one_write .......... Passed 0.01 sec Start 28: t_byte_stream_two_writes3/9 Test #28: t_byte_stream_two_writes ......... Passed 0.01 sec Start 29: t_byte_stream_capacity4/9 Test #29: t_byte_stream_capacity ........... Passed 0.22 sec Start 30: t_byte_stream_many_writes5/9 Test #30: t_byte_stream_many_writes ........ Passed 0.01 sec Start 31: t_webget6/9 Test #31: t_webget ......................... Passed 0.81 sec Start 53: t_address_dt7/9 Test #53: t_address_dt ..................... Passed 0.05 sec Start 54: t_parser_dt8/9 Test #54: t_parser_dt ...................... Passed 0.01 sec Start 55: t_socket_dt9/9 Test #55: t_socket_dt ...................... Passed 0.01 sec100% tests passed, 0 tests failed out of 9Total Test time (real) = 1.17 sec[100%] Built target check_lab0
这个 Lab 需要实现一个“重排器(reassembler)”,即把不同的数据片段根据提供的起始下标重新排列成连续的字节流。并且我们还需要保证尽可能快的把收到的数据放入字节流中(即如果 范围内的字符全部被收到了,就应该把这一段的数据立刻放入字节流中)。
先不提实验本身,实验指导书中的要求就挺难理解的,特别是 capacity
的概念。简单来说就是字节流中未读取数据的大小加上重排器的接收范围。
或者说,重排器的容量是有限的,如果某个数据段的 index
太大了,重排器可以直接抛弃。而字节流中的未读取数据越多,最小的,会被抛弃的 index
就会越小。
实现这个重排器有很多种方法,最简单的当然是把每个到达数据片段都复制一遍,然后一发现重排器的前面有连续的数据片段就放入字节流中。
但是很明显,这样的算法是非常低效的,对于每个新到达的数据段,都必须要完整的遍历一遍,即使是之前已经接收过完全一样的数据了。
这里我采用的避免重复复制的方法是实现一个专门维护“段的集合”的数据结构。
对于任何一个新到达的数据段,我们都可以把他的范围表示成 这样的形式。那么我们也可以维护一个段的集合,表示当前还没接收到的数据范围,用 表示 。对于新到达的段 ,如果我们能求出 和 相重合的部分,或者说 ,我们就只需要遍历这部分了(即 能覆盖的,还未填充的段),而如果我们发现 的长度为 (新到达的数据不包含任何未接收部分),就可以直接返回了,避免了前面提到的重复遍历问题。
在新数据段的 部分被写入后,我们也需要能够改变 ,让其去掉 这部分,表示该段已经接收到。
看这样的描述不太清晰,下面是一个例子:
假设我们的目标是接收一个 的数据段,那么在最开始的时候,还没有任何数据, 就是 这个范围。
现在接收到了一段新的数据,为 ,那可以求出 ,也就是 范围内的数据完全没有重复。
在填充完 后,我们进行 的操作(注意这里的 不表示差集,表示从 中移除一部分),代表 不再是未填充段。现在 就变成了 。
现在再接收一个新数据段 ,可以发现其完全覆盖了之前的 ,但是我们不需要重复的去遍历之前已经被填充的部分,而是根据 和 求交的结果 来填充。
到此为止,需要实现的数据结构就比较清晰了。我们应该先实现两个类,第一个表示单个的段(Seg
),第二个表示很多段的集合(Segs
)。
对于 Segs
,需要有以下几个功能:
Seg
的交集,即前面提到的 操作。Seg
,即前面提到的 操作。我们知道一个 Segs
里面可能有很多个 Seg
。如果我们要实现 Segs
和 Seg
的求交操作,就需要先找出 的一个子集 ,这个子集里的每一个 Seg
都和 有重合的部分,大概是下面这样的:
1 2(c1) 3(cn) 4Segs a : |---| |-----| |--------| |---|Seg b : |-----------------------|
图中 的 2 号和 3 号 Seg
就和 有重合部分,属于子集 。
可是一个一个的去遍历 中的小 Seg
是线性的复杂度,也没比朴素算法好多少。
这里我采用的优化方法是二分。
我们设子集 的第一个段为 (在上图中就是 2 号段),再设子集 的结束段为 (上图中的 3 号段)。
那么通过观察可以发现, 一定是第一个右端点比 的左端点大的段。而 一定是最后一个左端点比 的右端点小的段。很明显,这样如同“最大化最小值”的表达是可以通过二分解决的,但前提是 Segs
类里对于多个 Seg
的储存必须是有序的。
因为 Segs
类中会处理频繁的插入和删除,我实现的时候采用了 std::set<Seg>
来储存不同的段,同时把这些段维持在一个有序的状态里,方便查询。
这样一来,查询 和 的复杂度就降到了 。
这个查询 和 的函数可以说是整个数据结构里最核心的函数了,如下,别的部分因为不公开代码的规定还是不太方便展示。
template <integral T, bool REC_LEN>typename std::pair<typename Segs<T, REC_LEN>::s_iter_t, typename Segs<T, REC_LEN>::s_iter_t>Segs<T, REC_LEN>::intersect_iter(const Seg<T> &b) const { // return the first and last iterator of the intersected segments // 返回第一个和最后一个和 b 有重叠的段的迭代器 if (b.len() == 0) return {_segs.end(), _segs.end()}; auto fir = fir_GT_iter_r(b.l); // 前文讲的 c1,是第一个右端点比查询段大的段 if (fir != _segs.end() && ((*fir) ^ b).len() == 0) // if no intersection fir = _segs.end(); auto las = lst_LT_iter(b.r); // 前文的 cn,是最后一个左端点比查询段的右端点小的段 if (las != _segs.end() && ((*las) ^ b).len() == 0) las = _segs.end(); // 处理 c1 和 cn 没找到的一些情况 if (fir == _segs.end() && las != _segs.end()) fir = las; if (fir != _segs.end() && las == _segs.end()) las = fir; return {fir, las};}
然后在 StreamReassembler::push_substring
,就可以直接根据 Segs
提供的范围填充数据了:
…… // insert new arrival into _tmp const Seg coverage{index, index + data.size()}; // 新到达数据段的范围 auto &&unfilled_intersect = _unfilled_segs ^ coverage; // 这里我重载了 ^ 这个符号,表示求交 for (auto &s : unfilled_intersect) { // s 表示一个未填充的段 for (size_t i = s.l; i < s.r && (i - _fir_unpushed_idx) <= _capacity; i++) { _tmp[i - _fir_unpushed_idx] = data[i - index]; // 这里 _tmp[0] 对应的是 _fir_unpusehd_idx,即第一个没被放入字节流的位置 // 所以要加一个偏移量,同时 (i - _fir_unpushed_idx) <= _capacity 确保了 _tmp 不会越界 _unassembled_bt++; } } _unfilled_segs -= coverage; // find the first unfilled segment, before this segment, all data are filled……
这样实现的 push_substring
,性能还是比较令人满意的,如下:
[100%] Testing the stream reassembler...Test project /mnt/e/ocourses/st_cs144/sponge/build Start 18: t_strm_reassem_single 1/16 Test #18: t_strm_reassem_single ............ Passed 0.01 sec Start 19: t_strm_reassem_seq 2/16 Test #19: t_strm_reassem_seq ............... Passed 0.01 sec Start 20: t_strm_reassem_dup 3/16 Test #20: t_strm_reassem_dup ............... Passed 0.01 sec Start 21: t_strm_reassem_holes 4/16 Test #21: t_strm_reassem_holes ............. Passed 0.01 sec Start 22: t_strm_reassem_many 5/16 Test #22: t_strm_reassem_many .............. Passed 0.10 sec Start 23: t_strm_reassem_overlapping 6/16 Test #23: t_strm_reassem_overlapping ....... Passed 0.01 sec Start 24: t_strm_reassem_win 7/16 Test #24: t_strm_reassem_win ............... Passed 0.10 sec Start 25: t_strm_reassem_cap 8/16 Test #25: t_strm_reassem_cap ............... Passed 0.07 sec Start 26: t_byte_stream_construction 9/16 Test #26: t_byte_stream_construction ....... Passed 0.01 sec Start 27: t_byte_stream_one_write10/16 Test #27: t_byte_stream_one_write .......... Passed 0.01 sec Start 28: t_byte_stream_two_writes11/16 Test #28: t_byte_stream_two_writes ......... Passed 0.01 sec Start 29: t_byte_stream_capacity12/16 Test #29: t_byte_stream_capacity ........... Passed 0.20 sec Start 30: t_byte_stream_many_writes13/16 Test #30: t_byte_stream_many_writes ........ Passed 0.01 sec Start 53: t_address_dt14/16 Test #53: t_address_dt ..................... Passed 0.05 sec Start 54: t_parser_dt15/16 Test #54: t_parser_dt ...................... Passed 0.01 sec Start 55: t_socket_dt16/16 Test #55: t_socket_dt ...................... Passed 0.01 sec100% tests passed, 0 tests failed out of 16Total Test time (real) = 0.70 sec[100%] Built target check_lab1
后面我还用 perf 生成过火焰图尝试继续优化一下这个实现,生成的结果如下(这个 svg 图是可以交互的,不过需要在单独的一个窗口打开):
这里第一张是 debug 模式下的,第二张是 release 模式下的,可以看到,在 release 模式下,很多函数都被内联了,没法很好的分析。但是 debug 模式中,可以发现在 push_substring
这个函数里,Segs
的操作只使用了很少的时间,反倒是 deque
的字符串操作非常耗时,比如:
_ZNSt5dequeIcSaIcEEixEm -> std::deque<char, std::allocator<char> >::operator[](unsigned long)_ZNSt5dequeIcSaIcEE5frontEv -> std::deque<char, std::allocator<char> >::front()
这样的函数。
很显然,用 deque
去存临时数据不是一个很好的选择,不过鉴于 Segs
的性能是比较良好的,我现在就先不改了,等到 Lab4 优化性能的时候在专门去改善一下字符串拷贝的问题。
这个 Lab 有两个部分,第一个需要实现相对和绝对 seqno 的互相转换,第二部分才真正的使用之前实现的包装类来写 TCP receiver。
要写出这个 Lab 还得对 TCP 报头(header)有一些基本的了解。首先,一段消息在 TCP 协议中可能会被拆成很多小段传输,而每段都会有一个报头。其中的 SYN 和 FIN 分别标志着传输的开始和结束。
即,如果报头中的 SYN 标志位为真,表明这个 TCP 包是整段消息的第一个包,对于 FIN 也是同理(最后一个包)。
一般来说,我们把 0 作为一串数据中第一个的下标(比如字符数组),但是在 tcp 中不是这样的,这个第一个数据的下标是随机出来的。每个 TCP 报头都会包含一个 seqno,表示这个包中数据的启示下标,那我们知道含有 SYN 的包是整段数据的第一个包,这个包的 seqno 自就是整段数据的第一个下标,我们把这个第一个下标称为 ISN (initial sequence number)。
所以为什么要使用随机的 seqno 呢?这主要是因为防止和历史数据混淆,如果在前面的连接中,有些包发送的特别慢(在网络阻塞时),等到连接关闭了接收端才收到。这个时候,如果 seqno 不是随机出来的,刚刚历史数据的 seqno 有很大可能就在接收端的接收窗口中,被错误的接收了[2]。
虽然这个 TCP 数据包的下标是随机出来的,但是我们使用的时候(比如之前实现的 push_substring
函数),还需要转换成从 0 开始的下标,并且这个下标和 seqno 不一样,是 64 位的。
对于这个从 0 开始的下标,实验指导书称之为 abs seqno(即绝对 seqno),我们需要写一个类来专门转换这两种 seqno。
从 abs seqno 转换到 seqno 非常简单,只需要直接返回 ISN + abs_seqno 就行了,自然溢出后直接就能得到 seqno。
但是从 seqno 转换到 abs seqno 就没那么简单了。seqno 是 32 位的,而 abs seqno 是 64 位的。同一个 seqno 可以对应多个 abs seqno。所以要实现的 unwrap
函数里面多了一个 checkpoint
,转换出来的 abs seqno 需要是最接近 checkpoint
的那个。
其实这个问题还是用数学的语言来解释更加清晰一点。设 checkpoint
为 ,seqno 为 ,。
那么问题就转化为了:求一个 (asb seqno),使得 ,同时,最小化 。
我的实现是下面这样的,第一眼看上去可能有些迷惑(实际上下面解释也挺迷惑的,我试了好几种表达方法,但碍于本人的数学和语文水平,都没法把这个想法清晰的表达出来):
//! \param n The relative sequence number//! \param isn The initial sequence number//! \param checkpoint A recent absolute 64-bit sequence numberuint64_t unwrap(WrappingInt32 n, WrappingInt32 isn, uint64_t checkpoint) { WrappingInt32 wrapped_ckp = wrap(checkpoint, isn); // 模 2^32,同时 + isn // 实际上是把一个绝对的 ckp 变成一个在 isn 意义下的相对 ckp int32_t offset = n - wrapped_ckp; static constexpr uint32_t MX32 = numeric_limits<uint32_t>::max(); int64_t ret = offset + checkpoint; if (ret < 0) return ret + MX32 + 1; return ret;}
这里的 offset
代表的是 checkpoint + isn
到要转换的 seqno(在模 的意义下)的距离,可以是正的也可以是负的。
0 2^32 2*2^32 3*2^32| | | ||--------|--------|--------| | | |seqno ckp + isn ckp + isn(实际) |<--->| offset
为了得到一个和 ckp + isn 最近的 seqno,可以把刚刚得到的 offset 加到 ckp + isn 上。相当于是给 seqno 加上了某个 的倍数。
把这个 offset + ckp + isn 减去 isn 就得到了 abs seqno (因为 seqno 和 abs seqno 就差了个 isn)。
所以 abs seqno 就等于 offset + ckp。
不过,直接这样计算可能会有得不到最优解,下面就是直接采用这个方法的计算结果:
0 2^32 2*2^32 3*2^32| | | ||--------|--------|--------| | | seqno ckp + isn |<--->| offset
可以看到,如果直接给当前 seqno 加上 ,其与 ckp + isn 的距离会更近。同时也符合前面提到的 。
可以思考一下,这样得不到最优解的情况只会发生在 的情况下。
因为我们给 seqno 加上任意的 的倍数,其在模 的意义下是不变的。但是在 seqno 加这个倍数的同时,offset 是会变化的(而我们希望最小化 offset)。
比如 (这肯定符合 )。那么
像刚才那样的例子,直接给 seqno 加上 就变成了:
2^32 2*2^32 3*2^32 4*2^32| | | ||--------|--------|--------| | | ckp + isn seqno <--> offset
这时候,利用自然溢出,我们自己根本不用处理这个问题。
注意到在代码里面,储存 offset
的类型是 int32_t
,其有符号,储存的范围刚好是 。
所以一旦 ,offset 就会“自动”给自己加上或减去 的倍数,来最小化自己。
当然,这样的实现还是有 bug 的,比如下面这样:
0 2^32|-----------------------| | | ckp+isn seqno |<---------------->| offset
很明显,这里的 offset 是正数,并且大于 。虽然这个时候给 seqno 减去一个 会让 offset 的绝对值更小,但是这会让 seqno 变成负数,显然是不行的。所以写了下面这几行来防止出现负数,即,如果出现了负数就把这个 加回去。
static constexpr uint32_t MX32 = numeric_limits<uint32_t>::max(); int64_t ret = offset + checkpoint; if (ret < 0) return ret + MX32 + 1;
。
]]>吐槽一下,官方题解写的挺难看懂的,看了好久还是挺迷糊的(其实也是我太菜了)。搞懂之后感觉这题挺妙的,来写下题解。
我们首先需要有一个观察,就是对于 串,最后一个连续字串不会增加可能的获胜人数。比如 时,后面结尾的 就不会增加可能的获胜人数。
为啥呢,设我们设经过任意次数对战后,玩家可能组合的集合为 那么对于任意的 ,连续在环境 中对战任意次数后,最终的赢家一定是 中温度最高的(因为每个剩下的玩家都会需要连续在环境 中对战,唯一能胜出所有对战的玩家一定是最大的)。同理, 中的玩家连续在环境 中对战任意次数后,最终的赢家一定是 中温度最低的。
例如, 时,最后胜出的一定是 号玩家。
这样一来,如果结尾段是 ( 结尾同理,后面为了方便先用 的例子了),我们只需要算出前面的部分最多能构造出多少种最大值(玩家温度)不同的玩家组合,就能知道当前长度的 的答案了。
现在考虑如何构造出最多的最大值不同的玩家组合。如果玩家数量为 ,那么没有经过任何对战时,最大值就是 。想要让最大值不同,只能删除当前的最大值。
刚刚的描述可能比较抽象,考虑 这个例子就能较好的理解了。
对于第一个 ,除了玩家组合中温度最低的不能删掉(这个不管怎样都能赢),其他的都能删掉(让温度最低的玩家和其他任意玩家对战),共有如下几种情况:
观察发现,只有第一种情况改变了最大值(为啥呢?因为他删掉了最大值)。其他的情况中,必须要连续的删除结尾的一段数字,才能改变最大值。
这时候第二个 就起到作用了。对于第二种情况,其可以把 删掉,使得玩家组合的最大值变为 。我们按照这个规律可以进一步推广出这个结论:设结尾段前面连续段的长度为 ,能产生的最大值不同的玩家组合就为 ,具体来说,可能的最大值范围是 。(这里 是因为可以选择不改变最初的最大值)。
现在为止,我们已经能求出 只有两个连续段时的答案了。即 ,其中 ,表示结尾段的长度。
这个时候我们把例子换成 ,看看例子是否还成立( 不止一段)。同样,可以列出第一次对战后的可能玩家组合。因为第一个环境是 ,所以只能删除除了最大值以外的其他玩家:
虽然这些情况中,没有任何一种改变了玩家组合的最大值,但是我们只要在接下来的 环境中再对战一次,删除掉 ,就产生了 种新的最大值。对于第一种情况,最大值变为了 ,第二种变为了 。一共也是 种答案。
那如果段数再多一点呢?比如 。这个结论还是成立的。我们可以把 看作一组环境,其中 可以删除 范围内的任何玩家, 则为 中的任何玩家。把这两种环境组合起来就可以从 种任意挑选 个删除,构造出 种不同的最大值(取决于你删除前多少个温度最大的玩家)。
通过前面的例子我们已经分析出了,解决问题只需要知道 的 子串中,最后一个连续段的长度。不过对于每个 都扫一遍太慢了,需要采用类似动态规划的东西,具体我在代码注释里有解释:
#include<bits/stdc++.h>using namespace std;#define ll long longint main(){ int t; cin >> t; while(t--){ int n; string s; cin >> n >> s; int cur0len = 0; // 最后一个连续段如果是 0 的话 cur0len 表示其长度, // 如果不是的话 cur0len 就是 0 int cur1len = 0; // 和 1 相同 int curn = 2; // 最开始是两个玩家 for (char ch : s){ int x = ch - '0'; if (x == 0){ cur0len++; // 当前是 0 的话 0 结尾的连续段会比原来更长 cur1len = 0; // s 的最后一个不是 1 了 } else { cur1len++; cur0len = 0; } cout << curn - (x ? cur1len : cur0len) << " "; // 前文中的 n - k curn++; } cout << '\n'; }}
]]>upd@2022/11/5:添加了具体实现,修正了推导中的一些符号错误
反向传播算法的主要目的是计算出神经网络中误差对于偏置和权重等参数的偏导数,以此来进行梯度下降。本文的上半部分主要是算法的推导,后半部分使用全连接神经网络和 mnist 数据集实现手写数字识别。
这个算法对我来说还是很难理解的,为了防止自己忘掉,就写了这篇笔记(还有就是神经网络里这些公式的上标下标太多了,如果用真的笔记本写,稍微一不小心就写错了)。如果你对神经网络还没有基本的概念,推荐去看 3b1b 的神经网络系列视频
这里必须说一句 MqCreaple 大佬真的太巨了,看了视频之后直接手推了全部的公式更令人震惊的是居然把我这种人教会了。
先考虑一个最简单的全连接神经网络,其每层只有一个节点,那么可以画出下图,代表单个节点的输出值 的计算流程(通过箭头起点的变量以及对应的函数可以得到箭头指向的变量)。
graph TB alm1["a(l-1)"] & w & b --> z --> al["a(l)"] --> 误差值 y-->误差值
如果我们写成函数的形式,是下面这样的:
那如果我们想要根据误差值来对于权值 梯度下降,就需要求出误差对权值的偏导数,即:
使用偏导数是因为 的计算依赖于三个变量,而我们希望知道改变 后对误差值的影响。
求偏导时,我们假设其他变量都是常数,只有一个变量在变化(以及被这个变量直接影响的其他变量,这种情况下是上图中 的链),那可以写出如下的式子(常数下有下划线):
这个时候可以使用链式求导:
写成另一种形式(更方便之后使用)就是:
然后可以求出链式法则的各个中间偏导数,进一步还可以写作下面的形式(假设误差函数是平方误差函数):
注意 这里不能反了(举个例子, 过大的时候我们希望导数也大,这样可以给要调整的值减去导数)。
上面展示的是误差对于权值的偏导,对于偏置和上一层的输出,只需要替换掉 公式中的 即可。或者说让上一层的输出和这一层的偏置来影响 ,而不是权值。
对于 ,替换成:
对于 ,替换成:
现在考虑如下一个网络:
graph TB al["a(l)"] & wlp["w(l + 1)"] & blp["b(l + 1)"] --> zlp["z(l + 1)"] --> alp["a(l + 1)"] --> ...别的很多层 __["w(l + 2)"] & _["b(l + 2)"]-->...别的很多层
也就是 的下一层不直接连接误差函数,而是有多层。那 就不能直接求出了(也就不能直接求出 和 的偏导),因为 在很多层之后。这时候就需要用到反向传播的思想了。
我们知道:
观察式子可以发现我们能从后层推出前层的 ,所以在求权值和偏置的偏导前,我们需要先从输出层开始,一点一点的把 向前传。
在刚刚的例子中,反向传播算法的过程还是很清晰的,没有任何的线性代数。不过在真实的神经网络中,每层有多个节点,如下:
graph LR l1["(l-1)1"] & l2["(l-1)2"] & l3["(l-1)3"] ---> lp1["l1"] & lp2["l2"] & lp3["l3"]
表示一条从 层 节点连接到 层 节点的边。要如何求 呢?
我们其实还是可以把原来的公式带进来,毕竟多节点的层本质上还是由多个单节点的层组成的,不过要注意下标:
注意这里和 层有关的变量我们都使用的是 ,比如 (直观理解的话就是,改变单位权重,上一层的输入越大就对最终的误差函数影响越大),和 层有关的使用的都是 。
因为 的下标是一样的,我们为方便书写矩阵运算的公式,就叫他 。
重写一下刚才的公式:
写成矩阵形式的话, 随行增长, 随列增长。那上面的导数就是:
观察发现,这个矩阵其实就等于:
这样就可以使用矩阵运算库(如 numpy)来加速了。
这个就相对简单,因为 等于 (见前文),所以可以很方便的计算。
注意这里的误差对偏置的导数就等于前面用到的 ,所以实现的时候一般先计算这个,然后再把 带入到前文的式子中。
再观察一下前面的多节点神经网络,不过这次主要关注单个 节点对后面的影响:
graph LR l1["(l-1)1"] ===> lp1["l1"] & lp2["l2"] & lp3["l3"]
不难发现, 可以对每个 都产生影响。如果我们把 层当成一个接收 个 ,输出 个 的函数。那么现在每个输入的变量都在变化,求的就不是偏导数了(偏微分),而是全微分[1](total derivative)。
根据全微分的定义,应该把每个参数的偏导加起来,在我们的例子中,就是:
其中 这个部分需要比较小心的处理。我们需要清楚 是连接 层的 节点和 层 节点的边。
那么因为
可以推出
而
在前面已经解释过了,就等于 和误差对偏置的导数。
重写整个式子,可以得到:
现在可以思考如何以矩阵运算的形式得到 。
一个可行的方法是在 和 之间做乘法。
注意我们用 这个下标来累加,所以如果我们把 放在左边,其 坐标应该随着列数增加而增加( 的矩阵乘法中,会对 的行和 的列做向量的点乘)。 而把 放在乘法的右边,就需要让其的 下标随行数增长。
因为 的 本来是随行增加的,所以要对其进行转制。
最后可以得到:
其中 是一个列向量。
这个部分中会使用刚刚讲到的反向传播算法来实现一个简单的全连接神经网络,并且使用这个神经网络来识别 mnist 数据集中的手写数字。
说实话 mnist 这个数据集挺坑的,用的是二进制储存格式,所以想要读取数据集里的内容还得费点功夫。
代码如下[2]:
# 在项目中的位置:./src/utildef load_mnist(path: str, pref: str = "train"): """ path: 数据集路径 data_type: 数据集名称前缀(train or t10k) """ label_path = os.path.join(path, "{}-labels.idx1-ubyte".format(pref)) img_path = os.path.join(path, "{}-images.idx3-ubyte".format(pref)) with open(label_path, 'rb') as lfile: # rb 表示 read binary magic, n = struct.unpack('>II', lfile.read(8)) labels = np.fromfile(lfile, dtype=np.uint8) with open(img_path, 'rb') as ifile: # ifile 为 image file magic, num, rows, cols = struct.unpack('>IIII', ifile.read(16)) images = np.fromfile(ifile, dtype=np.uint8).reshape( len(labels), 28 * 28) label_one_hot = np.zeros((len(labels), 10), dtype=int) for i in range(len(labels)): label_one_hot[i] = np.eye(10)[labels[i]] return label_one_hot, images / 255.0
里面这个 struct
的包看起来可能比较迷,实际上他就是一个专门处理二进制数据的类。
struct.unpack('>II', lfile.read(8))
这句话的意思是就是从 lfile
里读取两个大端字节序的 4 字节无符号整数。>II
中的 >
表示了文件是以大端字节序储存的,而 I
则表示读取的是 4 字节无符号整数。
下面的 np.fromfile
也是一个作用,直接把二进制文件转换成了一个 np.array
,不用指定字节序应该是因为 numpy 默认的就是大端。
要注意 mnist 数据集中图片单个像素的范围是 的整数。而我们希望其变成 的浮点数,所以在输出时间除 255。
想要图片在 范围中主要是因为,如果把一个比较大的数字 sigmoid
函数就会出现溢出问题(虽然每层权值的初始值是 -1 到 1 之间随机生成的,但是有时候会输出较大值),sigmoid 的定义如下:
这里这个 过小那 就会变成一个特别大的数字,因为 numpy 实际上是调用 c 完成计算工作的,所以不像 python 那样自带高精,这样的数字自然就会造成溢出。
预处理的最后一部是把标签转换成 one-hot(中文翻译为独热)形式(方便最后求误差对整个神经网络的梯度),这里可以用 np.eye(x)
这个函数,它可以生成一个 x\timesx 的对角线矩阵,那么 np.eye(10)[labels[i]]
自然就是 labels[i]
的对应独热编码了。
单层神经网路本质上其实是一个函数,其接收一个向量,输出一个向量。不过这个函数是依赖于很多变量的,比如权重和偏置,所以我们希望用一个类将他们存起来。
同时,在反向传播的过程中,也需要用到类中储存的这些变量,所以最好能实现一个函数,其接收误差对当前层的导数,以其他必要的数据,返回误差对前层的导数(反向传播)。
最后,对于本层来说,我们还需要提供一个接口来更新其权重和偏置(如果不同层的数据不是权重和偏置,可以新建一个抽象类专门表示不同层的数据)。
根据这些需求,可以写出层类的抽象类:
注意每个函数的参数名都是符合之前的数学公式的,如果有不明白的可以看前文。
# 在项目中的位置:./src/layer.pyfrom typing import *import numpy as npfrom nptyping import NDArray, Shape, Floatfrom . import utilclass abs_layer(): def __init__(self, insize: int, outsize: int, activ: util.Dfunc = util.sigmoid): self.insize = insize self.outsize = outsize self.activ = activ def get_z(self, ipt: NDArray) -> NDArray: """ 根据输入返回一个没有经过激活函数的输出 """ pass def get_a(self, ipt: NDArray) -> NDArray: """ 根据输入返回经过激活函数的输出 """ pass def get_derivatives(self, prev_a : NDArray, DE_over_cur_a: NDArray, cur_z: NDArray) -> List[NDArray]: """ prev_a : 前面一层经过激活函数的输出 DE_over_cur_a : 误差对当前层输出的导数 cur_z : 当前层没经过激活函数的输出 """ pass def descent(self, w, b): """ w : 权重的梯度 b : 偏置的梯度 """ pass
这里的 util.Dfunc
表示的是一个可导的函数,定义如下:
# 在项目中的位置:./src/util.pyclass Dfunc(): """ 表示一个可导的函数,f 是原函数,df 是导数 如果 f 是多元函数,则 df 返回的应该是一个向量(不同输入参数的偏导数) """ def __init__(self, func: Callable, Dfunc: Callable): self.f = func self.Df = Dfuncsigmoid = Dfunc(lambda x: 1 / (1 + np.exp(-x)), lambda x: np.exp(-x) / ((1 + np.exp(-x)) ** 2))sq_err = Dfunc(lambda label, predict: np.sum((predict - label) ** 2), lambda label, predict: 2 * (predict - label))
对于一个全连接神经网络,可以有如下的实现:
# 在项目中的位置:./src/layer.pyclass dense_layer(abs_layer): def __init__(self, insize: int, outsize: int, activ: util.Dfunc = util.sigmoid) -> None: super(dense_layer, self).__init__(insize, outsize) self.wts = np.random.rand(outsize, insize) * 2 - 1 self.bias = np.random.rand(outsize) * 2 - 1 def get_z(self, ipt: NDArray) -> NDArray: return np.matmul(self.wts, ipt.reshape(ipt.size, 1)).reshape(self.outsize) + self.bias def get_a(self, ipt: NDArray) -> NDArray: return self.activ(self.get_z(ipt)) def get_derivatives(self, prev_a : NDArray, DE_over_cur_a: NDArray, cur_z: NDArray) -> List[NDArray]: if (DE_over_cur_a.size != self.outsize): raise Exception("size of DE_over_cur_a ({}) doesn't equal to number of node in this layer ({})".format(DE_over_cur_a.size, self.outsize), DE_over_cur_a ) Dbias : NDArray = DE_over_cur_a * self.activ.Df(cur_z) DE_over_prev_a: NDArray = np.matmul(self.wts.T, Dbias) Dweight = np.matmul( Dbias.reshape(Dbias.size, 1), prev_a.reshape(1, prev_a.size) ) return [DE_over_prev_a, Dweight, Dbias] # 返回三个变量,误差对上层输出,对当前层权重和偏置的偏导 def descent(self, w : NDArray, b : NDArray) -> None: self.wts -= w self.bias -= b
除了 get_derivatives
,其他几个函数都比较好理解,下面大概解释一下。
误差对上层偏导的公式如下:
对应到实现中,就是这一行:
DE_over_prev_a: NDArray = np.matmul(self.wts.T, Dbias)
这里的 Dbias
就等于 ,如下:
对应代码中的:
Dbias : NDArray = DE_over_cur_a * self.activ.Df(cur_z)
误差对权值导数的公式为:
对应如下代码:
Dweight = np.matmul( Dbias.reshape(Dbias.size, 1), prev_a.reshape(1, prev_a.size) )
网络类可以把不同的层连接在一起。把上一层的输出作为下一层的输入传递。也可以从误差函数开始反向传播:
# 在项目中的位置:./src/net.pydef __init__(self, layer_sizes: List[int] | None = None, layers: List[layer.abs_layer] | None = None) -> None: """ layer_sizes: 第一个是输入大小,最后一个是输出大小 """ if (layers != None and layer_sizes != None): raise Exception( "should only provide either layer_sizes or layers", self ) if (layers == None): layers: List[layer.abs_layer] = [] for i in range(0, len(layer_sizes) - 1): # 这一层的输入等于上一层的输出,等于下一层的输入 layers.append(layer.dense_layer( insize=layer_sizes[i], outsize=layer_sizes[i + 1])) self.lays = layers self.num_lay = len(layers) self.err = util.sq_err for i in range(1, self.num_lay): if (self.lays[i - 1].outsize != self.lays[i].insize): raise Exception( "layer {}'s output ({}) not equal to layer {}'s input ({})".format(i-1, self.lays[i-1].outsize, i, self.lays[i].insize), self.lays)
这里有两种方法可以初始化,可以直接提供不同的 layer
,让网络类把它们组合在一起,也可以输入一个表示不同层节点数量的类,让初始化函数自动创建对应的全连接网络。
def get_predict(self, ipt : NDArray): lay_z: List[NDArray] = [] lay_a: List[NDArray] = [] lay_z.append(self.lays[0].get_z(ipt)) lay_a.append(self.lays[0].activ.f(lay_z[0])) for i in range(1, self.num_lay): lay_z.append(self.lays[i].get_z(lay_a[i - 1])) lay_a.append(self.lays[i].activ.f(lay_z[i])) return [lay_z, lay_a]def get_simple_predict(self, ipt : NDArray): return self.get_predict(ipt)[1][-1]
这里神经网络的第一层比较特殊,不和上一层的输出相连,而是直接用的 ipt
,所以要特殊处理。
def bp(self, ipt: NDArray, label: NDArray, lrate: float): lay_z, lay_a = self.get_predict(ipt) # 每层的输出 lay_Dw: List[NDArray] = [np.zeros(0)] * (self.num_lay) # 对权值的导数 lay_Db: List[NDArray] = [np.zeros(0)] * (self.num_lay) # 对偏置的导数 DE_over_a: List[NDArray] = [np.zeros(0)] * (self.num_lay) # 误差对节点输出的导数 DE_over_a[-1] = self.err.Df(label, lay_a[-1]) for i in reversed(range(1, self.num_lay)): DE_over_a[i - 1], lay_Dw[i], lay_Db[i] = self.lays[i].get_derivatives( prev_a=lay_a[i - 1], DE_over_cur_a=DE_over_a[i], cur_z=lay_z[i] ) lay_Db[0] = self.lays[0].activ.Df(lay_z[0]) * DE_over_a[0] lay_Dw[0] = np.matmul( lay_Db[0].reshape(lay_Db[0].size, 1), ipt.reshape(1, ipt.size) ) for Dw, Db, lay in zip(lay_Dw, lay_Db, self.lays): lay.descent(Dw * lrate, Db * lrate)
这里主要的作用就是调用每层的 get_derivatives
,得到不同层输出,权值和偏置的导数。
不过有两个特殊的地方,首先误差对最后一层的导数需要通过误差函数和标签得到,如下:
DE_over_a[-1] = self.err.Df(label, lay_a[-1])
误差对于第一层权值和偏置的导数也只能通过输入的图片得到:
lay_Db[0] = self.lays[0].activ.Df(lay_z[0]) * DE_over_a[0] lay_Dw[0] = np.matmul( lay_Db[0].reshape(lay_Db[0].size, 1), ipt.reshape(1, ipt.size) )
可以看到准确率有 96%,还是很不错的(大概跑了一分多钟吧)。当然训练的方式还有很大优化空间,我也没怎么调参。
]]>一个多月终于断断续续搞完了第二本书里的内容。和前面两篇文章一样,这篇也会写一些我个人花了较长时间才搞懂的部分,以及一些我在原书基础上加的新功能。
对于原书就有的功能,会直接使用书上的代码,如果是我新加进去的功能,会使用自己的代码。因为我的代码在原书基础上做了较大幅度的变化(即使是原来就有的功能),所以只看一段代码可能不太明白,这里可以参考我的 GitHub 仓库: https://github.com/ttzytt/RTOW
这部分主要是一些小细节我当时没太理解。首先是
bvh_node::bvh_node( std::vector<shared_ptr<hittable>>& src_objects, size_t start, size_t end, double time0, double time1)
首先是这个构造函数的范围问题。每颗子树是不负责 end
位置的 hittable
的。也就是这个构造函数负责的是 [start, end)
这样的一个区间。
这也解释了代码中 sort
的用法:
std::sort(objects.begin() + start, objects.begin() + end, comparator);
std::sort()
会排序的其实是 [)
这样的一个区间(我之前居然没注意到这个)。所以这里的 objects.begin() + end
其实没有包括 end
。
在排序 vector
等容器时,使用的方法是 sort(vec.begin(), vec.end())
乍一看没有把 .end()
位置的元素包含进去,但其实 .end()
指向的是一个空的,或者说是最后一个元素更后面的位置(这我之前也没注意到),所以用这样的方法可以把整个 vector
都排一遍序。
class sphere : public hittable { ... private: static void get_sphere_uv(const point3& p, double& u, double& v) { // p: a given point on the sphere of radius one, centered at the origin. // u: returned value [0,1] of angle around the Y axis from X=-1. // v: returned value [0,1] of angle from Y=-1 to Y=+1. // <1 0 0> yields <0.50 0.50> <-1 0 0> yields <0.00 0.50> // <0 1 0> yields <0.50 1.00> < 0 -1 0> yields <0.50 0.00> // <0 0 1> yields <0.25 0.50> < 0 0 -1> yields <0.75 0.50> auto theta = acos(-p.y()); auto phi = atan2(-p.z(), p.x()) + pi; u = phi / (2*pi); v = theta / pi; }};
这里使用了一个 atan2
的函数,而不是普通的 atan
函数。我们知道 tan
这个三角函数会返回圆对应角度的切线的斜率。那 atan
就是返回某个斜率的对应角度。但是我们在求纹理坐标时实际上希望从圆上的某个坐标得到对应的角度,当然可以直接使用 atan(y/x)
来先求斜率再求角度。
但问题就出在描述圆的是一个方程而不是函数,一个 x 坐标可能对应多个 y 坐标。那么一个斜率就可能对应多个角度。具体来说,虽然 和 对应的角度不一样,但是他们的斜率是一样的。如果我们使用 atan
的话还需要自己再判断一遍坐标的符号,而 atan2
相当于做了这个工作。
virtual color value(double u, double v, const point3& p) const override { auto sines = sin(10*p.x())*sin(10*p.y())*sin(10*p.z()); if (sines < 0) return odd->value(u, v, p); else return even->value(u, v, p);}
这段代码把三个份量上的值加上 乘了起来,如果结果是正数就返回一种颜色,反之返回另一种。乍一看可能不太好理解,如果先画一个二维的版本就好很多了:
加入另一个轴后,因为 的符号周期性的变化,所以可以看到不同的层,每层之间的颜色会翻转一下,而单层内的话因为符号没变所以可以直接当成上面二维的版本:
不过说实话我认为周期性的函数也不止三角函数这一种。书中这么写只是为了获得正负号,而不是具体的值,使用 属实是有点浪费计算资源了。
一个很简单的例子就是让 模 ,如果结果小于 就返回正数,反之亦然。要简洁一点的话写成下面这样也可以:
不难发现书中的棋盘格是基于点在空间中的绝对坐标的。所以才会出现上图那样的分层。既然我们已经可以计算球面的纹理坐标了(其他 hittable
的纹理坐标在书中也有讲,比如长方形片),其实可以做一个基于物体表面的棋盘格纹理,如下:
class surface_checker : public texture { public: using text_array = std::vector<std::shared_ptr<texture>>; surface_checker() = default; surface_checker(const text_array& _texts, const std::pair<f8, f8> _siz = {514, 114}) : texts(_texts), polar_azim_siz(_siz) {} virtual color value(f8 polar, f8 azim, const pt3& p) const override { int x_idx = (i8)(azim * polar_azim_siz.first); int y_idx = (i8)(polar * polar_azim_siz.second / 2.0); // 极角只跨半球,所以想要整个球的垂直方向有 polar_azim_siz.second 这么多的格子,要先除以二 return texts[(x_idx + y_idx) % texts.size()]->value(polar, azim, p); } text_array texts; std::pair<f8, f8> polar_azim_siz; // 垂直方向和水平方向有多少格};
这里的 text_array
允许了棋盘中有多于两种颜色,而 (azim * polar_azim_siz.first)
会把原本 的纹理坐标范围放大到 polar_azim_siz.first
,确保球上有 polar_azim_siz
的颜色变化。最后就可以得到如下的效果:
生成该场景的代码如下:
scene surf_check_sc() { hittable_list world; auto checker1 = make_shared<surface_checker>( surface_checker::text_array{ make_shared<fixed_color>(color(0.2, 0.3, 0.1)), make_shared<fixed_color>(color(0.9, 0.9, 0.9)), make_shared<fixed_color>(color(0.3, 0.2, 0.15)), make_shared<fixed_color>(color(0.15, 0.3, 0.9))}, std::pair<f8, f8>{60, 60}); auto checker2 = make_shared<surface_checker>( surface_checker::text_array{ make_shared<fixed_color>(color(0.2, 0.3, 0.1)), make_shared<fixed_color>(color(0.9, 0.9, 0.9)), }, std::pair<f8, f8>{30, 30}); world.add(make_shared<sphere>(pt3(0, -10, 0), 10, make_shared<lambertian>(checker1))); world.add(make_shared<sphere>(pt3(0, 10, 0), 10, make_shared<lambertian>(checker2))); f8 asp_ratio = 1.0; vec3 lookfrom = pt3(13, 2, 3) * 2; vec3 lookat = pt3(0, 0, 0); f8 vfov = 40.0; auto dist_to_focus = 10.0; auto aperture = 0; vec3 vup(0, 1, 0); auto cam_ptr = make_shared<camera>(lookfrom, lookat, vup, vfov, asp_ratio, aperture, dist_to_focus, aperture, 1.0); return scene(make_shared<bvh_node>(world), blue_sky_back_ptr, cam_ptr);}
柏林噪声是书中一个比较难理解的点,不过柏林噪声是基于普通的值噪声的。值噪声其实就是在空间中的整数坐标上随机的生成一些随机数,再利用这些整数的坐标来给别的坐标线性插值(线性插值不懂的可以见这个链接,个人认为讲的很清楚)。
大概就是下面这样的[1]:
垂直和平行线交错(整数坐标)的点会随机的生成一个随机数,而图中的 p 点会基于周围四个关键点(也就是坐标为整数的点,这些点会产生随机数)做线性插值,最终 p 点的值取决于离周围四个关键点的距离和周围四个关键点的随机值。
下面就是一个二维值噪声的例子:
生成代码如下:
import numpy as npimport matplotlib.pyplot as pltfrom math import *XLEN = 25 # 产生多少个整数点YLEN = 25DIFF = 0.05ptsx = np.arange(0, XLEN, DIFF)ptsy = np.arange(0, YLEN, DIFF)xs, ys = np.meshgrid(ptsx, ptsy)z_orig = np.random.random((XLEN + 1, YLEN + 1))z_interped = np.zeros((round((XLEN) / DIFF), round(YLEN / DIFF)))def lerp(a, b, t): return a + t * (b - a)def lerp2(ld, rd, lu, ru, tx, ty): # 二维线性插值 # left down, right down, left up, right up upmid = lerp(lu, ru, tx) dnmid = lerp(ld, rd, tx) return lerp(dnmid, upmid, ty)for i in range(XLEN): for si in range(round(1 / DIFF)): # step i for j in range(YLEN): for sj in range(round(1 / DIFF)): z_interped[i * round(1 / DIFF) + si][j * round(1 / DIFF) + sj] = lerp2( z_orig[i][j], z_orig[i + 1][j], z_orig[i][j + 1], z_orig[i + 1][j + 1], DIFF * si, DIFF * sj)plt.imshow(z_interped, cmap=plt.cm.gray)plt.savefig("./2d.png", dpi = 150, format = 'png')plt.show()
很容易看出这种噪声不自然,你甚至可以从图中隐约的看出坐标轴。。。虽然整张图看起来比较随机,但仔细观察就能发现整张图都是由很多小的 “方形色块” 拼凑而成的。
这是因为每个关键点对于各个方向的影响是相同的,而线性插值会让这个影响变成类似菱形的形状。下图中中间的点就是一个关键点,这个点随机出来的值比较低,所以是黑色的,可以看出这个黑色向周围发散的形状是菱形。
要改变这种情况也很简单,让某个关键点对周围的影响在不同的方向上不同。既然需要表示方向,我们可以很自然的想到向量。
现在我们在每个关键点上产生一些随机的单位向量,记为 (关键点 上生成的随机向量)像下面这样[2]:
现在如何使用这些随机向量来达成不同方向影响就成了一个问题。一个比较自然的想法是考虑某个点相对于关键点的位置。我们可以把这个距离向量标记为 (对于关键点 的距离),像下图这样[2]:
如果 和 的方向相近,我们就可以让这个点更亮,相反,如果 和 的方向相反,那么这个点的颜色应该偏暗。
这样的效果可以通过点积来达到,其实就是把 投影到 后的长度。结果方向相反是负数,相同是正数,垂直的话是零。
我们把这个点乘记录下来:
接下来就可以用值噪声的方式对周围四个点做线性插值了。或者说我们把 当作了原来值噪声中关键点上的值。而现在这个值对于每个位置来说会变化。
下面这张图展示了柏林噪声的效果,其中不同的箭头代表不同的 ,越蓝值越小,越黄值越大 [2]:
注意看图中的三个框。
可以很明显的看出,柏林噪声的生成的噪声并没有值噪声的方块感。
观察下面实现湍流的代码:
double turb(const point3& p, int depth=7) const { auto accum = 0.0; auto temp_p = p; auto weight = 1.0; for (int i = 0; i < depth; i++) { accum += weight*noise(temp_p); weight *= 0.5; temp_p *= 2; } return fabs(accum);}
其中 turb
这个函数自身比较好理解,就是把很多频率的柏林噪声以一定的权重叠加在一起。最后的 fabs
看起来是为了让返回值符合 的范围,实际上还有别的目的。比如如果我们把最后一行换成 return (accum + 1) * 0.5
,虽然让返回值符合了范围,但是看起来的效果却和原写法非常不同。
下图是中的蓝线是 的函数图像,红线是 的函数图像,而绿线是 :
如果采用绿线的修正方法,原来暗的地方修正过后还是暗,反之亦然。如果采用红线的修正方式,则只有原本亮度中等或者说明暗过度的地方会变暗,不管是暗部还是亮部在修正过后都会变亮。对比书中两种材质的一个特征区域可以更明显的看出红色修正方式的特点:
左图的黑边像是给右图的黑色区域描了一个边,符合刚刚只有过渡部分会变暗的预测。
代码中的 noise(p)
返回的最大值是 1,而 weight
最开始的值也是 1。这样的话 abs(accum)
是有可能大于 1 的。这显然是没道理的,因为不可能光线打到某个物体后还变亮(除光源)了。我之前给这个博客的博主发过这个问题有关的邮件,不过他表示他也不知道,可能只是概率问题使得大于 1 的值很少见。
随后我又查看了 Ken Perlin 1985 年在 SIGGRAPH 上的论文[3],其中并没有很严格的描述,也没有实际的代码,不过基本的思路是清楚的。令我奇怪的一个点是整篇文章没有说新的噪声算法是用于改进值噪声的,主要关注的是柏林噪声的效果不受各种空间变换的影响(难道说他为了发明一个和空间变换无关的噪声算法,顺便把值噪声改进了,这也太离谱了):
Noise()
In order to get the most out of the PSE and the solid texture approach we have provided some primitive stochastic functions with which to bootstrap visual complexity. We now introduce the most fundamental of these.Noise()
is a scalar valued function which takes a three dimensional vector as its argument. It has the following properties :
- Statistical invariance under rotation (no matter how we rotate its domain, it has the same statistical character)
- A narrow bandpass limit in frequency (its has no visible features larger or smaller than within a certain narrow size range)
Appendix. Turbulence
A suitable procedure for the simulation of turbulence using the Noise() signal is :function turbulence(p) t = 0 scale = 1 while (scale > pixelsize) t += abs(Noise(p / scale) * scale) scale /= 2 return t
扰动的伪代码和书中的基本没有区别,但是对于 Noise()
函数 Perlin 只说了其接收一个点的位置,返回一个标量,没有标量的范围,所以还是比较令人疑惑的。
但是下文的一句话还是令人感觉他是想返回一个 范围内的值的(他提过使用的颜色是 这样的):
By evaluating Noise() at visible surface points of simulated objects we may create a simple “random” surface texture (figure Spotted.Donut) :
color = white * Noise(point)
这个问题实在是困扰了我比较久,如果你知道正确的解释是什么,欢迎在评论区提出,我过一段时间也准备去 stackoverflow 提个问,如果有结果我会更新这篇博客。
最初看到书中下面几个公式的时候我是比较懵逼的:
上网找了一圈后发现其实是旋转矩阵,公式的推导如下(前面这个公式是绕 z 轴旋转的,我们可以简单理解为二维平面上的旋转矩阵)[4]:
我们先把 和 用极坐标的方式表示出来:
在原来的角度上加上 :
使用如下两个两角和差公式:
带入 的极坐标形式得:
绕 轴的旋转基本和这个没区别,但是绕 轴的比较令人疑惑了。
其他两个轴的旋转都是 , 这种形式,唯独到了绕 轴这里变成了 和 这种形式。
因为绕 旋转中 的符号变了,所以很明显我们实际上旋转的不是 而是 。这是因为我们希望的旋转方向和右手坐标系中的旋转方向是“不同的”。
这么说很模糊,可以先一步一步来,搞清楚自己想要的旋转方向是怎么样的:
y+ | | |x- ------- z --------- x+ | | | y-
这是一个右手坐标系下我们从 轴方向观察的示意图,注意 轴的正方向是朝着观察者的。很明显,如果我说想要绕着 轴旋转 ,希望的就是把某个东西从 的正方向转到 的正方向。又或者是 ,,,总之就是逆时针旋转的。
再考虑绕 轴的旋转:
y+ | | |z+ ------- x --------- z- | | | y-
同样, 轴朝着观察者的方向,也是逆时针旋转,从 转到 。
现在我们再把公式加进来看一看是否符合我们的预期,也就是从 到 。
假设当前 (即在 上),旋转 后 就应该是在 上,即
我们先考虑 的公式
其次是 :
看起来没问题
现在考虑绕 轴的旋转:
z- | | |x- ------- y --------- x+ | | | z+
我们会发现如果还是逆时针旋转 并且起点在 上的话,那应该转到 上,如果我们这个时候还使用和其他两个轴的公式,就会转到 上,如下:
果然,把公式中 的符号改变一下,就能解决问题了。
那绕 旋转有什么特殊的呢?这里举一个例子:对于另外两个轴的旋转,如果旋转角度方向是逆时针,并且是从编号小的轴转到编号大的轴(如 ),那么这两个轴的方向都是相同的 ()。
对于绕 轴的旋转,如果逆时针从小编号轴转到大编号轴,那这两个轴的方向是不同的 ()。
毕竟三角函数一开始就是为了平面直角坐标系(xy 平面)设计的,现在应用到了一个符号不一样的平面,肯定得做些调整。
现在你可能会想,如果换成左手坐标系了是不是就能解决这个问题?对也不对,因为绕 轴的变换确实不用换 的符号,但是绕 轴的就需要了(换了 轴的方向,相当于从反方向看刚刚的 xy 平面,那么逆时针从 转到 就变成 或是 了)。
待更新
待更新
待更新
]]>除了朗伯体,RTOW 中还有个比较有趣的地方就是相机类的实现,特别是背景虚化这部分。
先来看一下相机类里一个相对简单的部分–相机的定位。只要通过三个参数就能确定相机的位置,分别是相机本身的位置(lookfrom),相机正在拍摄的位置(lookat)和表示相机上方位置的向量(vup),书里的图就能很好的解释:
在构造函数中,我们需要把这三个参数转换成表示相机朝向的三个参数,以及做一些对焦距,光圈和 fov 的处理,书中没有在这部分花很多的篇幅,我当时想明白也花了挺久的,下面是我对书中实现的一些思考。
因为我对书中的代码稍作了一些修改(主要是命名?)所以先贴一下代码:
#pragma once#include "rtow.h"// 这里的 f8 就是 double (八个字节的 float)class camera { public: camera(vec3 lookfrom, vec3 lookat, vec3 vup = vec3(0, 1, 0), f8 vfov = 90, f8 asp_ratio = 16.0 / 9.0, f8 aperture = 0, f8 foc_len = 1) { f8 deg_fov = deg2rad(vfov); f8 half_hei = tan(deg_fov / 2); // 对边比临边,但是临边是 1 f8 half_wid = half_hei * asp_ratio; cam_z = (lookfrom - lookat).unit_vec(); // z 和镜头指向位置是反的 cam_x = cross(vup, cam_z).unit_vec(); // 同时和 vup,z 垂直 cam_y = vup.unit_vec(); horizon = 2 * half_wid * cam_x * foc_len; // 焦平面的横竖边框 vertic = 2 * half_hei * cam_y * foc_len; orig = lookfrom; lower_left_corner = orig - horizon / 2 - vertic / 2 - cam_z * foc_len; // 焦平面的左下角 len_radius = aperture / 2; } inline ray get_ray(f8 x, f8 y) const { // x 和 y 的范围:[0, 1] // 相机传感器的像素点坐标 vec3 rd = len_radius * rand_unit_disk(); vec3 offset = cam_x * rd.x() + cam_y * rd.y(); ray r; r.orig = orig + offset; r.dir = lower_left_corner + x * horizon + y * vertic - orig - offset; // 产生一个从 orig + offset 到对应像素的向量 // 因为 ray 对应的就是 orig + t * dir return r; } vec3 orig; // 摄像机的位置 vec3 lower_left_corner; // 画面的左下角 vec3 horizon, vertic; // 画面的尺寸(或者说离相机 foc_len 的平面大小) vec3 cam_x, cam_y, cam_z;// 相机朝向 f8 len_radius; // 光圈半径};
下面这张图描述了代码段中各个变量的关系:
按照这张图来理解代码中的内容就比较容易了。
下面这段代码首先计算出了两个变量 half_hei
和 half_wid
:
f8 deg_fov = deg2rad(vfov);f8 half_hei = tan(deg_fov / 2); // 对边比临边,但是临边是 1f8 half_wid = half_hei * asp_ratio;
其表示相机前方 1 个单位距离的位置上,看到的画面的大小。随后需要计算出 cam_x, y, z
三个向量,方法如下:
cam_z = (lookfrom - lookat).unit_vec();// z 和镜头指向位置是反的cam_x = cross(vup, cam_z).unit_vec(); // 同时和 vup,z 垂直cam_y = vup.unit_vec();
cam_z
表示一个从 lookat
到 loofrom
的方向,这个方向和相机实际拍摄的位置是相反的。cam_x
的计算用到了向量的叉乘,在三维空间中,如果 那么 就是同时垂直于 和 的,当然符合这个条件的向量有两个,可以用右手定则确定,这里就不赘述了。根据前面的这个定义,可以得出 cam_x
同时和 cam_z
和 vup
(也就是 cam_z
) 垂直。cam_y
就是 vup
的单位向量。虽然我大概知道三维向量叉乘的几何意义,不过以前没完全理解是如何推导出来的,感觉下面这篇博客写还是非常清晰的,连我这种蒟蒻也看懂了:
https://www.cnblogs.com/qilinzi/archive/2013/05/09/3068158.html
接下来 horizon
,vertic
以及 lower_left_corner
变量的计算相对比较简单,这里就不解释了,图中都有标注。
要理解计算机是如何模拟实现景深效果,还是需要对相机镜头的结构有一定基本的了解,如下:
可以发现,在没有镜头的情况下,从 A 点出发的光线可以通过各种方向传播,每个方向又会到达成像面的不同位置。最终,成像面上每个点的颜色会由很多不同的光线贡献,得到的自然是模糊的影像。
加上镜头后再考虑 A 点,能观察到,从 A 点出发的每个方向的光线,最终都会汇聚在成像面的一个特定点上,也就是 A’。这样得到的影像就是清晰的了。
更宽泛的说,镜头能满足以下两个条件:
这里有一个前提条件,就是这个点必须在相机的焦平面上,如果某个点和相机成像面的距离不是焦距,就会有下面的情况:
如果成像平面是绿色的那个,那么 A1 就在正确的焦平面上。如果成像平面是红色的那个,那么 A2 就在正确的焦平面上。
为了方便分析,我们观察 A1 的情况,发现在红色的成像平面上,从两个方向(平着和斜着)出发的光线被汇聚到了不同的点。而在绿色成像面上,只被汇聚到了一个点。
虽然被汇聚到了不同的点,但是这个不同的程度有大有小,可以想象一下,如果我们把 A1 的位置继续向左移动。那么 A1’ 在红色成像面的位置一定会更高。反过来,如果把 A1 向右移动,A1’ 在红色成像面的位置也会随之下降,最终汇聚在正确的点上。如果继续向右移动,A1’ 在红色成像面的位置还会继续下降。最终造成从 A1 平着出发的光线和斜着出发的光线,在成像面的距离增大。
或者我们增大镜头的尺寸,就有更多从 A1 出发的,不同角度的光线可以进入镜头中,进而到达成像面。这种情况下,A1’ 在红色成像面的位置会更高,可以想象镜头被拉高了,这里光线构成的三角形也被拉高了(我实在是懒的自己画图了,就用网上的图这么解释一下吧)。
看前面的图可以发现,理论上能被相机清晰成像的距离只有一个,多一点少一点都不清晰了。但实际上,人眼的分辨能力没有这么强。我们把相机成像时,能清晰成像(人眼认为是清晰的)的距离范围称作景深。如下:
我们可以以景深的角度来思考前面提到的,镜头大小,或者说半径的影响。实际上,镜头的半径是不会改变的,通常的做法是给镜头加上一个可变的“闸门”,也就是光圈,来控制进入镜头的光线,如下:
可以发现,大光圈会让景深减少,反之亦然。
前面考虑过,从一个点出发的不同光线在不正确的焦距会被汇聚在成像面的不同点上。不过在实际渲染的时候,我们考虑的是不同的光线对于成像面某个像素的贡献。
那么在光圈大的时候,理应有更多方向的光线同时对成像面上一个点做出贡献,造成模糊的效果。具体可以见下图,也就是 RTOW 中对景深的实现:
代码中我们会随机的在光圈上取点,然后追踪从光圈到焦平面上对应像素的光线。最后把采样光线的贡献平均一下。这样光圈越大,景深也就越小。并且因为任何的光线都需要穿过焦平面上对应的点,所以可以确保焦平面上一定是清晰的。
对比上面实际镜头的工作原理还是非常不同的,但是达到了相同的效果。不过这也是因为光线追踪的特点,及从像素开始 “逆向” 的追踪。所以我们不关注实际镜头中,一个点发出的光线会被汇聚在成像面不同位置的问题。而换了一个角度思考,及有多少不同点发出的光线会对一个像素造成影响。不得不说书里的这个实现真的牛皮。
参考资料:
]]>最近(距离搞完 RTOW 已经过去一周了,我现在才把这笔记写出来,属实是懒狗)花了一些时间看完了 Ray Tracing in One Weekend (以下简称 RTOW)果然还是我太菜了,这玩意 One Weekend 没搞完,也跟着把代码写出来了。
本书写的非常不错,最后渲染出的效果也是出乎我的意料(封面图)。但是因为我以前对计算机图形学没有任何的认识,很多基本的知识都不了解。
而书上有时会把这些基本知识(或者数学推导和证明)一笔带过,因此准备写个博客把自己的思考过程写一下。
在书中,创建一个朗伯体漫反射材质的方法是下面这样:
class lambertian : public material { public: lambertian(const color& alb) : albedo(alb) {} virtual optional<pair<ray, color>> get_ray_out(const ray& r_in, const hit_rec& rec) const override { vec3 ref_dir = rec.norm + rand_unit_vec(); // 注意这里 if(ref_dir.near_zero()) // 如果 rand_unit_vec() 等于 -rec.norm ref_dir = rec.norm; ray ref_ray(rec.hit_pt, ref_dir); return make_pair(ref_ray, albedo); } color albedo; // 反射率};
也就是,击中漫反射材质后,发散光线的起点(rec.hit_pt
)会是击中的点,而发散光线的方向是一个随机的单位向量加上击中点的法向量。
但为什么要加上法向量呢,不能直接在一个半球形里随机一个向量吗?
要回答这个问题,需要对辐射度量学(radiometry)有一些认识。下面首先介绍一下一些辐射度量学的基本单位。
在光线追踪中,我们希望考虑相机(或者人眼)接收到的光照,所以下面的解释会以相机的视角进行。
首先需要考虑相机传感器接收的到底是什么物理量,显然,是能量,或者说是到达传感器上的光子数量,那么我们认为传感器接收到的物理量是辐射能量(radiant energy)用符号 表示,单位为焦耳。
不过能量并不能很好的反应一个物体的亮度。毕竟我们拿着相机拍同一个画面,曝光一分钟和 秒的效果肯定是不一样的。
虽然传感器最终接收的是能量,但只要我们拿着相机不同的曝光(积分),就可以一直得到更多的能量。
自然而然的,我们会想到,把得到的能量除以收集能量的时间,那就有了辐射通量(radiant flux)这个单位:
也就是传感器在单位时间内能收到的能量。
反过来,这也可以表述某个光源在单位时间内传输的能量。
不过这还是不能完全的表示物体的亮度。如果我们在相机中使用更大的传感器,那么单位时间内更大的传感器能接收到更多的能量。
我们在观测时用更大的传传感器并不能改变物体本身的亮度。因此还需要把接收到的辐射通量除以面积,也就是单位面积下的辐射通量。这个单位被称为辐照度(irradiance)。
对于光源来说,使用一个更大的光源,也能提供更多的辐射通量,但是单位面积能提供的通量是不变的。
考虑下面这样一张图[1]:
我们会发现,观测距离变远,要收集到相同的光通量,所需的面积就要越大。那么辐照度就会越小。这显然是不符合常理的,现实中随着距离变远,我们所观察到的亮度并不会显著的减小(有衰减主要还是因为光线在传播中会碰到很多细小的颗粒)。
那这是怎么一会事呢?直观上讲,虽然观测距离更远了,收到的光通量更少了,但是人眼看到的物体也变小了。
比如有一个面积很大的灯,以及一个面积很小的灯,如果它们两个发出的光通量相同,显然是面积小的灯更亮。
因此,人眼直接接收到的光通量小了,但是观测物体的面积也对应的小了,这两个变化相互抵消,会造成观测到的亮度不变。那么我们就需要引入一个物理量,描述人眼观测到的物体大小,随后把辐射照度除以这个量,就能真正的描述亮度。而这个量就是立体角。
我们可以把人眼的视线想象成一个球,这个球的球心是人眼,因此球面上的每个点到人眼的距离都是一样的。也因此,如果我们在这个球面上放置很多大小一样物体,因为他们到人眼的距离一样,人眼看起来的大小也是一样的。
那对于距离不同的物体,都可以将其投影到这个球上面,这样在球面上占的面积大,人眼看起来也就大。
从光源的角度来说,有时我们会希望关注光源对某个方向的影响(把那个方向照亮了多少,提供了多少的辐射通量),那么这个时候也可以引入立体角来分析。
所以立体角的定义就是,某个物体在单位球(半径为 1 )上的投影面积。
立体角的计算方法如下,单位为球面度(steradian, sr):
其中 是投影在某个球上的面积(不一定是单位球), 是球半径。
那么有了立体角后,我们就能真实的描述人眼所看见的物体的大小了,进一步修改辐照度就可以得到辐亮度(radiance)这个物理量了:
这个公式中的 是感光面元的面积, 是球面度。而 其实是用来计算某个物体平行于球面的面积的,可以见下图:
这里的 就是物体表面法线和球面法线的夹角, 为 的时候 最大, 为 时,物体表面和球面垂直,因此球面发出的光线和物体完全不相交, 也就为 。
辐亮度已经足够完美的描述大部分物体的亮度特征了。不过我们前面讨论的都是面光源,或者是有一定面积的传感器。一个点光源是没有面积的,这个时候辐亮度就没有意义了(因为要除以面积)。
同时,有的时候我们可能不关注光源和传感器的面积,单纯就是想知道某个发射或接收到的辐射通量,这个时候就需要有一种新的物理量——辐射强度(radiant intensity),它其实就是把辐亮度中除以面积的部分去掉了:
要理解朗伯余弦定律,可以看下面这张图:
用数学公式表述的话就是:
其中, 表示观察表面的法线完全平行于光线时的辐射强度。
对于观察者, 是观察者表面的法线和光线的夹角,这个夹角越大,收到的辐射通量也就越小。而朗伯余弦定理选择的是辐射强度就是因为辐射强度规定了方向,这样就能计算出光线和表面法线的夹角(要不然光线可以从四面八方射过来)。
至于用的为什么是 ,其实就是为了计算出当前观察表面投影到垂直于光线的表面后的面积。
有了这些知识,就可以介绍朗伯体的性质了,以下是维基百科对朗伯体的介绍。
余弦辐射体,也称为朗伯辐射体(Lambert radiator),指的是发光强度的空间分布符合馀弦定律的发光体(不论是自发光或是反射光),其在不同角度的辐射强度会依馀弦公式变化,角度越大强度越弱
该规律以约翰·海因里希·朗伯的名字命名,因首次提出自他1760年出版的《光度学(Photometria)》。[2]遵循朗伯定律的表面被称为兰伯特表面,并表现出朗伯反射率。这样的表面从任何角度看都具有相同的辐射度。这意味着,例如,对人眼而言,它具有相同的视亮度(或亮度)。因为功率和实心角之间的比例是恒定的,所以辐射度(单位实心角单位投射源面积的功率)保持不变。
乍一看这两段话好像是反的。其中一个说强度符合照余弦定律,不同角度观察的强度不同,另一个亮度在任何角度都相同。
我们先根据定义分析一下,符合余弦定律也就是符合下面这个公式:
回忆一下辐强度和辐亮度的定义:
尝试推导出 和 的关系。
可以看到,分子和分母的 被消掉了,也就是角度对辐强度有影响,但是对辐亮度没影响。
直观上讲,这也是对的,我们从观察者的角度思考。如果 角大,那么观察者看到的发光表面是倾斜的,自然看到的面积也就小了。
虽然总体的辐射通量变少了,但是辐射通量从更集中的区域发出来,两者互相抵消,造成亮度没有变化(和介绍辐亮度时提到的很相)。
那么为什么发光表面在不同角度的辐强度不一样呢,假设发光表面每个区域的辐照度是一样的(单位面积的辐通量), 角大的话,投影到观察者上的面积就小了,而这个投影面积的系数就是 。
所以完美的漫反射体在不同角度看到亮度都是一样的。
了解辐射度量学后我们就可以分析上面朗伯体的光线追踪代码了。
在光线追踪的时候,我们其实是在反方向(也就是从相机到光源)的追踪。但是对于一条从物体到相机的光线,可能有不同的光线对这条光线做出了贡献。
或者说,我们设物体上的点为 ,而相机上的点是 ,那么可能有很多条光线打到 上,造成了 最终的亮度和色彩。
所以在追踪时,光线从 到达 后,决定下一个追踪的方向就成了问题。
我个人认为,光线追踪时追踪的是辐射强度,也就是带方向的辐射通量。这是因为,相机传感器上每个像素(每个像素的面积一样,因此不用考虑面积)最后的颜色都取决于某个方向上的光通量。那么,不考虑面积,只有方向,就是辐射强度了。
这样在分析其他光线 对点的贡献时,就要考虑朗伯余弦定理。
我们可以把 点当作一个面积无限小的观察者,那么别的光线(单位辐射通量)和观察面的法线夹角越小,对该面贡献的也会按照 的系数衰减。
对于相机得到的每个像素点,我们都会进行多次采样,书里的代码如下:
……color pixel_color(0, 0, 0);for (int s = 0; s < samples_per_pixel; ++s) { auto u = (i + random_double()) / (image_width-1); auto v = (j + random_double()) / (image_height-1); ray r = cam.get_ray(u, v); pixel_color += ray_color(r, world);}write_color(std::cout, pixel_color, samples_per_pixel);……
这样的多次采样可以模拟不同光线对 点的贡献。为了模拟 的衰减,我们有两个选择,第一个是每次随机的选择一个 点上单位半圆的表面作为光线的方向,继续追踪,大概和下图一样:
不过对于随机选出来的光线,需要计算其和 点法线的夹角( 角),然后加上衰减。
还有一种选择是,让 作为概率密度函数来随机的选取光线的方向,这样就不用加上衰减了,如下:
显然书里选择的是第二种方法,让 作为概率密度函数。这里有个比较神奇的事情,如果我们把 作为和 的法线夹角为 的线段的长度,并把线段的一段固定在 点上,就会得到下面的图像,即一个和 点相切的圆,或者在三维空间里,球:
这里我暂时不知道如何证明,但这是一个正确的结论,如果你知道可以在评论区提出。RTOW 显然是利用了这一性质,让击中点的法向量加上一个随机的单位向量(单位球球面上的随机一点)作为光线的方向,如下:
vec3 ref_dir = rec.norm + rand_unit_vec();
这里还有个小问题,即为什么我们能保证, 点收到的光照就会向周围“均匀的”发散。前面我们说了,现在讨论的朗伯体的定义如下:
我们追踪的也是辐射强度,那么不应该把表面的法线和摄像机的夹角算出来,然后加上 的衰减吗?
可以结合下面这张图理解:
可以观察到,随着夹角的增加,一个像素对应的物体表面积也相应的增大了,所以和 的衰减抵消了。
而对于每个像素,每次的采样是在一个像素的范围内任意选取坐标,所以可以覆盖到单个像素对应的物体表面。
如果真的要按照余弦定律加上衰减,我们也相应的对夹角更大的区域做更多的采样(单个像素对应的面积更大)。
参考资料:
1: https://www.cnblogs.com/ludwig1860/p/13930745.html
2: https://zh.wikipedia.org/zh-hans/余弦辐射体
upd@2022/9/14:最近把实验的代码放到 github 上了,如果需要参考可以查看这里:
https://github.com/ttzytt/xv6-riscv
里面不同的分支就是不同的实验。
最后一个 lab 了,终于搞完了!!
实现一个 UNIX 操作系统中常见系统调用 mmap()
和 munmap()
的子集。此系统调用会把文件映射到用户空间的内存,这样用户可以直接通过内存来修改和访问文件,会方便很多。
mmap()
的定义如下:
void *mmap(void *addr, size_t length, int prot, int flags, int fd, off_t offset);
意思是映射描述符为 fd
的文件,的前 length
个字节到 addr
开始的位置。并且加上 offset
的偏移量(即不从文件的开头映射)。
如果 addr
参数为 0,系统会自动分配一个空闲的内存区域来映射,并返回这个地址。
在实验中我们只需要支持 addr
和 offset
都为 0 的情况,也就是完全不用考虑用户指定内存和文件偏移量。
prot
和 flags
都是一些标志位,具体说,prot
有以下的选项:
规定了能对映射后文件做的操作。
flags
则决定,如果在内存映射文件中做了修改,是否要在取消映射时,把这些修改更新到文件中。
有 MAP_SHARED 和 MAP_PRIVATE 两个选项。
unmap()
的定义如下:
int munmap(void *addr, size_t length);
意思是取消从 addr
开始的,长度为 length
的文件映射。不过需要注意的一点是,这个函数不支持在映射范围的中间“挖洞”,只能从开始或者结尾取消部分(或全部)的映射。
这样说可能有点不清晰,假设我们有一个 的映射范围,那么如果我们想要取消 范围的映射,需要符合 或者 。
首先我们要考虑把内存映射的文件放在用户进程的哪个地方。用户进程的内存布局如下:
起初我想的是直接参考 sbrk()
的方式来分配映射内存的,如下:
uint64sys_sbrk(void){ int addr; int n; if(argint(0, &n) < 0) return -1; addr = myproc()->sz; if(growproc(n) < 0) return -1; return addr;}
也就是给进程分配更多的堆区,然后把文件放在这里。虽然实现很方便,但是仔细一想会造成很多问题,我们默认 myproc()->sz
以下的内存都是可以给用户自由使用的, malloc()
分配的就是内存。
那么如果我们把映射的文件放在这里,完全可能会被 malloc()
分配出去,再被覆盖掉。
同时,取消文件映射后(这个时候会设置映射位置的 PTE 为 0),如果用户访问了对应位置的内存,还会引发缺页错误,这又需要去处理,显然是比较复杂的。
所以我们完全可以“倒过来”的分配文件映射的内存,来避免和用户进程的堆冲突。也就是说,我们可以从 trapframe 的位置开始,向下分配文件映射的内存。
根据给的提示,可以在内核的进程结构体中加入一个 VMA (virtual memory area, 虚拟内存区域) 结构体,这个结构体储存了文件映射的元数据,比如,映射开始的地址,长度,以及映射的文件等。有了这些元数据才能更方便的管理。
想要同时支持映射多少个文件,就需要在 struct proc
中放多少个 VMA,这里提示给的推荐是 16 个。
文件的映射还必须是懒分配的,要不然一次性拷贝大文件会很耗费时间,只有用户进程触发了缺页错误后,我们才实际的把文件拷贝过去。
最后一点,我们还需要支持在 fork()
的时候也把映射的文件 fork()
过去。当然这点比较简单,只要拷贝 VMA 就行了。因为子进程的页表中没有对应的映射,如果访问 VMA 中记录的地址会引发缺页错误,这个时候只需要把需要的文件拷贝过去就好了。
注意:这个 lab 没有帮我们注册系统调用和 mmaptest
,直接按照 Lab2 的方法来就好了,这里不赘述,如果你不会,可以看这篇文章。
struct mmap_vma
:
// in proc.hstruct mmap_vma{ int in_use; // 该 vma 结构体是否代表了一个正在使用的文件映射 uint64 sta_addr; // 起始地址 uint64 sz; // 映射大小 int prot; struct file* file; // 映射的文件 int flags; // map_shared or map_private};#define VMA_SZ 16struct proc { …… struct mmap_vma mmap_vams[VMA_SZ];}
sys_mmap()
:
这个调用不实际的分配内存。其调用 get_mmap_space()
找到一个没被使用的 mmap_vams
,以及用于映射文件的空间,再给 vma 结构体初始化。
还需要增加被映射文件的引用计数(如果不增加,引用计数为 0 后,文件会被关闭,然后我们在懒分配的时候就无法拷贝对应文件内容到内存了)
// in sysfileuint64 sys_mmap(){ uint64 addr, length, offset; // addr 和 offset 都只有 0 int prot, flags, fd; struct file* file; //void *mmap(void *addr, size_t length, int prot, int flags, int fd, off_t offset); // 这参数是真的多。。 try(argaddr(0, &addr), return -1) try(argaddr(1, &length), return -1) try(argint(2, &prot), return -1) try(argint(3, &flags), return -1) try(argfd(4, &fd, &file), return -1) // 同时取得文件和描述符 try(argaddr(5, &offset), return -1) // 读入参数 struct proc* p = myproc(); if(addr || offset) // 我们实现的是 mmap 的子集,不支持自定内存和偏移量 return -1; if(!file->writable && (prot & PROT_WRITE) && (flags & MAP_SHARED)) return -1; // 如果文件本身不允许写入,但 PROT_WRITE 还是设置了 int unuse_idx = -1; uint64 sta_addr = get_mmap_space(length, p->mmap_vams, &unuse_idx); if(unuse_idx == -1) return -1; if(sta_addr <= p->sz) // 没内存来 mmap 了 return -1; struct mmap_vma* cur_vma = &p->mmap_vams[unuse_idx]; cur_vma->file = file; cur_vma->in_use = 1; cur_vma->prot = prot; cur_vma->flags = flags; cur_vma->sta_addr = sta_addr; cur_vma->sz = length; filedup(file); // 增加引用计数 return cur_vma->sta_addr;}
get_mmap_space()
:
此函数需要给新的文件映射找到一个可用的内存区域,那么我们需要思考一下这个策略。最稳的方法肯定是找到所有 vma 中使用到的最低虚拟地址。然后把这个位置作为新映射区域的结尾。这样永远不会造成冲突,不过也有一定问题,如下:
首先可以看到,为了方便取消映射,我们不允许同一个页帧上有两个文件的映射(要不然 kfree()
就一起释放了)。
其次,如果我们使用了找最低虚拟地址的方法来分配,就会造成实际内存够用,却还要向下增长文件映射空间的情况。这样的策略可能在某些情况下(较少)会造成用户堆内存的缩减,在极端情况下(非常极端,因为大部分时候 MAXVA 都是很大的,至少比物理内存大),是会出问题的。
但不管怎么样,我闲的没事干还是写了一个应对这种情况的代码。大概就是搞个双层循环,每层都遍历所有的 vma,具体的可以见注释。
// in sysfile.cuint64get_mmap_space(uint64 sz, struct mmap_vma* vmas, int* free_idx){ *free_idx = -1; // 返回一个可以储存新文件映射的地址(开始地址) // 优先查看 vma 槽中的“空隙”,如果没有,那就映射到最下面 // 其实可以写一个快速排序,但是我懒。。。 uint64 lowest_addr = TRAPFRAME; struct mmap_vma tmp; // 作为上边界,可能和上图一样,最上方没有任何映射区域 tmp.sta_addr = TRAPFRAME, tmp.sz = 0; for(int i = 0; i <= VMA_SZ; i++){ // 假设 vmas[i] 的 PGROUNDDOWN(sta_addr) 是新文件映射的结束位置 if(vmas[i].in_use == 0 && i != VMA_SZ){ *free_idx = i; continue; } uint64 ed_pos = i != VMA_SZ ? PGROUNDDOWN(vmas[i].sta_addr) : tmp.sta_addr; lowest_addr = ed_pos < lowest_addr ? ed_pos : lowest_addr; // 取 min for(int j = 0; j < VMA_SZ; j++){ // 假设 vmas[j] 的 sta_addr + sz(vma[j] 的结束位置) 往上是新映射的起始位置 if(vmas[j].in_use == 0 && i != VMA_SZ) continue; uint64 st_pos = i != VMA_SZ ? vmas[j].sta_addr + vmas[j].sz : tmp.sta_addr + tmp.sz; // 这个位置一定是页对齐的 if (ed_pos <= st_pos) continue; // 这里直接跳过,不在下面判断是因为无符号类,如果做下面的减法会出错 if (ed_pos - st_pos >= sz){ // [st_pos, ed_pos) 的区间 return st_pos; } } } return lowest_addr - sz;}
到此为止我们所有的映射都是懒分配的,所以需要一个处理缺页错误的函数:
mmap_fault_handler()
:
注意这里有个比较坑的地方。就是用户要求映射的大小超过了文件本身的大小,这个时候我们需要把剩下的映射区域设成 0,要不然 mmaptest()
就通不过了。
还有一点就是,触发缺页错误后我们只分配和映射一页内存,而不是把整个文件都映射过去。
// in trap.cint mmap_fault_handler(uint64 addr){ struct proc* p = myproc(); struct mmap_vma* cur_vma; if((cur_vma = get_vma_by_addr(addr)) == 0){ // 找到这个地址属于哪个文件的映射 // 等于零说明不属于任何一个 return -1; } if(!cur_vma->file->readable && r_scause() == 13 && cur_vma->flags & MAP_SHARED){ DEBUG("mmap_fault_handler: not readable\n"); return -1; } // 读错误 if(!cur_vma->file->writable && r_scause() == 15 && cur_vma->flags & MAP_SHARED){ DEBUG("mmap_fault_handler: not writable\n"); return -1; } // 写错误 uint64 pg_sta = PGROUNDDOWN(addr); uint64 pa = kalloc(); if(!pa){ DEBUG("mmap_fault_handler: kalloc failed\n"); return -1; } memset(pa, 0, PGSIZE); int perm = PTE_U | PTE_V; if(cur_vma->prot & PROT_READ) perm |= PTE_R; if(cur_vma->prot & PROT_WRITE) perm |= PTE_W; if(cur_vma->prot& PROT_EXEC) perm |= PTE_X; // 在 mmap 的时候已经排除了不可能的情况了 uint64 off = PGROUNDDOWN(addr - cur_vma->sta_addr); // 这个 off 代表文件拷贝时要跳过多少个页帧 ilock(cur_vma->file->ip); int rdret; if((rdret = readi(cur_vma->file->ip, 0, pa, off, PGSIZE)) == 0){ iunlock(cur_vma->file->ip); return -1; } iunlock(cur_vma->file->ip); // 没有 put 是这个文件之后还需要使用 // 在 unmap 中应该可以 put mappages(p->pagetable, pg_sta, PGSIZE, pa, perm); return 0;}
get_vma_by_addr()
:
此函数是前面的处理函数用到的,返回对应地址所在的 vma:
struct mmap_vam* get_vma_by_addr(uint64 addr){ struct proc* p = myproc(); for(int i = 0; i < VMA_SZ; i++){ if(p->mmap_vams[i].in_use && addr >= p->mmap_vams[i].sta_addr && addr < p->mmap_vams[i].sta_addr + p->mmap_vams[i].sz){ // 判断该地址是否在文件映射区的中间 return p->mmap_vams + i; } } return 0;}
usertrap()
:
// in trap.c……if(r_scause() == 8){ // system call if(p->killed) exit(-1); // sepc points to the ecall instruction, // but we want to return to the next instruction. p->trapframe->epc += 4; // an interrupt will change sstatus &c registers, // so don't enable until done with those registers. intr_on(); syscall();} else if((which_dev = devintr()) != 0){ // ok} else if ((r_scause() == 13 || r_scause() == 15)){ try(mmap_fault_handler(r_stval()), bad = 1)}else{ bad = 1;}if (bad){ printf("usertrap(): unexpected scause %p pid=%d\n", r_scause(), p->pid); printf(" sepc=%p stval=%p\n", r_sepc(), r_stval()); p->killed = 1;}……
接下来我们就可以来尝试实现 munmap()
了,如果 vma 的 flag 设置为了 MAP_SHARED,就需要在取消映射的时候拷贝内存中修改过的内容到文件。
因为这个过程相对复杂,所以专门写了一个 mmap_writeback()
函数来处理这个。其中,我们利用了 PTE 的标志位 PTE_D 来判断文件映射的某个页帧是否被修改过,如果修改过,就需要拷贝回去。
这个标志位没被定义,需要参考 risc-v 手册在 riscv.h
中定义:
#define PTE_D (1L << 7)
如果说 unmap 的 addr 和长度不是 PGSIZE
的倍数,那么这个函数会变得特别复杂,如下:
可能也是考虑到了这个复杂度,mmaptest.c
中所有 munmap()
和 mmap()
调用的 addr
和 len
都是 PGSIZE
的倍数。实验提示中也说只要支持 mmaptest.c
使用到的特性就行了。所以下面的版本是不支持非 PGSIZE
倍数的。当然我也写了一个支持的版本,只是没有经过任何测试(我懒的再写一个加强版的 mmaptest.c
,当然以后有时间,可能会)。
正常版本:
// in vm.cintmmap_writeback(pagetable_t pt, uint64 src_va, uint64 len, struct mmap_vma* vma){// 把带脏位的页帧写回文件中,并且取消映射// 写回的是 src_va 开始的,长度为 len uint64 a; pte_t *pte; for(a = PGROUNDDOWN(src_va); a < PGROUNDDOWN(src_va + len); a += PGSIZE){ if((pte = walk(pt, a, 0)) == 0){ panic("mmap_writeback: walk"); } if(PTE_FLAGS(*pte) == PTE_V) panic("mmap_writeback: not leaf"); if(!(*pte & PTE_V)) continue; // 懒分配 if((*pte & PTE_D) && (vma->flags & MAP_SHARED)){ // 写回 begin_op(); ilock(vma->file->ip); uint64 copied_len = a - src_va; writei(vma->file->ip, 1, a, copied_len, PGSIZE); iunlock(vma->file->ip); end_op(); } kfree(PTE2PA(*pte)); *pte = 0; } return 0;}
支持非 PGSIZE
倍数版本(未经测试):
//in vm.cintmmap_writeback_na(pagetable_t pt, uint64 src_va, uint64 len, struct mmap_vma* vma){ uint64 a; pte_t *pte; a = PGROUNDDOWN(src_va); if(a == PGROUNDDOWN(src_va + len)){ // 如果 unmap 的部分在一个页帧的范围内 begin_op(); ilock(vma->file->ip); writei(vma->file->ip, 1, src_va, 0, src_va - a); iunlock(vma->file->ip); end_op(); } for(; a < PGROUNDDOWN(src_va + len); a += PGSIZE){ // 这部分只处理整页 // 如果结尾停在页中间不会处理 if((pte = walk(pt, a, 0)) == 0){ panic("mmap_writeback: walk"); } if(PTE_FLAGS(*pte) == PTE_V) panic("mmap_writeback: not leaf"); if(!(*pte & PTE_V)) continue; // 懒分配 if((*pte & PTE_D) && (vma->flags & MAP_SHARED)){ // 写回 begin_op(); ilock(vma->file->ip); // 第一次的时候,a 会比 src_va 小 uint64 copied_len = a - src_va; if(a < src_va){ // 第一个页帧,不是完整的 // 这种情况也需要 kfree,因为跨过了一个页帧 writei(vma->file->ip, 1, src_va, 0, src_va - a); } else { writei(vma->file->ip, 1, a, copied_len, PGSIZE); } iunlock(vma->file->ip); end_op(); } kfree(PTE2PA(*pte)); *pte = 0; } uint64 copied_len = a - src_va; uint64 len_left = vma->sz - copied_len; if (len_left){ // 处理 unmap 结尾在页帧中间的情况 begin_op(); ilock(vma->file->ip); writei(vma->file, 1, a, copied_len, len_left); if(len_left + a == vma->sz + src_va){ // 如果停在的页帧刚好是最后一个 pte_t *pte; if((pte = walk(pt, a, 0)) == 0){ panic("mmap_writeback: walk"); } kfree(PTE2PA(*pte)); } iunlock(vma->file->ip); end_op(); } return 0;}
相比之下,munmap()
就比较简单了,但需要注意,如果 unmap 好了之后整个映射区都没了,就代表我们不需要再用到对应的文件了,所以调用 fileclose()
来减少引用计数和关闭文件。
同时,还不能忘记 munmap()
取消映射区时的限制,只能从头取消或者是结尾,不能中间挖洞(见本文开头)
。
// in sysfile.cuint64munmap(uint64 addr, uint64 len){ struct proc* p = myproc(); struct mmap_vma* cur_vma = get_vma_by_addr(addr); if(!cur_vma) return -1; if(addr > cur_vma->sta_addr && addr + len < cur_vma->sta_addr + cur_vma->sz){ // 从中间挖洞 return -1; } mmap_writeback(p->pagetable, addr, len, cur_vma); if(addr == cur_vma->sta_addr){ // 从起始位置删除的 cur_vma->sta_addr += len; } cur_vma->sz -= len; if(cur_vma->sz <= 0){ // 如果整个映射区都没了 fileclose(cur_vma->file); cur_vma->in_use = 0; } return 0; }
可能你会发现这个函数不是系统调用的形式,这是因为我们之后还需要在内核中调用。系统调用的形式如下:
uint64sys_munmap(){ // int munmap(void *addr, size_t length); uint64 addr; uint64 len; try(argaddr(0, &addr), return -1) try(argaddr(1, &len), return -1) return munmap(addr, len);}
内核需要调用 munmap()
是因为有些进程在退出后还没有取消它的文件映射,那我们就需要帮它强制清理掉这些映射,要不然会造成内存泄露,这个清理可以放在 exit()
中。
这里讲一下为为什么放在 exit()
中而不是真正释放进程号的 freeproc()
。我们可以观察一下,一个进程被 freeproc()
是在 wait()
函数中,如下:
// in proc.c wait():…… for(;;){ // Scan through table looking for exited children. havekids = 0; for(np = proc; np < &proc[NPROC]; np++){ if(np->parent == p){ // make sure the child isn't still in exit() or swtch(). acquire(&np->lock); havekids = 1; if(np->state == ZOMBIE){ // Found one. pid = np->pid; if(addr != 0 && copyout(p->pagetable, addr, (char *)&np->xstate, sizeof(np->xstate)) < 0) { release(&np->lock); release(&wait_lock); return -1; } freeproc(np); // 注意这里,只有父进程 wait 的时候才会去 freeproc。 release(&np->lock); release(&wait_lock); return pid; } release(&np->lock); } } …… }……
那么如果父进程不调用 wait()
这些映射的文件就一直放着不会被写会文件中。当然,父进程是应该调用 wait()
的,这里放在 exit()
中主要还是实验的提示,但实验提示这么写可能是这个原因。
// in proc.c exit():voidexit(int status){ struct proc *p = myproc(); if(p == initproc) panic("init exiting"); // 释放和写回 mmap 数据需要在关闭文件之前 for(int i = 0; i < VMA_SZ; i++){ if(p->mmap_vams[i].in_use){ try(munmap(p->mmap_vams[i].sta_addr, p->mmap_vams[i].sz), panic("exit: munmap")); } } // Close all open files. for(int fd = 0; fd < NOFILE; fd++){ if(p->ofile[fd]){ struct file *f = p->ofile[fd]; fileclose(f); p->ofile[fd] = 0; } }……}
实验的最后一步就是在 fork()
之后也能让子进程访问到映射的文件。前面提到过我们只需要拷贝 vma 就行了。vma 中的 sta_addr
是一个虚拟地址,那么子进程尝试访问的时候会造成缺页错误,因为这个虚拟地址没有映射到物理地址上。
因此在 mmap_fault_handler()
中,我们会发现触发缺页错误的这个地址属于一个文件映射区。因此会给这个虚拟页帧分配一个物理页,然后把对应文件拷贝过去。
当然 fork()
之后意味着有另外一个进程也在使用被映射的文件,所以需要调用 filedup()
来增加引用计数。
fork()
:
// in proc.c…… for (int i = 0; i < VMA_SZ; i++){ if(p->mmap_vams[i].in_use){ np->mmap_vams[i] = p->mmap_vams[i]; filedup(p->mmap_vams[i].file); // 复制 vma } }……
我最初在这里有个小问题,就是前面的 uvmcopy()
已经复制过内存了,那难道不会把 vma 也复制了吗,我们后面再复制是否会造成重复复制。
看了代码之后就解决了,uvmcopy()
只会复制 myproc()->sz
以下的内存:
// in vm.c for(i = 0; i < sz; i += PGSIZE){ // 注意这里范围 if((pte = walk(old, i, 0)) == 0) panic("uvmcopy: pte should exist"); if((*pte & PTE_V) == 0) panic("uvmcopy: page not present"); pa = PTE2PA(*pte); flags = PTE_FLAGS(*pte); if((mem = kalloc()) == 0) goto err; memmove(mem, (char*)pa, PGSIZE); if(mappages(new, i, PGSIZE, (uint64)mem, flags) != 0){ kfree(mem); goto err; } }
做好之后就可以 AC 了,也祝正在做这个 lab 的人尽快 AC:
这里我一定要吐槽一下 (我都不知道是哪的 bug,xv6?qemu?还是 Makefile?)的一个 bug。
大概就是我在用 gdb 调试的时候希望能使用宏(主要是 PGROUNDDOWN()
和 PGROUNDUP()
),所以在 Makefile 文件中加入了 -g3
编译选项,像下面这样:
CFLAGS = -Wall -O -g3 -fno-omit-frame-pointer -ggdb -UFDEBUG
而这就会导致 usertest.c
中的一个测试通不过,直接 panic 了,如下:
$ usertests writebigusertests startingtest writebig: panic: balloc: out of blocks
去掉这个 -g3
居然就正常了???我是怎么也想不到一个编译选项居然可以影响虚拟磁盘的块数。然后就因为这个东西调了一天没调出来,毕竟谁会想到一个编译选项有这种效果,后来我是直接用 git 看这个分支和别的分支文件的区别,然后一个一个试才试出来的。
这个问题我已经发在 xv6-riscv 的 github 上了,然后在 issue 区逛了一圈后还发现了更离谱的:
https://github.com/mit-pdos/xv6-riscv/issues/59
就是在编译选项里加一个 -O3
也会造成这个问题。。。我不李姐。。。
upd@2022/9/14:最近把实验的代码放到 github 上了,如果需要参考可以查看这里:
https://github.com/ttzytt/xv6-riscv
里面不同的分支就是不同的实验。
在 xv6 的底层实现中,文件是由 struct dinode
来描述的,如下:
struct dinode { short type; // File type short major; // Major device number (T_DEVICE only) short minor; // Minor device number (T_DEVICE only) short nlink; // Number of links to inode in file system uint size; // Size of file (bytes) uint addrs[NDIRECT + 1]; // Data block addresses};
这里我们主要关注结构体中的 addrs
属性。维护了此文件的实际储存位置。其中有 addrs
的前十二个直接指向文件储存的块,最后一个是间接的块,即其指向的块中储存了别的指针,这些指针才指向了实际储存数据的块。听起来有点绕,大概是下面这个示意图的样子:
那我们可以计算一下 xv6 能支持的最大文件大小。已知一个 struct dinode
有 64B 的大小,一个磁盘块能储存 1024B 的数据。
那么前 12 个直接指向数据块的 addrs
就能储存
而最后一个间接的数据指针指向一个存满了指针(指向别的磁盘块)的块,这个块能存放 个地址。
而这里的每个地址都是一个块,因此,这个间接 addrs
能提供 的储存空间。
那么他们加起来就是 的储存空间,等于
这样的储存空间显然是非常小的,所以在这个 lab 中我们需要给 inode 加入一个二级的间接块指针来解决这个问题。
一个一级的间接块指针就是刚刚提到的,inode 中 addrs
的最后一个,其指向一个块,而这个块中储存的块指针又指向别的数据块。
在二级块指针中,addrs
中指针指向的块中的指针会指向另外的,储存指针的块,以此加大储存空间(有点像是多级页表了),如下:
可以计算一下这个二级间接块指针能提供的空间,已知一个块能储存 256 个块指针,那么 addrs
指向的那个块能储存 256 个块指针块的块号,所以总数就是 个块。除以 1024 为 64MB,可见提升非常巨大。
需要修改 bmap()
和 itrunc()
这两个函数,不过没有什么特别难以思考的地方,所以具体的解释还是放到代码部分。
因为加入了二级间接索引,所以要先对一些宏定义进行修改:
#define NDIRECT 11 // 减少一个直接索引,增加一个二级间接索引#define NINDIRECT (BSIZE / sizeof(uint))#define NBI_INDIRECT NINDIRECT * NINDIRECT // 二级间接索引提供的块#define MAXFILE (NDIRECT + NINDIRECT + NBI_INDIRECT) //
同时也需要修改 struct dinode
和 struct inode
。其中,dinode
是实际储存在磁盘上的,而 inode
在 dinode
的基础上加入了很多方便处理 inode
的元数据:
//in fs.h// On-disk inode structurestruct dinode { short type; // File type short major; // Major device number (T_DEVICE only) short minor; // Minor device number (T_DEVICE only) short nlink; // Number of links to inode in file system uint size; // Size of file (bytes) uint addrs[NDIRECT + 2]; // Data block addresses 这里修改成了 + 2};
// in file.h// in-memory copy of an inodestruct inode { uint dev; // Device number uint inum; // Inode number int ref; // Reference count struct sleeplock lock; // protects everything below here int valid; // inode has been read from disk? short type; // copy of disk inode short major; short minor; short nlink; uint size; uint addrs[NDIRECT+2];// 这里修改成了 + 2};
bmap()
:
这个函数接收 inode
指针和 bn
,表示 inode
中的第几个块,返回对应的块号。
我们需要在这个函数中添加对二级间接块的支持。为了取得二级的间接块,我们可以先获取到一级的间接块。
代码中很多写法可以参考前面对一级间接块的处理。
// in fs.c…… bn -= NINDIRECT; // bn 代表还剩多少个 if(bn < NBI_INDIRECT){ if((addr = ip->addrs[NDIRECT + 1]) == 0) // 如果之前没分配这个 block ip->addrs[NDIRECT + 1] = addr = balloc(ip->dev); bp = bread(ip->dev, addr); // buf pointer 的简称 a = (uint *)bp->data; uint idx_b1 = bn / NINDIRECT; // 取得 bn 对应的,一级间接块在 addr 中的下标 if((addr = a[idx_b1]) == 0){ // 一个一级块负责 256 个二级块,这里检测对应一级块是否存在 a[idx_b1] = addr = balloc(ip->dev); log_write(bp); // 标志这个块被修改了,随后会更新到磁盘的日志区 // 修改是因为,我们给这个储存块指针的块添加了一个新的块指针 } brelse(bp); // 释放块缓存 bp2 = bread(ip->dev, addr); // bp2 为二级块的缓存 a = (uint *)bp2->data; uint idx_b2 = bn % NINDIRECT; if((addr = a[idx_b2]) == 0){ a[idx_b2] = addr = balloc(ip->dev); log_write(bp2); } brelse(bp2); return addr; }……
itrunc()
:
此函数会清理 inode 中的所有块,或者可以理解成删除一个文件。这个函数内实际上是在不停的调用 brelse()
和 bfree()
。
其中 brelse()
释放一个块缓存,而 bfree()
则通过修改磁盘上 bitmap 块的数据来释放磁盘上的一个块。
和 bmap()
相同,很多地方可以参考一级间接索引的实现。主要的思路类似递归,先遍历每个一级块,检查里面是否有数据,如果有,就去遍历这个一级块里的二级块。
// in fs.c…… if(ip->addrs[NDIRECT + 1]){ // 判断 inode 是否使用了二级间接索引 bp = bread(ip->dev, ip->addrs[NDIRECT + 1]); a = (uint*)bp->data; for (i = 0; i < NINDIRECT; i++){ // 遍历一级块 if(a[i]){ // 如果有数据,就遍历这个一级块里的二级块 struct buf* bp2 = bread(ip->dev, a[i]); // 获取这个块的对应缓存 uint *a2 = bp2->data; for(j = 0; j < NINDIRECT; j++){ if(a2[j]) bfree(ip->dev, a2[j]); // a2[j] 存的是块号,这里把磁盘中这个块的内容清空了。或者说释放 } brelse(bp2); // 释放块缓存 bfree(ip->dev, a[i]); // 释放磁盘中的块 // 和 a[i] 对应的是 bp2 // a[i] 是块号,bp2 是实际的块缓存 } } brelse(bp); // 释放缓存 bfree(ip->dev, ip->addrs[NDIRECT + 1]); // 释放磁盘块 ip->addrs[NDIRECT + 1] = 0; }……
这个实验需要我们实现符号链接,或者说软链接(说实话我现在还不是很清楚软硬链接的本质区别),有点像 windows 中的快捷方式。
实现起来其实很简单,不过这个 lab 中的提示给的(对我来说)不是很足,所以做的时候还是有点懵逼的,最后看了别人的博客才做出来。
首先软链接就像是一个文件的“指针”,如果我们打开某个软链接,实际打开的是那个链接指向的文件,这样就可以实现某个目录打开实际储存在不同目录的文件。
那么我们要如何实现这个软链接呢?软链接的本质其实也是一个文件,我们只要在这个文件(其实是 inode)中储存此链接指向的文件的路径就行了。
为了实现链接的效果,在 open()
函数中,需要去根据链接中储存的路径,递归的找到最终指向的文件(可能会有一个软链接指向另一个软链接)。
可是万一我们想打开的是这个软连接本身呢?这就需要新定义的 open()
标志位了,这些标志位用于指定打开文件描述符的一些设置。那我们可以添加一个 O_NOFOLLOW
的标志位,意味不去递归打开软连接里的路径,而打开软连接本身。
//in fcntl.h#define O_RDONLY 0x000#define O_WRONLY 0x001#define O_RDWR 0x002#define O_CREATE 0x200#define O_TRUNC 0x400#define O_NOFOLLOW 0x800
同时 inode 本身是对磁盘中储存的各种数据的一种“抽象”,为了得知 inode 里面具体放的是什么,需要定义一个新的 inode 类型:
//in stat.h#define T_DIR 1 // Directory#define T_FILE 2 // File#define T_DEVICE 3 // Device#define T_SYMLINK 4 // 软连接
注意这个实验中比较烦人的一点是,sys_symlink()
这个系统调用是没有注册好的,需要和 lab2 一样,在各种文件中加入这个系统调用,我假设看这个文章的人都是做过 lab2 的,所以不赘述,如果你没有,可以看我的这篇文章。
sys_symlink()
:
前面说软连接的本质其实是一种文件,不过这个文件其实又是一个 inode,那么在写代码的时候就需要注意各种操作都是对 inode 进行的。然后还有就是在各种文件相关的系统调用中,我们都需要使用 open_op()
和 end_op()
把这些系统调用包裹起来。其代表,在这个区间内的任何操作会先被记录到日志系统中(不熟悉可用参考 xv6 的书以及 lecture)。
uint64 sys_symlink(){ char tar_path[MAXPATH], path[MAXPATH]; try(argstr(0, tar_path, MAXPATH), return -1); try(argstr(1, path, MAXPATH), return -1); struct inode* ip; begin_op(); ip = create(path, T_SYMLINK, 0, 0); // 创建一个文件,返回其 inode(因为没注释,我其实不是很确定这个函数 // 的用法,只是根据其实现猜测的) if(ip == 0){ end_op(); return -1; } try(writei(ip, 0, tar_path, 0, strlen(tar_path)), end_op(); return -1); // writei 其实就是往某个 inode 中写数据,这里把软链接想要指向的路径写进去了 iunlockput(ip); // 使用完 inode 后的标准操作 // 先释放了锁,然后释放这个 inode // 这里对于 inode 的 iput() 和对于块缓存的 brelse() 很相似 // 都是先减少了引用计数,然后判断是否可用真正的释放 end_op(); return 0;}
sys_open()
:
下面这段 sys_open()
开头的代码打开或者创建了用户传进来路径所对应文件的 inode,记录在 ip
中。而 sys_open()
的后续代码会处理这个 ip
来完成打开的操作,我们先不用管。
\\ in sysfile.c if(omode & O_CREATE){ ip = create(path, T_FILE, 0, 0); if(ip == 0){ end_op(); return -1; } } else { if((ip = namei(path)) == 0){ end_op(); return -1; } ilock(ip); if(ip->type == T_DIR && omode != O_RDONLY){ iunlockput(ip); end_op(); return -1; } }
那对于一个符号链接来说,用户传进来路径对应的 ip
并不是其想要打开的 ip
,所以我们需要递归的跟随符号链接中指向的文件来修改这个 ip
。注意最终这个 ip
必须是上锁的。
如下(这部分代码添加在上面代码的后面):
\\ in sysfile.c if(!(omode & O_NOFOLLOW)){ int rec_left = 10; // 递归次数限制,软链接可能成环 struct inode* next_file; while(rec_left && ip->type == T_SYMLINK){ if(readi(ip, 0, path, 0, MAXPATH) == 0){ iunlockput(ip); end_op(); return -1; } if((next_file = namei(path)) == 0){ // namei 可用从一个路径获得 inode iunlockput(ip); end_op(); return -1; } iunlockput(ip); // 储存链接的文件已经使用完了 ip = next_file; rec_left--; ilock(ip); // 在这里加锁而不在 while 的下面是因为如果这个 inode 不是一个软链接 // 我们还是需要持有这个锁的,因为后面的处理代码会修改 inode } if(rec_left <= 0){ iunlockput(ip); end_op(); return -1; } }
这里要特别特别注意一个点,在递归跟随软链接时,我们碰到一个不是软链接的文件需要停下来。这也要求我们去访问 inode 的 type 属性。那么判断这个属性一定要在 ilock(ip)
的后面,我调了好久才发现这个 bug。
我们先看下 ilock()
的代码:
// Lock the given inode.// Reads the inode from disk if necessary.voidilock(struct inode *ip){ struct buf *bp; struct dinode *dip; if(ip == 0 || ip->ref < 1) panic("ilock"); acquiresleep(&ip->lock); if(ip->valid == 0){ bp = bread(ip->dev, IBLOCK(ip->inum, sb)); dip = (struct dinode*)bp->data + ip->inum%IPB; ip->type = dip->type; ip->major = dip->major; ip->minor = dip->minor; ip->nlink = dip->nlink; ip->size = dip->size; memmove(ip->addrs, dip->addrs, sizeof(ip->addrs)); brelse(bp); ip->valid = 1; if(ip->type == 0) panic("ilock: no type"); }}
可以发现,会先检查 ip->valid
,这个 valid
属性表示当前 inode 的数据是否从磁盘中加载过。如果是没有,那么会先读取磁盘,然后把数据加载进这个 inode 中。
也就是说,如果在执行 ilock()
之前先访问了 inode,意味着这个 inode 很可能是空的,自然读到的东西也没意义(这也再一次提醒了我们访问线程间共享数据时,一定要加锁)。
做完这些后,就可以愉快的 AC 了,也祝在做这个 lab 的人尽快 AC:
提醒一点,如果你发现你的程序在 qemu 中跑测试没问题,但是 make grade 过不了的话,很可能是因为超时了(估计是我电脑性能太垃了),这个时候需要去 python 的计分程序 grade-lab-fs
中改下时限。
数组越界,内存泄漏实在是非常可怕的事情——实际的错误和系统报的错没有任何的相关性,调都调不出来。
这里大概讲下我做这个 lab 时犯的一些傻逼到极致的错误吧,关键是调了两个下午才调出来。
最开始我在进行 symlinktest 的时候,会报 panic,信息是 virtio_disk_intr status
。那这种跟虚拟磁盘有关的东西我肯定是不会处理的,于是单步了以下,找到了 symlinktest 中具体是哪一步出了问题。结果如下:
r = symlink("/testsymlink/4", "/testsymlink/3");if(r) fail("Failed to link 3->4");close(fd1);close(fd2); // 问题fd1 = open("/testsymlink/4", O_CREATE | O_RDWR);if(fd1<0) fail("Failed to create 4\n");
这里,symlinktest 调用 close(fd2)
之后就直接 panic 了。
然后我又单步了以下,大概发现,发生问题时的调用过程是这样的:
sys_close() -> fileclose() -> iput() -> itrunc() -> bread():
我一想是 itrunc()
写错了,还直接新开了个分支,抄了别人的 itrunc()
然后还是不行。
后来又想,不会是什么玄学问题把,于是直接把那个 panic()
给注释掉了,又发现有新的 panic()
,这次报的错是 freeing free block
:
static voidbfree(int dev, uint b){ struct buf *bp; int bi, m; bp = bread(dev, BBLOCK(b, sb)); bi = b % BPB; m = 1 << (bi % 8); if((bp->data[bi/8] & m) == 0) panic("freeing free block"); // 这里 bp->data[bi/8] &= ~m; log_write(bp); brelse(bp);}
后来又发现,在 itrunc()
中,根本没有释放一级间接索引的块,而是直接释放了二级间接索引(因为 addrs[12]
非零)。这肯定是不合理的,一定是一级的用完了再用二级的。结合 freeing free block
的 panic()
信息,我基本确定了问题可能是由某种越界引起的。
最后发现,居然是 struct inode
这里出了问题:
struct inode { uint dev; // Device number uint inum; // Inode number int ref; // Reference count struct sleeplock lock; // protects everything below here int valid; // inode has been read from disk? short type; // copy of disk inode short major; short minor; short nlink; uint size; uint addrs[NDIRECT+2];};
我把 dinode
的 addrs[NDIRECT + 1]
改成了 addrs[NDIRECT + 2]
,但是忘了改 inode
的。。。
这就造成了,我在访问 addrs[12]
时,访问的实际是下一个 inode 的 dev
属性。那么事情就离谱起来了,你说一个 inode 的二级间接索引块怎么可能会在一号块(超级块)呢。。。我其实还挺好奇的,itrunc()
的时候怎么没有把超级块给释放了,又是如何引起虚拟磁盘的 panic()
的。我是懒得调了,有兴趣的可以试试看。
不说了,破大防了。。。
]]>upd@2022/8/18: 本文的第二个实验不完全正确,并且还有很多其他的做法,具体可以见这篇博客中我和博主的讨论。以及博主根据讨论新写的代码。
如果接下来有时间,会把第二部分的代码改掉并添加注释。
upd@2022/9/14:最近把实验的代码放到 github 上了,如果需要参考可以查看这里:
https://github.com/ttzytt/xv6-riscv
里面不同的分支就是不同的实验。
这 lab 的描述也是非常长,所以就不截图了。下面描述一下大概的题意:
在原本的 kalloc()
中,只有一个大锁,我们会维护一个 freelist
链表,如果有任何程序申请内存,都需要竞争 freelist
的锁,以修改 freelist
的内容。具体可见 freelist
和 kalloc()
的实现:
struct run { struct run *next;};struct { struct spinlock lock; struct run *freelist;} kmem;……// Free the page of physical memory pointed at by v,// which normally should have been returned by a// call to kalloc(). (The exception is when// initializing the allocator; see kinit above.)voidkfree(void *pa){ struct run *r; if(((uint64)pa % PGSIZE) != 0 || (char*)pa < end || (uint64)pa >= PHYSTOP) panic("kfree"); // Fill with junk to catch dangling refs. memset(pa, 1, PGSIZE); r = (struct run*)pa; acquire(&kmem.lock); r->next = kmem.freelist; kmem.freelist = r; release(&kmem.lock);}// Allocate one 4096-byte page of physical memory.// Returns a pointer that the kernel can use.// Returns 0 if the memory cannot be allocated.void *kalloc(void){ struct run *r; acquire(&kmem.lock); r = kmem.freelist; if(r) kmem.freelist = r->next; release(&kmem.lock); if(r) memset((char*)r, 5, PGSIZE); // fill with junk return (void*)r;}
可以发现,不可能同时有多个核心去调用 kalloc()
函数以及 kfree()
函数,大大降低了内存分配的效率。
经测试,可以发现这个大锁就是一个很大瓶颈(kmem 这个锁是所有锁中等待次数最多,竞争最激烈的):
$ kallocteststart test1test1 results:--- lock kmem/bcache statslock: kmem: #fetch-and-add 83375 #acquire() 433015lock: bcache: #fetch-and-add 0 #acquire() 1260--- top 5 contended locks:lock: kmem: #fetch-and-add 83375 #acquire() 433015lock: proc: #fetch-and-add 23737 #acquire() 130718lock: virtio_disk: #fetch-and-add 11159 #acquire() 114lock: proc: #fetch-and-add 5937 #acquire() 130786lock: proc: #fetch-and-add 4080 #acquire() 130786tot= 83375test1 FAIL
在这个 lab 中,我们就需要解决这个问题。实验提示中给出的提示是给每个处理器核心都分配一个 freelist
,那如果某个核心想要分配一页内存,就无需等待耗时的锁操作,直接分配就行了(其实也要加锁,但是竞争显著的变少了)。
这也带来了一个新的问题,有的时候某些核心会有充足的待分配页帧,而某些核心已经没有了,那么就算总的空闲页帧是足够的,也不能分配新的页帧。
所以,如果当前核心没有页帧可以分配了。我们需要去从别的核心“偷”一些新的页帧。
那我们大概可以写出下面的伪代码:
struct { struct spinlock lock; struct run *freelist;} kmems[NCPU];void kalloc(){ struct run* r = 0; push_off(); int cpu = cpuid(); pop_off(); acquire(&kmems[cpu].lock); int stealed = 0; if(!kmems[cpu].freelist){ for (i : kmems){ acquire(&i.lock); while (i 中还有页帧 && stealed < STEAL_CNT) { 释放 i 中 feelist 的页帧; 把释放的页帧加入 kmems[cpu].freelist; } if(stealed >= STEAL_CNT){ break; } releae(&i.lock); } } r = kmems[cpu].freelist; if (r) { kmems[cpu].freelist = r->next; } release(&kmems[cpu].lock); return r;}
看起来还是比较合理的,其实这样的代码也能通过测试。不过这个代码其实是可能发生死锁的(其实是几乎不可能)。
注意 for (i : kmems)
这个循环,可以发现,在循环中,会持有两个锁,或者说是尝试获得两个锁:第一个是本核心的锁,也就是 kmems[cpu].lock
第二个是尝试偷页帧时,获得的锁 i.lock
。
假设我们的处理器只有两个核心,a 和 b,那如果这两个线程现在都没有空闲页帧了,就会先拿到自己的锁,然后去尝试偷对方的页帧。
在偷的过程中,都会先尝试拿到对方的锁,但是之前 a 和 b 都已经拿到自己的锁了。这就造成了死锁。
当然死锁不止会发生在只有两个核心的情况下,这里使用两个核心只是为了方便说明。
要解决这个问题,我们可以让每个核心不能同时持有本核心和别的核心的锁。
当然这也引出了别的问题,比如我们在偷页帧,并且加入本核心 freelist
的时候,另一个核心可能试图从我们这里偷页帧。这样两个核心同时修改 freelist
的时候,就会出现奇怪的问题。
下面解释下我的解决方案:
首先在发现没有空闲页帧后,立刻释放掉本核心的锁,然后尝试偷页。需要同时持有两个锁是因为可能有多个核心同时修改 freelist
,那我们不如就让本核心不去修改 freelist
,而是把可以偷的页从别的核心那里释放掉,然后把这个页加入一个候选队列。随后取得本核心的锁后,再扫描候选队列,然后把这些页加入 freelist
。
同时,因为我们并没有在本核心的 freelist
中加入偷到的页,而只是记录在候选队列,如果别的核心尝试去偷本核心的页帧,就会发现已经没有空闲页了,不会更改本核心的 freelist
。这样在偷页过程中没有任何核心修改 freelist
,自然也不需要加锁。
不过这里需要注意一个点,就是中断。因为在偷页过程中可能是不持有任何锁的,xv6 会把中断打开。那当前核心可能会跳出去处理别的进程,而别的进程可能又会导致调用 kalloc()
,会造成重复的偷页。
然后就可以写出如下代码:
kinit()
:
struct { struct spinlock lock, stlk; struct run *freelist; uint64 st_ret[STEAL_CNT]; // 候选队列} kmems[NCPU];const uint name_sz = sizeof("kmem cpu 0");char kmem_lk_n[NCPU][sizeof("kmem cpu 0")];voidkinit(){ for(int i = 0; i < NCPU; i++){ snprintf(kmem_lk_n[i], name_sz, "kmem cpu %d", i); initlock(&kmems[i].lock, kmem_lk_n[i]); } freerange(end, (void*)PHYSTOP);}
kfree()
:
voidkfree(void *pa){ struct run *r; if(((uint64)pa % PGSIZE) != 0 || (char*)pa < end || (uint64)pa >= PHYSTOP) panic("kfree"); push_off(); uint cpu = cpuid(); pop_off(); // Fill with junk to catch dangling refs. memset(pa, 1, PGSIZE); r = (struct run*)pa; acquire(&kmems[cpu].lock); r->next = kmems[cpu].freelist; kmems[cpu].freelist = r; release(&kmems[cpu].lock);}
这里相当于是哪个核心在运行当前进程,就把这个页帧分配到当前核心的 freelist
。也是一个比较简单的分配策略,可能有更好的策略,不过我懒。
steal()
:
这个函数是新添加的,其实就是扫描所有核心的 freelist
,然后把空闲的加入当前核心的候选队列,也就是 st_ret[STEAL_CNT]
:
int steal(uint cpu){ // 返回偷到了几个 uint st_left = STEAL_CNT; int idx = 0; memset(kmems[cpu].st_ret, 0, sizeof(kmems[cpu].st_ret)); for(int i = 0; i < NCPU; i++){ if(i == cpu) continue; acquire(&kmems[i].lock); while(kmems[i].freelist && st_left){ kmems[cpu].st_ret[idx++] = kmems[i].freelist; kmems[i].freelist = kmems[i].freelist->next; st_left--; } release(&kmems[i].lock); if(st_left == 0) { // 一共偷 STEAL_CNT 个 break; } } return idx;}
kalloc()
:
如果没有空闲的页帧了,会调用 steal()
,之后会把偷来的真正的加到 freelist
中。注意整个 kalloc()
都是关闭中断的,因为开中断可能造成同时有两个进程执行 steal()
,造成重复偷页。
void *kalloc(void){ struct run *r = 0; push_off(); uint cpu = cpuid(); acquire(&kmems[cpu].lock); r = kmems[cpu].freelist; // r 是之后要返回的页帧 if(r){ kmems[cpu].freelist = r->next; release(&kmems[cpu].lock); } else { release(&kmems[cpu].lock); int ret = steal(cpu); // steal 过程中不可能 kfree,因为关闭中断 // ret 是偷到了多少页 if(ret <= 0){ pop_off(); return 0; } acquire(&kmems[cpu].lock); for(int i = 0; i < ret; i++){ if (!kmems[cpu].st_ret[i]) break; ((struct run*)kmems[cpu].st_ret[i])->next = kmems[cpu].freelist; // 把偷来的页加到 freelist 的前面 kmems[cpu].freelist = kmems[cpu].st_ret[i]; } r = kmems[cpu].freelist; kmems[cpu].freelist = r->next; release(&kmems[cpu].lock); } if(r){ memset((char*)r, 5, PGSIZE); // fill with junk DEBUG("kalloc 成功\n"); } pop_off(); return r;}
首先说下:这部分的思路很大程度参考抄了这位大佬的博客。
在 xv6 中,我们是不能直接访问硬盘设备的,如果想要读取硬盘中的数据,需要先把数据拷贝到一个缓存中,然后读取缓存中的内容。
在 xv6 中,磁盘数据的最小单位是一个块,大小为 1024 kb。或者说我们每次从硬盘中最少能读出 1024kb 的数据。
在读写硬盘的时候,需要通过 bread()
函数得到相应的缓存(缓存中已经存放了硬盘对应块中的数据):
// 文件位于 bio.c// Return a locked buf with the contents of the indicated block.struct buf*bread(uint dev, uint blockno){ struct buf *b; b = bget(dev, blockno); if(!b->valid) { virtio_disk_rw(b, 0); b->valid = 1; } return b;}
注意这里先调用了 bget()
函数。这个 bget()
会首先判断是否之前已经缓存过了硬盘中的这个块。如果有,那就直接返回对应的缓存,如果没有,会去找到一个最长时间没有使用的缓存,并且把那个缓存分配给当前块。如下:
// Look through buffer cache for block on device dev.// If not found, allocate a buffer.// In either case, return locked buffer.static struct buf*bget(uint dev, uint blockno){ struct buf *b; acquire(&bcache.lock); // Is the block already cached? for(b = bcache.head.next; b != &bcache.head; b = b->next){ if(b->dev == dev && b->blockno == blockno){ b->refcnt++; release(&bcache.lock); acquiresleep(&b->lock); return b; } } // Not cached. // Recycle the least recently used (LRU) unused buffer. for(b = bcache.head.prev; b != &bcache.head; b = b->prev){ if(b->refcnt == 0) { b->dev = dev; b->blockno = blockno; b->valid = 0; b->refcnt = 1; release(&bcache.lock); acquiresleep(&b->lock); return b; } } panic("bget: no buffers");}
可以看到,所有的缓存被串到了一个双向链表里。链表的第一个元素是最近使用的,最后一个元素是很久没有使用的。
每次 bget()
的时候会先遍历一遍链表,检查当前块是否已经被存到缓存里了。如果没有,那就会从后到前遍历链表(意味着是从最久没有使用的开始找),找到第一个引用计数为 0 (代表没有程序正在使用这个块)的缓存作为当前块的缓存。
这就造成了,在任何时候想要分配缓存,都需要竞争这个链表的锁。
可能你会想到使使用前一个实验的方法来优化,但把缓存分配到不同核心的方法是行不通的。因为分配页帧和回收页帧的时候,只需要有一个核心参与,而且分配后某个页帧只会被一个进程访问。
而分配出去的块缓存可能会被不同进程访问。比如不同的进程可以访问和写入同一个块缓存。如果预先按照核心分配缓存,有很大概率进程需要的缓存不属于当前核心。那就需要去一个一个的遍历别核心的块缓存,造成性能下降。(不过如果每个块缓存单独持有一个锁,粒度更小了会不会性能更好点)。
实验描述中给我们的提示是实现一个散列表。散列表会把块号映射到块缓存的桶,那么只有两个进程试图操作同一个桶中的块缓存,才会造成竞争。而且在查找所需块缓存时页不需要遍历所有的缓存,只需要遍历对应的桶。
当然,在对应桶中没有足够缓存时,我们可以像在 kalloc()
中一样,从别的桶中偷缓存。
这个实验中的散列表还是比较容易理解的。不过散列表中也有涉及页表分配实验中“偷”的过程,这样会陷入一种两难的境地。
在“偷”的过程中,我们会需要同时获得当前桶的锁,也需要检查别的桶,所以需要拿到别的桶的锁。这样就不可避免的同时持有了两把锁。
而这两把锁可能会造成死锁,如下[1]:
假设块号 b1 的哈希值是 2,块号 b2 的哈希值是 5并且两个块在运行前都没有被缓存----------------------------------------CPU1 CPU2----------------------------------------bget(dev, b1) bget(dev,b2) | | V V获取桶 2 的锁 获取桶 5 的锁 | | V V缓存不存在,遍历所有桶 缓存不存在,遍历所有桶 | | V V ...... 遍历到桶 2 | 尝试获取桶 2 的锁 | | V V 遍历到桶 5 桶 2 的锁由 CPU1 持有,等待释放尝试获取桶 5 的锁 | V桶 5 的锁由 CPU2 持有,等待释放!此时 CPU1 等待 CPU2,而 CPU2 在等待 CPU1,陷入死锁!
这里有一个办法就是,如果发现没有需要的缓存,就在开始偷之前把自己的锁释放掉。
当然这就造成了新的问题。假设在某一时刻我们放弃了自己的锁,然后开始找别的桶里空闲的缓存。这时候另一个进程调用了 bget()
函数,并且 blockno 还是同一个。那么这另一个个进程也会进入到找空闲缓存的状态。
在两个进程都找到了空闲缓存后,它们会把两个缓存都加到当前 blockno 的桶中,这样一个 blockno 对应的缓存就有了两个。
所以我们需要对添加缓存的操作加锁,然后得到锁之后再检查一遍是否已经有了对应缓存(可能有别的进程在相同时间调用了 bget()
并且块号还是一样的)。
除了锁相关的问题,我们还需要考虑如何找出最长时间没用过的缓存(LRU, least recent used)。因为 LRU 缓存通常在短时间之内不会再用到,所以在缓存不够的时候一般会回收这些缓存。
在原来的设计中,我们维护了一个双向链表,如果有新释放的缓存就加到链表的前面。所以链表尾部的缓存是最久没使用的,反之亦然。
但是在新设计中,我们维护了好几条链表(桶)没有办法在这些链表之间做比较。那么我们可以给 buf
结构体新加一个 lst_use
属性,表示最后一次使用的时间。而这个最后使用的时间可以从 ticks
全局变量获得,这个变量是由计时器中断维护的。代码如下:
//trap.c……voidclockintr(){ acquire(&tickslock); ticks++; wakeup(&ticks); release(&tickslock);}if(cpuid() == 0){ clockintr();}……
binit()
:
#define BUCK_SIZ 13#define BCACHE_HASH(dev, blk) (((dev << 27) | blk) % BUCK_SIZ) // 支持多个 dev // 其实也可以直接模 BUCK_SIZ// or 13, 1009, 10007struct { struct spinlock bhash_lk[BUCK_SIZ]; // buf hash lock struct buf bhash_head[BUCK_SIZ]; // 每个桶的开头,不用 buf* 是因为我们需要得到某个 buf 前面的 buf // 用了指针会比较麻烦,见后文 struct buf buf[NBUF]; // 最终的缓存 // Linked list of all buffers, through prev/next. // Sorted by how recently the buffer was used. // head.next is most recent, head.prev is least.} bcache;voidbinit(void){ for (int i = 0; i < BUCK_SIZ; i++){ initlock(&bcache.bhash_lk[i], "bcache buf hash lock"); bcache.bhash_head[i].next = 0; } for(int i = 0; i < NBUF; i++){ // 最开始把所有缓存都分配到桶 0 上 struct buf *b = &bcache.buf[i]; initsleeplock(&b->lock, "buf sleep lock"); b->lst_use = 0; b->refcnt = 0; b->next = bcache.bhash_head[0].next; // 往 0 的头上插 bcache.bhash_head[0].next = b; }}
bget()
:
这个就是我们主要修改的函数
// Look through buffer cache for block on device dev.// If not found, allocate a buffer.// In either case, return locked buffer.static struct buf*bget(uint dev, uint blockno){ struct buf *b; uint key = BCACHE_HASH(dev, blockno); acquire(&bcache.bhash_lk[key]); for(b = bcache.bhash_head[key].next; b; b = b->next){ // 查看 blockno 是否在对应的桶里被缓存 if(b->dev == dev && b->blockno == blockno){ b->refcnt++; release(&bcache.bhash_lk[key]); acquiresleep(&b->lock); return b; } } release(&bcache.bhash_lk[key]); int lru_bkt; struct buf* pre_lru = bfind_prelru(&lru_bkt); // pre_lru 会返回空闲缓存前一个(链表中前一个)缓存的地址 // 并且确保拿到了缓存对应的桶锁 // 我们会传进去一个 lru_bkt,函数执行好后,这个值会储存缓存对应的桶 if(pre_lru == 0){ panic("bget: no buffers"); } struct buf* lru = pre_lru->next; // lru (lru 是最久没有使用的缓存,并且 refcnt = 0)是 pre_lru 后面的一个 pre_lru->next = lru->next; // 让 pre_lru 的后面一个直接变成 lru 的后面一个,相当于删除 lru release(&bcache.bhash_lk[lru_bkt]); acquire(&bcache.bhash_lk[key]); for(b = bcache.bhash_head[key].next; b; b = b->next){ // 拿到锁之后要确保没有重复添加缓存 if(b->dev == dev && b->blockno == blockno){ b->refcnt++; release(&bcache.bhash_lk[key]); acquiresleep(&b->lock); return b; } } lru->next = bcache.bhash_head[key].next; // 把找到的缓存添加到链表头部 bcache.bhash_head[key].next = lru; lru->dev = dev, lru->blockno = blockno; lru->valid = 0, lru->refcnt = 1; release(&bcache.bhash_lk[key]); acquiresleep(&lru->lock); return lru;}
bfind_prelru()
:
比较关键的一个函数,接收一个 lru_bkt
的指针,然后返回最久没使用的,ref_cnt
为 0 的缓存的前一个缓存的地址。注意我们需要一直持有 lru
所在的桶的锁。要不在然释放掉这个锁后,把缓存添加近当前桶前,这个缓存(指 lru)可能会被修改。
传进 lru_bkt
指针是因为我们希望给 lru_bkt
赋值,这样函数返回后我们能知道去释放哪个桶的锁。
struct buf* bfind_prelru(int* lru_bkt){ // 返回 lru 前面的一个,并且加锁 struct buf* lru_res = 0; *lru_bkt = -1; struct buf* b; for(int i = 0; i < BUCK_SIZ; i++){ acquire(&bcache.bhash_lk[i]); int found_new = 0; for(b = &bcache.bhash_head[i]; b->next; b = b->next){ if(b->next->refcnt == 0 && (!lru_res || b->next->lst_use < lru_res->next->lst_use)){ lru_res = b; found_new = 1; } } if(!found_new){ // 没有更好的选择,就一直持有这个锁(需要确保一直持有最佳选择对应桶的锁) release(&bcache.bhash_lk[i]); }else{ // 有更好的选择(有更久没使用的) if(*lru_bkt != -1) release(&bcache.bhash_lk[*lru_bkt]); // 直接释放以前选择的锁 *lru_bkt = i; // 更新最佳选择 } } return lru_res;}
brelse()
:
// Release a locked buffer.// Move to the head of the most-recently-used list.voidbrelse(struct buf *b){ if(!holdingsleep(&b->lock)) panic("brelse"); releasesleep(&b->lock); uint key = BCACHE_HASH(b->dev, b->blockno); // 改成散列表后要先得到 key acquire(&bcache.bhash_lk[key]); b->refcnt--; if (b->refcnt == 0) { // no one is waiting for it. b->lst_use = ticks; } release(&bcache.bhash_lk[key]);}
bpin
和 bunpin
:
voidbpin(struct buf *b) { uint key = BCACHE_HASH(b->dev, b->blockno); acquire(&bcache.bhash_lk[key]); b->refcnt++; release(&bcache.bhash_lk[key]);}voidbunpin(struct buf *b) { uint key = BCACHE_HASH(b->dev, b->blockno); acquire(&bcache.bhash_lk[key]); b->refcnt--; release(&bcache.bhash_lk[key]);}
upd@2022/9/14:最近把实验的代码放到 github 上了,如果需要参考可以查看这里:
https://github.com/ttzytt/xv6-riscv
里面不同的分支就是不同的实验。
这个 lab 的描述属实是长,不过很多的篇幅都在介绍 E1000 网卡。最终的任务其实很简单,就是实现 E1000 网卡驱动中的 transmit()
和 recv()
函数。
这个 lab 的代码不复杂,但写出来需要对 lab 中的提示有很好的理解。同时,也需要查阅 E1000 的文档。
下面先介绍处理器和 E1000 交互的方法,随后再介绍两个函数的具体实现方法。
E1000 使用了 DMA(direct memory access)技术,可以直接把接收到的数据包写入计算机的内存,这在数据量大的时候非常有用,可以当作缓存。
在发送时也可以把描述符(见下文)写入内存的特定位置,这样 E1000 就会自己去找到待发送的数据,然后发送。
不管是接收还是发送,数据包都是以描述符数组描述的。在下面的接收和发送部分,会分别介绍接收描述符和发送描述符的格式。
如果网卡收到了数据,会产生一个中断,然后调用对应的中断处理程序去处理这个新到达的数据。
接收描述符的格式如下:
在 xv6
中,这个描述符的定义如下:
// [E1000 3.2.3]struct rx_desc{ uint64 addr; /* Address of the descriptor's data buffer */ uint16 length; /* Length of data DMAed into data buffer */ uint16 csum; /* Packet checksum */ uint8 status; /* Descriptor status */ uint8 errors; /* Descriptor Errors */ uint16 special;};
我们会在内存中放一个数组的描述符,然后这个数组会被解读成一个环形队列。
如果网卡接收到了一个新的数据包,会检查环形队列 head
位置的描述符。然后把数据写入 head
描述符的缓冲区,也就是 addr
记录的地址。
这里比较重要的还有 status
和 length
属性。网卡在写入的时候就会设置这些属性。
其中,length
表示写入 addr
的数据包长度。status
则可以代表下列状态:
其中,我们需要用到的主要是 DD (Descriptor Done) 这个标志位。其表示网卡已经接收好了这个包。
在编写驱动的过程中,我们需要注意判断这个标志位,如果还没有完全接收好,我们就应该继续等待一段时间。
上面我们提到了,如果网卡收到了新的数据,会往环形队列 head
位置描述符的缓冲区写入数据,下面来讨论网卡和驱动程序是如何具体管理这个缓冲区的。
下图展示了接收描述符环形队列的结构:
初始化时,head
为 0,tail
为队列缓冲区减一。
其中,head
到 tail
的这段浅色的区域是空闲的(图好像有点问题,其实 tail
指向的位置也时空闲的)。也就是说,这个区域内的数据包都已经被软件处理好了,那么如果有新的数据包到达,网卡会把数据写入这个区域的开始,也就是 head
,把老的数据覆盖掉。网卡把老的数据覆盖掉后会把 head
的值加一。
而软件会按照顺序处理深色的区域。读取环形队列时,读取的是 tail + 1
位置描述符缓冲区的数据(这个位置是所有未处理数据中等待时间最长的),处理完这个缓冲区后会把 tail
增加一。
发送描述符的格式如下:
在 xv6
中,这个描述符的定义如下:
// [E1000 3.3.3]struct tx_desc{ uint64 addr; uint16 length; uint8 cso; // checksum offset uint8 cmd; // command field uint8 status; // uint8 css; // checksum start field uint16 special; // };
其中 addr
和 length
的作用和接收描述符的作用相同,这里不赘述。
除了这两个,我们主要还需要用到 cmd
和 status
这两个属性。
和接收标志位一样,在 status
中我们需要用到 DD 标志位,表示当前标志位指向的数据是否发送完成。
而 cmd
描述了传输这个数据包时的一些设置,或者说对于网卡的命令。
有以下的命令可以选择:
这里需要用到的命令有如下几个:
和接收描述符的环形队列略有不同,发送描述符的 head
到 tail
这段区域(途中浅色区域)表示我们希望发送,但是网卡还没发送出去的数据。
其中 head
指向等待时间最长的待发送数据,网卡会从这里开始发送。完成后会把 tail
加一而如果我们要新加入一个描述符,是从 tail
这个方向加入的,也会把 tail
加一。
为了方便网络数据的处理,xv6 还定义了一个结构体,即 struct mbuf
,如下:
struct mbuf { struct mbuf *next; // the next mbuf in the chain char *head; // the current start position of the buffer unsigned int len; // the length of the buffer char buf[MBUF_SIZE]; // the backing store};
在 e1000_transmit()
函数中,我们就需要接收一个 mbuf
类型的网络数据,然后写入 DMA 对应的内存地址,进而让网卡发送这个数据。
mbuf
的结构大致是下面这样的:
// The above functions manipulate the size and position of the buffer:// <- push <- trim// -> pull -> put// [-headroom-][------buffer------][-tailroom-]// |----------------MBUF_SIZE-----------------|//// These marcos automatically typecast and determine the size of header structs.// In most situations you should use these instead of the raw ops above.#define mbufpullhdr(mbuf, hdr) (typeof(hdr)*)mbufpull(mbuf, sizeof(hdr))#define mbufpushhdr(mbuf, hdr) (typeof(hdr)*)mbufpush(mbuf, sizeof(hdr))#define mbufputhdr(mbuf, hdr) (typeof(hdr)*)mbufput(mbuf, sizeof(hdr))#define mbuftrimhdr(mbuf, hdr) (typeof(hdr)*)mbuftrim(mbuf, sizeof(hdr))----------------MBUF_SIZE-----------------|
其中的 headroom 可以被 push 进去,用来储存网络协议的包头。在接收网络数据后也可以把中间 buffer 的部分 pull 进去来转换成如下的包头:
// an Ethernet packet header (start of the packet).struct eth { uint8 dhost[ETHADDR_LEN]; uint8 shost[ETHADDR_LEN]; uint16 type;} __attribute__((packed));
转换的部分可以在 net_rx()
函数找到:
struct eth *ethhdr;uint16 type;ethhdr = mbufpullhdr(m, *ethhdr);
而 buffer 部分是数据正文,剩下的 tailroom 是 char buf[MBUF_SIZE]
这个缓存除去前两部分的剩下部分。
在 struct mbuf
结构体中,len
表示正文的长度,head
表示 headroom 的结束位置。
在 net.c
中有很多和 mbuf
相关的函数,最主要的就是 mbufalloc()
和 mbuffree()
分别对应着 mbuf
的分配和释放。
我们可以通过特定的内存映射访问到 E1000 的控制寄存器。具体来说,是通过 e1000.c
中的 regs
全局变量加上一些偏移量。在 e1000_dev.h
中定义了额这些偏移量。
思路大概是这样的(其实就是 lab 中的提示)。
首先通过内存映射的控制寄存器得到当前环形队列的 tail(第一个没在发送的描述符位置)。然后取得 tail 对应的描述符,如下:
acquire(&e1000_lock); // 可能多个线程同时发送,所以要加锁uint idx = regs[E1000_TDT]; // transmit tail,表明第一个空闲的环形描述符struct tx_desc *desc = &tx_ring[idx];
然后检测当前描述符的状态。如果没有 E1000_TXD_STAT_DD
这个标志位,说明这一整个队列已经没有空闲的位置了(或者说这个 tail 已经碰到了环形队列的浅色区域了,也就是整个队列都储存了待发送的描述符)。在这种情况下,我们需要直接返回。
if(!(desc->status & E1000_TXD_STAT_DD)){ // 是否传输完成,没传完的话说明环形缓冲区没了,是错误 release(&e1000_lock); return -1;}
接下来需要检测这个描述符对应的 mbuf
的状态。描述符的 addr
属性会指向这个 mbuf
,如果这个描述符中的数据(也就是对应的 mbuf
)已经发送完了,那就可以把这个 mbuf
释放掉。
if(tx_mbufs[idx] != NULL){ // 这里的 buf 指向要发的数据包 // 因为前面的判断,这里肯定是发送完了 // tx_mbufs 是不需要分配的,直接指向 m 这个参数 mbuffree(tx_mbufs[idx]); tx_mbufs[idx] = NULL;}
老的释放掉之后就可以让描述符的 addr
指向当前要发送的数据了。并且还需要更新数据长度,如下:
desc->addr = m->head;desc->length = m->len;
这里有个地方我花了很久才搞懂,就是为什么要写 desc->addr = m->head
,而不是 desc->addr = m->buf
。
我一开始以为 mbuf
的 headroom 就是储存数据包头的。实际上,真正储存包头的部分是 mbuf
中间 buffer 的开头。而 headroom 只是一个“缓冲区”。比如如果我们需要把当前的包头换成另一个占用空间更大的包头,就可以先调用 mbufpullhdr()
再调用 mbufpushhdr()
。
我们可以看一个别函数调用 e1000_transmit()
的例子来了解 headroom 的作用。整个 net.c
中只有 net_tx_eth()
一个函数调用了 e1000_transmit()
。如下:
// sends an ethernet packetstatic voidnet_tx_eth(struct mbuf *m, uint16 ethtype){ struct eth *ethhdr; ethhdr = mbufpushhdr(m, *ethhdr); // 注意这里 memmove(ethhdr->shost, local_mac, ETHADDR_LEN); // In a real networking stack, dhost would be set to the address discovered // through ARP. Because we don't support enough of the ARP protocol, set it // to broadcast instead. memmove(ethhdr->dhost, broadcast_mac, ETHADDR_LEN); ethhdr->type = htons(ethtype); if (e1000_transmit(m)) { mbuffree(m); }}
这个函数的主要作用就是给以太网的数据包加上包头。ethhdr = mbufpushhdr(m, *ethhdr);
这句话缩小了 headroom 的大小,增加了 buffer 的大小。并且把增加出来的这部分空间赋值到了 ethhdr
上。
然后接下来的 memmove(ethhdr->shost, local_mac, ETHADDR_LEN);
和 memmove(ethhdr->dhost, broadcast_mac, ETHADDR_LEN);
就把数据头复制到了这个新在 headroom 中开辟出来的空间。这样 mbuf
的 buffer 部分就包括了数据头。
如果之后有更大的数据头,还可以缩小 headroom 增加 buffer 来存放。
回到 e1000_transmit()
函数的实现,在更新好描述符的 addr
和 len
后,还需要设置对这个描述符的命令:
desc->cmd = E1000_TXD_CMD_RS | E1000_TXD_CMD_EOP;
这里的两个命令在前面发送描述符的部分已经解释过了,这里不赘述。
e1000_transmit()
的最后一点代码如下:
tx_mbufs[idx] = m; // 方便之后清理regs[E1000_TDT] = (idx + 1) % TX_RING_SIZE; // 更新 tail 的位置release(&e1000_lock);return 0;
这里主要解释 tx_mbufs[idx] = m;
这句话。回想我们在该函数的前面部分检查了描述符的 E1000_TXD_STAT_DD
标志位,其表明网卡是否发送完成了这个描述符的数据。如果没有,我们会直接退出。如果有则清理这个数据缓存。
那么我们设置 tx_mbufs[idx] = m
就是为了方便检测这个标志,由此跟踪数据发送的状态。
e1000_transmit()
的完整代码如下:
inte1000_transmit(struct mbuf *m){ acquire(&e1000_lock); uint idx = regs[E1000_TDT]; struct tx_desc *desc = &tx_ring[idx]; if(!(desc->status & E1000_TXD_STAT_DD)){ release(&e1000_lock); return -1; } if(tx_mbufs[idx] != NULL){ mbuffree(tx_mbufs[idx]); tx_mbufs[idx] = NULL; } desc->addr = m->head; desc->length = m->len; desc->cmd = E1000_TXD_CMD_RS | E1000_TXD_CMD_EOP; tx_mbufs[idx] = m; regs[E1000_TDT] = (idx + 1) % TX_RING_SIZE; release(&e1000_lock); return 0;}
首先要注意的一点是,在 e1000_recv()
中,我们需要一次性读出所有的待读取数据包。也就是需要加一个循环,然后一直读取 tail
位置的描述符,直到描述符的状态为未完成接收。
对于接收到的数据包,E1000 网卡有很多种不同的中断策略。一般最常用的是 RDTR (Receive Interrupt Delay Timer 接收中断延迟计时?) 。大概就是收到一个包,并且用 DMA 写入宿主的内存后,会开启计时器,在到达设定的事件后发生中断。
这个策略的主要好处是可以减少大量包在短时间内到达时发生的中断次数。但是 xv6 中没有采用这个策略,而是每次写入宿主内存后都产生一次中断,相关的代码如下:
regs[E1000_RDTR] = 0; // interrupt after every received packet (no timer)regs[E1000_RADV] = 0; // interrupt after every packet (no timer)
那如果使用了这样的终端策略,每次中断就只需要读取一个描述符啊,为什么需要循环的读取 tail。
我个人的理解是因为在处理这样外部设备中断的时候,我们会先关闭中断。
假设大量包在短时间内到达,那么产生第一个中断后,我们会去处理这个中断。处理过程中,可能又会产生很多中断,在这样的情况下我们是接收不到这些中断的,因为处理单个描述符的速度赶不上中断的速度。
所以就需要每次处理中断时再检查是否有别的到达的包,如果有就继续读取。
回到这个函数的实现,我们还是需要先读取 tail 的位置,然后取得对应的描述符:
uint idx = (regs[E1000_RDT] + 1) % RX_RING_SIZE; // head 到 tail 是一个空的缓冲区struct rx_desc *desc = &rx_ring[idx];
要注意的是 tail 本身也是一个空的缓冲区,其数据已经在之前被处理过,所以我们需要将 tail 加一。
接下来判断,是否读完了所有待读取的描述符,方法还是使用 DD 标志位:
if(!(desc->status & E1000_RXD_STAT_DD)){ return;}
重新设置 mbuf
的长度:
rx_mbufs[idx]->len = desc->length;
和发送函数不同,这里的 mbuf
和描述符是一一对应的。也就是每个描述符的缓存都是一个之前设置好的 mbuf
。这里描述符的 addr
已经被设置过了,具体的代码在初始化函数中(这是第一次的 mbuf
,之后会覆盖掉):
// [E1000 14.4] Receive initializationmemset(rx_ring, 0, sizeof(rx_ring)); for (i = 0; i < RX_RING_SIZE; i++) { rx_mbufs[i] = mbufalloc(0); if (!rx_mbufs[i]) panic("e1000"); rx_ring[i].addr = (uint64) rx_mbufs[i]->head;}
随后需要调用 net_rx()
函数把这个 mbuf
转发到相应的网络协议栈进行处理。
net_rx(rx_mbufs[idx]);
因为上层的协议栈还需要使用这个 mbuf
,所以我们不能将其覆盖,需要给当前描述符分配一个新的 mbuf
:
rx_mbufs[idx] = mbufalloc(0);desc->addr = rx_mbufs[idx]->head;desc->status = 0;
最后一步是更新 tail 指向的位置(注意 tail 本身是已经被软件处理过的描述符):
regs[E1000_RDT] = idx;
e1000_recv()
的完整代码如下:
static voide1000_recv(void){ while(1){ uint idx = (regs[E1000_RDT] + 1) % RX_RING_SIZE; struct rx_desc *desc = &rx_ring[idx]; if(!(desc->status & E1000_RXD_STAT_DD)){ return; } rx_mbufs[idx]->len = desc->length; net_rx(rx_mbufs[idx]); rx_mbufs[idx] = mbufalloc(0); desc->addr = rx_mbufs[idx]->head; desc->status = 0; regs[E1000_RDT] = idx; }}
搞好了之后就可以顺利 AC 了:
]]>upd@2022/9/14:最近把实验的代码放到 github 上了,如果需要参考可以查看这里:
https://github.com/ttzytt/xv6-riscv
里面不同的分支就是不同的实验。
实现用户态线程。
因为我们要实现用户态的多线程机制,所以很大程度上可以参考内核态中多线程的实现。
查看 user/uthread.c
后可以发现,基本的框架已经给我们写好了,我们只需要实现一些函数的内容就行了。
那不如先把函数中要实现的内容写出来:
thread_switch()
: 这个函数和内核中的 swtch()
完全一样,用于切换处理器的上下文。和内核中相同(参考这篇文章),因为执行这个函数的过程是一个正常的函数调用,所以我们不需要保存和交换调用者保存的寄存器。thread_create()
:这个函数是用于创建新的用户线程的。参考内核态多线程的实现。我们调用 swtch()
后,决定跳转位置的是 ra 寄存器,决定恢复出来的被调用者保存寄存器的是 sp 寄存器。所以,在这个函数中,我们应该合理的设置 ra 寄存器,使得第一次执行用户函数时,是这个函数的第一条语句。thread_schedule()
:参考内核中的实现,这个函数和内核中的 scheduler()
的作用相同。也就是在当前进程调用 yield()
后,找到一个 RUNNABLE 的进程,然后执行这个进程。在 thread_schedule()
中,我们会需要调用 thread_switch()
来切换处理器的上下文。这样我们就大概的把各个函数的功能和实现思路理清楚了,接下来可以从第一个函数开始实际的实现。
首先我们要注意到,utrhead.c
原本的文件中并没有给 struct thread
加上一个上下文的属性,所以我们给他加上,上下文保存的寄存器和内核态多线程中完全相同:
struct Context{ uint64 ra; uint64 sp; // callee-saved uint64 s0; uint64 s1; uint64 s2; uint64 s3; uint64 s4; uint64 s5; uint64 s6; uint64 s7; uint64 s8; uint64 s9; uint64 s10; uint64 s11;};struct thread { char stack[STACK_SIZE]; /* the thread's stack */ int state; /* FREE, RUNNING, RUNNABLE */ struct Context ctx;};
然后 thread_switch()
差不多就可以直接把 swtch()
中的东西抄过来了:
.text /* * save the old thread's registers, * restore the new thread's registers. */ .globl thread_switch // a0 是老的上下文,a1 是新的thread_switch: /* YOUR CODE HERE */ sd ra, 0(a0) sd sp, 8(a0) sd s0, 16(a0) sd s1, 24(a0) sd s2, 32(a0) sd s3, 40(a0) sd s4, 48(a0) sd s5, 56(a0) sd s6, 64(a0) sd s7, 72(a0) sd s8, 80(a0) sd s9, 88(a0) sd s10, 96(a0) sd s11, 104(a0) ld ra, 0(a1) ld sp, 8(a1) ld s0, 16(a1) ld s1, 24(a1) ld s2, 32(a1) ld s3, 40(a1) ld s4, 48(a1) ld s5, 56(a1) ld s6, 64(a1) ld s7, 72(a1) ld s8, 80(a1) ld s9, 88(a1) ld s10, 96(a1) ld s11, 104(a1) ret /* return to ra */
那这个函数我们就写完了。
接下来是 thread_create()
。实现这个函数主要需要思考如何设置 ra 和 sp 寄存器。因为用户进程一开始的时候是没有使用寄存器的,所以如何设置上下文中的其他寄存器是无所谓的。
首先,在 thread_create()
之后,如果我们调用了 thread_schedule()
,应该执行的是线程函数的第一个语句。所以我们可以这么设置 ra:
t->ctx.ra = (uint64) func;
对于 sp,我们需要注意的是栈是从高地址到低地址增长的(我一开始没想到),那么 sp 应该被设置在栈的最高地址:
t->ctx.sp = (uint64) &t->stack + (STACK_SIZE - 1);
那么这个 thread_create()
就写完了:
void thread_create(void (*func)()){ struct thread *t; for (t = all_thread; t < all_thread + MAX_THREAD; t++) { if (t->state == FREE) break; } t->state = RUNNABLE; // YOUR CODE HERE t->ctx.ra = (uint64) func; t->ctx.sp = (uint64) &t->stack + (STACK_SIZE - 1);}
接下来可以处理 thread_schedule()
:
观察原来函数的代码可以看到,最开始的循环找到了第一个为 RUNNABLE 的线程,然后把这个线程赋值到 next_thread()
。所以很明显,我们应该交换 current_thread
和 next_thread()
的上下文。
不过这个函数有个比较坑的地方,就是在交换前写了这个东西:
t = current_thread;current_thread = next_thread; // 当前线程变成下一个线程了
那我们就需要交换 t 和 next_thread 了:
thread_switch((uint64) &t->ctx, (uint64) &next_thread->ctx);
完整代码如下:
void thread_schedule(void){ struct thread *t, *next_thread; /* Find another runnable thread. */ next_thread = 0; t = current_thread + 1; for(int i = 0; i < MAX_THREAD; i++){ if(t >= all_thread + MAX_THREAD) t = all_thread; // 循环 if(t->state == RUNNABLE) { next_thread = t; break; } t = t + 1; } if (next_thread == 0) { printf("thread_schedule: no runnable threads\n"); exit(-1); } if (current_thread != next_thread) { /* switch threads? */ next_thread->state = RUNNING; t = current_thread; current_thread = next_thread; // 当前线程变成下一个线程了 /* YOUR CODE HERE * Invoke thread_switch to switch from t to next_thread: * thread_switch(??, ??); */ thread_switch((uint64) &t->ctx, (uint64) &next_thread->ctx); } else next_thread = 0;}
看了别人的一些博客[1]后发现,这里实现的用户态多线程其实更接近协程。因为这里的线程是自愿交出处理器资源的,而不是靠定时器中断,同时,使用的核心也只有一个。
或者说,这里的函数可以把自己挂起,然后过一段时间再通过 thread_schedule()
来恢复执行。
以前看了一些协程的东西,基本上只能理解为什么协程被称作“可以被挂起的函数”,而不能理解,为什么协程是“用户态线程”,更搞不懂协程是怎么实现的。
这个感觉还是挺奇怪也挺爽的,就是在学另一个知识的时候,把以前一直都搞不懂的,看似不相关的东西给搞懂了。所以花了很久时间没学懂的时候可以先放一放,说不定以后不知道什么时候就搞懂了。
这个 lab 的描述还是挺长的,所以我就不放图片了。大概就是让我们阅读一个散列表(哈希表)的程序,然后做一些更改,使得这个程序在多线程的环境下也可用。
可以尝试运行下提供给我们的程序,如果只使用一个线程,那么一切正常。如果改成两个及以上,就会发现某些在散列表中插入的键值对直接消失不见了。
为了解决这个问题,我们可以先看一遍这个散列表,找一找问题出现的地方。这个程序中,最关键的有三个函数 insert()
,put()
和 get()
。我们可以一个接一个看:
首先是 insert()
:
static void insert(int key, int value, struct entry **p, struct entry *n){ struct entry *e = malloc(sizeof(struct entry)); e->key = key; e->value = value; e->next = n; *p = e; // 把 p table[i] 的起始点改成 e}
我们知道,在散列表中,如果哈希函数把多个不同的键映射到了同一个位置,就会需要把这个当作链表的形式,在查找时遍历这个链表来找到正确的键值对。
这个 insert()
函数做的就是在链表中插入元素的工作。其中,e
是一个新被插入链表 *p
中的元素,我们先利用参数初始化了 e
的各个属性。
特别需要注意的是 e->next = n
这句话,这里的 n
是链表 table[i]
或者说 *p
的第一个元素,那么 e->next = n
就意味着现在把 e
插入在 *p
的前面。
下一个函数是 put()
:
static void put(int key, int value){ // is the key already present? struct entry *e = 0; for (e = table[i]; e != 0; e = e->next) { if (e->key == key) break; } if(e){ // update the existing key. e->value = value; } else { // the new is new. insert(key, value, &table[i], table[i]); // 在 table[i] 的最前面插入一个 key val 对 }}
其实就是尝试在散列表中添加一个键值对。这个函数会先尝试查找散列表中是否存在某个 key
如果存在,就用 value
替代掉原来和 key
对应的值。
如果不存在,就调用 insert()
函数插入该键值对。
最后一个重要的函数是 get()
:
static struct entry*get(int key){ int i = key % NBUCKET; struct entry *e = 0; for (e = table[i]; e != 0; e = e->next) { if (e->key == key) break; } return e;}
也就是说,遍历散列表中的对应链表,来查找值对应的键。
总的来说,这是一个比较常规的散列表实现,看似没有任何问题,但是在多线程环境下会出现一些 bug。
考虑这样一种情况[1]:
有两个键 k1 和 k2,他们属于散列表中的同一链表,并且链表中都还不存在这两个键值对。现在有两个线程 t1 和 t2,它们分别尝试在该链表中插入这两个键值。
那么有如下的可能情况:
t1 先检查了链表中不存在 k1,于是准备调用 insert()
在链表前插入键值对。
这个时候,线程调度器切换到了 t2(也可能是在多核环境下,两个线程并行执行,但是 t2 比 t1 快)。
然后 t2 也发现了链表中不存在 k2,所以调用 insert()
插入。插入之后,k2 成了链表的第一个元素。
随后 t1 也真正的插入了 k1。但是,因为 t1 并不知道 t2 已经把 k2 插入到了开头,于是在其认为的链表开头(k2 所处位置)插入了 k1,k2 就被覆盖掉了,于是造成了键值对丢失。
这样的情况下,我们需要通过加锁来解决问题。
观察前面的情况,可以发现,对于每一个散列表,在每一个时刻,只能由一个线程来操作,这里的操作包括了读取和修改。因为如果有多个线程,可能会造成某些线程获到的信息是滞后的(如前面的情况)。
所以我们可以对于散列表中的每个链表都创建一个互斥锁,然后在 put()
和 get()
的开头和结尾加锁和解锁。
那为啥不在 insert()
里加锁呢?因为 insert()
都是 put()
调用的,对于一个互斥锁,这样就会造成死锁。
所以就可以这样修改 put()
和 get()
:
pthread_mutex_t bkt_lock[NBUCKET];static void put(int key, int value){ int i = key % NBUCKET; pthread_mutex_lock(&bkt_lock[i]); // is the key already present? struct entry *e = 0; for (e = table[i]; e != 0; e = e->next) { if (e->key == key) break; } if(e){ // update the existing key. e->value = value; } else { // the new is new. insert(key, value, &table[i], table[i]); // 在 table[i] 的最前面插入一个键值对 } pthread_mutex_unlock(&bkt_lock[i]);}static struct entry*get(int key){ int i = key % NBUCKET; pthread_mutex_lock(&bkt_lock[i]); struct entry *e = 0; for (e = table[i]; e != 0; e = e->next) { if (e->key == key) break; } pthread_mutex_unlock(&bkt_lock[i]); return e;}
实现同步屏障。
先简单解释一下同步屏障是个什么东西。根据维基百科:
同步屏障(Barrier)是并行计算中的一种同步方法。对于一群进程或线程,程序中的一个同步屏障意味着任何线程/进程执行到此后必须等待,直到所有线程/进程都到达此点才可继续执行下文。
那么一个朴素的实现方法就是在一个线程到达屏障时把某个变量 +1,最后如果这个变量等于线程总数量,就可以执行了。
当然,在变量到达总数量前,我们需要让线程阻塞在屏障的位置。同时,当变量符合条件后,阻塞的线程就可以越过屏障了。
我们当然可以使用互斥锁加上轮询的方式来检查变量是否符合条件,但是这样对性能的损失是比较大的。
这样轮询的方法是被动的,也就是每个线程都去询问,那为何不让最后一个到达屏障的线程去通知其他线程呢?
pthread 库函中的条件变量实现的就是这样的功能。
举个例子,如果我们调用了 pthread_cond_wait(&cond, &mutex)
,那么在最后一个线程调用 pthread_cond_broadcast(&cond)
之前,程序就会一直阻塞。
更具体的,pthread_cond_wait(&cond, &mutex)
按照顺序干了下面的事情:
注意 1 和 2 是原子的操作。
如果有线程用条件变量发出了信号,那么:
pthread_cond_wait()
会返回。mutex
再次被锁住至于为什么条件变量一定要和一个互斥锁配合,在这里把我自己目前的认识写一下。
条件变量通常是要和一个别的变量配合着使用的,我们这里就叫这个变量 x 吧。
在调用 wait()
之前,我们肯定会先判断以下 x 是否符合一定的条件,如果符合了,那我们也没必要用 wait()
了。
如果不符合,我们会调用 wait()
,这样一旦 x 符合了条件,我们就会知道。
但是这里这个普通变量 x 一定是在多线程的环境下被使用的。那么我们在调用 wait()
之前,检查 x 的时候,就要确保我们拿到了一个保护 x 的锁。
然后调用 wait()
后,发现 x 不符合条件,那肯定是要把锁释放出来的,要不然,别的线程也没办法修改 x 使其符合条件。
相同的,如果 x 符合了条件,wait()
会返回,这个时候会拿到保护 x 的锁。因为我们也许会修改 x ,或者使用 x,如果这个时候 x 被改变了,会出问题。
那为啥要把解锁和加入等待队列做成原子操作呢?
假设有这样一个使用条件变量的程序,并且其使用的条件变量没有把解锁和加入等待队列做成原子操作[2]:
lock(x_lock) // 拿到保护 x 的锁if (x 满足条件){ unlock(x_lock); // 释放保护 x 的锁 pthread_cond_wait(&cond); // 等待信号 lock(x_lock); // dosomething 可能会更改 x dosomething();}unlock();
那么万一,在 unlock(x_lock)
之后,把当前线程放入 cond
的等待队列之前。有一个线程更改了 x 的值,并且发出了信号,当前线程就因为没被加入到等待队列,错过了这个信号。
所以必须要把放入队列和解锁做成原子操作。
艹,没想到写着写着光条件变量就扯了这么多,同步屏障倒是一点没讲。现在进入正题,来具体实现同步屏障。
我们观察一下 barrier.c
中提供的 barrier
结构体:
struct barrier { pthread_mutex_t barrier_mutex; pthread_cond_t barrier_cond; int nthread; // Number of threads that have reached this round of the barrier int round; // Barrier round} bstate;
可以看到这里的 nthread
就是之前我们提到的 “x”,因为只有不符合 nthread
,我们才会调用条件变量的 wait()
。
然后,对应的,保护 x 的锁就是 barrier_mutex
。这样的话,就可以写出下面的程序了:
static void barrier(){ // YOUR CODE HERE // // Block until all threads have called barrier() and // then increment bstate.round. // pthread_mutex_lock(&bstate.barrier_mutex); bstate.nthread++; if(bstate.nthread < nthread){ pthread_cond_wait(&bstate.barrier_cond, &bstate.barrier_mutex); // 如果没有全部到达 barrier 的位置,就等待 // 在收到信号之前,这里是阻塞的 }else{ // 如果这是最后一个线程。 bstate.nthread = 0; bstate.round++; pthread_cond_broadcast(&bstate.barrier_cond); } pthread_mutex_unlock(&bstate.barrier_mutex);}
这里需要注意一个细节,就是 pthread_cond_broadcast()
和 pthread_cond_signal()
的区别。
如果我们用了 broadcast()
,那所有在等待列表中的线程都会被唤醒,反之,signal()
只会唤醒列表中的一个线程。
在我们的情况中,如果最后一个线程执行到了屏障,那所有的线程都可以继续往下执行,所以用了 broadcast()
。
然后我们就可以愉快的 AC 了,也祝在做这个 lab 的人尽快 AC:
发现写博客还是挺重要的。有的时候把代码搞出来了不一定代表完全懂了。比如最后一个 lab 的条件变量。写的时候只是懂了他干的事情,感觉没问题。但是写博客时,就发现不知道如何解释,于是只能去查更多的资料。这大概说明了,如果想给别人讲清楚某个知识,需要对这个知识有更深刻的理解。
其次,这个 lab 的代码量是比较小的(说实话到目前为止还没做到过码量特别多的 lab)。如果没有完全理解 xv6 中线程调度和切换的原理,也能做出来。但完全理解后再做这个 lab,就能有更好的理解(特别是 uthread 那个实验,剩下两个还是跟 pthread 库的关系更多点)。
]]>upd@2022/9/14:最近把实验的代码放到 github 上了,如果需要参考可以查看这里:
https://github.com/ttzytt/xv6-riscv
里面不同的分支就是不同的实验。
这个 lab 的描述属实是简洁,其实他主要的描述在前面:
The problem
The fork() system call in xv6 copies all of the parent process’s user-space memory into the child. If the parent is large, copying can take a long time. Worse, the work is often largely wasted; for example, a fork() followed by exec() in the child will cause the child to discard the copied memory, probably without ever using most of it. On the other hand, if both parent and child use a page, and one or both writes it, a copy is truly needed.
The solution
The goal of copy-on-write (COW) fork() is to defer allocating and copying physical memory pages for the child until the copies are actually needed, if ever.
COW fork() creates just a pagetable for the child, with PTEs for user memory pointing to the parent’s physical pages. COW fork() marks all the user PTEs in both parent and child as not writable. When either process tries to write one of these COW pages, the CPU will force a page fault. The kernel page-fault handler detects this case, allocates a page of physical memory for the faulting process, copies the original page into the new page, and modifies the relevant PTE in the faulting process to refer to the new page, this time with the PTE marked writeable. When the page fault handler returns, the user process will be able to write its copy of the page.
COW fork() makes freeing of the physical pages that implement user memory a little trickier. A given physical page may be referred to by multiple processes’ page tables, and should be freed only when the last reference disappears.
大概就是说我们需要实现 UNIX 中的写时复制技术 (copy on write)。在没有写时复制的系统中,调用 fork()
时,我们会把父进程的所有的内存都拷贝到子进程的空间,自然,这个耗时是巨大且不可接受的。
并且在实际应用中,fork()
时拷贝的大部分内存都时不会被用到的,比如,在 UNIX 中新建一个进程的通常会先调用 fork()
,然后调用 exec()
。那么原先复制过来的数据就全部没用了。
在 fork()
时,只有一种情况是需要复制内存的。就是写入数据时,如果父进程或子进程尝试往某个地址写入值,那么为了确保写入的这个值不会影响别的进程,我们需要复制这个页帧。
而写时复制就是这样的一个技术,我们会把父进程和子进程共享页帧的 PTE 标为不可写的。那么有任何一个进程尝试往这个页帧写入时,就会产生缺页错误。在 usertrap()
函数中,我们可以处理这样的情况,也就是把共享页帧复制一份给尝试写入的进程,这个被复制的页帧会被标记为可写的。
实现写时复制后,可能会有多个进程同时共享一个页帧,那么只有所有的进程都不需要这个共享页帧时,我们才能真正的释放这个页帧。
然后就可以根据提示一点一点实现了:
Modify
uvmcopy()
to map the parent’s physical pages into the child, instead of allocating new pages. Clear PTE_W in the PTEs of both child and parent.
修改uvmcopy()
,把父进程的物理内存直接映射到子进程的虚拟内存上,而不是去分配新的内存。清除父进程和子进程 PTE 的 PTE_W。
修改 uvmcopy()
后,子进程和父进程相当于共享内存了,然后我们希望任何一方试图写入共享内存时都会引发缺页错误,所以要清楚 PTE_W:
// Given a parent process's page table, copy// its memory into a child's page table.// Copies both the page table and the// physical memory.// returns 0 on success, -1 on failure.// frees any allocated pages on failure.intuvmcopy(pagetable_t old, pagetable_t new, uint64 sz){ pte_t *pte; uint64 pa, i; uint flags; char *mem; for(i = 0; i < sz; i += PGSIZE){ if((pte = walk(old, i, 0)) == 0) panic("uvmcopy: pte should exist"); if((*pte & PTE_V) == 0) panic("uvmcopy: page not present"); pa = PTE2PA(*pte); *pte &= (~PTE_W); // 这里清除了 PTE_W *pte |= PTE_C; // 添加 PTE_C 代表这是一个 COW 页,之后会讲 flags = PTE_FLAGS(*pte); // if((mem = kalloc()) == 0) 这里都是实际分配内存的,需要删除 // goto err; // memmove(mem, (char*)pa, PGSIZE); if(mappages(new, i, PGSIZE, (uint64)pa, flags) != 0){ // 这里并没有把虚拟地址 i 映射到新分配的物理地址 mem // 而是映射到了父进程的物理内存 pa 上 printf("uvmcopy failed\n"); kfree(mem); goto err; } refcnt_inc(pa); // 这个东西之后会讲 } return 0; err: uvmunmap(new, 0, i / PGSIZE, 1); return -1;}
Modify
usertrap()
to recognize page faults. When a page-fault occurs on a COW page, allocate a new page with kalloc(), copy the old page to the new page, and install the new page in the PTE with PTE_W set.
修改usertrap()
来处理缺页错误。如果缺页错误发生在 COW 页上,就分配一个新的物理页,拷贝原页帧的数据到新页,并设置新页的 PTE_W。
和页表懒分配那个 lab 类似,我们也需要有一个函数判断某个虚拟地址是否是合法的,未分配的 COW 页。这个提示中说到了只有缺页错误发生在 COW 页上才能分配新的物理页。那么我们如何判断当前页是否是一个合法的 COW 页呢?这就可以利用 riscv PTE 中的保留位了。我们知道每个 PTE 中有 10 个标志位,其中已经定义了的有 8 个,剩下 10 个就是保留位,如下:
其中的 RSW 位,也就是 8 和 9 位就是保留位。
我们可以定义第 8 位为 1 的就说明当前页帧是 COW 页,所以可以在 kernel/riscv.h
中加入如下的宏定义,同时,这也解答了为什么我们之前要在 uvmcopy()
中给子进程的 PTE 设置 PTE_C:
#define PTE_V (1L << 0) // valid#define PTE_R (1L << 1)#define PTE_W (1L << 2)#define PTE_X (1L << 3)#define PTE_U (1L << 4) // 1 -> user can access#define PTE_C (1L << 8) // 这里是新加的
然后判断是否为未分配 COW 页的函数如下,和懒分配页表那个 lab 一样,我放在了 vm.c
这个文件中:
int uncopied_cow(pagetable_t pgtbl, uint64 va){ if(va >= MAXVA) return 0; pte_t* pte = walk(pgtbl, va, 0); if(pte == 0) // 如果这个页不存在 return 0; if((*pte & PTE_V) == 0) return 0; if((*pte & PTE_U) == 0) return 0; return ((*pte) & PTE_C); // 有 PTE_C 的代表还没复制过,并且是 cow 页}
接下来就可以修改 usertrap()
了:
…… syscall(); } else if((which_dev = devintr()) != 0){ // ok } else if(r_scause() == 15 && uncopied_cow(p->pagetable, r_stval())){ if(cowalloc(p->pagetable, r_stval()) < 0){ p->killed = 1; } } else { printf("usertrap(): unexpected scause %p pid=%d\n", r_scause(), p->pid); printf(" sepc=%p stval=%p\n", r_sepc(), r_stval()); p->killed = 1; }……
注意这里有一个和页表懒分配 lab 不一样的点,就是我们只会处理 scause 寄存器为 15 的情况,根据 riscv 的文档:
scause 为 15 代表尝试写入引发的缺页错误。
然后我们发现当前页是合法的 COW 页之后,就需要给这个 COW 页分配物理内存,这里也和上一个 lab 一样,我封装了一个 cowalloc()
函数:
int cowalloc(pagetable_t pgtbl, uint64 va){ pte_t* pte = walk(pgtbl, va, 0); uint64 perm = PTE_FLAGS(*pte); if(pte == 0) return -1; uint64 prev_sta = PTE2PA(*pte); // 这里的 prev_sta 就是这个页帧原来使用的父进程的页表 // 这里写 sta 是因为这个地址是和页帧对齐的(page-aligned) // 所以写个 sta 表示一个页帧的开始 uint64 newpage = kalloc(); if(!newpage){ return -1; } uint64 va_sta = PGROUNDDOWN(va); // 当前页帧 perm &= (~PTE_C); // 复制之后就不是合法的 COW 页了 perm |= PTE_W; // 复制之后就可以写了 memmove(newpage, prev_sta, PGSIZE); // 把父进程页帧的数据复制一遍 uvmunmap(pgtbl, va_sta, 1, 1); // 然后取消对父进程页帧的映射 if(mappages(pgtbl, va_sta, PGSIZE, (uint64)newpage, perm) < 0){ kfree(newpage); return -1; } return 0;}
这里需要注意一点,我们这个 memmove()
必须在 uvmunmap()
的前面(我当时调了好久)因为 uvmunmap()
之后这个父进程的物理页可能就被释放了,这个时候 memmove()
得到的是无效的数据。
看完这段程序之后,你可能会发现一个问题,就是这个父进程的页表可能被不止一个子进程共享,那我们调用 uvmunmap()
,并且 do_free
参数还是 1,这个父进程页帧不就可能会被释放吗,然后其他使用这个页帧的进程就会出问题。
这就引出了 lab 的下一个提示:
Ensure that each physical page is freed when the last PTE reference to it goes away – but not before. A good way to do this is to keep, for each physical page, a “reference count” of the number of user page tables that refer to that page. Set a page’s reference count to one when
kalloc()
allocates it. Increment a page’s reference count when fork causes a child to share the page, and decrement a page’s count each time any process drops the page from its page table.kfree()
should only place a page back on the free list if its reference count is zero. It’s OK to to keep these counts in a fixed-size array of integers. You’ll have to work out a scheme for how to index the array and how to choose its size. For example, you could index the array with the page’s physical address divided by 4096, and give the array a number of elements equal to highest physical address of any page placed on the free list bykinit()
in kalloc.c.
也就是说,我们需要使用引用计数来解决这个问题。对于每个页帧,都有一个引用计数,代表有多少个 COW 页正在使用这个页。那如果没有任何 COW 页还在使用这个页帧,我们就可以真正的释放这个页了(有点类似 close()
函数)。在 kalloc()
函数中,我们会把一个页的引用计数设为 1。然后在 kalloc()
函数中,我们需要先减少这个页的引用计数,如果减少后为 0,就可以直接释放这个页。
然后我们可以思考下如何储存这些引用计数,因为每个页帧的起始位置肯定都是能被 4096 整除的,所以我们可以直接把每个页帧的地址除以 4096 作为其编号。
那就可以写出如下的宏:
#define PG2REFIDX(_pa) ((((uint64)_pa) - KERNBASE) / PGSIZE)#define MX_PGIDX PG2REFIDX(PHYSTOP)#define PG_REFCNT(_pa) pg_refcnt[PG2REFIDX((_pa))]int pg_refcnt[MX_PGIDX];
最好照着下面这张图来理解:
里面的 PHYSTOP 和 KERNBASE 代表着内存物理地址的起始和结束,所以我们要把 pa 减去 KERNBASE 后再除以 PGSIZE。
我刚开始还很疑惑,我们在内核中开了这个数组,是存在哪里的。其实可以看下 kinit()
的实现:
voidkinit(){ initlock(&kmem.lock, "kmem"); freerange(end, (void*)PHYSTOP); // 注意这里}
这里的 end
是上图中 Free memory 的开始,定义在 kernle.ld
中,也就是说,对于内核自己的数据和代码(包括这个数组),是存在 kernel text 和 kernel data 中的,而 kalloc()
函数只会去分配 end ~ PHYSTOP 中的内存。
接下来就可以基于引用计数开始修改 kalloc.c
中的各种函数了:
首先是 kalloc()
:
void *kalloc(void){ struct run *r; acquire(&kmem.lock); r = kmem.freelist; if(r){ kmem.freelist = r->next; } release(&kmem.lock); if(r){ memset((char*)r, 5, PGSIZE); // fill with junk PG_REFCNT(r) = 1; // 注意这里,分配时总共有一个进程使用这个页帧,所以置为 1 。 } return (void*)r;}
接下来是 kfree()
:
voidkfree(void *pa){ struct run *r; if(((uint64)pa % PGSIZE) != 0 || (char*)pa < end || (uint64)pa >= PHYSTOP) panic("kfree"); acquire(&refcnt_lock); if(--PG_REFCNT(pa) <= 0){ // 先减少引用计数,如果小于等于 0 就真的释放 memset(pa, 1, PGSIZE); // Fill with junk to catch dangling refs. r = (struct run*)pa; acquire(&kmem.lock); r->next = kmem.freelist; kmem.freelist = r; release(&kmem.lock); } release(&refcnt_lock);}
其中的 refcnt_lock
是一个锁,其初始化在 kinit()
中:
voidkinit(){ initlock(&kmem.lock, "kmem"); initlock(&refcnt_lock, "ref cnt"); // here freerange(end, (void*)PHYSTOP);}
这里加锁是因为可能有多个引用某个页的进程同时 kfree()
这个页,那么他们同时减少引用计数就会造成错误的结果。
然后在 uvmcopy()
中,我们需要增加父进程页帧的引用计数(多一个进程在共享这个页帧),所以在 mappages()
后面写了 refcnt_inc()
,其定义如下:
void refcnt_inc(void* pa){ acquire(&refcnt_lock); PG_REFCNT(pa)++; release(&refcnt_lock);}
然后我们就完成了实现了引用计数的部分。
最后,还有一个提示:
修改 copyout()
的原因和上一个 lab 很类似,主要是因为有些系统调用也会去往 COW 页上写数据。因为 COW 页的 PTE_W 没有设置,就会引发缺页错误。在 trap.c
中,我们规定了如果异常是从系统调用发生的,就会直接 panic。所以在 copyout()
的时候,如果我们发现了当前页是 COW 页,就直接给他分配一个新的页。
这个 lab 不需要和上一个 lab 一样,修改 copyin
是因为,我们 copyin()
时,实际上读取的是父进程共享给我们的页帧,但是在页表懒分配的 lab 中,copyin()
时的页帧根本就没有分配一个物理地址,当然是无法读入的。
所以可以这样修改 copyout()
:
// Copy from kernel to user.// Copy len bytes from src to virtual address dstva in a given page table.// Return 0 on success, -1 on error.intcopyout(pagetable_t pagetable, uint64 dstva, char *src, uint64 len){ uint64 n, va0, pa0; while(len > 0){ va0 = PGROUNDDOWN(dstva); if(uncopied_cow(pagetable, va0)){ // 注意这里是新加的 try(cowalloc(pagetable, va0), return -1); } pa0 = walkaddr(pagetable, va0); if(pa0 == 0) return -1; n = PGSIZE - (dstva - va0); if(n > len) n = len; memmove((void *)(pa0 + (dstva - va0)), src, n); len -= n; src += n; dstva = va0 + PGSIZE; } return 0;}
然后写这个函数的时候一定要注意一个点,就是 cowalloc()
和 walkaddr()
的顺序。我之前就写错了,然后调了好久才找到问题。如果我们在 cowalloc()
之前用 walkaddr()
来查找虚拟地址对应的物理地址,查到的物理地址其实是父进程的共享页帧。
那么到时候就会往这个地址里写东西,造成错误(别的进程也会使用这个页帧)。
而在 cowalloc()
之后查找物理地址,查到的就是新分配的物理地址,写入的也是当前进程独有的页帧,不会影响别的进程。
然后写完这个,lab 就能 AC 了,如下,也祝在做这个 lab 的人尽快 AC:
真不知道为什么一些傻逼错误用 gdb 调了那么久还没发现………… 都开始怀疑编译器出错了。以后写之前还是得先想明白了再写,要不然你写了错的东西,debug 的时候也往错的方向想,那这个 bug 就永远找不出来了。
]]>upd@2022/9/14:最近把实验的代码放到 github 上了,如果需要参考可以查看这里:
https://github.com/ttzytt/xv6-riscv
里面不同的分支就是不同的实验。
删除sbrk()
系统调用里实际分配内存的部分。
这个没啥好说的,直接按照提示信息,删掉对growproc()
的调用就好了,如下:
uint64sys_sbrk(void){ int addr; int n; if(argint(0, &n) < 0) return -1; addr = myproc()->sz;// if(growproc(n) < 0) <- 这里删掉实际申请内存的部分// return -1; myproc()->sz += n; // 但是把当前进程占用空间扩大 return addr;}
然后很自然的,当我们去输入 echo hi
的时候,就报 panic 了。
实现页表的懒分配,如果发现在陷入过程中产生了缺页错误,就给这个发生错误的地址新分配一页。
查询 riscv 的手册,以及实验提示,可以找到 scause 寄存器中储存 13 和 15 代表缺页错误(试图写入或者试图读出):
那么我们在 trap.c
这个文件中可以查询 scause 寄存器,如果是 13 或 15 就进行下一步的处理:
…… } else if((which_dev = devintr()) != 0){ // ok } else if((r_scause() == 13 || r_scause() == 15)){ // do something here } else { printf("usertrap(): unexpected scause %p pid=%d\n", r_scause(), p->pid); printf(" sepc=%p stval=%p\n", r_sepc(), r_stval()); p->killed = 1; }……
这里的处理其实就是给用户分配这一页页表,我们可以把它封装成一个函数,叫做 lazy_alloc()
:
注意虽然发生缺页错误的是一个地址,但是我们需要把这个地址所在的页帧映射到物理地址上,所以要先用 PGROUNDDOWN
找到这个地址所在的页帧。
int lazy_alloc(uint64 va){ struct proc *p = myproc(); uint64 page_sta = PGROUNDDOWN(va); uint64* newmem = kalloc(); if(newmem == 0){ return -1; } memset(newmem, 0, PGSIZE); if(mappages(p->pagetable, page_sta, PGSIZE, (uint64)newmem, PTE_W|PTE_R|PTE_X|PTE_U) != 0){ kfree(newmem); return -1; } return 0;}
并且,在调用 mappages()
映射的时候,需要注意这个页表的权限,因为是允许在用户态使用的,所以要把 PTE_U
设置上。
改好这些代码,我们再去执行 echo hi
,会发现 uvmunmap()
这个函数会报 panic。
这是因为,我们采取页表懒分配之后,有些页可能一直都没被使用就被 uvmunmap()
了,这个时候,因为想要 unmap 的页根本就没有实际的分配,就会 panic,所以我们需要去修改一下 uvmunmap()
这个函数:
voiduvmunmap(pagetable_t pagetable, uint64 va, uint64 npages, int do_free){ uint64 a; pte_t *pte; if((va % PGSIZE) != 0) panic("uvmunmap: not aligned"); for(a = va; a < va + npages*PGSIZE; a += PGSIZE){ if((pte = walk(pagetable, a, 0)) == 0) continue; // 从 panic 改成 continue // panic("uvmunmap: walk"); // 释放进程的时候会用到 uvmunmap,但是有可能释放的时候这个页根本就没实际被分配 if((*pte & PTE_V) == 0) continue; // 从 panic 改成 continue // panic("uvmunmap: not mapped"); if(PTE_FLAGS(*pte) == PTE_V) panic("uvmunmap: not a leaf"); if(do_free){ uint64 pa = PTE2PA(*pte); kfree((void*)pa); } *pte = 0; }}
然后这个 lab 就可以顺利完成了。
让前面写出来的 Lazy allocation 通过 usertests 和 lazytests。
我们刚刚写出来的懒分配实际上是有些 bug 的,这个 lab 就是让我们修复这些 bug,然后通过 lazytests 和 usertests。
可以根据提示一个一个的改,首先需要处理 sbrk()
函数的参数为负数的情况。
对于正数的情况,我们只是改变进程的大小属性,并不会去实际分配空间。但如果是负数(减少当前进程空间),我们需要实际的释放空间,要不然就没法把这些内存分配给别的需要的进程,所以可以这样写:
uint64sys_sbrk(void){ int addr; int n; struct proc *p = myproc(); if(argint(0, &n) < 0) return -1; addr = p->sz; if(n < 0){ if(p->sz + n < 0){ // 一个进程不能释放比自己大的空间 return -1; } if(growproc(n) < 0){ // 注意这里是实际调用 growproc 去释放空间的。 printf("growproc err\n"); return -1; } }else{ myproc()->sz += n; } // if(growproc(n) < 0) // return -1; return addr;}
下一个提示是:
Kill a process if it page-faults on a virtual memory address higher than any allocated with
sbrk()
.
大概就是说,如果一个进程出现缺页错误的地址以前并没有被分配过(通过调用 sbrk()
)。那么我们就不应该去分配这个页,而是直接把进程 kill 了。
可以写一个函数,用来判某个虚拟地址是否属于合法的懒分配页:
int is_lazy_addr(uint64 va){ struct proc *p = myproc(); if(va < PGROUNDDOWN(p->trapframe->sp) && va >= PGROUNDDOWN(p->trapframe->sp) - PGSIZE ){ // 防止 guard page,这个之后会提到 return 0; } if(va > MAXVA){ return 0; } pte_t* pte = walk(p->pagetable, va, 0); if(pte && (*pte & PTE_V)){ return 0; } if(va >= p->sz){ return 0; } return 1;}
首先,很明显的一点是,如果一个页有 PTE_V
的标志,那么一定不是懒分配的,因为已经分配了。
然后,如果 va >= p->sz
,就说明这个地址之前没有通过 sbrk()
申请,所以也不是懒分配。
之后再把这个函数加到 trap.c
的判断中,就变成了:
…… } else if((which_dev = devintr()) != 0){ // ok } else if((r_scause() == 13 || r_scause() == 15) && is_lazy_addr(r_stval())){ // 这里加了一个 is_lazy_addr // 如果是 page fault,那就直接分配内存 uint64 fault_addr = r_stval(); if(lazy_alloc(fault_addr) < 0){ p->killed = 1; } } else { printf("usertrap(): unexpected scause %p pid=%d\n", r_scause(), p->pid); printf(" sepc=%p stval=%p\n", r_sepc(), r_stval()); p->killed = 1; }……
接下来要解决的是:
Handle the parent-to-child memory copy in fork() correctly.
大概是说需要正确的处理 fork()
中从父进程到子进程的内存拷贝。
阅读 fork()
的代码后可以发现,执行这个内存拷贝的函数是 vm.c
中的 uvmcopy()
。其在懒分配中出现问题的原因是,父进程的某些页帧是没有实际分配的,这个时候再试图去拷贝这个页帧,uvmcopy()
函数就会报 panic。和之前处理 uvmunmap()
函数一样,这里我们只需要跳过那些懒分配的页就行了,所以直接把 panic
改成 continue
:
intuvmcopy(pagetable_t old, pagetable_t new, uint64 sz){ pte_t *pte; uint64 pa, i; uint flags; char *mem; for(i = 0; i < sz; i += PGSIZE){ if((pte = walk(old, i, 0)) == 0) continue; // 注意这里,panic 改成了 continue。 // panic("uvmcopy: pte should exist"); if((*pte & PTE_V) == 0) continue; // panic("uvmcopy: page not present"); pa = PTE2PA(*pte); flags = PTE_FLAGS(*pte); if((mem = kalloc()) == 0) goto err; memmove(mem, (char*)pa, PGSIZE); if(mappages(new, i, PGSIZE, (uint64)mem, flags) != 0){ kfree(mem); goto err; } } return 0; err: uvmunmap(new, 0, i / PGSIZE, 1); return -1;}
Handle the case in which a process passes a valid address from sbrk() to a system call such as read or write, but the memory for that address has not yet been allocated.
这个提示说说实话挺难理解的,我当时在网上查了好久才搞懂。这大概就是说,有些系统调用会在用户态的虚拟地址上写值,比如说 write()
。那万一这个地址是一个懒分配的地址,就会出问题,会引起缺页错误。当然,如果是用户态引起的缺页错误(像之前的一样)就完全没问题。但是如果我们发现内核态出现了异常,会直接 panic (见 xv6 学习笔记那篇文章)。
如果系统调用想要往用户态的虚拟地址写值(或者读值),是需要调用 copyin()
和 copyout()
的。可以观察一下这两个函数:
// Copy from user to kernel.// Copy len bytes to dst from virtual address srcva in a given page table.// Return 0 on success, -1 on error.intcopyin(pagetable_t pagetable, char *dst, uint64 srcva, uint64 len){ uint64 n, va0, pa0; while(len > 0){ va0 = PGROUNDDOWN(srcva); pa0 = walkaddr(pagetable, va0); // 注意这里 if(pa0 == 0) return -1; n = PGSIZE - (srcva - va0); if(n > len) n = len; memmove(dst, (void *)(pa0 + (srcva - va0)), n); len -= n; dst += n; srcva = va0 + PGSIZE; } return 0;}
能发现,它们都会调用 walkaddr()
来找到用户态虚拟地址对应的物理地址,而 walkaddr()
的实现如下:
// Look up a virtual address, return the physical address,// or 0 if not mapped.// Can only be used to look up user pages.uint64walkaddr(pagetable_t pagetable, uint64 va){ pte_t *pte; uint64 pa; if(va >= MAXVA) return 0; pte = walk(pagetable, va, 0); if(pte == 0) return 0; if((*pte & PTE_V) == 0) return 0; if((*pte & PTE_U) == 0) return 0; pa = PTE2PA(*pte); return pa;}
可以发现 walkaddr()
会调用 walk()
,而如果得到的结果是 0,就会直接返回 0。
我们还可以从 walkaddr()
函数作用的角度去理解。因为这个函数是用于查找虚拟地址对应的物理地址的,那一个懒分配的页帧并没有实际的物理地址,就自然找不到物理地址,所以会返回一个 0 。
也就是,如果 va
属于一个懒分配的页帧,这个 walk()
一定是会返回 0 的,具体可以看下面的代码:
pte_t *walk(pagetable_t pagetable, uint64 va, int alloc){ if(va >= MAXVA) panic("walk"); for(int level = 2; level > 0; level--) { pte_t *pte = &pagetable[PX(level, va)]; if(*pte & PTE_V) { // 这里会判断是否为分配过的地址, // 如果没分配过并且 alloc 参数还为 0,就会返回 0 pagetable = (pagetable_t)PTE2PA(*pte); } else { if(!alloc || (pagetable = (pde_t*)kalloc()) == 0) return 0; memset(pagetable, 0, PGSIZE); *pte = PA2PTE(pagetable) | PTE_V; } } return &pagetable[PX(0, va)];}
那我们可以在 walkaddr()
中判断,当前 va
是否属于懒分配的页帧,如果是的话就先别返回 0,而是先给它分配一个物理页,然后再进行后面的操作。(分配完物理页后就能查询到物理地址了)。
// Look up a virtual address, return the physical address,// or 0 if not mapped.// Can only be used to look up user pages.uint64walkaddr(pagetable_t pagetable, uint64 va){ pte_t *pte; uint64 pa; if(va >= MAXVA) return 0; if(is_lazy_addr(va)){ // 注意这里,如果是懒分配的会先分配物理地址。 lazy_alloc(va); } pte = walk(pagetable, va, 0); if(pte == 0) return 0; if((*pte & PTE_V) == 0) return 0; if((*pte & PTE_U) == 0) return 0; pa = PTE2PA(*pte); return pa;}
再看第五个提示:
Handle out-of-memory correctly: if kalloc() fails in the page fault handler, kill the current process.
也就是如果分配物理页的时候,没有足够内存了,应该把当前进程 kill 了。
其实这个东西我们已经完成了,在 trap.c
中,是这样写的:
uint64 fault_addr = r_stval();if(lazy_alloc(fault_addr) < 0){ p->killed = 1;}
如果 lazy_alloc()
不成功(没内存)就会把进程 kill 了。
然后是最后一个提示:
Handle faults on the invalid page below the user stack.
也就是正确处理发生在用户栈下面地址的缺页错误。
这个就需要复习下页表那章的内容了,下图是用户态下的内存布局:
可以看到,栈下面是一个保护页,这个页的 PTE_V
是没有设置的,如果用户访问,就会触发缺页错误。本来这个机制是没啥问题的,但是我们现在搞了懒分配,也就是触发缺页错误的时候不会 kill 掉这个进程,而是给这个地方分配物理地址。
那显然这个保护页是用于防止内存溢出的,不能去再分配物理页。所以需要在 is_lazy_addr()
这个函数中加入这个判断,如果某个地址属于保护页,那就不是一个合法的懒分配的地址,然后就有了下面的代码:
if(va < PGROUNDDOWN(p->trapframe->sp) // 这里使用了用户栈的栈指针 sp 来判断用户栈的虚拟地址 // 因为用户栈的下面就是保护页,所以把 // PGROUNDDOWN(p->trapframe->sp) 当作保护页的上界&& va >= PGROUNDDOWN(p->trapframe->sp) - PGSIZE){ return 0;}
这样写完之后就可以成功 AC 了,也祝在做这个 lab 的人尽快 AC:
感觉要提升下 debug 的能力,这个 lab 真的调了好久……
]]>前言:今天是 2022/7/25 先庆祝一下博客运行 100 天了。
upd@2022/9/14:最近把实验的代码放到 github 上了,如果需要参考可以查看这里:
https://github.com/ttzytt/xv6-riscv
里面不同的分支就是不同的实验。
先鸽了
实现一个backtrace()
的函数,如果某个程序调用了这个函数,该函数应该输出这个程序的 “函数调用顺序”,也就是把当前栈中的函数地址按照先后顺序全部打印出来。
做这个实验最主要的还是需要了解函数调用的过程,具体可以参考我之前写的这篇文章。
这里我把那篇文章中最重要的图和视频放在下面(绝对不是水字数),如果你之前比较熟悉函数调用的过程,但是现在忘了,看了之后应该比较容易回忆起来。
实验中,我们需要把函数调用的一个 “链条” 打印出来。
比如有下面这个程序:
int third(int x){ backtrace(); return x;}int second(int x){ return third(x); // 假设地址为 114}int first(int x){ return second(x); // 假设地址为 514} int main(){ int test = first(114514); // 假设地址为 1919}
那么调用 backtrace()
后的正确输出应该是
1145141919
其实就是让我们把函数调用者的地址递归的打印下去。
那我们知道,每个栈帧中都储存了当前函数的返回地址。(也就是,这个函数执行好了,应该返回到哪里)。
所以可以直接把每个栈帧中的返回地址打印出来。还应该开一个变量储存当前帧指针的位置,通过这个帧指针加上一些偏移量,获取上一个函数的帧指针,就可以打印上一个函数的返回地址了。
不过要注意的是,在我原来那篇文章中,使用的是 x86 (x64) 架构的处理器,其帧指针的名称为 bp (base pointer) 寄存器,在 riscv 中,fp (frame pointer) 寄存器做了相同的工作。
并且 riscv 中的 fp 指向的位置也和 x86 中的略有不同,具体可以看下面这张图[1]:
高地址Stack . . +-> . | +-----------------+ | | | return address | | | | previous fp ------+ | | saved registers | | | local variables | | | ... | <-+ | +-----------------+ | | | return address | | +------ previous fp | | | saved registers | | | local variables | | +-> | ... | | | +-----------------+ | | | return address | | | | previous fp ------+ | | saved registers | | | local variables | | | ... | <-+ | +-----------------+ | | | return address | | +------ previous fp | | | saved registers | | | local variables | | $fp --> | ... | | <-- 注意这里!!! +-----------------+ | | return address | | | previous fp ------+ <-- 如果是 x86,那么 bp 指针会指向这里 | saved registers | $sp --> | local variables | +-----------------+低地址(增长方向)
可以发现,在 riscv 中,fp 指向的是当前栈帧返回地址前面的一个位置(地址更高)。但是在 x86 中,bp 指向的是前一个栈帧的 bp 寄存器。
这个大概是因为 x86 和 riscv 对于一个栈帧定义的不同。在 riscv 的定义中,返回地址也是属于当前栈帧的的一部分(说实话我觉得这个设计更合理)。
虽然我们总是可以通过 fp 获得函数的返回地址,但是还需要获得到当前的 fp,这就需要用到 c 语言的内嵌汇编了,我们可以把这个函数放到 kernel/riscv.h
里:
static inline uint64r_fp(){ uint64 x; asm volatile("mv %0, s0" : "=r" (x) ); return x;}
GCC 拓展内联汇编的基本格式是:
asm asm-qualifiers ( AssemblerTemplate : OutputOperands [ : InputOperands [ : Clobbers ] ])
其中,asm 代表着内联汇编的开始,asm-qualifiers 表示这个内联汇编的一些性质,比如我们这里加了 volatile 就表示不希望 GCC 把这个汇编优化掉。
在上面的 ("mv %0, s0" : "=r" (x) )
中,mv %0, s0
是一个汇编的模板,并不是真正的汇编,有点类似于 C++ 中的模板,在编译的时候会把类型替换掉。GCC 编译的时候也会把 %0
这个东西替换成后面 : "=r" (x)
规定的变量(这里是 x)所在的寄存器。
而这个 "=r"
代表了一种限制条件,里面的 r
表示这个 x 变量可以在任何的通用寄存器中,而等于号表明该变量是被写入的。
除了 r
,还有很多种限制条件[2],比如 m 代表了该变量可以储存在内存中。如果你还想了解更多的限制条件,可以参考 GCC 的文档。
GCC 的文档中对拓展内联汇编也有非常详细的解释。
所以,总的来说,r_fp()
这个函数读出了 s0
这个寄存器的值,然后储存在 x
中,最后又把 x
返回了。
但是我们要读取的明明是 fp 这个寄存器,为什么这个函数里写的是 s0
呢,具体可以看看下面这个表[3]:
在 ABI Name 那一列,可以看到 s0 其实就是 fp 的别名。
有了这些知识,就可以写出 backtrace()
这个函数了:
void backtrace(){ printf("in bt\n"); // 帧指针下面的是返回地址 // 再下面一个是上一个栈帧的帧指针 uint64* cur_frame = (uint64 *)r_fp(); uint64* top = PGROUNDUP((uint64)cur_frame); uint64* bot = PGROUNDDOWN((uint64)cur_frame); while(cur_frame < top && cur_frame > bot){ printf("%p\n", cur_frame[-1]); // 先打印当前的返回地址 cur_frame = cur_frame[-2]; // 然后把当前栈帧变成上一个栈帧 }}
可以看到这里用了一些很奇怪的写法,好像是负数下标的数字,其实这个 cur_frame[-1]
等价于 *(cur_frame - 1)
。并且,因为这里 cur_frame
是六十四位的指针,所以 *(cur_frame - 1)
是读取 cur_frame
前八个字节位置的数据。
这里使用 PGROUNDDOWN
和 PGROUNDUP
是因为,一连串的函数调用最多放在一个页中。那么如果我们在递归打印的时候,超出了这一页的范围,就可以说明已经是最底层的函数,可以停止了。
最后我们按照要求在 sys_sleep()
这个系统调用里添加一下 backtrace()
,就完成这个 lab 了。
实现一个sigalarm(interval, handler)
的系统调用。及每过 interval 个时钟周期,就执行一遍 handler 这个函数。此外还要实现一个sigreturn()
系统调用,如果 handler 调用了这个系统调用,就应该停止执行 handler 这个函数,然后恢复正常的执行顺序。如果说sigalarm
的两个参数都为 0,就代表停止执行 handler 函数。
其实理解这个 lab 还是挺难的,特别是 sigreturn
,具体可以看看 alarmtest.c
这个程序,然后就是,需要对陷入的过程有比较好的理解,如果不熟悉,可以看看我的这篇文章:
voidperiodic(){ count = count + 1; printf("alarm!\n"); sigreturn();}// tests whether the kernel calls// the alarm handler even a single time.voidtest0(){ int i; printf("test0 start\n"); count = 0; sigalarm(2, periodic); for(i = 0; i < 1000*500000; i++){ if((i % 1000000) == 0) write(2, ".", 1); if(count > 0) break; } sigalarm(0, 0); if(count > 0){ printf("test0 passed\n"); } else { printf("\ntest0 failed: the kernel never called the alarm handler\n"); }}
这个 sigreturn
的意思就是,我们本来可能在执行这个 for 循环中的代码,然后突然开始执行 periodic()
这个函数(因为时间到了)。如果在 periodic()
函数中调用了 sigreturn()
。就应该停止执行 periodic()
里的东西,然后回到 for 循环中执行。(可以看这个 up 主讲的,比较清晰)
这里我们可以依次查看 alarmtest.c
中的几个 test(或者说就是测试点),然后按照这些测试点的要求去实现这个系统调用。
Get started by modifying the kernel to jump to the alarm handler in user space, which will cause test0 to print “alarm!”. Don’t worry yet what happens after the “alarm!” output; it’s OK for now if your program crashes after printing “alarm!”. Here are some hints:
大概就是说,我们可以先尝试去正确的跳转到用户态去执行 handler 函数(为了保持隔离性,不能在内核里直接把这个函数执行了),如果跳转之后报错了也没关系。
首先可以回忆下 xv6 发生陷入的过程,我们是根据 epc 这个寄存器来判断陷入之后返回的地址的。如果直接改变了 epc 的地址,就可以在返回的之后跳转到 handler 的地址。
那如何判断时候到了要跳转的时间呢?
riscv 的硬件(其实我不太确定是哪个硬件)会每过一个时钟周期都产生一个时钟中断,而 trap.c
会处理这个中断。
我们可以依靠这个中断出现的次数去判断是否应该跳转,如果需要,就直接在 trap.c
中把 trapframe 里 epc 的值改了(改成 handler 的)。
因此需要在 struct proc
给每个进程加入如下的属性:
uint64 alarm_tks;
用于记录执行 handler 的间隔,如果为 0 代表不执行void (*alarm_handler)();
handler 的地址uint64 alarm_tk_elapsed;
距离上次执行 handler 过去的时间并且在 sys_sigalarm()
把获取到的这些参数存入这些属性中,对于 sys_sigreturn()
,我们先不做任何操作,直接返回一个 0:
uint64 sys_sigalarm(void){ int ticks; struct proc* p = myproc(); uint64 handler; try(argint(0, &ticks), return -1); try(argaddr(1, &handler), return -1); p->alarm_tks = ticks; p->alarm_handler = handler; p->alarm_tk_elapsed = 0; return 0;}
相应的,我们创建了这些属性,就需要在进程的初始化函数 allocporc()
和释放函数 freeproc()
中做相应的初始化和释放。
首先是 allocporc()
的改动:
…… p->alarm_tk_elapsed = 0; p->alarm_state = 0; p->alarm_tks = 0; return p;}
然后是 freeproc()
:
…… p->alarm_handler = 0; p->alarm_tk_elapsed = 0; p->alarm_tks = 0;}
接下来就可以在 trap.c
的 usertrap()
中函数实现跳转了:
…… if(which_dev == 2){ // 时钟中断的编号为 2 if(p->alarm_tks > 0){ p->alarm_tk_elapsed++; // 距离上次执行 handler 经过的时间 if(p->alarm_tk_elapsed > p->alarm_tks){ // 如果超过了规定的时间 p->alarm_tk_elapsed = 0; p->trapframe->epc = p->alarm_handler; // 直接改 epc,这样回用户态的时候就会执行地址为 epc 的指令 } } yield(); }
这样我们就能顺利的跳转到 handler,并且通过 test0,当然也毫无悬念的报错了。
报错的主要原因是还没实现 sys_sigreturn()
,这样在执行完 handler 函数之后就不知道返回哪里了。
而要通过 test1 和 test2 就必须解决这个问题:
Chances are that alarmtest crashes in test0 or test1 after it prints “alarm!”, or that alarmtest (eventually) prints “test1 failed”, or that alarmtest exits without printing “test1 passed”. To fix this, you must ensure that, when the alarm handler is done, control returns to the instruction at which the user program was originally interrupted by the timer interrupt. You must ensure that the register contents are restored to the values they held at the time of the interrupt, so that the user program can continue undisturbed after the alarm. Finally, you should “re-arm” the alarm counter after each time it goes off, so that the handler is called periodically.
大概的意思是,我们需要在执行完 handler 后返回到正确的位置。
需要注意的是,我们跳转到内核去响应陷入和系统调用时,寄存器的值是会改变的,这样就算通过改变 epc 的值回到了正确的位置,也不能正确的执行(没有把寄存器的环境备份下来)。
因此我们在 struct proc
再加一个 struct trapframe
类的属性,用于备份执行 handler 前的环境:
……struct trapframe *trapframe; // data page for trampoline.Sstruct trapframe *alarmframe; // 新增的备份 trapframe……
当然,在 allocproc()
和 freeproc()
中的初始化和释放也是少不了的:
allocproc()
:
……if((p->alarmframe = (struct trapframe *)kalloc()) == 0){ freeproc(p); release(&p->lock); return 0;}……
freeproc()
:
if(p->alarmframe) kfree((void*)p->alarmframe);p->alarmframe = 0;
alarmframe
可以在 trap.c
里的 usertrap()
获取,也就是需要执行 handler 的时候,我们先备份一下环境,然后再执行:
if(which_dev == 2){ if(p->alarm_tks > 0){ p->alarm_tk_elapsed++; if(p->alarm_tk_elapsed > p->alarm_tks){ p->alarm_tk_elapsed = 0; *p->alarmframe = *p->trapframe; // 注意这里 p->trapframe->epc = p->alarm_handler; } } yield();}
在 sys_sigreturn()
里面,我们应该去按照 alarmframe
恢复 trapframe
,这样包括 epc 在内的所有通用寄存器都会被恢复,自然也就会跳出 handler,按照原来的顺序执行程序了:
uint64sys_sigreturn(void){ struct proc* p = myproc(); *p->trapframe = *p->alarmframe; return 0;}
到这里,我们再去运行 alarmtest,会发现还是不能完全过。
试想这样一个情况,如果 handler 执行的特别慢,自从上次调用 handler 已经过去了规定的时钟周期,但是 handler 还没执行好,这个时候我们又去改一遍 epc,这个 handler 又从头开始执行了,那着不就出大问题了,因为我们每次都会去改 epc,然后就永远执行不完 handler 了。
测试程序里就包括了这个情况:
voidslow_handler(){ count++; printf("alarm!\n"); if (count > 1) { printf("test2 failed: alarm handler called more than once\n"); exit(1); } for (int i = 0; i < 1000*500000; i++) { // 超慢的 handler asm volatile("nop"); // avoid compiler optimizing away loop } sigalarm(0, 0); sigreturn();}
所以我们需要在 struct proc
里再加一个属性,就是 alarm_state
。如果这个属性为 1,就表示,handler 程序正在执行,这个时候就算又过了 tick 个时钟周期,我们也不能去改 epc 让 handler 重复执行。
因为新添加了一个属性,所以 allocproc
和 freeproc
也需要改,这里就不细讲了。
更重要的还是要更改 usertrap()
函数中的东西:
if(which_dev == 2){ if(p->alarm_tks > 0){ p->alarm_tk_elapsed++; if(p->alarm_tk_elapsed > p->alarm_tks && !p->alarm_state){ // 注意这里必须是 p->alarm_state 为 0 p->alarm_tk_elapsed = 0; *p->alarmframe = *p->trapframe; p->trapframe->epc = p->alarm_handler; p->alarm_state = 1; // 注意这里:改了 epc 就代表开始执行了 } } yield();}
同时,sys_sigreturn()
函数里的东西也要改,因为调用了这个函数就代表 handler 不再执行了:
uint64sys_sigreturn(void){ struct proc* p = myproc(); *p->trapframe = *p->alarmframe; p->alarm_state = 0; // 更改 alarm_state 的值为 0,代表 handler 停止执行 return 0;}
改完之后就能成功 AC 了,也祝现在做这个实验的人尽快 AC:
比起这里的实验,其实更重要的还是理解 xv6 中陷入的过程,就算没有完全理解陷入过程,也能一步一步的照着实验指导做出这些实验。当然,要理解这里的陷入机制也属实是令人头疼,毕竟有很多以前从来没接触过的 riscv 汇编和底层的知识。虽然难理解,但理解和完成实验后,会让人不由自主的感叹操作系统设计的巧妙。
做完这个实验后,以前很多对操作系统的疑问也解决了,比如像 alarm 实验的原理。同时,也发现自己对汇编的理解还很浅。具体可以看 xv6 笔记那篇文章,一直理解不了为什么 userret
和 uservec
里要交换 sscratch
寄存器,后来问了才知道这个是特权级寄存器,不能用 ld,和 sd 这样的指令操作(实际上现在也没理解这样设计的原因)。
这个题需要使用到一个异或的性质。我们可以发现,对多个 0 或 1 连续的异或时,只有出现奇数个 1 才能使运算结果为 1。
因为如果出现了偶数个 1,那么对于每一个 1,总是能找到另一个 1 让它们的异或值变为 0 。而 0 的出现不会影响最终的结果,所以如果出现了偶数个 1,最后的结果一定是 0 。
The Xor-value of a node is defined as the bitwise XOR of all the binary values present in the subtree of that node.
题面中的这一句话表明,一个树的异或值被定义为该树下每个节点的异或和。
或者说,设当前树的根节点为 , 有 个子节点(包括不直接的,比如其子树的孩子),这些子节点的值是 。那么 的异或值就是:
因为每个子节点的值要么是 1 要么是 0 。我们根据上面提到的性质就可以知道,如果当前树的异或值为 1,那么其所有子树中,一定有奇数个的值为 1 ,反之亦然。
也就是说如果树 的异或值为 1,那么:
题目要求有 个子树的异或值为 1。那么我们就可以确定,对于这 个子树中的每个子节点,它们的值的和必须是奇数。
我们设 为树 中有少个值为 1 的子节点,当前树为 ,并且现在还需要 个树的异或值为 1(也就是说已经有些树的异或值为 1 了)。
那么如果 ,并且 ,也就是其所有子节点的值为 1 的有偶数个。那么我们应当把这个节点的值设成 1。
这是因为 ,我们还需要更多的树的异或值为 1,而当前这个树,因为其子节点的值为 1 的有偶数个,所以其异或值不是 1 。如果我们把这个树本身的值改为 1,其异或值就变为了 1,达到了我们让更多树的异或值为 1 的目标。
反过来讲,如果 ,我们不需要更多的树的异或值为 1 了,但是 ,也就是其所有子节点的值的和为奇数,那么我们应该把 设为 1 。
这是因为我们不想要产生更多异或值为 1 的树了,把 设成 1 就可以把其所有节点的值的和变为偶数, 的异或值也会变为 。
有了这两点结论,就可以使用 dfs 来找到答案了。
// tzyt#include <bits/stdc++.h>using namespace std;#define ll long longconst int MAXN = 2e5 + 10;vector<int> e[MAXN];// k 个奇数大小的子树int od_cnt[MAXN];int n, k;void dfs(int cur, string& ans) { for (int nex : e[cur]) { dfs(nex, ans); od_cnt[cur] += od_cnt[nex]; } if (k) { if ((od_cnt[cur] & 1) == 0) { // 子树里节点为 1 的是偶数个 // 将其变为奇数个 ans[cur] = '1'; od_cnt[cur]++; } k--; } else { // 已经满足条件了,但是可能多一个出来 if(od_cnt[cur] & 1){ // 子节点里为 1 的是奇数个 ans[cur] = '1'; od_cnt[cur]++; } }}int main() { int t; cin >> t; while (t--) { cin >> n >> k; for_each(e + 1, e + 1 + n, [](vector<int>& a) { a.clear(); }); string ans; ans.resize(n + 1); for_each(ans.begin(), ans.end(), [](char &a){a = '0';}); fill(od_cnt + 1, od_cnt + 1 + n, 0); // 重置数据 for (int i = 2; i <= n; i++) { int tmp; cin >> tmp; e[tmp].push_back(i); } dfs(1, ans); for (int i = 1; i <= n; i++) { cout << ans[i]; } cout << '\n'; }}
我原来想的是,把每个限制按照位置排序,如果位置一样,就按照值排序。
然后再遍历每个限制,交错的插入每个限制和没被限制的值(根据它们的值,因为题目要求字典序最小)。这里说的估计不清楚,下面是我之前的代码:
/*Date: 22 - 07-20 20 10PROBLEM_NUM: */#define FDEBUG#if (defined FDEBUG) && (!defined ONLINE_JUDGE)#define DEBUG(fmt, ...) fprintf(stderr, fmt, ##__VA_ARGS__)#define DWHILE(cnd, blk) \ while (cnd) blk#define DFOR(ini, cnd, itr, blk) \ for (ini; cnd; itr) blk#else#define DEBUG(fmt, ...)#define DWHILE(cnd, blk)#define DFOR(ini, cnd, itr, blk)#endif#include <bits/stdc++.h>using namespace std;#define ll long long#define pause system("pause")#define IINF 0x3f3f3f3f#define rg register// keywords:struct Constrain { int val, pos; bool operator<(Constrain b) const { if (pos != b.pos) return pos < b.pos; return val < b.val; } bool operator>(Constrain b) const { return b < *this; }};int main() { int t; cin >> t; while (t--) { int n, m; cin >> n >> m; priority_queue<Constrain, vector<Constrain>, greater<Constrain>> pq; vector<int> ans; ans.reserve(n); set<int> ncons; for (int i = 1; i <= n; i++) { ncons.insert(i); } for (int i = 0; i < m; i++) { Constrain tmp; cin >> tmp.val >> tmp.pos; pq.push(tmp); ncons.erase(tmp.val); } while (pq.size()) { auto tp = pq.top(); pq.pop(); bool used = false; if (ans.size() >= tp.pos) { goto FAIL; } while (ans.size() < tp.pos - 1) { int ist = *ncons.begin(); if (tp.val < ist) { ans.push_back(tp.val); used = true; } else { ans.push_back(ist); ncons.erase(ist); } } if (!used) { ans.push_back(tp.val); } } while (ncons.size()) { int ist = *ncons.begin(); ans.push_back(ist); ncons.erase(ist); } SUCC: for (int cur : ans) { cout << cur << ' '; } cout << '\n'; continue; FAIL: cout << "-1\n"; } pause;}
这么瞎搞会造成一个问题,假设我们把每种限制按照之前说的方法排序,并且设这些限制为
那么 中的数字只会在 这个区间中出现,不符合题目要求。所以才会疯狂 WA。
正确的解法是从后往前的计算。
我们维护一个大根堆 ,然后后往前遍历每个位置(就是题目的排列的位置)。
如果有些限制的位置就是当前遍历到的这个,那么我们就把这些限制的值加入 。然后对于每个遍历到的位置,就可以直接从 中取出栈顶的元素,放入答案中。
这样,只有当前的位置小于某个限制的位置,我们才可能从 中拿到这个限制的值,因此每个从 中拿到的元素都是合法的。
同时,在满足合法的同时,这些元素还是最大的,那么因为我们是从后往前遍历的,就确保了最后得到的排列字典序是最小的。
最后还需要考虑什么情况下输出 -1。因为 存的是所有这个位置合法的元素,那么如果 中拿不出任何东西了,就说明不能产生一个合法的排列。
最后,还有一点需要注意,对于那些没有任何限制的数字,我们可以在一开始就直接把他们加入 中,或者说这些数字的限制位置就是 。
// tzyt#include <bits/stdc++.h>using namespace std;// keywords:int main() { int t; cin >> t; while (t--) { int n, m; cin >> n >> m; vector<int> lim(n + 1, n), ans(n + 1); // 默认就是只要 n 前面就行(没有任何限制) vector<vector<int>> lislim(n + 1); // lislim[i] 储存所有限制位置为 i 的值 for (int i = 1; i <= m; i++) { int val, pos; cin >> val >> pos; lim[val] = pos; } for (int i = 1; i <= n; i++) { lislim[lim[i]].push_back(i); } priority_queue<int> pq; for (int i = n; i >= 1; i--) { for (int cur : lislim[i]) { // 到了某个限制的点,就会有新的数字可用 pq.push(cur); } if (pq.empty()) { // 空的话就是没有合法元素了 goto FAIL; } ans[i] = pq.top(); pq.pop(); } SUCC: for (int i = 1; i <= n; i++) { cout << ans[i] << ' '; } cout << '\n'; continue; FAIL: cout << "-1\n"; }}
]]>观察题面上给第一个样例提供的图:
可以发现,如果我们要让某种颜色形成一个塔,除非多个相同颜色在 数组中挨在一起,可以直接向上排布。就一定需要在排布该颜色后,向两侧放一些其他颜色,然后又往相反方向放置,最后使得两个颜色相同的块在一条直线上,大概是下面这样:
⬆->->->A⬆<-<-<-A<-<-<-⬆ A->->->⬆ 1 2 ... z
其中, 表示一个颜色的塔,而箭头表示放置颜色块的路径。
观察发现,在放置两个 之间,需要放置偶数个其他颜色块,下面是解释:
假设第一个 的位置是 ,并且我们往右侧放置的其他颜色块的数量是(也可以是左侧) 。
那么为了把第二个 搞到 上,就需要在 和 这些位置上放置颜色块,共计 个块,因此是偶数(直接网上堆的话是 个,也是偶数)。
这就意味着,假设有两个相同的颜色块 和 ,它们在 数组中的位置分别是 和 。只有 为奇数时,才可能把 叠到 上面,或是 叠到 上。
并且,只有 和 的奇偶性不同, 才可能为奇数。
然后就可以使用 dp 的方法来解决这个题目,我们对每种颜色都重复一遍相同的 dp 过程(其实更像是递推)。设 为 数组中,使用 个该颜色的块最高能垒成多高的塔。
那么 就可以从 中转移而来( ),并且如前面所说 和 的奇偶性应该不同。
同时,我们需要找的是,最近的奇偶性不同的块,要不然可能造成浪费,或者是在前面已经放过的位置又放了一个块。
//author: tzyt#include <bits/stdc++.h>using namespace std;#define ll long longint main() { int t; cin >> t; while (t--) { int n; cin >> n; int c[n + 1]; vector<int> cpos[n + 1], ans(n + 1); set<int> unqc; // 储存所有不同的颜色 for (int i = 1; i <= n; i++) { cin >> c[i]; cpos[c[i]].push_back(i); unqc.insert(c[i]); } int dp[n + 1]; for (int cur : unqc) { fill(dp, dp + cpos[cur].size(), 1); // 不管怎样,只要有块,总能垒成高度为 1 的塔 int mx = 1; // dp[0 ~ cpos[cur].size()] 中最大的 int lstod = -1, lstev = -1; // 最近的奇数位置和偶数位置,-1 为初始值 cpos[cur][0] & 1 ? lstod = 0 : lstev = 0; // 判断第一个的奇偶性 for (int i = 1; i < cpos[cur].size(); i++) { int lst = cpos[cur][i] & 1 ? lstev : lstod; if (lst != -1) dp[i] = dp[lst] + 1; // lst 为第一个奇偶性不同的位置 mx = max(dp[i], mx); cpos[cur][i] & 1 ? lstod = i : lstev = i; // 更新最近的奇数位置和偶数位置 } ans[cur] = mx; } for (int i = 1; i <= n; i++) { cout << ans[i] << ' '; } cout << '\n'; }}
能发现,不管怎么样,城市中酷的房子最多有 个。
如果是奇数个的话,只有一种排布方法能达到这么多个酷的房子。也就是第一个样例展示的。
从第二个房子开始,把每个偶数位置的房子都搞成酷的,也就是酷和不酷的房子隔着出现。
计算把一个普通房子变成酷的房子的代价可以用如下方法:
inline ll calc_cost(int i, int* h) { if (h[i] <= h[i - 1] || h[i] <= h[i + 1]) return max(h[i - 1], h[i + 1]) - h[i] + 1; else return 0;}
也就是把当前的房子搞的比相邻的最高的房子还要高一格。
但是偶数个房子的情况就比较复杂了。这种情况下 一定等于 。
那么就会有 个不酷的房子,也就一定有两个连在一起出现的不酷的房子,而这两个连续的不酷的房子可以出现在任何位置,我们需要考虑所有的情况。
比如 那么有如下几种排布方式。
但是如果从头到尾的把所有情况都计算一遍,时间就不够了。
所以我们可以只计算从一种情况到另一种情况之间代价的变化量。
比如:
这个过程中,第六个房子从酷变为不酷,第七个房子从不酷变为酷。
假设我们当前正在把第 个房子从酷变为不酷,第 个房子从不酷变为酷。我们只需要调用前面的 calc_cost
减去 的价格再加上 的价格就行了。
//author: tzyt#include <bits/stdc++.h>using namespace std;#define ll long longinline ll calc_cost(int i, int* h) { if (h[i] <= h[i - 1] || h[i] <= h[i + 1]) return max(h[i - 1], h[i + 1]) - h[i] + 1; else return 0;}int main() { int t; cin >> t; while (t--) { int n; cin >> n; int h[n + 1]; for (int i = 1; i <= n; i++) { cin >> h[i]; } ll ans = 0, tmp = 0; // 奇数情况的解法 for (int i = 2; i < n; i += 2) { ans += calc_cost(i, h); } if (n & 1) { cout << ans << '\n'; continue; } tmp = ans; for (int i = n - 2; i >= 2; i -= 2) { // 枚举连续 0 的位置 tmp -= calc_cost(i, h); tmp += calc_cost(i + 1, h); ans = min(ans, tmp); } cout << ans << '\n'; }}
我们尝试设最小的 为 。那么 ,因为 最小为 ,那 就是 了。
在这个的基础上,我们再贪心的尝试让每个 都尽可能的接近 ,这样就可以尽可能的让最大的 更小。
这样我们就可以算出 ,因为 ,所以 。当然, 不能大于 ,并且如果 ,我们就让 。
然后我们去枚举每个可能的 ,并且计算该情况下的最大的 ,就能得到答案了。思路好像挺简洁,但是真的挺难想的,
// author: tzyt#include <bits/stdc++.h>using namespace std;#define IINF 0x3f3f3f3fint main() { int t; cin >> t; while (t--) { int n, k; cin >> n >> k; int a[n + 1]; for (int i = 1; i <= n; i++) { cin >> a[i]; } int ans = IINF; int mxv = 0; for (int mnv = 0; mnv <= a[1]; mnv++) { // 枚举 mn for (int i = 1; i <= n; i++) { int p = min(k, (mnv ? (a[i] / mnv) : k)); // mnv ? (a[i] / mnv) : k 是因为 mnv 为 0 的情况 mxv = max(mxv, a[i] / p); } ans = min(ans, mxv - mnv); } cout << ans << '\n'; }}
]]>这个数据范围显然不能真的去复制字符串,所以要找到一些别的方法。
可以发现,每次在字符串末尾新加上的那一段,都可以通过一个偏移量在前面找到一模一样的。
比如我们可以观察样例中第一个数据的最后一次插入:
把这个 中的每个字母都向前移动 个位置,就是另一个 ,如下:
所以我们可以维护一个三元组 ,表示 这段区间内的字符,和 这段区间内的字符完全一样。
这样每次在查询位置 的时候,我们就可以一直减去相应的 ,直到 被包含在初始字符串的范围内。
这里再说明一下
// ttzytt#include <bits/stdc++.h>using namespace std;#define ll long longstruct Seg { ll l, r, diff; // [l, r] 这个范围内的每一个点,都和上一段有 diff 的偏移};int main() { int t; cin >> t; while (t--) { int n, c, q; cin >> n >> c >> q; string str; cin >> str; vector<Seg> a(c + 1); a[0].l = 0, a[0].r = n - 1; for (int i = 1; i <= c; i++) { ll l, r; cin >> l >> r; l--, r--; a[i].l = a[i - 1].r + 1; // 左端点和上一段的右端点一样 a[i].r = a[i].l + (r - l); // 右端点就加上长度 -1 a[i].diff = a[i - 1].r - l + 1; /* | 第一段 | 上一段 | 新插入的段 | |--------| \ / ⬆ \ / 正在复制的段 \ l a[i - 1].r 所以偏移量为 a[i - 1].r - l + 1 */ } while (q--) { ll x; cin >> x; x--; for (int i = c; i >= 1; i--) { if (x < a[i].l) // 如果 x 的位置不属于当前段 continue; else x -= a[i].diff; // 那就减去偏移量 } cout << str[x] << '\n'; } }}
先来模拟一下第四个样例:
000101010011
注:标红的位置表示发生了变化。
可以发现,这个过程中我们只能将一个由 组成的段,比如 ,或者 (当然反过来看的话可以说是 组成的段)延长或者是缩短一点,而不能凭空创造出一个新的“ 段”。这是因为,只有从 变成了 或者从 变成了 , 和 才会是不一样的,我们才能改变 。
所以我们可以知道,如果 串和 串的段数量不一样,那么一定是不可能从 转换成 的。
可以发现,每一次操作中,我们能将一个 “ 段” 的开头或者结尾移动一个位置,那么用这个方式就可以计算出从 到 的变换需要多少步了。
也就是,对于 和 中的每一个段,我们计算出段开始和结尾的位置,然后再算出 中和 中的段端点的差,把这些差累加起来就是答案了。
那如何判断段的开始和结尾呢?无非就是 变成 和 变成 。所以我们开两个数组 和 ,输入 和 之后遍历一遍这两个字符串。只要 ,就把 放入 中(对于 和 相同)。这样 和 中就存了两个串的所有段端点。
#include <bits/stdc++.h>using namespace std;#define ll long longint main() { int t; cin >> t; while (t--) { string s, t; int n; cin >> n >> s >> t; vector<int> sdiff, tdiff; // 题解中的 a 和 b ll ans = 0; if (s.front() != t.front() || s.back() != t.back()) { // 因为我们不能改变 s[0] 和 s[n - 1],所以 s 和 t 的第一和最后一位必须一样 goto FAIL; } for (int i = 0; i < s.size() - 1; i++) { if (s[i] != s[i + 1]) sdiff.push_back(i); // 如果相邻两位变化了,说明是端点 if (t[i] != t[i + 1]) tdiff.push_back(i); } if (sdiff.size() != tdiff.size()) { goto FAIL; } else { for (int i = 0; i < sdiff.size(); i++) { // 计算端点差的和 ans += abs(sdiff[i] - tdiff[i]); } } SUCC: cout << ans << '\n'; continue; FAIL: cout << "-1\n"; }}
]]>upd@2022/9/14:最近把实验的代码放到 github 上了,如果需要参考可以查看这里:
https://github.com/ttzytt/xv6-riscv
里面不同的分支就是不同的实验。
注:和页表相关的基础知识在这篇文章中有说,可以参考。
为了加速系统调用,很多操作系统都会在用户空间内开辟一些只读的虚拟内存,内核会把一些数据分享在这里。这样就可以减少来回在用户态和内核态中切换的操作。我们需要用这个方法给getpid()
加速。
这个 lab 的大概思路是,在创建进程时,就直接把进程的 pid 放入共享空间中,然后用户查询 pid 时,就不必通过 ecall 跳转到内核了,省去了保存现场等开销。
首先我们需要在用户态的虚拟内存中多添加一页,专门用于储存和内核共享的数据。
创建一个新的虚拟内存到物理内存的映射需要用到 mappages()
函数,这个函数在 kernel/vm.c
中实现:
// Create PTEs for virtual addresses starting at va that refer to// physical addresses starting at pa. va and size might not// be page-aligned. Returns 0 on success, -1 if walk() couldn't// allocate a needed page-table page.intmappages(pagetable_t pagetable, uint64 va, uint64 size, uint64 pa, int perm){ // pagetable 是根页表,va 和 pa 分别是虚拟地址起始位置和物理地址起始位置 // perm 是标志位 uint64 a, last; pte_t *pte; if(size == 0) panic("mappages: size"); a = PGROUNDDOWN(va); last = PGROUNDDOWN(va + size - 1); // PGROUNDOWN 实际上是把一个数字的后 12 位全部都设成了 0 // 所以 a 表示新映射的起始地址,last 为最后一个要映射的页帧 for(;;){ if((pte = walk(pagetable, a, 1)) == 0) return -1; if(*pte & PTE_V) panic("mappages: remap"); *pte = PA2PTE(pa) | perm | PTE_V; if(a == last) break; a += PGSIZE; pa += PGSIZE; // 每次新分配一页 } return 0;}
所以我们可以在 kernel/proc.c
这个文件中的 proc_pagetable()
中调用 mappages()
创建新的一页映射。
这个 proc_pagetable()
会在创建新进程时被调用,符合我们的要求。
我们先观察 proc_pagetable()
是如何使用 mappages()
来创建 trampoline 和 trapframe 页的:
if(mappages(pagetable, TRAMPOLINE, PGSIZE, (uint64)trampoline, PTE_R | PTE_X) < 0){ uvmfree(pagetable, 0); return 0;}// map the trapframe just below TRAMPOLINE, for trampoline.S.if(mappages(pagetable, TRAPFRAME, PGSIZE, (uint64)(p->trapframe), PTE_R | PTE_W) < 0){ // 没映射成功的话会把之前的 unmap,而不是这个本身 uvmunmap(pagetable, TRAMPOLINE, 1, 0); uvmfree(pagetable, 0); return 0;}
可以发现,如果当前这一页没有映射成功,我们需要把之前成功映射的 uvmunmap()
了。并且把映射失败的这一页 uvmfree()
。
这是因为,如果想要使用 uvmunmap()
,必须要确保我们 unmap 的页是存在的,如果不存在就会崩溃(毕竟这都没映射你咋取消呢)。
所以,因为我们没有成功映射当前页,就只能 uvmfree()
去释放内存,而不是取消映射。
uvmfree()
的源码如下:
// Free user memory pages,// then free page-table pages.voiduvmfree(pagetable_t pagetable, uint64 sz){ if(sz > 0) uvmunmap(pagetable, 0, PGROUNDUP(sz)/PGSIZE, 1); freewalk(pagetable);}
可以发现如果 sz
为 0,就只会调用 freewalk
去释放一整个页表的内存。包括之前所有映射过的页。
还有一个小细节是,调用 freewalk()
时,我们必须确保映射是已经取消了的,所以我们会先调用 uvmunmap()
。具体可以看 freewalk()
的实现:
// Recursively free page-table pages.// All leaf mappings must already have been removed.voidfreewalk(pagetable_t pagetable){ // there are 2^9 = 512 PTEs in a page table. for(int i = 0; i < 512; i++){ pte_t pte = pagetable[i]; if((pte & PTE_V) && (pte & (PTE_R|PTE_W|PTE_X)) == 0){ // this PTE points to a lower-level page table. uint64 child = PTE2PA(pte); freewalk((pagetable_t)child); pagetable[i] = 0; } else if(pte & PTE_V){ // 重点:PTE_V 为 1,说明映射没取消,会 panic panic("freewalk: leaf"); } } kfree((void*)pagetable);}
根据这些信息,我们就能写出给 USYSCALL (也就是共享页) 的映射,这个 USYSCALL 的位置在 trampoline 和 trapframe 的下面:
if(mappages(pagetable, USYSCALL, PGSIZE, (uint64)(p->usyscall), PTE_R | PTE_U) < 0){ // 映射完成后,我们访问 USYSCALL 开始的页,就会访问到 p->usyscall uvmunmap(pagetable, TRAMPOLINE, 1, 0); uvmunmap(pagetable, TRAPFRAME, 1, 0); uvmfree(pagetable, 0); return 0; }
需要注意的是,因为这一页是和用户共享的,我们需要把 PTE_R
和 PTE_U
的标志位设置成 1,分别代表允许读,和允许用户访问。
和前面调用 mappages()
时相同,如果不成功,需要先把前面映射成功的取消,随后清空该页表的所有数据。
写完这些代码后,我们在用户态访问 USYSCALL 这个页中的地址,就能访问到内核中储存的 p->usyscall
了。
和 lab2 中给 proc
结构体加 trace_mask
属性一样,我们创建进程时多创建了一页映射,就需要在销毁进程时也取消这个映射。
因此在 kernel/proc.c
中,还需要更改一下 proc_freepagetable()
函数:
// Free a process's page table, and free the// physical memory it refers to.voidproc_freepagetable(pagetable_t pagetable, uint64 sz){ uvmunmap(pagetable, USYSCALL, 1, 0); // 新添加的 uvmunmap(pagetable, TRAMPOLINE, 1, 0); uvmunmap(pagetable, TRAPFRAME, 1, 0); uvmfree(pagetable, sz);}
现在还有个问题,我们已经成功创建了从虚拟内存到物理的映射,但是并没有在创建进程的时候申请这个物理内存。如果不去申请这个物理内存,我们就会尝试把一个虚拟内存映射到空指针上,自然会出问题。
所以还需要改一下 allocproc()
这个函数。
观察 allocproc()
中给 trapframe 分配物理内存的过程:
if((p->trapframe = (struct trapframe *)kalloc()) == 0){ freeproc(p); release(&p->lock); return 0;}
还是比较好理解的,那我们直接抄一波参考一下不就好了。
// 分配 usyscall 页if((p->usyscall = (struct usyscall *)kalloc()) == 0){ freeproc(p->usyscall); release(&p->lock); return 0;}p->usyscall->pid = p->pid;// 创建完了顺便把 pid 直接放进去
现在内核态这边的东西已经搞好了,用户态的函数就不需要我们自己写了,根据实验提示,已经在 user\ulib.c
中实现了:
intugetpid(void){ struct usyscall *u = (struct usyscall *)USYSCALL; return u->pid;}
和前面说的一样,我们直接访问 USYSCALL 这个虚拟地址,就能得到 p->usyscall
这个物理地址(其实也是虚拟的,但是内核中大部分页虚拟地址直接映射到物理地址)中的东西。
这样我们就完成了这个任务。
实现一个vmprint()
函数,该函数接收一个 pagetable_t 的参数,然后打印该页表,具体格式参考图片中的样式。在创建init
进程时,调用这个函数打印页表。
我们先别管在创建 init
进程时调用这个函数,先在 kernel/vm.c
中把这个函数写出来。
因为 xv6 的页表是多级的,所以是一个树的结构(不懂的话可以看我的这篇文章),那么本质上我们就是需要写一个通过 dfs 打印树的函数。
如下:
void vmprint(pagetable_t pagetable, uint dep){ if(dep == 0) printf("page table %p\n", pagetable); for(int i = 0; i < 512; i++){ pte_t pte = pagetable[i]; if(pte & PTE_V){ for(int j = 0; j < dep; j++) printf(".. "); uint64 child = PTE2PA(pte); printf("..%d: pte %p pa %p\n", i, pte, child); if(dep < 2) // 如果层数等于 2 就不需要继续递归了,因为这是叶子节点 vmprint((pagetable_t) child, dep + 1); } } }
这个函数接收两个参数,要打印的页表(可以理解为要打印的树的根节点)和当前的深度,多出来一个深度是因为根据图片中的格式,我们需要根据当前的深度打印出不同数量的点。而且我们需要通过深度知道是否到达了叶子节点。
对于每个 pagetable
,最多有 512 个节点,所以我们就依次遍历它们。如果发现这个页表是已分配的,也就是符合 pte & PTE_V == 1
的,我们就继续递归。
在打印的时候,我们先需要打印出 dep + 1
个 ..
,然后再打印出 pte 和 pa。
这里指的 pte 指的是直接读取页表项的结果,而 pa 是去掉页表项中的标志位后得到的物理地址,我们通过这个物理地址可以找到下一层的页表项或是页帧。
注意可以这么 pte_t pte = pagetable[i];
写是因为,pa 指向的实际上是这个子页表的第一个元素,而 pagetable[i]
和 *(pagetable + i)
是等价的,也就是去访问第 i 个页表。
这样这个 lab 中的主要部分就搞好了,下面我们可以去 kernel/exec.c
中的结尾插入以下代码:
if(p->pid == 1) vmprint(p->pagetable, 0);
因为 init
是系统创建的第一个进程,所以 init
的 pid 是 1,那么在创建 init 时,我们就会打印这个页表。
然后我们就完成了。
实现一个pgaccess()
函数,这个函数的申明为:int pgaccess(void *base, int len, void *mask);
。这个函数的主要作用就是检测从上次调用这个函数开始,页表是否被访问过。其中base
参数是要检测的第一个页表,len
从这个页表开始,要检测多少个页表,而我们需要把每个页表的访问情况写到mask
上。这个mask
的作用和 lab2 中的 trace_mask 相同,如果当前页表被访问,那么mask
中对应的位应该是 1。
因为这个 lab 的主要目的和 lab2 不一样,不是让我们熟悉系统调用的过程,所以这个系统调用已经注册好了,我们就不需要去注册一遍了。
接下来我们直接尝试在 kernel/sysproc.c
中实现这个函数。
首先我们的第一步一定是使用 arg
系列函数从用户态获取到传进来的参数(原因在 lab2 那篇文章中有讲),因此有如下的代码:
pagetable_t u_pt = myproc()->pagetable;uint64 fir_addr, mask_addr;uint ck_siz; uint mask = 0;try(argaddr(0, &fir_addr), return -1);try(argint(1, &ck_siz), return -1);try(argaddr(2, &mask_addr), return -1);
其中,fir_addr
,ck_siz
和 mask_addr
分别对应函数申明中的三个参数。
接下来我们要考虑如何确认某个页表是否被访问过。这个就需要用到 PTE 中的标志位(xv6 学习笔记那篇文章中有解释),具体如下[1]:
Each leaf PTE contains an accessed (A) and dirty (D) bit. The A bit indicates the virtual page has been read, written, or fetched from since the last time the A bit was cleared. The D bit indicates the virtual page has been written since the last time the D bit was cleared.
翻译:每个叶子 PTE 有一个 accessed (a) 和 dirty (D) 标志位,标志位 A 表示从上次标志位被重置,这个虚拟地址被读写或是被使用了。标志位 D 表示自上次被重置,这个虚拟地址被写过了。
注意以上的标志位都是 risc-v 处理器去设置的,并不需要任何软件上的操作,所以我们在实现函数的时候只需要去读取标志位的信息并重置就好了。
因为我们需要检测的是这个地址是否被访问过,而不是单纯的读取,我们需要使用的是标志位 A。而 PTE_A
在 xv6 中还没被定义过,所以我们在 kernel/riscv.h
中定义一下:
#define PTE_A (1L << 6) // 左移六位是看上图决定的
然后我们就可以在 sys_pgaccess
中这么写:
if(ck_siz > 32){ return -1;}pte_t* fir_pte = walk(u_pt, fir_addr, 0);for(int i = 0; i < ck_siz; i++){ if((fir_pte[i] & PTE_A) && (fir_pte[i] & PTE_V)){ mask |= (1 << i); fir_pte[i] ^= PTE_A; // 复位 }}
ck_siz
大于 32 的话我们就没有那么多位去在 mask 中储存,所以要返回。
下面的 walk()
函数就比较重要了,这里不介绍具体的细节,其作用为:对于一个给定的页表和虚拟地址,walk()
函数会返回对应这个虚拟地址的叶子 PTE。
所以我们通过这个函数得到了第一个需要检测的页表的 PTE 的地址,fir_pte
。
那么接下来只需要检测这个 PTE 后面 ck_siz
个 PTE 的 PTE_A 标志位就行了。
也就是:
for(int i = 0; i < ck_siz; i++){ if((fir_pte[i] & PTE_A) && (fir_pte[i] & PTE_V)){ mask |= (1 << i); fir_pte[i] ^= PTE_A; // 复位 }}
接下来我们需要把计算出来的 mask
传回用户态。需要用到 copyout()
函数,这个函数在 lab2 那篇文章解释过。
大概的用处就是,给定一个用户页表和虚拟地址,就可以把一些数据从内核态中拷到用户态中。
因此我们可以这么写:
try(copyout(u_pt, (uint* )mask_addr, &mask, sizeof(uint)), return -1);
也就是把 mask
的数据拷贝到基于用户态页表的 mask_addr
这个地址上。
然后这个 lab 就做完了。
页表和虚拟地址的这些概念,说实话还是比系统调用难的。要做出这个 lab,还是得对 risc-v 中的页表实现非常清楚。我花了很久时间才弄明白。也只有做了这个 lab 才能理解页表和虚拟地址的设计的巧妙。
祝在做这个 lab 的人尽快 AC:
]]>观察题目中的不等式 。可以发现,对于数组中的任意元素,只要不符合 ,那就绝对不会和任何的元素组成一个合法的数对。所以我们可以直接跳过不符合 的元素。
我们可以把这个不等式拆成三个部分 和 。
那对于所有符合 的元素,第一个和最后一个不等式已经满足了,只要找到满足 的元素,就可以构成一个合法数对了。
我们设 去掉不满足 的数组为 (听起来有点奇怪,但是 中的每个元素的下标是跟 一样的)。
比如我们说数组中的一个元素有值和下标两个属性,并且用这样的方式标记:,那么如果 数组是:
那去掉所有 的,就能得到 :
那只要我们对于每个 找到所有符合的 ,就可以解决本题。
其中不难发现 是单调递增的,所以可以使用二分来找最大的,小于 的 。那么 中所有 小于 的元素(以及 自身)都可以跟 构成一个合法数对。
除了二分法,我们还可以用树状数组来找到 中所有 小于一个特定值的元素的数量。
具体来说,我们可以用树状数组维护一个前缀和数组,然后遍历 中的元素,每次都做 upd(id)
的操作。也就是使查询所有大于等于 的数时,查到的值都增加 。
这样,在树状数组中查询某个 的时候,就会返回比 小的所有 了。
当然,用差分的方法,也可以得到和树状数组相同的前缀和数组。而且本题不需要我们在得到这个前缀和数组后做别的更新,所以差分可以用 的复杂度解决本题。
复杂度:
#include <bits/stdc++.h>using namespace std;#define ll long long// author: ttzytt (ttzytt.com)// ref: https://codeforces.com/blog/entry/104786int main() { int t; cin >> t; while (t--) { int n; cin>>n; int a[n + 1]; for (int i = 1; i <= n; i++) { cin >> a[i]; } ll ans = 0; vector<int> valid; // 前文所说的 b 数组,但是只储存了下标 // 因为我们只需要找最大的,小于 b_i 的下标 j for (int i = 1; i <= n; i++) { if (a[i] >= i) continue; // 不符合就直接跳过 // 这里可能算是一个优化吧,可以发现 valid 中的下标 i 都是小于 j 的 // 我们并没有把全部的 b 的下标都塞进 valid 中,因为 a[i] < i < a[j] < j // 所以只有 i < j 才可能符合。 ans += (ll)(lower_bound(valid.begin(), valid.end(), a[i]) - valid.begin()); // lower_bound 会找到 valid 中第一个大于等于 a[i] 的元素。 // 那么这个元素**之前**的全部是可用的。一个区间的长度为 r - l + 1 // 因为只有这个元素之前的才是可用的,所以这个 1 我们就不加了 valid.push_back(i); } cout << ans << '\n'; }}
我们能发现,对于所有的盒子,交错的使用好钥匙和坏钥匙总是更加不合算的。
并且,连续的,在前面的盒子使用好钥匙会更合算。(或者说使用好钥匙作为前缀)。
假设我们在一个好钥匙之前使用了一个坏钥匙,那么我们获得的收益是:
但如果先使用好钥匙,后使用坏钥匙,获得的收益是:
很明显,先使用好钥匙更合算。
更直观一点的解释是,不管先使用哪种钥匙,都会减去 的收益,但是如果先使用坏钥匙,我们会把两个盒子的收益减半,但如果先使用好钥匙,就只会把一个盒子的收益减半。
所以我们只会在最后的部分使用坏钥匙,在某些 比较大的情况下,可能相比减去 ,把 减半更合算。
所以只需要枚举一个使用好钥匙和坏钥匙的分割点,在这个点前面全部使用好钥匙,后面全部使用坏钥匙。
我们设这个分割点为 。
那么在分割点 后面使用坏钥匙,每个盒子的收益就会变成:
能发现这个 会增长的很快,在某个点之后, 就会变成 。那么在这个点之后我们就没必要再计算了。
因为最大的 为 ,所以 之后就没必要计算了。(或者你可以想,一直右移一个数,那么过了某个点整个数的二进制形式就没有 了)。
#include <bits/stdc++.h>using namespace std;#define ll long long// author: ttzytt (ttzytt.com)// ref: https://codeforces.com/blog/entry/104786int main() { int t; cin >> t; while (t--) { int n, k; cin >> n >> k; vector<int> a(n); for (int i = 0; i < n; i++) { cin >> a[i]; } ll psum = 0, ans = 0; // psum 是使用好钥匙获得的收益 for (int i = -1; i < n; i++) { if (i != -1) psum += (ll)(a[i] - k); ll cur = psum; // 枚举 i 这个分割点 for (int j = i + 1; j < min(n, i + 32); j++) { // 过了 i + 32 就没必要继续计算了 int bkval = a[j]; bkval >>= (j - i); // i + 1 要 / 2, i + 2 要 / 4 ... cur += bkval; } ans = max(ans, cur); } cout << ans << endl; }}
]]>upd@2022/7/14:添加了 sysinfo 这个 lab,至此为止,lab2 已经全部写完。
upd@2022/9/14:最近把实验的代码放到 github 上了,如果需要参考可以查看这里:
https://github.com/ttzytt/xv6-riscv
里面不同的分支就是不同的实验。
跟名字一样,这个 lab 需要我们往内核里增加两个系统调用。而要增加这些系统调用,我们首先需要了解系统调用的过程。
首先,用户态的系统调用函数被声明(没有实现)在 user/user.h
中。
// system callsint fork(void);int exit(int) __attribute__((noreturn));int wait(int*);int pipe(int*);int write(int, const void*, int);int read(int, void*, int);int close(int);int kill(int);int exec(char*, char**);int open(const char*, int);int mknod(const char*, short, short);int unlink(const char*);int fstat(int fd, struct stat*);int link(const char*, const char*);int mkdir(const char*);int chdir(const char*);int dup(int);int getpid(void);char* sbrk(int);int sleep(int);int uptime(void);
这些函数其实都是由汇编实现的,在 user/usys.S
这个文件中(其实语言不是 nasm,是 risc-v 的汇编,但是好像只有我输 nasm 才能比较好的高亮):
fork:#include "kernel/syscall.h".global fork li a7, SYS_fork ecall ret.global exitexit: li a7, SYS_exit ecall ret.global waitwait: li a7, SYS_wait ecall ret.global pipepipe: li a7, SYS_pipe ecall ret.global readread: li a7, SYS_read ecall ret……
注意到li a7, SYS_fork
这个命令。其中 li
这个命令(load immediate)的形式是这样的:
li, rd, imm
它把一个立即数 imm 加载到rd寄存器中。[1]
那上面的 li a7, SYS_fork
中的 SYS_fork
就是一个立即数。它被定义在 kernel/syscall.h
中,这也是为什么这个汇编的开头要 #include
。
// System call numbers#define SYS_fork 1#define SYS_exit 2#define SYS_wait 3#define SYS_pipe 4#define SYS_read 5#define SYS_kill 6#define SYS_exec 7#define SYS_fstat 8#define SYS_chdir 9#define SYS_dup 10#define SYS_getpid 11#define SYS_sbrk 12#define SYS_sleep 13#define SYS_uptime 14#define SYS_open 15#define SYS_write 16#define SYS_mknod 17#define SYS_unlink 18#define SYS_link 19#define SYS_mkdir 20#define SYS_close 21
可以看到,这个文件定义了不同系统调用的编号,我们暂且叫他调用号吧。所以 li a7, SYS_fork
的意思就是,把 fork
的调用号赋值到 a7 寄存器内,这样进入内核之后,我们就知道之前调用的是哪个系统调用。
汇编的下一行是 ecall
。这是一个 risc-v 架构里比较神奇的指令,我也不是很了解,不过我从网上[2]查到了一些资料:
ECALL instruction does an atomic jump to a controlled location (i.e. RISC-V 0x8000 0180)
- Switches the sp to the kernel stack
- Saves the old (user) SP value
- Saves the old (user) PC value (= return address)
- Saves the old privilege mode
- Sets the new privilege mode to 1
- Sets the new PC to the kernel syscall handler
大概是说,ecall 这个指令会让我们跳转到一个特定的地址,而这个地址就是存放内核服务的地方(内核栈)。同时,和普通的函数调用一样,ecall 会保存现场,这样结束系统调用的时候我们就可以顺利的恢复到当前状态。比如,保存栈指针(sp),和程序计数器(pc)的值。
ecall 把我们跳到内核之后,会先进入一个内核的处理函数,syscall()
。
static uint64 (*syscalls[])(void) = { [SYS_fork] sys_fork, [SYS_exit] sys_exit, [SYS_wait] sys_wait, [SYS_pipe] sys_pipe, [SYS_read] sys_read, [SYS_kill] sys_kill, [SYS_exec] sys_exec, [SYS_fstat] sys_fstat, [SYS_chdir] sys_chdir, [SYS_dup] sys_dup, [SYS_getpid] sys_getpid, [SYS_sbrk] sys_sbrk, [SYS_sleep] sys_sleep, [SYS_uptime] sys_uptime, [SYS_open] sys_open, [SYS_write] sys_write, [SYS_mknod] sys_mknod, [SYS_unlink] sys_unlink, [SYS_link] sys_link, [SYS_mkdir] sys_mkdir, [SYS_close] sys_close, [SYS_trace] sys_trace, [SYS_sysinfo] sys_sysinfo,}; // 指向函数的指针的数组voidsyscall(void){ int num; struct proc *p = myproc(); num = p->trapframe->a7; if (num > 0 && num < NELEM(syscalls) && syscalls[num]) { p->trapframe->a0 = syscalls[num](); } else { printf("%d %s: unknown sys call %d\n", p->pid, p->name, num); p->trapframe->a0 = -1; }}
这个 syscall
会根据 a7 寄存器中存的调用号,去调用相应的服务。那如何去通过调用号来得到相应的函数呢?答案就是一个指向函数指针的数组,
这里的 [SYS_fork] sys_fork
是 C 语言数组的一个语法,表示以方括号内的值作为元素下标。比如 int arr[] = {[3] 2333, [6] 6666}
代表 arr 的下标 3 的元素为 2333,下标 6 的元素为 6666,其他元素填充 0 的数组。(该语法在 C++ 中已不可用)[3]
这些系统服务的具体实现都不在这个文件中,在 kernel/sysproc.c
中。比如 get_pid()
的实现:
uint64sys_getpid(void){ return myproc()->pid;}
调用完成之后,系统调用的返回值会在返回用户态时,被赋到 a0 寄存器上,也就是 p->trapframe->a0 = syscalls[num]();
这句话的用处。
实现一个追踪特定进程系统调用的系统调用,叫做 trace。比如有个进程调用了这个 trace,那么 trace 就会以特定格式输出这个进程调用过的系统调用。其中,有一个 mask 作为参数,指定有哪些调用需要被追踪。
具体来说,这个 mask 的每一位都代表一个系统调用,如果这个 mask 的第 位为 ,我们就需要去追踪编号为 的系统调用。
在实现之前,我们需要先顺着系统调用的过程,在各种文件中“注册”一遍这个新的系统调用。
首先是在用户态的 user/user.h
中申明一下,使得用户能通过调用这个接口去调用汇编代码,从而进入内核:
……int getpid(void);char* sbrk(int);int sleep(int);int uptime(void);int trace(int)//新加入的调用,有一个 int 的参数是 mask
如前文所讲,我们需要使用汇编去实现这个跳转函数。不过,这个汇编是 perl 的脚本自动生成的,所以需要去更改这个脚本(user/usys.pl
)。
print "# generated by usys.pl - do not edit\n";print "#include \"kernel/syscall.h\"\n";sub entry { my $name = shift; print ".global $name\n"; print "${name}:\n"; print " li a7, SYS_${name}\n"; print " ecall\n"; print " ret\n";}entry("fork");entry("exit");……entry("sleep");entry("uptime");entry("trace"); # 加在这里!
之后我们 make qemu
的时候,在脚本中新加的这个 entry
就会在 user/usys.S
中输出:
.global tracetrace: li a7, SYS_trace ecall ret
到此为止已经完成了在用户态的注册。接下来需要在内核中注册。
现在我们需要在 kernel/syscall.h
给这个新的调用注册一个调用号,这样才能通过调用号找到函数。
// System call numbers#define SYS_fork 1#define SYS_exit 2……#define SYS_mkdir 20#define SYS_close 21#define SYS_trace 22 // 这里 !
然后,就像之前介绍的,内核中的中转函数 syscall()
需要通过一个函数指针数组来查找需要调用的函数,所以我们需要去在这个数组中新加一个元素,并且申明一下这个 trace 函数。
kernel/syscall.c
:
extern uint64 sys_chdir(void);extern uint64 sys_close(void);……extern uint64 sys_write(void);extern uint64 sys_uptime(void);extern uint64 sys_trace(void); // 加在这里!static uint64 (*syscalls[])(void) = { [SYS_fork] sys_fork, [SYS_exit] sys_exit, [SYS_wait] sys_wait, [SYS_pipe] sys_pipe, [SYS_read] sys_read, [SYS_kill] sys_kill, [SYS_exec] sys_exec, [SYS_fstat] sys_fstat, [SYS_chdir] sys_chdir, [SYS_dup] sys_dup, [SYS_getpid] sys_getpid, [SYS_sbrk] sys_sbrk, [SYS_sleep] sys_sleep, [SYS_uptime] sys_uptime, [SYS_open] sys_open, [SYS_write] sys_write, [SYS_mknod] sys_mknod, [SYS_unlink] sys_unlink, [SYS_link] sys_link, [SYS_mkdir] sys_mkdir, [SYS_close] sys_close, [SYS_trace] sys_trace, // 加在这里}; // 指向函数的指针的数组
如前文所讲,像 extern uint64 sys_trace(void);
这样的申明是在 kernel/syscall.c
中的,而实现在 kernel/sysproc.c
中,我们需要到这个文件中随便添加一个实现(具体的实现在下文讲)。
……uint64sys_uptime(void){ uint xticks; acquire(&tickslock); xticks = ticks; release(&tickslock); return xticks;}uint64 sys_trace(){ // 新加的 printf("hello from trace\n"); return 0;}
这个时候,我们重新 make qemu
一下,然后在 shell 中随便输入一个 trace 命令,比如 trace 32 grep hello README
。就可以看到 了,说明我们成功注册上了。
想要了解使用了哪些系统调用,其实可以直接在系统调用的中转函数中做一些手脚,因为用户程序想要使用任何的系统服务都需要经过这个函数。那么就可以直接在这个函数中输出 trace 的信息了。
但是可能同时有很多个进程都在使用系统调用,而直接在 syscall()
函数中输出的话,就不只是输出一个进程使用的系统调用了。
而且直接输出的话也不符合 lab 中对 mask 的要求(也就是指定输出哪些系统调用)。
所以我们必须要有一种方法来确定当前的进程是否希望 trace,如果希望,那是希望 trace 哪些系统调用(也就是 mask)。要达到这个要求我们可以直接去给描述进程的结构体加一个 mask 属性。而定义进程的结构体就是 struct proc
,在 kernel/proc.h
这个文件中:
struct proc { struct spinlock lock; // p->lock must be held when using these: enum procstate state; // Process state struct proc *parent; // Parent process void *chan; // If non-zero, sleeping on chan int killed; // If non-zero, have been killed int xstate; // Exit status to be returned to parent's wait int pid; // Process ID // these are private to the process, so p->lock need not be held. uint64 kstack; // Virtual address of kernel stack uint64 sz; // Size of process memory (bytes) pagetable_t pagetable; // User page table struct trapframe *trapframe; // data page for trampoline.S struct context context; // swtch() here to run process struct file *ofile[NOFILE]; // Open files struct inode *cwd; // Current directory char name[16]; // Process name (debugging) int trace_mask; // 加在这里!};
这样,在中转函数 syscall()
中,我们只需要检测当前进入内核的这个进程的 trace_mask
就行了。如果发现这个进程希望追踪现在它调用的这个系统调用,我们就可以直接输出了。这样一来,就不会随便碰到一个进程就输出信息了。
下面是修改过的 syscall()
函数,在 kernel/syscall.c
中。
const static *syscall_names[] = { "fork", "exit", "wait", "pipe", "read", "kill", "exec", "fstat", "chdir", "dup", "getpid", "sbrk", "sleep", "uptime", "open", "write", "mknod", "unlink", "link", "mkdir", "close", "trace", "sysinfo"};voidsyscall(void){ int num; struct proc *p = myproc(); // myproc() 会给出当前调用系统调用的进程 num = p->trapframe->a7; // 当前进程希望调用的系统调用 if (num > 0 && num < NELEM(syscalls) && syscalls[num]) { p->trapframe->a0 = syscalls[num](); // 通过 num 找到需要调用哪个函数 // 这个 a0 储存了系统调用的返回值 int trace_mask = p->trace_mask; // 检查这个进程的 trace mask if ((trace_mask >> num) & 1) { // 如果当前这个系统调用是进程希望追踪的,那就输出 // 3: syscall read -> 1023 是 lab 中要求的格式,所以我们也按照这个格式输出 // 这里的 3 是进程号,read 是调用的系统调用的名字,1023 是调用过后的返回值。 printf("%d: syscall %s -> %d\n", p->pid, syscall_names[num - 1], p->trapframe->a0); } } else { printf("%d %s: unknown sys call %d\n", p->pid, p->name, num); p->trapframe->a0 = -1; }}
不过,每个进程的 trace_mask
也不是凭空出现的,只有调用了 trace 这个系统调用,我们才会给进程增加一个 trace_mask
。
所以肯定不能像刚才那样在实现 sys_trace()
时,直接输出一个 。
下面就是修改后的 sys_trace
的实现。
uint64 sys_trace(){ int mask; if(argint(0, &mask) < 0){ //从用户态读取第 0 个 32 位的数据 return - 1; } struct proc *cur_proc = myproc(); // 进行系统调用的这个进程 cur_proc->trace_mask = mask; return 0;}
本质上很简单,我们在用户态调用 trace()
时,会传进去一个 mask
,而现在这个系统调用实际上就是把传进来的这个 mask 赋值到当前的 struct proc
上。这样之后经过中转函数时,就可以知道要追踪哪些系统调用了。
注意这里的 argint(0, &mask)
这句话,其用处是读取第一个 位的参数。
我们不适用 C 语言的形式传参,而是用这样方式,是因为内核与用户进程的页表不同,所以需要使用 argaddr()
、argint()
、argstr()
等系列函数[3]。
这些函数最后都会调用到一个叫做 argraw()
的函数,实现如下,其参数 n
代表现在希望读取的是第几个参数:
static uint64 argraw(int n) { struct proc *p = myproc(); switch (n) { case 0: return p->trapframe->a0; case 1: return p->trapframe->a1; case 2: return p->trapframe->a2; case 3: return p->trapframe->a3; case 4: return p->trapframe->a4; case 5: return p->trapframe->a5; } panic("argraw"); return -1;}
可以看到其读取了 trapframe
中的数据。其实这个 trapframe
就是用来给系统调用保存现场的,它记录了发生系统调用时的寄存器状态,以及当前进程内核栈的位置,内核的页表等数据,在完成系统调用后,我们可以根据这里储存的数据,来恢复到调用之前的状态(和函数调用很像,可以参考我的这篇文章)。
那为什么我们想要取第几个参数,就返回 trapframe
的 a 几呢?虽然我不是很清楚,但大概是因为 risc-v 的函数调用约定(我的这篇文章 有讲函数调用规则)。
gcc 对于 risc-v 使用的部分函数调用约定有下面几点[4]:
这样看来,似乎和 argraw()
的实现是比较符合的(我们把系统调用的返回值放在 a0 也挺符合这个规则的)。不过。我还是不太清楚为什么不能使用 a6,a7 的话因为要储存系统调用号所以肯定不能放参数,a6 就不知道了,如果你知道,可以在评论区中讨论。
到此为止,如果你再尝试输入 trace 32 grep hello README
这个命令,就会看到正确的输出了。
不过,如果你再输入一个 grep hello README
(不带 trace 命令),你会发现还是输出了 trace 的信息。
仔细一想,这也是合理的,xv6 中会维护一个进程的列表(总共 个),我们新开一个进程时,系统给我们分配的是第一个没被使用的进程号。
具体的实现可以看 kernel/proc.c
文件中的 allocproc()
函数:
// Look in the process table for an UNUSED proc.// If found, initialize state required to run in the kernel,// and return with p->lock held.// If there are no free procs, or a memory allocation fails, return 0.static struct proc*allocproc(void){ struct proc *p; for(p = proc; p < &proc[NPROC]; p++) { acquire(&p->lock); if(p->state == UNUSED) { // 可以看到新创建进程时,总是会按顺序找到第一个没被使用的进程号 goto found; } else { release(&p->lock); } } return 0; // …… 下面还有一堆,就先省略了}
所以说,我们输入 grep hello README
时,因为没执行其他的命令,系统给这个命令分配的进程号是之前 trace 32 grep hello README
使用的。
那么,因为 trace 32 grep hello README
用的进程中的 trace_mask
已经被更改过了,并且没有改回来,所以我们 grep hello README
时,自然就会输出追踪的信息。
要解决这个问题,我们需要了解,在一个进程结束时,是由哪个函数来释放资源并且清空信息的,如果我们在这个函数中添加一行重置 trace_mask
的代码,就可以避免“明明没有 trace,但却输出信息了”的情况。
这个做最后收尾工作的函数(感觉有点像是 C++ 里的析构函数)就是 freeproc()
,也和 allocproc()
一起,在 kernel/proc.c
这个文件中:
那么我们直接在最后来一句 p->trace_mask = 0;
就可以了。
// free a proc structure and the data hanging from it,// including user pages.// p->lock must be held.static voidfreeproc(struct proc *p){ if(p->trapframe) kfree((void*)p->trapframe); p->trapframe = 0; if(p->pagetable) proc_freepagetable(p->pagetable, p->sz); p->pagetable = 0; p->sz = 0; p->pid = 0; p->parent = 0; p->name[0] = 0; p->chan = 0; p->killed = 0; p->xstate = 0; p->state = UNUSED; p->trace_mask = 0;}
现在再去尝试一下刚刚出 bug 的操作,就会发现没问题了。
到这里,离完成这个 lab 就只剩最后一步了。
The trace system call should enable tracing for the process that calls it and any children that it subsequently forks, but should not affect other processes.
也就是实现这句话说的功能,如果我们的父进程有 trace_mask
,子进程也需要有相同的。因为创建子进程都需要用 fork()
,那直接去改 fork
的源码就好了:
fork()
的具体实现和上面的两个函数一样,还是在 kernel/proc.c
中(毕竟和进程有关)。
可以看到,第一行定义了两个 struct proc
,一个是 np
,一个是 p
。因为代码中的注释,所以很明显可以看出来,这个 np
就是新的进程,那我们就完全不用管这里一堆看不懂的东西了,直接在中间来一个 np->trace_mask = p->trace_mask
。
然后就……结束了,现在去跑提供的单元测试就可以顺利 AC 了!!
fork(void){ int i, pid; struct proc *np; // new process struct proc *p = myproc(); // Allocate process. if((np = allocproc()) == 0){ return -1; } // Copy user memory from parent to child. if(uvmcopy(p->pagetable, np->pagetable, p->sz) < 0){ freeproc(np); release(&np->lock); return -1; } np->sz = p->sz; np->parent = p; // copy saved user registers. *(np->trapframe) = *(p->trapframe); // Cause fork to return 0 in the child. np->trapframe->a0 = 0; // 复制 trace mask np->trace_mask = p->trace_mask; // 在这 !!!!!! // increment reference counts on open file descriptors. for(i = 0; i < NOFILE; i++) if(p->ofile[i]) np->ofile[i] = filedup(p->ofile[i]); np->cwd = idup(p->cwd); safestrcpy(np->name, p->name, sizeof(p->name)); pid = np->pid; np->state = RUNNABLE; release(&np->lock); return pid;}
实现一个系统调用,用于收集当前系统的空闲内存,和运行进程的数量。系统调用接收一个struct sysinfo*
,在系统调用中需要把信息写进这个结构体里。
和前面一样,需要先在各种文件中把这个系统调用注册上,然后才能开始实现。因为过程和前面的完全一样,这里就不赘述了,唯一要注意的是需要在 user/user.h
申明用户态函数时,加上 struct sysinfo*
这个参数,而不是之前 trace 的参数。
内核中并没有提供给我们获取可用内存和当前进程数的函数,所以我们需要自己实现。
首先我们去实现一下获取可用内存的函数。根据 lab 的要求,应该实现在 kernel/kalloc.c
这个文件里。
可以看到该文件内定义了一个结构体 kmem
,如下:
struct run { struct run *next;};struct { struct spinlock lock; struct run *freelist;} kmem;
以及一些函数比如 kalloc()
,如下:
// Allocate one 4096-byte page of physical memory.// Returns a pointer that the kernel can use.// Returns 0 if the memory cannot be allocated.void *kalloc(void){ struct run *r; acquire(&kmem.lock); r = kmem.freelist; if(r) kmem.freelist = r->next; release(&kmem.lock); if(r) memset((char*)r, 5, PGSIZE); // fill with junk return (void*)r;}
通过代码中的注释以及变量命名和这个 kalloc
函数等,大概可以推断出这个 kmem
是一个链表,而链表中的每一个元素都指向一个可用的内存页(大小为 4KB)。
所以我们就可以遍历这个链表来得到空闲的空间。
uint64 get_fremem(){ // 返回空闲内存,用字节作为单位 uint64 ret = 0; acquire(&kmem.lock); // 先加锁 struct run *free_pagelist = kmem.freelist; while(free_pagelist){ // 遍历这个链表 free_pagelist = free_pagelist->next; ret++; } release(&kmem.lock); return ret * PGSIZE; // 返回时,需要乘以一个页的大小}
接下来我们还需要正在运行的进程数,按照 lab 的要求,要把这个函数实现在 kernel/proc.c
中。
观察之前讲过的 allocproc
函数:
// Look in the process table for an UNUSED proc.// If found, initialize state required to run in the kernel,// and return with p->lock held.// If there are no free procs, or a memory allocation fails, return 0.static struct proc*allocproc(void){ struct proc *p; for(p = proc; p < &proc[NPROC]; p++) { acquire(&p->lock); if(p->state == UNUSED) { // 可以看到新创建进程时,总是会按顺序找到第一个没被使用的进程号 goto found; } else { release(&p->lock); } } return 0; // …… 下面还有一堆,就先省略了}
然后参考这个遍历的过程,遍历所有的进程,并且计算出哪些的 state
不是 UNUSED
。我们就能得到正在使用的进程了。
uintget_proc_cnt(){ struct proc* cur_proc; //proc 是一个数组,定义为:struct proc proc[NPROC]; uint ret = 0; for(cur_proc = proc; cur_proc < &proc[NPROC]; cur_proc++){ acquire(&cur_proc->lock); if(cur_proc->state != UNUSED) ret++; // 如果这个进程是正在使用的 release(&cur_proc->lock); } return ret;}
现在,我们已经能够获得剩余内存和进程的数量了,接下来就可以在 kernel/sysproc.c
中实现 sys_sysinfo
了。
和 trace 上我们获得参数的方法一样,因为用户态和内核态的页表不一样,我们只能通过记录用户调用系统调用时的寄存器状态,并且存在 trapframe
里面,来获取用户传进来的参数。
因为这次需要接收的是一个结构体的指针,所以我们可以使用 argaddr
函数。
uint64 sys_sysinfo(){ struct sysinfo info; struct proc *cur_proc = myproc(); uint64 usr_addr; info.freemem = get_fremem(); // 这两行是获取系统信息 info.nproc = get_proc_cnt(); try(argaddr(0, &usr_addr), return -1); // 记录用户态的 sysinfo 地址 try(copyout(cur_proc->pagetable, usr_addr, (char *)&info, sizeof(info)), return -1); return 0;}
但是这个指针指向的是基于用户态页表的虚拟地址,所以我们获取了系统信息,也就是 info
后,需要用 copyout
函数去把我们的 info
复制到这个用户页表的地址上。
copyout
的申明是:int copyout(pagetable_t pagetable, uint64 dstva, char *src, uint64 len)
。
根据源码中的注释:
Copy from kernel to user.
Copy len bytes from src to virtual address dstva in a given page table.
Return 0 on success, -1 on error.
可以看出第一个参数是这个虚拟地址 dstva
基于的页表,我们这个情况下要填的肯定是用户的页表,也就是 cur_proc->pagetable
。
下一个参数,dstva
是我们拷贝数据的目的地,这是一个基于第一个参数的页表的虚拟地址。我们可以填 usr_addr
,也就是我们通过 argaddr
从用户态获取到的参数。
而 src
就添数据来源,也就是 info
。最后一个参数就很明显了,复制数据的长度,也就是 sizeof(info)
。
写完这些就可以愉快的 AC 了,也祝在做这个 lab 的人尽快 AC。
做这个 lab 真的让我搞清楚了之前对系统调用的很多疑惑,只能说这个课是真的牛逼。比如之前一直不能理解普通的函数调用和系统调用有什么区别,然后这次因为要实现系统调用,要先顺着系统调用的过程把一个新的系统调用在各种文件中注册一遍。这个过程中就对系统调用清楚了很多。
]]>题目链接(CF,洛谷)
给你一个长度为 的数组 。代表所有的火车站。火车只能从左边的站台开到右边的站台。也就是从 开始,再到 ,最后到 。
现在给你 个询问,每个包含两个整数 和 ,问你是否可以从 这个站台开始,坐火车到 。
比如: 数组为 ,有以下三个询问:
我们只需要知道某个站台第一次出现的位置和最后一次出现的位置就行了。假设站台 第一次出现的位置为 ,最后一次出现的位置为 。并且这时有询问 。
那么只要 就一定可以从站台 坐车到站台 了。因为我们知道第一个 号站台在最后一个 号站台的左边,而火车只能从左向右开,所以可以到达 。
因为要形成一个站台编号到位置的映射,并且站台的编号比较大(),站台编号的数量相对较少()。用平常的数组肯定不行,因为需要的空间过大 ()。所以有两种办法,离散化(用排序离散化)和使用 map
。
这里我开了两个 map
,其中一个是站台编号到第一次出现位置的映射,还有一个,和前面讲的一样,是编号到最后一次出现位置的映射。
然后我们就可以得到如下代码:
因为使用的是 cin
和 cout
,所以可能会因为输入速度比较慢造成 TLE,所以可以取消一下同步。
// author: ttzytt (ttzytt.com)#include <bits/stdc++.h>using namespace std;#define ll long longint main() { int t; cin >> t; while (t--) { int n, k; cin >> n >> k; int a[n + 1]; map<int, int> v2pos_frt, v2pos_bk; //编号->第一次出现, 编号->第二次出现 for (int i = 1; i <= n; i++) { cin >> a[i]; if (!v2pos_frt[a[i]]) v2pos_frt[a[i]] = i; // 只有第一次才会赋值 v2pos_bk[a[i]] = i; } while (k--) { int l, r; cin >> l >> r; int lp = v2pos_frt[l]; int rp = v2pos_bk[r]; if (lp <= rp && lp != 0 && rp != 0) { // 如果根本没有这个站台,那 lp 或 rp 就会为 0 cout << "YES\n"; } else { cout << "NO\n"; } } }}
设 为一个由小写拉丁字母组成的字符串。它的价格被定义为,字符串中每个字母在字母表中的位置的和。
比如,字符串 的价格是 。
现在给你一个字符串 ,和一个整数 ,请你从字符串中尽量少的移除字母,使得 的价格小于或等于 。注意移除的字母数量可以是 个,也可以是字符串中全部的字母。
这道题的难度其实跟上一个差不多。因为题目让你删除尽量少的字母,所以我们直接挑对价格贡献大的字母删,直到整个字符串的价格小于等于 。
具体的实现上,我们还是可以用 map
建立一个字符到出现次数的映射(或者说桶)。
然后我们倒着遍历这个 map
,这样先遍历到的字符就对价格有更大的贡献。然后在遍历时如果发现当前的价格大于 ,就删除这个字符。并且如果我们删除了这个字符,那也相应的给字符的出现次数 。
最后输出时,我们遍历原来的字符串,如果发现对应的字符在桶里有出现,就输出,然后把出现次数 ,否则就不输出了。
// author: ttzytt (ttzytt.com)#include <bits/stdc++.h>using namespace std;#define ll long longint main() { int t; cin >> t; while (t--) { string str; int p; cin >> str >> p; map<char, int> bkt; // 桶 ll price = 0; for (char ch : str) { bkt[ch]++; price += (ch - 'a' + 1); //计算初始价格 } map<char, int>::reverse_iterator it = bkt.rbegin(); //倒着遍历 map,所以需要用反向迭代器 while (price > p) { // 如果价格没有小于等于 p,就一直删 (*it).second--; // 减少桶代表的出现次数 price -= ((*it).first - 'a' + 1); // 维护价格 if ((*it).second <= 0) { // 如果说这个字母已经被删光了 if (it != bkt.rend()) it++; // 并且这不是字符串中最小的字符 // 我们就开始删比当前字符小的字符 } } string ans; for (char ch : str) { if (bkt[ch] > 0) { //如果发现这个字符还没被删除 ans.push_back(ch); bkt[ch]--; } } cout << ans << endl; }}
给你 ( 为偶数,)个,数对。数对中的每个数字都是从 到 的。
现在问你是否能将这些数对分到两个集合中。使得每个集合中没有任何一个重复的数字。
比如有下面这四个数对:。
那么可以这样分配这些数对:
看起来是个贪心,能放一个集合的就放,不能就放另一个,另一个还不行就输出 ,但毕竟是个 E 题,所以没那么简单。(别学我直接交了个贪心上去,还半天都想不明白为什么错)。
要证明这个贪心是错的,只需要举一个反例,顺便吐槽一下,这个题的样例还是挺坑的,你用贪心完全能过。
比如给你下面这样一个数据:
61 2 5 42 3 4 3 5 66 1
如果我们用贪心做,设第一个集合为 ,第二个为 ,就可以把前两个,也就是 和 放到 中。到第三个,就会发现 中的 和 的 重复了,于是放到 中。
而对于第四个数对 ,可以发现不管放到哪里都有重复的。
然而,这个数据是可以合法的分到两个集合的:
我们可以把数对拆称每个数字来看。
从 开始,所有数对中,包含 的有两个: 和 。那么我们知道,因为两个数对都有 ,所以肯定不能放到一个集合里。
按照相同的方式来看 。包含 的数对有两个: 和 。所以这两个也一定在不同的集合中。
按照这样的方法从 到 的列出包含这些数字的集合,可以得到:
然后我们检查这些条件,发现似乎没有矛盾的,并且你可以根据这些条件得到我之前给出的分配方法。
这样一看,告诉你两个东西在不同的集合中,并且让你判断这些规则是否能满足,那不就是一种带逻辑关系的并查集吗?
如果你不熟悉,可以去看看这些题目:
的确,这个题是可以用带逻辑关系的并查集来做的,tourist 就是这么做的。
不过,我们还可以从图论的角度来思考。
如果我们给一个数对中的两个数字连上一条边,就可以得到下面这样的图:
1 <--> 2 <--> 3| |6 <--> 5 <--> 4
可以发现,因为和之前一样的原因,对于一个数字,比如 。我们不可能把包含 的两个数对,也就是 和 ,放到一个集合里。
也就是说可以从边的角度思考, 这个节点连了两个边,而我们不能同时选 的两条边放到一个集合里。
那么唯一能满足这个要求的办法就是交替的把边分配到集合中。
比如:
1 <--> 2 <==> 3 or 1 <==> 2 <--> 3 || | <---> | ||6 <==> 5 <--> 4 6 <--> 5 <==> 4
其中 <-->
这样的边和 <==>
这样的边代表边上的两个节点会被放到不同的集合中。
接下来我们可以分类讨论一下,不同的图是否能满足要求。
首先,如果一个节点连了三个及以上的边,那么一定是不能满足交替放入不同集合中的。
比如:
A /|\ / | \B C D
因为如果要把 放入两个集合中,, ,或是 就一定会被放入一个集合中,然后就不能满足交替出现的要求了,因为 不可避免的出现了两次。
其次,如果图中只有一个链,那么交替的放入不同集合中是一定能满足的。
最后,如果图是一个环,并且有偶数的边(就像上面那样),那是一定可以满足交替出现的要求的。而奇数就不行了。
判断环奇偶的办法其实比较直观,我们给每个边设置一个颜色的属性,共有两种颜色,然后用 dfs 去遍历一遍这个环。
遍历时尝试给边交错的染上颜色,如果我们不能成功的交错染色,那一定是奇环,反之亦然。(如果能交错的染色,那么两种颜色的数量一定是相等的,因此一定是偶环)。
还有一点在具体实现时需要注意,我们建出来的图不一定是联通的,所以需要尝试对每一个节点 dfs,同时,之间按照输入建图可能有重边,而我们需要避免。
整体来说,代码还是比较简洁的。
// author: ttzytt (ttzytt.com)#include <bits/stdc++.h>using namespace std;#define ll long longstruct E { int to, color;};const int MAXN = 2e5 + 10;vector<E> e[MAXN];set<int> have_e[MAXN];bool iseven_cycle(int cur, int fa, bool cur_color) { if (e[cur].size() < 2) return true; // 小优化,size 小于 2 说明是一个链的终点。 // 那么一个链是一定可以交错的染色的,这时候直接返回 true for (E &nex : e[cur]) { if (nex.to == fa) continue; if (nex.color == -1) // -1 是初始值,所以直接给它染和当前边不同的颜色 nex.color = !cur_color; else if (nex.color == cur_color)// 如果发现下一个边和当前边同色,那肯定是不能成功染色的 return false; else if (nex.color == !cur_color)// 有颜色了,但是是我们想染的。 return true; if (!iseven_cycle(nex.to, cur, !cur_color)) return false; } return true;}int main() { int t; cin >> t; while (t--) { int n; cin >> n; for_each(e + 1, e + 1 + n, [](vector<E> &a) { a.clear(); }); for_each(have_e + 1, have_e + 1 + n, [](set<int> &a) { a.clear(); }); // 每次清空一下数据。 bool isable = true; map<int, int> bkt; // 记录每个节点的度,如果大于 2 那一定不行(原因如上文) for (int i = 1; i <= n; i++) { int x, y; cin >> x >> y; bkt[x]++, bkt[y]++; if (bkt[x] > 2 || bkt[y] > 2 || x == y) isable = false; // 发现度大于 2 if (!have_e[x].count(y)) { //用于避免重边 e[x].push_back({y, -1}); have_e[x].insert(y); } if (!have_e[y].count(x)) { e[y].push_back({x, -1}); have_e[y].insert(x); } } for (int i = 1; i <= n && isable; i++) { if (e[i][0].color == -1) isable = iseven_cycle(i, 0, 1); // 建出来的图不一定联通,所以尝试对每个节点 dfs } if (isable) cout << "yes\n"; else cout << "no\n"; }}
前言:本题解的解法参考了这个视频。
多重集是一种特殊的集合,其元素可以重复,并且,和集合一样,元素的顺序不重要。如果两个多重集中,每个元素的出现次数都一样,那么这两个多重集就是相等的。
如, 和 是相同的。而 和 不是相同的。
现在给你两个多重集 和 ,每个包含 个整数。
在一次操作中, 中的一个元素可以被翻倍或是减半(向下取整)。或者说,对于一个 中的元素 ,你可以做下面两种操作。
注意你不能对多重集 做任何操作。
请问你是否能使多重集 在经过任意数量的操作后和 相等(也可以是 个操作)。
这个 和 可以联系到位运算的左移和右移。如 的二进制形式为 , 的二进制形式就为 。可以看到相比 , 的二进制形式在最后加了一个 。而 就是 ,二进制形式下的 在最后一位比 少了一个 。
所以左移和乘二的运算是等价的,右移和向下取整的除二是等价的。
那么我们就可以发现一个性质,也就是集合(实为多重集,这里为了方便称为集合) 和 中元素的后缀 是不重要的。
这里我来解释一下什么是后缀 ,以及“不重要”。
现在有一个数,比如 ,其二进制形式为 。可以看到二进制下的 在尾部有 个 。那么这三个 就是 的后缀 。
而不重要的意思是:
如果我们设 。再设 和 分别为 和 去掉后缀 的后的数字。那么如果我们能通过提供的两个操作,把 转换成 就一定能把 转换为 。
这是因为可以通过左移和右移操作,在 的尾部增加和删去任意数量的 。
这样就可以让 变成 。而对于 , 我们已经知道了可以将其转换成 。现在我们再在当前数字上减去一些 ,就可以变成 。
所以为了计算的方便,可以直接在输入的时候去掉元素的后缀 。
接下来,还有一个性质:
当且仅当 在二进制形式下是 的前缀,我们可以将 转换为 。
这里先解释一下,什么是二进制形式下的前缀。有两个数字, 和 。其二进制形式分别是 和 。
那么从字符串的角度来看, 就是 的前缀。而能将 转换为 是因为右移操作,我们可以把 的尾部去掉使其变成自己的任意二进制下的前缀。
并且,显而易见的,如果 , 一定不是 二进制形式下的前缀。那就自然不能将 转换为 。
有了这些性质,我们就可以搞出一些奇怪的方法了。
首先我们把集合 的元素存到一个数组里,把集合 的元素存到一个优先队列里。在存之前,需要先去掉后缀 。
vector<int> a(n);priority_queue<int> b;for (int i = 0; i < n; i++) { cin >> a[i]; while ((a[i] & 1) == 0) { // 如果最后一位是 0,那就一直右移来消除后缀 0 a[i] >>= 1; }}for (int i = 0; i < n; i++) { int temp; cin >> temp; while ((temp & 1) == 0) { temp >>= 1; } b.push(temp);}
然后再对 升序排序,之后就可以搞出一些骚操作了:
sort(a.begin(), a.end());while (b.size()) { int lb = b.top(); b.pop(); int la = a.back(); if (la > lb) { goto FAIL; } else if (la < lb) { lb /= 2; b.push(lb); } else { // la == lb a.pop_back(); }}
可以看到,在这个 while
中,我们每次取出的 和 都分别是 和 中最大的元素。
那么有三种情况。
a.pop_back();
这句话。对于第三种情况,如果说直接把 右移了然后放入优先队列中,那是否会造成: 本来是可以跟 中别的元素匹配,但现在不行了的情况呢?
答案是不会的,因为 中最大的元素已经小于 了,那其他元素一定也小于它,所以不会有别的元素等于 了。
#include <bits/stdc++.h>using namespace std;// author: tzyt// ref: https://www.youtube.com/watch?v=HIiX3r5n27Mint main() { int t; cin >> t; while (t--) { int n; cin >> n; vector<int> a(n); priority_queue<int> b; for (int i = 0; i < n; i++) { cin >> a[i]; while ((a[i] & 1) == 0) { // 如果最后一位是 0,那就一直右移来消除后缀 0 a[i] >>= 1; } } for (int i = 0; i < n; i++) { int temp; cin >> temp; while ((temp & 1) == 0) { temp >>= 1; } b.push(temp); } sort(a.begin(), a.end()); while (b.size()) { int lb = b.top(); b.pop(); int la = a.back(); if (la > lb) { goto FAIL; } else if (la < lb) { lb /= 2; b.push(lb); } else { // la == lb a.pop_back(); } } SUCC: cout << "YES\n"; continue; FAIL: cout << "NO\n"; }}
最后那个 G2,现在还没完全搞懂,我太菜了。。
最后,希望这篇题解对你有帮助,如果有问题可以通过评论区或者私信联系我。
]]>upd@2022/9/14:最近把实验的代码放到 github 上了,如果需要参考可以查看这里:
https://github.com/ttzytt/xv6-riscv
里面不同的分支就是不同的实验。
开始之前先吐槽一句,为什么 xv6 源码的码风这么怪啊???函数的返回类型居然跟函数名不在同一行??
intmain(int argc, char* argv[]){}
像这样……
然后就是建议阅读时关闭暗黑模式(右下角齿轮标),因为有些图片上的字是黑的,开了暗黑模式就看不清了。
实验说明地址:https://pdos.csail.mit.edu/6.828/2020/labs/util.html
实现一个sleep
命令,唯一的参数是休眠的时间。
因为有系统调用,所以实现起来还是比较简单的,可以直接调用提供的 sleep
系统调用。
唯一需要注意的是要在 #include user/user.h
之前先 #include kernel/types.h
。这个文件里面包含了一些类型的定义,而 user.h
需要用到这些定义。
#include "kernel/types.h"#include "kernel/stat.h"#include "user/user.h"#include "kernel/fd_types.h"int main(int argc, char *argv[]){ if (argc != 2){ fprintf(STDERR, "usage: sleep <tick count>"); exit(1); } int tm = atoi(argv[1]); // 字符串 -> 整数 sleep(tm); exit(0);}
其中的 kernel/fd_types.h
是我自己加的,源码如下,就是简单定义了输入输出的文件标识符,防止自己忘了:
#pragma onceconst char STDIN = 0;const char STDOUT = 1;const char STDERR = 2;
需要创建一个子进程,然后使用管道进行进程间通信。子进程和父进程互相通过管道发送一次信息。父进程收到后在终端打印 “ping”,子进程收到后打印 “pong”。
创建子进程后,先让父节进程发送一些信息。然后父进程就可以调用 wait()
了。而子进程会先输出 “pong”,然后向父进程发送信息。最后父进程会收到消息,然后输出一个 “ping”。
这个过程看着比较简单,但是因为我一开始不清楚管道的特性,所以没有正确的使用。一般来说管道是用于单向通信的,因为这个 lab 需要父进程和子进程互相通信,所以应该创建两个管道。
这个知乎回答比较清晰的解释了管道的实现:
数据只能单向移动的意思是FIFO,于是linux中实际构建了一个循环队列。具体一点则是,申请一个缓冲区,作为
pipe()
操作中匿名管道文件实体,缓冲区设俩指针,一个读指针,一个写指针,并保证读指针向前移动不能超过写指针,否则唤醒写进程并睡眠,直到读满需要的字节数。同理写指针向前也不能超过读指针,否则唤醒读进程并睡眠,直到写满要求的字节数。
并且,我一开始没有加 wait()
,就会出问题,比如会输出一些乱码。因为我们不知道系统会先执行子进程还是父进程,可能两个进程同时输出 “ping” 和 “pong”,然后这两个词就会混在一起了。
#include "kernel/fd_types.h"#include "kernel/types.h"#include "user/user.h"enum PIPE_END { REC = 0, SND = 1 };int main(int argc, char* argv[]) { if (argc != 1) { fprintf(STDERR, "usage: pingpong (no parameter)"); exit(114514); //(悲 } int p[2]; pipe(p); int cur_pid = fork(); if (cur_pid == 0) { //子进程 //子进程先接收消息 char buf[20]; if (read(p[REC], buf, sizeof(buf)) > 0) { printf("%d: received pong\n", getpid(), buf); } //子进程用管道发送消息给父进程 fprintf(p[SND], "child"); exit(0); } else if (cur_pid > 0) { char buf[20]; //父进程先发送消息,后接收消息 fprintf(p[SND], "parent"); wait(0); if (read(p[REC], buf, sizeof(buf))) { printf("%d: received ping\n", getpid(), buf); } exit(0); } else { fprintf(STDERR, "failed to fork"); exit(1919810);// homo 特有的 exit 参数(悲 } exit(0);}
创建多个子进程寻找素数。每个子进程筛掉上一个进程传来的数字中,为某个素数倍数的数字,然后把剩下的数传给子进程。因为 xv6 的性能限制,可以只输出前 个质数。具体的解释可以看下面的图。
这是我见过的最奇怪的素数筛了,但其实还是很符合 “筛” 的定义的。每个进程就是一种特定的筛子,会筛掉一个质数的倍数,然后经过很多层“筛子”,我们就能得到最终的素数。注意,传给下一个进程的第一个数字一定是质数,因为该数字不能被任何一个比它小的数字(素数)整除(能被整除就在前面筛掉了)。
需要注意的一点是 fork()
之后,子进程会从 fork()
的下一行开始执行,毕竟 fork()
会把父进程的所有状态拷贝过来,包括 pc 寄存器。(其实是常识,没啥好注意的,只是我之前不知道,然后搞出了很傻逼的错误)。
还有一点是用完了一个管道需要即时关闭,因为 xv6 的资源有限,一直不 close()
可能会让程序崩溃。
int main(int argc, char* argv[]) { int pp[2]; pipe(pp); int pid; pid = fork(); if (pid == 0) { close(pp[SND]); child_proc(pp); } else { int init_num[MAX_P]; int idx = 0; for(int i = 2; i <= MAX_P; i++){ init_num[idx++] = i; } close(pp[REC]); send_to_next(pp[SND], init_num, idx); close(pp[SND]); wait(0); } exit(0);}
首先,在主函数的父进程中,我们需要先创建从 到 的初始数组。然后调用 send_to_next()
函数,这个函数的作用就是把某个数组中的内容通过管道传给下一个进程。
其实现如下:
void send_to_next(int outpp, int msg[], int msg_len) { //发送到下一个子进程 //outpp 是管道的发送端 for (int i = 0; i < msg_len; i++) { write(outpp, msg + i, sizeof(int)); }}
在主函数的子进程中,我们会调用 child_proc()
。这个函数的唯一一个参数是管道的接收端,子进程会从这个管道接收没有被筛掉的数字。
void child_proc(int pp[2]) { int child_pp[2]; pipe(child_pp); int prime; int len = read(pp[REC], &prime, sizeof(int)); if(len == 0){ //如果全部都筛掉了,那自然可以结束了 printf("OK"); exit(0); return; } printf("prime %d\n", prime); int outlen; int* filtered = filter(prime, pp[REC], &outlen); close(pp[REC]); int pid = fork(); if(pid == 0){ close(child_pp[SND]); child_proc(child_pp); } else { close(child_pp[REC]); send_to_next(child_pp[SND], filtered, outlen); close(child_pp[SND]); wait(0); // wait 可以释放子进程的进程号以及别的资源 exit(0); }}
然后我们在这个 child_proc()
中会把接收到的第一个数字当作素数(原因如前面)。
然后用这个素数和 filter()
函数筛掉所有是这个素数倍数的数。filter()
的实现如下。
int* filter(int num, int inpp, int* outlen) { //把 inpp 管道中的 num 倍数全部过滤掉,返回过滤好的数组(没有 num 的倍数) (*outlen) = 0; //len 是过滤之后还有多少个数字 int* out = (int *)malloc(MAX_P * sizeof(int)); int ret = 0; do { ret = read(inpp, out + (*outlen), sizeof(int)); // ret 返回读到了多少字节 if (out[(*outlen)] % num != 0 && ret > 0) { (*outlen)++; } } while (ret > 0); return out;}
筛掉当前素数的倍数后,就可以再创建一个进程,把剩下的数字传过去了。在子进程中,可以继续调用 child_proc()
:
if(pid == 0){ close(child_pp[SND]); child_proc(child_pp);}
需要注意我们调用 child_proc
时,传进去的管道不是原来那个 pp
,是新创建的 child_pp
。这样做是因为在一个进程中,我们既需要读读取上一个进程传进来的数字,也需要把过滤好的数字发到下一个进程中。
而管道是只能单向传输的,如果我们只使用一个管道。那么一个进程在接收上一个进程的数据时,不能 close()
管道的发送端,因为之后还要把过滤好的数据发到下一个进程上。
但 read()
一个管道时,如果不 close()
这个管道的发送端,这个 read()
是会阻塞的,也就是会卡在这里,等待新数据。因为系统不知道之后会不会有信息从发送端发过来。只有关闭了发送端才能表明传输已经结束,之后再也不会有新的数据从发送端传过来。
同时,因为最开始的时候,子进程和父进程的管道都是默认开启的,也就是说有两个进程打开了管道的发送端。那么如果只有一个进程关闭了发送端,我们去 read()
接收端时,还是会阻塞的,因为发送端并不是真正的关闭。
这样讲可能还是有点不清晰,下面这张图可以比较清楚的解释整个过程。
另外还有一点,在子进程执行 child_proc
时,父进程一定要调用 wait()
,要不然可能会产生僵尸进程。
也就是父进程已经执行完而且调用 exit()
释完空间了,而子进程还在执行。
但是和直觉不太一样,子进程调用 exit()
释放资源呢后并没有完全从系统上消失,进程的描述符还存在在系统上,其唯一目的是给父进程提供信息。
所以我们需要父进程调用 wait()
来释放该进程最后剩余的进程标识符,slab缓存等,该调用会阻塞当前父进程,直到某个子进程退出[3]。
像这样的僵尸进程会占用进程号,文件描述符等资源,所以会有危害。
除此之外,不加 wait()
也会导致你的程序通不过提供的单元测试 (./grade-lab-util
)这也是为什么我会发现我程序有问题。具体来说,在跑测试的时候,进程一直都不会结束,然后单元测试就会显示你超时。
在 shell 运行时也是这样,虽然已经输出了所有的质数,但 shell 一直不会输出 $
。说明这个进程一直没有运行完毕。
不过我也不太清楚为什么僵尸进程会导致这样的现象,如果你清楚,可以在评论区说一下。
完整代码如下,参考了[1]:
#include "kernel/fd_types.h"#include "kernel/types.h"#include "user/user.h"#include "kernel/dbg_macros.h"const int MAX_P = 35;// #define FDEBUGenum PIPE_END { REC = 0, SND = 1 };void send_to_next(int outpp, int msg[], int msg_len) { //发送到下一个子进程 for (int i = 0; i < msg_len; i++) { write(outpp, msg + i, sizeof(int)); }}int* filter(int num, int inpp, int* outlen) { //把 inpp 管道中的 num 倍数全部过滤掉,返回过滤好的数组(没有 num 的倍数) (*outlen) = 0; //len 是过滤之后还有多少个数字 int* out = (int *)malloc(MAX_P * sizeof(int)); int ret = 0; do { ret = read(inpp, out + (*outlen), sizeof(int)); // ret 返回读到了多少字节 if (out[(*outlen)] % num != 0 && ret > 0) { (*outlen)++; } } while (ret > 0); return out;} void child_proc(int pp[2]) { int child_pp[2]; pipe(child_pp); int prime; int len = read(pp[REC], &prime, sizeof(int)); DEBUG("len: %d\n", len); if(len == 0){ printf("OK"); exit(0); return; } printf("prime %d\n", prime); int outlen; int* filtered = filter(prime, pp[REC], &outlen); dbg_arr_i32(filtered, 0, outlen); DEBUG("outlen: %d\n", outlen); close(pp[REC]); int pid = fork(); if(pid == 0){ close(child_pp[SND]); child_proc(child_pp); } else { close(child_pp[REC]); send_to_next(child_pp[SND], filtered, outlen); close(child_pp[SND]); wait(0); // wait 可以释放子进程的进程号以及别的资源 exit(0); }}int main(int argc, char* argv[]) { int pp[2]; pipe(pp); int pid; pid = fork(); if (pid == 0) { close(pp[SND]); child_proc(pp); } else { int init_num[MAX_P]; int idx = 0; for(int i = 2; i <= MAX_P; i++){ init_num[idx++] = i; } close(pp[REC]); send_to_next(pp[SND], init_num, idx); close(pp[SND]); wait(0); } exit(0);}
其中的 DEBUG
和 dbg_arr_i32
是一些调试用的函数或者宏,是我自己加在 kernel/dbg_macros.h
里面的,如下:
#pragma once#include "kernel/fd_types.h"#if (!defined FPRINTF)//内核态是没有定义 fprintf 的,只有 printf,所以重新定义 printf#define fprintf(_stream, _fmt, ...) printf(_fmt, ##__VA_ARGS__)#endif#ifdef FDEBUG#define try(_expr, _act) \ { \ if ((_expr) < 0) { \ fprintf(STDERR, "try: %s failed, at line %d, file %s\n", #_expr, \ __LINE__, __FILE__); \ _act; \ } \ }#else#define try(_expr, _act)#endif#ifdef FDEBUG#define DEBUG(fmt, ...) fprintf(STDERR, fmt, ##__VA_ARGS__)#else#define DEBUG(fmt, ...)#endifvoid dbg_arr_i32(int arr[], int st, int ed) {#ifdef FDEBUG for (int i = st; i <= ed; i++) { DEBUG("%d ", arr[i]); } DEBUG("\n");#endif}
实现 find 命令,查找该目录下所有为指定名字的文件。并且输出该文件的绝对路径。
这个可以参考 ls 的实现:
其实就是一个 dfs,如果检测到当前的路径是一个文件夹,那儿就 dfs 这个文件夹下的每一个文件/文件夹。
要获取文件夹里面放的东西,可以直接去 read()
这个文件夹。然后 read()
出来的是一个 dirent
结构体。
这个结构体的定义如下:
struct dirent { ushort inum; char name[DIRSIZ];};
其中里面的 inum
是文件节点,跟文件描述符不太一样,有多个文件描述符可以指向一个文件,但是每个文件的 inum
是唯一的。
注意在文件夹中还需要跳过 .
和 ..
这两个文件,要不然就死循环了。
通过这个 dirent
结构体,我们可以直接把 name
加到当前的路径后面,然后把这个新的路径传入,继续递归。
其实这个程序需要的功能和 ls
不同,所以其实还可以再简化一下。
在 ls
中,因为不是递归实现的,所以对于最开始的文件节点需要调用 fstat()
来判断是文件夹还是文件。然后如果是文件夹,再调用 stat()
来输出该文件夹内每个文件节点的信息。
stat()
和 fstat()
都是用来获取文件节点信息的,唯一的不同是,fstat()
接收的是这个文件的标识符,而 stat()
则接收路径。
但是在 find
中,因为是递归的,所以只需要调用一个 fstat()
就够了(不用 stat()
)是因为我们已经通过 open()
获得了标识符。
#include "kernel/fd_types.h"#include "kernel/types.h"#include "kernel/fs.h"#include "kernel/stat.h"#include "user/user.h"// #define FDEBUG#include "kernel/dbg_macros.h"const int BUF_SIZ = 512;char* get_fname_from_path(char path[]) { char* ptr = path + strlen(path); // ptr 指向 path 的最后一个元素 for (; ptr >= path && *ptr != '/'; ptr--) { } return ++ptr; // 从 for 里出来指向的是 '/',所以要减一下}void dfs_find(char* cur_path, char* name) { int cur_fd; char nexdir_buf[BUF_SIZ]; struct stat cur_stat; struct dirent nex_dir; try(cur_fd = open(cur_path, 0), return ); try(fstat(cur_fd, &cur_stat), return ); // fstat 接收的是一个文件描述符 if (cur_stat.type == T_FILE) { if (strcmp(get_fname_from_path(cur_path), name) == 0) { printf("%s\n", cur_path); } } else if (cur_stat.type == T_DIR) { strcpy(nexdir_buf, cur_path); char* path_end = nexdir_buf + strlen(nexdir_buf); *(path_end) = '/'; path_end++; while (read(cur_fd, &nex_dir, sizeof(struct dirent)) == sizeof(struct dirent)) { if (nex_dir.inum == 0) continue; // inum 就是文件节点,等于 0 为不可用 if (strcmp(".", nex_dir.name) == 0 || strcmp("..", nex_dir.name) == 0){ DEBUG(". or ..\n"); continue; } memmove(path_end, nex_dir.name, DIRSIZ); path_end[DIRSIZ] = '\0'; try(stat(nexdir_buf, &cur_stat), continue); //这里的 stat 接收的是绝对路径,并且这句话是可以删掉的,因为是递归实现。 dfs_find(nexdir_buf, name); } } close(cur_fd);}int main(int argc, char* argv[]) { if (argc != 3) { fprintf(STDERR, "usage: find <directory> <file name>"); exit(114); } dfs_find(argv[1], argv[2]); exit(0);}
实现 UNIX 中的 xargs 命令。
最开始搞了好久都没搞懂这东西是干啥的。其实就是因为把标准输入的数据传到一个命令中。xargs 的第一个参数是另一个命令的名字。然后我们需要把之后所有的参数,和从标准输入输进来的数据,当作那个命令的参数,去执行那个命令。
有这个 xargs 其实是因为很多命令不支持读取管道的输入作为参数,因为 shell 里的管道会把上一个命令的标准输出输出到下一个命令的标准输入上,所以我们需要从标准输入读出这些东西,然后作为参数给另一个命令执行。
比如 echo hello too | xargs echo bye
。管道会往 xargs 的标准输入输入 “hello” 和 “too” 两个字符串,xargs 就需要读取这两个字符串,然后和 “bye” 这个参数一起,作为执行第二个 echo 的参数,去执行 echo。
所以我们首先需要通过换行符和空格来判断不同的参数,然后把它们分割开来,存进入另一个字符数组(std_args
)。
然后再开一个新的字符数组,作为 exec()
时传进去作为参数的字符数组(arg2pass
)。首先需要在 arg2pass
中放入命令的名字(也就是 argv[1]
)然后再放入剩余的 argv
。最后再把 std_args
加进来。
#include "kernel/types.h"#include "user/user.h"// #define FDEBUG#include "kernel/fd_types.h"#include "kernel/param.h"#include "kernel/dbg_macros.h"const char* DEFAULT_CMD = "echo";#define MX_ARG_CNT 32#define MX_ARG_LEN 32char cut_str_by(char* src, char* dst, int* srcpos, char* signs) { // 从 src 串的下标为 srcpos 的位置开始往后找,在第一个碰到 signs 里面的字符时停下来 // 然后把 src[srcpos……<碰到 sign 的位置>] 这段字符串截取下来,放入 dst 中。 // 注意这个 srcpos 是一个指针,也就是调用完这个函数后我们可以通过 srcpos 知道在哪里 // 碰到了 signs // 返回值其实是一个布尔类,但是因为 c 语言没有,就用了 char,其表示是否碰到 signs 里 // 的字符 // 如果没有的话可能是 srcpos 这个位置就是一个 \0,也就是没有新的参数了。也可能是读完 // 了一段字符,后面没空格和 \n 了,那么说明这是最后一个参数。 for (int i = *srcpos; src[i] != '\0'; i++) { for (int s = 0; signs[s] != '\0'; s++) { if (src[i] == signs[s]) { src[i] = '\0'; strcpy(dst, src + *srcpos); *srcpos = i + 1; return 1; } } } return 0;};char std_args[MX_ARG_CNT][MX_ARG_LEN];int main(int argc, char* argv[]) { char* cmd; if (argc == 1) { cmd = DEFAULT_CMD; } else { cmd = argv[1]; } int argcnt = 0; char buf[MX_ARG_LEN * MX_ARG_CNT]; int curlen = 0; int lst_pos = 0; try(read(STDIN, buf, sizeof(buf)), exit(1145)); memset(std_args, 0, sizeof(std_args)); while (cut_str_by(buf, std_args[argcnt], &lst_pos, "\n ")) { while (buf[lst_pos] == '\n' || buf[lst_pos] == ' ') { //可能两个参数之间隔了很多个空格 lst_pos++; } argcnt++; } char* arg2pass[MX_ARG_CNT]; int lst = 0; arg2pass[lst++] = cmd; // 先放 argv[1] for (int i = 2; i < argc; i++) { // 然后是其他 argv arg2pass[lst++] = argv[i]; } for (int i = 0; i < argcnt; i++) { // 最后放从标准输入来的 argv arg2pass[lst++] = std_args[i]; } exec(cmd, arg2pass); exit(0);}
先来张 AC 的照片,也祝在做这个 lab 的人尽快 AC。
感觉大部分还是不难想的,主要是调试浪费了很多时间导致我的速度奇慢无比。因为常年使用 C++ 的 stl,现在对 C 都不是特别熟悉了,特别是调试 cstring 时浪费了很多时间。所以之后还是应该练习一下调试的技巧,以及 C 语言。
]]>书的链接(中文翻译版):https://github.com/duguosheng/6.S081-All-in-one
这章大部分的内容还是能够看懂的,但是在管道的示例程序上卡了很久。最后终于搞懂了,这里把我的理解写一下:
int p[2];//p[0] 储存管道接收端的文件描述符,p[1] 为管道发送端的描述符char *argv[2];argv[0] = "wc"; //第一个参数是命令的名字argv[1] = 0; //stdinpipe(p);//子 p[0] <--------- p[1] 父if(fork() == 0) { close(0); dup(p[0]); close(p[0]); close(p[1]); exec("/bin/wc", argv);} else { write(p[1], "hello world\n", 12); close(p[0]); close(p[1]);}
首先需要注意的是父子进程的文件描述符是共享的,比如这里的 pipe()
函数,看似是运行子进程后又打开了一个管道,但是因为打开的文件描述符是共享的,实际父进程和子进程中的 p[0]
和 p[1]
指向的是一个文件,方便了进程之间的交流。
除了文件之外,根据《UNIX环境高级编程》,还有以下的资源也是父子进程共享的( 虽然我基本都不懂 )。
而不同的地方则是:
对于这段程序中的父进程,通过 pipe()
拿到管道的文件标识符后,会往 p[1]
,也就是管道的发送端,写入 hello world
。然后关闭管道的两端。比较容易理解。
对于子进程,会先 close(0)
,这里的 是代表 stdin 的标识符,然后呢用 dup(p[0])
把 p[0] 这个标识符复制到另一个标识符上。
比如我们写 x = dup(y)
,那么 x 和 y 就会指向相同的文件。可是这里调用 dup()
时并没有接收返回值,那我们如何知道会被复制到哪个标识符上呢?
其实 dup()
会在所有标识符中从小到大找到第一被关闭的表示符,在子进程的程序中,我们先关闭了 stdin,也就是 ,那么调用这个 dup()
自然会把 p[0]
重定向到 stdin 上,也就是我们读取 stdin 相当于读取了 p[0]
。
接下来在子进程中,调用了 exec()
去执行 wc
命令,这个命令会统计文件的字数。可以发现,我们传给这个命令的参数的 argv[1] = 0
。我们希望让 wc
统计标准输入的字数。
需要注意的是调用 exec()
之后,系统会直接把新的程序写进这个进程,也就是说进程直接变成了 exec()
的程序。所以调用 exec()
之后是永远不会返回的。如果希望当前进程不被取代,可以先 fork()
,然后在子进程中 exec()
。
但是因为前面的 dup()
, 现在的标准输入已经被重定向到了 p[0]
,那么我们实际统计的就是从管道 p[1]
传进来的数据,也就是 write(p[1], "hello world\n", 12)
里的 hello world
。
还有一点我最初也感到很奇怪。既然父子进程是共享文件标识符的,那么 p[0]
和 p[1]
被关闭两次,不会出问题吗?
于是我就找到了这篇文章,终于搞懂了。
close()
函数关闭文件时,并不是在任何情况下都直接关闭文件,而是找出file
结构体中f_count
成员,执行自减操作;直到f_count
为0,才是真正的关闭文件。这就是著名的技术——引用计数。
强烈推荐一篇文章:https://zhuanlan.zhihu.com/p/351646541 非常详细的介绍了 xv6 页表相关的知识。下面的文章也大量参考了这篇文章。
页表是一种特殊的数据结构,用于实现操作系统中的内存虚拟化。页表储存了一个从虚拟地址到物理地址的映射。对于每个进程,操作系统都会维护一个页表,在每个进程中,也只能通过这个页表来访问物理内存,这样每个进程都在表面上拥有了一整台机器的资源。并且,一个进程的内存发生了泄露,也不会影响到另一个进程。
从虚拟地址到物理地址的转换是通过 CPU 中的内存管理单元完成的。如下图:
除了增强进程间的隔离性,页表和虚拟地址的作用就是更高效的利用计算机的内存。在实际中,我们是很难在内存中找到一大段空间的,这就会导致一种情况:当程序申请一段内存时,尽管总的空闲内存超过用户需要的内存,但是每段连续的内存都不能满足。因为随着时间推移,程序和数据被不断的移出和加载进内存中,内存中的碎片也越来越多。而页表把内存划分成了很多块,这样我们就可以把一段连续的虚拟内存映射到间断的物理内存上,从而更高效的利用空间。
最后,使用页表和虚拟内存还可以实现很多别的骚操作。如同时把一个物理地址映射到两个虚拟地址上。
最简单的页表实现就是一个类似数组的东西,记录着虚拟内存到每块内存的映射(也就是页帧,xv6 上一个页帧为 4KB)。
但是,这样的线性数组本身就需要很大的空间储存。我感觉现在大部分的个人电脑都有 8GB 的内存,那接下来可以算一下 8GB 的内存会需要多少空间储存页表。
首先,如果一个页帧为 4KB,那么 8GB 的内存中一共有 个页帧。对于每一个页帧,我们都需要有一个从虚拟地址的映射。假设我们使用的是 64 位的机器(8GB 的话只可能是 64 位的),那么就需要 8 个字节来储存一个地址。所以我们共需要 的空间储存页表。
如果一个 8GB 的机器就使用这么多的空间储存页表,也没什么。关键的问题是,对于系统中的每一个进程,我们都需要开一个新的页表来保存这个进程专属虚拟地址空间(并且内核态和用户态程序单独开)。那如果有 50 个进程正在运行,我们就需要 的内存来储存页表,这显然是不可接受的。
为了解决这个问题 risc-v 和几乎所有其他的现代处理器都使用了多级页表的方式。
具体来说,risc-v 使用了三级页表,可以理解为一个三层的树。树的根节点有 512 个子节点,这 512 个子节点是必须有的,剩下的两层子节点可以有最多 512 子节点(不是必须)。
这些子节点的正式名称是页表条目PTE(Page Table Entry)。每个 PTE 为 54 位(risc-v 能处理 64 位的虚拟地址,但是物理内存最多为 56 位数)其中,44 位为物理页帧号,用于索引下一页页表,10 位为一些标志位,用来记录一些关于当前 PTE 指向的 PTE 或内存的信息。
这就是为什么前面说,剩下的两层子节点不一定有 512 个子节点,并且因此,可以节省一定空间。因为我们可以利用 PTE 中的标志位来判断,当前这个 PTE 指向的下一个页表是否存在,如果说标志位指示其不存在,那就完全不需要储存这个 PTE 指向的页表了。而在单级页表中,就算我们可以用标志位指示这个页帧不存在,也必须储存对应的 PTE。
同时,多级页表和 PTE 中的标志位还可以让我们把一些页表交换到硬盘中,需要的时候再取出来,因此更大程度的节省空间。
那么拿到了一个虚拟地址之后,具体是如何转换成物理地址的呢。
我们先要了解 xv6 中使用的虚拟地址是什么样的,(可能因为方便教学?)xv6 中的虚拟地址只有低 39 位数是在使用的,剩下的 25 位都是保留位。
CPU 中的 satp 寄存器会指向当前页表的根页表,我们的要转换的虚拟地址就是基于这个 satp 指向的页表。
然后如下图所描绘:
虚拟地址的前 9 位指定了应该选取根页表中的哪个 PTE,这个 PTE 储存着一个物理地址,根据这个地址就能找到下一级 PTE。因为总共有三级页表,所以这个过程需要重复三次。相应的,对于每一级页表,虚拟地址中有 位来储存应该选取哪一个 PTE。而虚拟地址剩下的 12 位指示了在页帧中的偏移量()。
上图的下半部分还展示了 PTE 中的各种标志位,如:
在 qemu 中,RAM 的内存是从 0x80000000 (KERNBASE) 开始的,在这个地址下面的都是一些 IO 设备,比如网卡或是一些中断控制器,读取和写入这些特殊的内存可以实现和 IO 设备通信,RAM 的截止位置为 0x86400000 (PHYSTOP) ,共计 128 MB。
在内核中,除了两个特殊的页帧外,所有的虚拟地址都被直接的映射到物理地址中。这样就允许了内核更方便的操作物理内存。同时,因为这样的直接映射,内核还可以模拟 MMU 的行为,在页表不同的情况下读取用户态的数据( walk
函数,之后会讲到)。
页表中没有直接映射的两部分为 trampoline(蹦床) 和内核栈。
trampoline 被映射到了虚拟地址的顶端,在用户态中,这个相同的物理地址也被映射到了顶端。(至于为啥这么搞,应该会在 trap 那章有解答,我现在还不知道)
而内核栈在虚拟地址的顶端和中间部分都被映射了一遍(中间部分是直接映射)。这是因为 guard page 的设计。
guard page(可以翻译为保护页?)是一个为了防止栈溢出的设计,其 PTE 中的 V 标志位并没有被设置,这样我们访问保护页时,就会发生 page fault。
可以看到内核栈是被夹在保护页中间的,这样发生栈溢出,并且访问到这些保护页时,就会报错,而不会访问到内核栈不该访问的地方。
为了节省空间,这些保护页其实没有被映射到任何物理地址上,这也是使用虚拟内存才能实现的灵活操作,你可以把一个物理地址映射到多个虚拟地址,也可以不把一个虚拟地址映射到任何物理地址上。
用户态的虚拟内存布局就没什么特别了,除了最上面的 trampoline,如前面所说,这一页也被映射到了内核的虚拟内存中。
而 trampoline 下面的 trapframe 是用来在系统调用和发生 trap 时,保存各个寄存器的状态的。
鸽
在正常的情况下,我们写一个程序,那么这个程序运行起来大概是一个 “线性” 的过程,也就是程序里的内容是一条接着一条的运行下去的。
但在某些特殊情况下,这样“线性”的运行过程会被打破。比如我们熟悉的系统调用,就会暂停用户态程序的状态,跳转到内核态执行一些服务,然后再跳回用户态。
这样在用户态和内核态之间切换,去处理特殊性事件的过程,在 xv6 中称为陷入(trap)。
至于为啥叫陷入,我觉得下面这张图[1]挺形象的,就像是“陷入”到一个坑里又出来了:
通常有以下几种情况会发生陷入:
首先介绍一下陷入机制中会用到的一些寄存器:
下面会以一个系统调用的过程为例去讲解 xv6 的陷入机制。
在 lab2 的试验记录中,提到了会通过 ecall 指令从用户态切换到内核态,但是并没有详细解释这中间发生了什么(其实是我不知道),学过 trap 这章的内容,就可以明白了。
ecall 指令会干以下的事情[2]:
前面说 ecall 之后会跳到 stvec 指向的位置,在 xv6 中的用户态,这个位置指向的是 kernel/trampoline.S
中的 uservec。在内核态时,指向的是 kernel/kernelvec.S
。
更具体的,这个 kernelvec 最早是在 main 函数中被设置的,也就是 main 的 trapinithart()
。
这个函数干了以下的事情:
// set up to take exceptions and traps while in the kernel.voidtrapinithart(void){ w_stvec((uint64)kernelvec);}
也就是把 kernelvec 的地址写入 stvec。
不过我们现在的例子是从用户态陷入,所以先放一下 kernel/trampoline.S
中 uservec 的代码:
uservec: # # trap.c sets stvec to point here, so # traps from user space start here, # in supervisor mode, but with a # user page table. # # sscratch points to where the process's p->trapframe is # mapped into user space, at TRAPFRAME. # # swap a0 and sscratch # so that a0 is TRAPFRAME csrrw a0, sscratch, a0 # save the user registers in TRAPFRAME sd ra, 40(a0) sd sp, 48(a0) sd gp, 56(a0) sd tp, 64(a0) sd t0, 72(a0) sd t1, 80(a0) sd t2, 88(a0) sd s0, 96(a0) sd s1, 104(a0) sd a1, 120(a0) sd a2, 128(a0) sd a3, 136(a0) sd a4, 144(a0) sd a5, 152(a0) sd a6, 160(a0) sd a7, 168(a0) sd s2, 176(a0) sd s3, 184(a0) sd s4, 192(a0) sd s5, 200(a0) sd s6, 208(a0) sd s7, 216(a0) sd s8, 224(a0) sd s9, 232(a0) sd s10, 240(a0) sd s11, 248(a0) sd t3, 256(a0) sd t4, 264(a0) sd t5, 272(a0) sd t6, 280(a0) # save the user a0 in p->trapframe->a0 csrr t0, sscratch sd t0, 112(a0) # restore kernel stack pointer from p->trapframe->kernel_sp ld sp, 8(a0) # make tp hold the current hartid, from p->trapframe->kernel_hartid ld tp, 32(a0) # load the address of usertrap(), p->trapframe->kernel_trap ld t0, 16(a0) # restore kernel page table from p->trapframe->kernel_satp ld t1, 0(a0) csrw satp, t1 sfence.vma zero, zero # a0 is no longer valid, since the kernel page # table does not specially map p->tf. # jump to usertrap(), which does not return jr t0
其中有几个比较重要的地方,第一个是
csrrw a0, sscratch, a0
这行代码交换了 a0
和 sscratch
的值,也就是从这句话开始,a0
就指向了 trapframe。我们不能直接使用 sscratch,而是交换后使用的原因是:sscratch 是特权级的寄存器,而 sd 和 ld 等命令只能操作通用寄存器,特权寄存器一般是 csr 开头的指令操作(这块我也不太懂)。
接下来,我们通过 sd ra, 40(a0)
这样的命令,把寄存器中的值复制到内存的 trapframe 中。在内核态的 c 代码中,我们可以通过 trapframe 结构体访问这个 trapframe:
struct trapframe { /* 0 */ uint64 kernel_satp; // kernel page table /* 8 */ uint64 kernel_sp; // top of process's kernel stack /* 16 */ uint64 kernel_trap; // usertrap() /* 24 */ uint64 epc; // saved user program counter /* 32 */ uint64 kernel_hartid; // saved kernel tp /* 40 */ uint64 ra; …… /* 264 */ uint64 t4; /* 272 */ uint64 t5; /* 280 */ uint64 t6;};
其中,sd
的意思是 store,也就是储存 ra
的值到 a0 寄存器中的地址偏移 40 个字节的位置。
接下来,我们除了 a0 以外的所有寄存器都被复制了一遍,所以要再来复制一遍 a0:
# save the user a0 in p->trapframe->a0 csrr t0, sscratch sd t0, 112(a0)
注意因为前面交换过,现在这个 sscratch 储存着用户态 a0 的值,然后这个寄存器又和 t0 交换了下,t0 就成了用户态的 a0。因此 sd t0, 112(a0)
就保存了 用户态 a0 的值。
接下来,我们需要把处理器的环境完全的切换到内核中。因为我们之前用的是用户态的页表以及栈指针等,所以要更新相关寄存器的值。
然后就有了如下的代码,其中 ld 命令表示 load,及从内存中复制值到寄存器中:
# restore kernel stack pointer from p->trapframe->kernel_spld sp, 8(a0)# make tp hold the current hartid, from p->trapframe->kernel_hartidld tp, 32(a0)# load the address of usertrap(), p->trapframe->kernel_trapld t0, 16(a0)# restore kernel page table from p->trapframe->kernel_satpld t1, 0(a0)csrw satp, t1sfence.vma zero, zero
这里有个比较有意思的点,就是 trampoline 页(uservec 就放在 trampoline 页)在内核态和用户态的虚拟地址都是一样的,也就是同一个物理地址被映射了两次(页表部分有讲)。这样的设计允许使用 csrw satp, t1
命令更换页表后继续执行 uservec 的程序,不得不说还是很巧妙的。
trapframe 中的这些值(及内核态的根页表,内核态的栈指针等)其实是内核态第一次进入用户态时存下来的。
可以在 kernel/trap.c
的 usertrapret()
函数中找到:
// set up trapframe values that uservec will need when// the process next re-enters the kernel.p->trapframe->kernel_satp = r_satp(); // kernel page tablep->trapframe->kernel_sp = p->kstack + PGSIZE; // process's kernel stackp->trapframe->kernel_trap = (uint64)usertrap;p->trapframe->kernel_hartid = r_tp(); // hartid for cpuid()
然后就到了 uservec 的最后一条代码:
jr t0
也就是跳转到 t0 寄存器的位置。注意前面的这句话:
# load the address of usertrap(), p->trapframe->kernel_trapld t0, 16(a0)
我们把 usertrap 函数的值加载到了 t0 中,那么 jr
之后就会跳转到 usertrap 函数中。
总结一下,uservec 一共干了下面这些事情:
接下来,就到了 usertrap,代码如下:
voidusertrap(void){ int which_dev = 0; if((r_sstatus() & SSTATUS_SPP) != 0) panic("usertrap: not from user mode"); // send interrupts and exceptions to kerneltrap(), // since we're now in the kernel. w_stvec((uint64)kernelvec); struct proc *p = myproc(); // save user program counter. p->trapframe->epc = r_sepc(); if(r_scause() == 8){ // scause 储存陷入原因 // system call if(p->killed) exit(-1); // sepc points to the ecall instruction, // but we want to return to the next instruction. p->trapframe->epc += 4; // an interrupt will change sstatus &c registers, // so don't enable until done with those registers. intr_on(); syscall(); } else if((which_dev = devintr()) != 0){ // ok } else { printf("usertrap(): unexpected scause %p pid=%d\n", r_scause(), p->pid); printf(" sepc=%p stval=%p\n", r_sepc(), r_stval()); p->killed = 1; } if(p->killed) exit(-1); // give up the CPU if this is a timer interrupt. if(which_dev == 2){ yield(); } usertrapret();}
首先用一下代码判断中断是用户态来的还是从内核态来的:
if((r_sstatus() & SSTATUS_SPP) != 0) // sstatus 的 spp 为储存是用户态中断还是内核态中断 panic("usertrap: not from user mode");
如果是内核态来的,就……就……处理不了了,直接来个 panic 摆烂了。
如果是用户态的话,那会先用如下代码把 stvec 改成 kernelvec:
w_stvec((uint64)kernelvec);
因为万一在内核中发生中断,处理逻辑是不一样的,所以不能用 uservec 的程序。
p->trapframe->epc = r_sepc();
注意我们这里把 sepc 储存起来是因为,在内核态处理时,可能会切换到另一个进程,而这个进程也可能会去调用系统调用。这个时候 sepc 寄存器的值会被覆盖,那么我们现在把他存起来了,就算中途去处理另一个进程的系统调用,回来的时候也没问题。
并且我们是在这些都保存好后,才去调用 intr_on()
打开中断的,这样只有保存好信息后在可能去在别的进程中执行中断。(打开中断时因为系统调用可能比较费时间,这段时间中,cpu 可以同时处理别的进程)
后面的话基本上就是按照产生陷入的原因,去做不同的处理。如,如果是因为系统调用产生的陷入,那么一定使用了 ecall 指令,这时候我们就希望系统调用执行好了后在用户态执行的是 ecall 的下一条指令,所以要把 sepc 改成 sepc + 4。
如果是设备产生的中断,那会在 devintr()
函数中相应的处理逻辑。
如果产生了异常,那就直接把那个发生陷入的进程 kill 了。
所以这个 usertrap 函数大概干了如下的事情:
这段代码的最后一行调用了 usertrapret()
这个函数,做了一些返回前的工作。
代码如下:
voidusertrapret(void){ struct proc *p = myproc(); // we're about to switch the destination of traps from // kerneltrap() to usertrap(), so turn off interrupts until // we're back in user space, where usertrap() is correct. intr_off(); // send syscalls, interrupts, and exceptions to trampoline.S w_stvec(TRAMPOLINE + (uservec - trampoline)); // set up trapframe values that uservec will need when // the process next re-enters the kernel. p->trapframe->kernel_satp = r_satp(); // kernel page table p->trapframe->kernel_sp = p->kstack + PGSIZE; // process's kernel stack p->trapframe->kernel_trap = (uint64)usertrap; p->trapframe->kernel_hartid = r_tp(); // hartid for cpuid() // set up the registers that trampoline.S's sret will use // to get to user space. // set S Previous Privilege mode to User. unsigned long x = r_sstatus(); x &= ~SSTATUS_SPP; // clear SPP to 0 for user mode x |= SSTATUS_SPIE; // enable interrupts in user mode w_sstatus(x); // set S Exception Program Counter to the saved user pc. w_sepc(p->trapframe->epc); // tell trampoline.S the user page table to switch to. uint64 satp = MAKE_SATP(p->pagetable); // jump to trampoline.S at the top of memory, which // switches to the user page table, restores user registers, // and switches to user mode with sret. uint64 fn = TRAMPOLINE + (userret - trampoline); ((void (*)(uint64,uint64))fn)(TRAPFRAME, satp);}
这个函数中,我们会先关掉中断,然后把 stvec 从 kernelvec 改回 uservec。
接下来,为了之后用户态发生陷入时,能成功恢复内核态的一些上下文,会把一些寄存器的值存入 trapframe(见 uservec 的部分,uservec 会用到这些):
p->trapframe->kernel_satp = r_satp(); // kernel page tablep->trapframe->kernel_sp = p->kstack + PGSIZE; // process's kernel stackp->trapframe->kernel_trap = (uint64)usertrap;p->trapframe->kernel_hartid = r_tp(); // hartid for cpuid()
然后我们又重置了 sepc,因为从陷入返回时会根据这个寄存器的内容来重置 pc,以执行用户态陷入之后的程序。
函数的最后两行:
// jump to trampoline.S at the top of memory, which // switches to the user page table, restores user registers,// and switches to user mode with sret.uint64 fn = TRAMPOLINE + (userret - trampoline);((void (*)(uint64,uint64))fn)(TRAPFRAME, satp);
用这个骚操作跳到了 trampoline 页中的另一个函数 —— userret。
总结一下,usertrapret 干了下面的事情:
其实我感觉 stvec 和 sepc 这两个东西没必要在 usertrapret 中储存,它们本质也是恢复 trapframe 的数据。
基本上就是 uservec 的“反函数”,代码如下:
要注意这个函数是有两个参数的,即 trapframe 的地址和用户态页表的地址,按照 xv6 的函数调用规则,分别放在 a0 和 a1 寄存器。
userret: # userret(TRAPFRAME, pagetable) # switch from kernel to user. # usertrapret() calls here. # a0: TRAPFRAME, in user page table. # a1: user page table, for satp. # switch to the user page table. csrw satp, a1 sfence.vma zero, zero # put the saved user a0 in sscratch, so we # can swap it with our a0 (TRAPFRAME) in the last step. ld t0, 112(a0) # ld 之后 t0 储存用户的 a0 # 112(a0) 是用户的 a0 # 现在的 a0 是传进来的参数(trapframe)的地址 csrw sscratch, t0 # 交换 t0 和 sscratch,也就是 sscratch 储存用户 a0 # restore all but a0 from TRAPFRAME ld ra, 40(a0) ld sp, 48(a0) ld gp, 56(a0) ld tp, 64(a0) ld t0, 72(a0) ld t1, 80(a0) ld t2, 88(a0) ld s0, 96(a0) ld s1, 104(a0) ld a1, 120(a0) ld a2, 128(a0) ld a3, 136(a0) ld a4, 144(a0) ld a5, 152(a0) ld a6, 160(a0) ld a7, 168(a0) ld s2, 176(a0) ld s3, 184(a0) ld s4, 192(a0) ld s5, 200(a0) ld s6, 208(a0) ld s7, 216(a0) ld s8, 224(a0) ld s9, 232(a0) ld s10, 240(a0) ld s11, 248(a0) ld t3, 256(a0) ld t4, 264(a0) ld t5, 272(a0) ld t6, 280(a0) # restore user a0, and save TRAPFRAME in sscratch csrrw a0, sscratch, a0 # return to user mode and user pc. # usertrapret() set up sstatus and sepc. sret # 和 ecall 相对应
所以这个函数基本就是把所有的通用寄存器从 trapframe 里恢复了一遍,并且切换了页表,最后使用了 sret 指令。
注意这个 sret 指令,也和 ecall 指令一样,能同时做很多事情,具体的,有以下几个:
然后就可以愉快的继续执行用户态的程序了。
总结一下, userret 做了一下事情:
待更新(鸽)
现代操作系统通常会提供多线程的功能,也就是在表面上同时的执行多个任务。实现多线程主要有以下的原因[3]:
在实践中,我们通常会按照经过的时间让处理器运行不同的任务来实现多线程,或者说让处理器快速的在不同线程中切换以实现同步运行的假象。
实现多线程有上面列举的好处,也有很多困难,比如[3]:
下面会以一个具体的从一个用户进程切换到另一个用户进程的例子,来描述 xv6 中对多线程的实现。
下面这张来自 xv6 书中的图就大致的解释了 xv6 中进程切换的过程:
大部分进程切换的开始其实是一个硬件引发的计时器中断。在 xv6 中,我们会设置 rsicv 的处理器来产生计时器中断,也就是,每过一段时间,都会产生一个中断来提醒我们某个进程已经占用处理器够多的时间了,需要切换进程。
如果产生这个中断的时候,我们正在跑用户态的程序(也就是上图展示的),那么处理这个中断的函数就是 kernel/trap.c
中的 usertrap()
:
…… if(p->killed) exit(-1); // give up the CPU if this is a timer interrupt. if(which_dev == 2) // which_dev 为 2 代表产生中断的是计时器 yield(); usertrapret();}
如果我们发现产生中断的外部设备是计时器,就会调用下面的 yield()
函数:
// Give up the CPU for one scheduling round.voidyield(void){ struct proc *p = myproc(); acquire(&p->lock); p->state = RUNNABLE; sched(); release(&p->lock);}
这个 yield()
除了给进程上锁和解锁,就调用了 sched()
。
而 sched()
其实也是给 swtch()
套了层皮:
// Switch to scheduler. Must hold only p->lock// and have changed proc->state. Saves and restores// intena because intena is a property of this// kernel thread, not this CPU. It should// be proc->intena and proc->noff, but that would// break in the few places where a lock is held but// there's no process.voidsched(void){ int intena; struct proc *p = myproc(); if(!holding(&p->lock)) panic("sched p->lock"); if(mycpu()->noff != 1) panic("sched locks"); if(p->state == RUNNING) panic("sched running"); if(intr_get()) panic("sched interruptible"); intena = mycpu()->intena; swtch(&p->context, &mycpu()->context); mycpu()->intena = intena;}
函数中前面的一堆判断加 panic 其实都是一些合法性检查,我们先不用关注,主要看这里的 swtch()
函数。顺便提一嘴,这函数因为跟 c 语言关键重了,所以少了个 i(乐 。
swtch
函数是用汇编实现的,在 kernel/swtch.S
文件中,如下:
# Context switch## void swtch(struct context *old, struct context *new);# # Save current registers in old. Load from new. .globl swtchswtch: sd ra, 0(a0) sd sp, 8(a0) sd s0, 16(a0) sd s1, 24(a0) sd s2, 32(a0) sd s3, 40(a0) sd s4, 48(a0) sd s5, 56(a0) sd s6, 64(a0) sd s7, 72(a0) sd s8, 80(a0) sd s9, 88(a0) sd s10, 96(a0) sd s11, 104(a0) ld ra, 0(a1) ld sp, 8(a1) ld s0, 16(a1) ld s1, 24(a1) ld s2, 32(a1) ld s3, 40(a1) ld s4, 48(a1) ld s5, 56(a1) ld s6, 64(a1) ld s7, 72(a1) ld s8, 80(a1) ld s9, 88(a1) ld s10, 96(a1) ld s11, 104(a1) ret
可以看到,这个函数其实是把一些当前的寄存器储存在了 old->context
里面。然后读取 new->context
里读取了数据,并用这些值给寄存器赋值。
这个函数的实际作用是切换内核线程的上下文,也就是如一开始那张图所示的,从 kstack shell 切换到 kstack scheduler 的线程。
看到这里,你可能会感到很奇怪,既然 swtch()
函数切换的是不同线程的上下文,那为啥没有像 trapframe 一样,保存所有 32 个通用寄存器的值,而只保存了 14 个呢?
这是因为 s0-s11 在 xv6 的函数调用规则中,都是由被调用者保存的。而 32 个通用寄存器中剩下的那些,都是由调用者保存的。
也就是说,这些剩下的寄存器都是可以通过 sp 加上一些偏移量从栈中恢复的,我们也自然没有理由去保存它们。
关于具体的调用者和被调用者保存寄存器,可以参考下面 riscv 文档上截下来的图:
这里要特别注意的是储存和恢复的 ra 和 sp 寄存器。
其中 ra 寄存器表明了 swtch()
函数结束时会返回到哪个地址,而 sp
则表明了当前栈的位置。这意味着,在 swtch()
返回的时候,不会返回到 sched()
的最后一个语句,而是返回到 mycpu()->context.ra
指向的位置。
而 mycpu()->context
中的 ra 指向的是 scheduler()
函数的一个位置(和上图演示的切换过程一样):
// Per-CPU process scheduler.// Each CPU calls scheduler() after setting itself up.// Scheduler never returns. It loops, doing:// - choose a process to run.// - swtch to start running that process.// - eventually that process transfers control// via swtch back to the scheduler.voidscheduler(void){ struct proc *p; struct cpu *c = mycpu(); c->proc = 0; for(;;){ // Avoid deadlock by ensuring that devices can interrupt. intr_on(); for(p = proc; p < &proc[NPROC]; p++) { acquire(&p->lock); if(p->state == RUNNABLE) { // Switch to chosen process. It is the process's job // to release its lock and then reacquire it // before jumping back to us. p->state = RUNNING; c->proc = p; swtch(&c->context, &p->context); // 返回的是这里 // Process is done running for now. // It should have changed its p->state before coming back. c->proc = 0; } release(&p->lock); } }}
那为啥返回的是这里呢?我们可以看 kernel/main.c
的内容:
#include "types.h"#include "param.h"#include "memlayout.h"#include "riscv.h"#include "defs.h"volatile static int started = 0;// start() jumps here in supervisor mode on all CPUs.voidmain(){ if(cpuid() == 0){ consoleinit(); printfinit(); printf("\n"); printf("xv6 kernel is booting\n"); printf("\n"); kinit(); // physical page allocator kvminit(); // create kernel page table kvminithart(); // turn on paging procinit(); // process table trapinit(); // trap vectors trapinithart(); // install kernel trap vector plicinit(); // set up interrupt controller plicinithart(); // ask PLIC for device interrupts binit(); // buffer cache iinit(); // inode table fileinit(); // file table virtio_disk_init(); // emulated hard disk userinit(); // first user process __sync_synchronize(); started = 1; } else { while(started == 0) ; __sync_synchronize(); printf("hart %d starting\n", cpuid()); kvminithart(); // turn on paging trapinithart(); // install kernel trap vector plicinithart(); // ask PLIC for device interrupts } scheduler(); // 注意这里}
在初始化工作完成后,第一个执行的函数就是 scheduler()
。那在 scheduler()
函数中,我们找到了一个 RUNNABLE 的进程,然后执行了 swtch(&c->context, &p->context);
。
这个时候的 sp 寄存器和 ra 寄存器指向的自然是 scheduler()
函数,所以 mycpu()->context
中的 ra 也是 scheduler()
中 swtch()
后面的地址。
这个感觉就很奇妙,像是一个传送门和 “时光机”,相当于我们在某个地方调用了 swtch()
后,返回的是另一个地方很久(对计算机来说)之前调用 swtch()
的地方。或者说,这个函数的调用和返回是分离开的,我们调用的 swtch
,一定是通过另一个地方调用的 swtch
返回的[4]。
在 sched()
函数调用 swtch()
后,我们会从 scheduler()
函数调用 swtch()
的后面开始,继续执行 scheduler()
函数。这个函数的主要用处就是找到一个 RUNNABLE 的进程,然后执行 swtch()
。
在 swtch()
之前,会先执行下面的操作:
p->state = RUNNING;c->proc = p;
也就是把进程结构体的状态改成 RUNNING,以及把 mycpu()
的 proc
属性改成 p
。
这样我们在切换进程后调用 myproc()
就能得知当前处理器正在执行的进程。如下:
// Return the current struct proc *, or zero if none.struct proc*myproc(void) { push_off(); struct cpu *c = mycpu(); struct proc *p = c->proc; pop_off(); return p;}
其实就是返回了处理器上下文中的 proc 属性。
和前面讲的一样,swtch()
像是一个传送门,这个函数的调用和返回是分开的,调用后,会返回另一个地方之前调用 swtch()
的位置。在 scheduler()
函数中,这个位置就是进程 p 的 sched()
:
// Switch to scheduler. Must hold only p->lock// and have changed proc->state. Saves and restores// intena because intena is a property of this// kernel thread, not this CPU. It should// be proc->intena and proc->noff, but that would// break in the few places where a lock is held but// there's no process.voidsched(void){ int intena; struct proc *p = myproc(); …… intena = mycpu()->intena; swtch(&p->context, &mycpu()->context); mycpu()->intena = intena; // 返回后从这里接着执行。}
总结一下,我们在 sched()
中调用的 swtch()
会返回到 scheduler()
中。相应的,在 scheduler()
中调用 swtch()
会返回到 sched()
中。
这样发生定时器中断后,就会到 scheduler()
中找到可用进程。然后通过 swtch()
把这个可用进程的上下文恢复出来。
这样就可以大致的把进程的切换和调度过程搞清楚了,不过还有一些小细节没有提到。
我们可以注意到, yield()
函数,和 scheduler()
函数,都做了锁相关的操作。那这么做的原因是什么呢?
我们可以先梳理一下 yield()
和 scheduler()
中锁操作的过程。
首先,scheduler()
会给 p->lock()
加锁,然后调用 swtch()
来切换上下文。之后 sched()
会返回到 yield()
函数中,而这个函数会释放 p->lock
。
如果发生了定时器中断,那么 yield()
会给进程加锁,随后在 sched()
中调用 swtch()
,返回到 scheduler()
的 swtch()
函数。随后释放进程锁。
和 swtch()
函数相似,进程锁的加锁和释放不在同一个函数中。如果 yield()
给进程加了锁,那一定是 scheduler()
来释放的,反之,如果 scheduler()
加了锁,那一定是 yield()
来释放的。
可以发现进程加锁和解锁的这个区间正是处理器切换上下文的区间。这主要是因为,在进程切换的过程中,线程结构体处于一种不稳定的状态[4]。
比如,我们在 yield()
中把状态标记为了 RUNNABLE,但实际上还没执行 scheduler()
把这个进程切换出去。那如果正好有另一个核心正在执行 scheduler()
,寻找 RUNNABLE 的进程,并且发现了当前这个进程,就会有两个处理器同时执行一个进程,这显然是一个严重的错误。
但是加锁后,如果别的核心刚好遇到了这个没切换完的 RUNNABLE 的进程,也不会执行它,因为在 scheduler()
中,我们会试图去得到进程锁,所以在进程真正完成切换前,是会一直阻塞下去的。
同时,加锁和解锁的操作也关闭了中断。这样就避免了我们正在切换进程时,又发生了一个计时器中断(应该不太可能吧)。
通过之前的代码,可以发现,在 scheduler()
中调用 swtch()
会跳转到 sched()
中,是因为这个进程之前因为定时器中断,执行过 sched()
中的 swtch()
,而现在这个跳转,实际上是 shced()
中 swtch()
的返回。
但对于第一个进程,或者说刚刚被创建出来的进程来说,以前并没有发生定时器中断。并且我们在 main.c
执行完初始化后就执行了 scheduler()
,那么 scheduler()
中的第一次 swtch()
会切换到哪里呢?
这就需要看 allocproc()
函数中的内容了:
// Look in the process table for an UNUSED proc.// If found, initialize state required to run in the kernel,// and return with p->lock held.// If there are no free procs, or a memory allocation fails, return 0.static struct proc*allocproc(void){ …… // Set up new context to start executing at forkret, // which returns to user space. memset(&p->context, 0, sizeof(p->context)); p->context.ra = (uint64)forkret; // 注意这里 p->context.sp = p->kstack + PGSIZE; return p;}
可以看到,进程刚刚被创建的时候,ra 被设成了 forkret。也就是说,第一次被 scheduler()
找到并执行的时候,swtch()
不会跳转到 sched()
中的 swtch()
而是跳转到 forkret()
中。
forkret()
干的事情很简单,其实就是直接返回到用户空间:
// A fork child's very first scheduling by scheduler()// will swtch to forkret.voidforkret(void){ static int first = 1; // Still holding p->lock from scheduler. release(&myproc()->lock); if (first) { // File system initialization must be run in the context of a // regular process (e.g., because it calls sleep), and thus cannot // be run from main(). first = 0; fsinit(ROOTDEV); } usertrapret(); // 注意这里是返回用户空间}
]]>求最大公约数 。
有:
辗转相除法指出:
我们设
为 和 的任意一个公约数。
以及
那么:
因为 。
所以我们可以得出,如果 和 有任意一个公因数 ,这个公因数就一定会是 和 的公因数( 的情况除外,如果 ,那么最大公约数就是 了)。
也就是如果我们设 为 和 的公约数集合, 为 和 的公因数集合,那么 。
但是这还不足以证明 ,因为有可能 中有比 更大的数字。
但如果我们证明了 ,我们就可以证明 ,这样 中就绝对没有比 更大的数了。
我们设 为 中的任意一个数字。
那么有
再回到这个式子上,带入 和 。
说明,,或者说 中的任意一个数字 也在 中。也就是 。
所以 。
把辗转相除法写成程序的话就是下面这样,非常的简洁:
int gcd(int a, int b){ if(b) return gcd(b, a % b); else return a;}
参考资料:
在扩展欧几里得算法中,我们尝试找出方程:
的一个解。
下面是一个辗转相除法计算过程的例子,它计算的是 ,最后的结果是 :
我们可以从这个过程中推出 的一组解。
首先从过程的倒数第二步,也就是 开始看。把这个式子变换一下,变成:
按照相同的方法,也就是 变换辗转相除法的前面几个步骤,可以得到:
再把这些式子带到 中,可以发现,我们能把式中的 替换成 和 的和。
现在这个式子就变成 了,其中 。
进一步替换式中的 为 和 的和,式子也就成了 。
而这个 被替换为 和 的和,最终的式子就成为了。
其中
这正是我们想要的答案。
看的出来 exgcd 有点像是辗转相除法的逆向过程。它利用辗转相除法的计算过程,推出了 的一个解。
现在我们来尝试来推广一下刚刚观察到的规律。首先我们想求的是:
因为 。而 也可以被写成 的形式,就是。
注意虽然这里的 是和 一样的。也就是。
这两个式子的形式是一样的,都是 ,但是它们中的 和 不一样,所以解出来的 和 是不一样的。假设我们已经知道了。解出来的 和 ,那么只要知道如何从 和 中计算出来 和 ,就能递归的求解 和 了,
而我们可以化简 。
可以发现,假设我们已经求出了 的解 ,那么原式 中的 ,而 。这样我们就可以递归的求解了。
而边界条件和普通辗转相除法相似,是 。那么。
虽然这里的 随便怎么搞都可以,但是我们一般返回的是 。
下面是代码(用的是 c++20 的标准):
template<typename T>concept Integral = std::is_integral<T>::value;// gcd, x, ytemplate<Integral T>tuple<T, T, T> ex_gcd(T a, T b){ if (b == 0) { return {a, 1, 0}; } auto [gcd, x2, y2] = ex_gcd(b, a % b); //从 x2, y2 推出 x 和 y T x = y2; T y = x2 - (a / b) * y2; return {gcd, x, y};}
参考资料:
的乘法逆元定义为 的解 。
乘法逆元有点像是模意义下的相反数。
在 和 互质的情况下,我们可以使用 exgcd 解决这个问题。
因为 和 互质,所以:
那么扩展欧几里得可以解决:
我们稍微把 变一下形:
如果我们让 ,那么就得到了。
但是 中的 和 中可能有一个是负数,如果 是负数,那没问题,但如果 是负数,我们得到的答案就不是所有可行的 中最小的正整数了。
观察 这个式子,我们可以给 加 的倍数,让式子变成 (注意 是负数, 会被抵消掉)。这样就可以在不改变 的情况下把 变成正数。
所以我们可以这么写:
x = (x % b + b) % b;
我们假设 是一个负数。
注意这里第一个的 x % b
的作用是先给 加上一些 ,让它变成符合条件的最大的负数。比如 是 , 是 。我们让 x = x % b
, 就变成了 ,相当于把 加上了 。
后面的 +b
就是让这个符合条件的最大负数变成符合条件的最小正数。比如 。那么最后的这个 % b
有什么用呢?
这个是为了应对 为正数的情况,我们可以通过给 减去一些 ,让其变成符合条件的最小正数。
然后对于乘法逆元的模板题,可以写出如下的代码:
int n, p;template<typename T>concept Integral = std::is_integral<T>::value;// gcd, x, ytemplate<Integral T>tuple<T, T, T> ex_gcd(T a, T b){ if (b == 0) { return {a, 1, 0}; } auto [gcd, x2, y2] = ex_gcd(b, a % b); T x = y2; T y = x2 - (a / b) * y2; return {gcd, x, y};}int main() { cin>>n>>p; for(int i = 1; i <= n; i++){ auto[gcd, x, y] = ex_gcd(i, p); x = (x % p + p) % p; cout<<x<<endl; }}
需要注意的是,因为 的数据规模和 的时限,用 的算法是过不了的,需要用下面讲的线性算法。
参考资料:
线性递推的方法可以让我们在 的时间内求出 中所有整数在模质数 下的乘法逆元。
注意如果要求出 的范围中的所有整数, 必须是质数,因为 的这个区间中可能有很多非质数,要保证 中的数和 互质,只能确保 为质数。
因为这是一个递推算法,所以需要有初始条件。
不难发现 在模任何整数意义下的逆元都是 本身(因为 )。所以我们有了初始条件。
假设我们现在已经递推到了数字 。
设 ,那么:
转化为同余方程可以得到:
记 分别为 在模 意义下的乘法逆元。把 同时乘到同余式,可得:
展开,得:
移项,得:
因为 ,所以 。而 。
带入 后,得:
考虑前面用 exgcd 解的时候提到的 ,其中的 可能是负数。所以我们用了这个方法让他变成最小的正整数解。
x = (x % b + b) % b;
现在的 也是一样的,我们可以将其写成 的形式。并且这个 也是负数,于是就可以用相同的方法确保我们得到的 是最小的正整数解。然后就可以写出如下的模板题代码:
ll inv[MAXN];template <typename T>concept Int_t = is_integral<T>::value;template <Int_t T>inline T mod_norm(T val, T m) { return (val % m + m) % m;}int main() { ios::sync_with_stdio(0); cin.tie(0); ll n, p; cin >> n >> p; inv[1] = 1; cout << inv[1] <<'\n'; for (int i = 2; i <= n; i++) { inv[i] = mod_norm(-p / i * inv[p % i] % p, p); cout << inv[i] <<'\n'; }}
参考资料:
]]>这题是真的难想,我 cf 的题解看了好久才搞明白(我太菜了)。
给你一个长度为 的排列 ,请你找出有多少个相同长度的排列 和 相似。
如果对于所有区间 ,下面的条件满足:
我们就称排列 和 是相似的。
其中 对于数组 的定义是:最小的,没有出现在 中的非负整数 。
例如 。
由于答案可能会很大,所以输出时需要打印答案模 的结果。
直接想答案可能比较难,可以先模拟一下给出的样例,尝试构造出一些 。
我们从 这个数字开始考虑样例。可以发现在 中, 的位置一定和 中的 的位置相等。
为什么呢?我们定义在 中数字 出现的位置为 ,比如 (下标从 开始)。
这时,选择 的区间对比 和 的 是否相等,首先,在 中,因为 , 的 一定等于 。
如果 的 的位置被改变了,那么在 中,因为 ,,就一定等于 了。
所以我们可以推断出 的位置是不能变的。
我们还能推出, 的位置也是不能改变的。
可以考虑 和 这两个区间的 值。
因为 的存在, 一定大于 ,而因为有 并且没有 一定等于 。
架设我们在 中改变了 的位置,比如改到了 ,那么在 中, 就大于 了,不符合 中等于 的情况。
现在考虑能合法放置 的位置。我们可以推断出,如果在 中 ,那么在 中, 就可以被放在 这个区间的任意位置。
为啥呢?我们设区间 在 中包含 和 ,也就是 。
那因为在 中 ,a 中所有 的 一定大于 (也就是说,在 中,一个区间如果同时包含 和 ,就一定会包含 )。
同时,在 中,不符合 的其他所有区间的 最大只有 ,(这样的区间最多包含一个 ,那么 就是 了)。
那么在 中,只要 ,就能符合 。并且符合在不动其他数字的位置的情况下 和 相似。
符合这样的位置一共有 个( 是因为 和 占用了区间内的两个位置)。
那么如果在 中, 呢?
比如
我们可以推断出,和前面讲的 和 一样,这种情况下,我们需要在 中把 放到相同的位置上。
考虑 这个区间,其 一定大于 。而 这个区间的 就一定等于 (包含 ,)。
假设我们把 放到 上,那么 的 就大于 了。
在 的情况下。
我们可以把 放在 的区间内。因为只有这样,才能满足所有的 ,其中,,并且除 之外的所有区间, 都小于 。
也就是说,在 中,如果一个区间包含了所有比 小的数,就一定会包含 。或者说,在 中,不可能会有一个区间的 等于 ,而为了满足这一点,我们需要让 。
我们设 为 , 为 。符合上面 条件的位置就有 个( 是因为区间中已经有 了)。
现在我们可以从刚刚的观察中推广一下。我们刚刚发现如果在 中,一个数在所有比他小的数的中间,那这个数就有很多位置可以放,相反,如果在所有比他小的数的外面,那就只能放在一个位置。
我们设正在考虑的数为 , 为 , 为 ,如果 在 这个区间外面,那么 就只能放在 上,否则,可以选择 中任意一个没被占用的地方放置。
我们设每个数能选的位置的数量为 ,那么最终的答案就是所有的 乘起来,也就是
在具体实现的时候,可以从 开始一个一个的考虑,这样可以很方便的确定前面提到的 和 ,以及 区间内,被占用的位置的数量。
#include <bits/stdc++.h>using namespace std;#define ll long long// keywords:const int MOD = 1e9 + 7;int main() { int t; cin >> t; while (t--) { int n; cin >> n; int a[n + 1]; int pos[n + 1]; for (int i = 0; i < n; i++) { cin >> a[i]; pos[a[i]] = i; } ll l = pos[0], r = pos[0]; ll ans = 1; for (int i = 1; i < n; i++) { // l 和 r 就是之前讲的 x, y if (pos[i] < l) l = pos[i]; else if (pos[i] > r) r = pos[i]; //如果在 x, y 的外面 else ans = ans * (r - l + 1 - i) % MOD; } cout << ans << endl; }}
最后,希望这篇题解对你有帮助,如果有问题可以通过评论区或者私信联系我。
]]>大概一年多之前看了《复杂》这本书,最近因为一个比赛又想起了书里面介绍过的遗传算法,书里提供了详细的思路,所以想自己实现以下。
关于《复杂》这本书:不过多介绍内容,看完了只感觉非常牛逼,下面的介绍摘自豆瓣,自己加了一些标点符号:
蚂蚁在组成群体时为何会表现出如此的精密性和具有目的性?数以亿计的神经元是如何产生出像意识这样极度复杂的事物?是什么在引导免疫系统、互联网、全球经济和人类基因组等自组织结构?这些都是复杂系统科学尝试回答的迷人而令人费解的问题的一部分。
理解复杂系统需要有全新的方法。需要超越传统的科学还原论,并重新划定学科的疆域。借助于圣塔菲研究所的工作经历和交叉学科方法,复杂系统的前沿科学家米歇尔以清晰的思路介绍了复杂系统的研究,横跨生物、技术和社会学等领域,并探寻复杂系统的普遍规律,与此同时,她还探讨了复杂性与进化、人工智能、计算、遗传、信息处理等领域的关系。
书中讲的问题大概是这样的:
有一个机器人罗比,它生活在一个 的网格中。这个网格中散落着许多易拉罐,罗比需要在有限的动作中收集尽可能多的易拉罐。罗比的初始位置在 它只能看到自己周围的四个格子,和自己所在的格子的情况,每个格子一共有三种可能,有易拉罐,无易拉罐,和墙。罗比能做的动作有七种:向四个不同方向移动,随机移动,捡起罐子,和不动。
首先需要确定我们想要进化的是什么,因为罗比只能看到周围的格子,然后根据这个几个格子的情况做出动作,所以这就是我们想进化的策略。我们可以把不同的策略表示成一个字符串的形式,它的长度为 的,包含从 的数字。其中 表示的是七种不同的动作,而这个长为 的字符串中的每一个位置代表了罗比看到的不同的情形。其中, 中的 代表的是三种不同的格子, 就是罗比能看到的格子的数量。
这个字符串就代表了一个情形到动作的映射,每次罗比看到周围五个格子后可以检查这个映射,然后做出动作。而我们要进化的就是这个字符串,或者说,基因。
不过呢,在实际实现的时候,我用了一个 map
来实现这个映射,绝对不是我懒得写字符串的这些处理。
适应度是我们用来衡量不同策略好坏的,在遗传算法中,合理的适应度可以加速进化的过程。书中给出的适应度的计算方法是这样的:
捡到一个罐子 | 撞墙 | 没有罐子却做了捡罐子的动作 |
---|---|---|
首先需要随机的生成一个初始群体,书中给出的是 个个体。
然后计算群体中每一个个体的适应度,根据适应度(适应度越高的越容易被选中)让两个不同的基因“繁殖”。为了这个策略的适用性,这里的适应度会随机出很多个地图,然后在每个地图中都计算适应度,最后去平均。
产生下一代。“繁殖” 的实现参考了生物学中的染色体交换(chromosomal crossover),大概是下图的样子:
也就是说随机的选择一个中间点,子代基因的前半部分来自一个父代,后半部分的来自另外一个。此外,还可以在染色体交叉的过程中引入变异,让子代的基因有一定概率发生变化,这也是为了给我们的基因池引入更多的变化。
const int MAP_SIZ = 10; //地图大小const float CAN_RATE = 0.5; //是罐子的概率//地图设置const int SUC_CLCT_PT = 10; //成功捡起罐子的适应度变化const int ERR_CLCT_PT = -1; //没有罐子却做了捡罐子的动作const int HIT_WALL_PT = -5; //撞墙//适应度设置const int MOV_LIM = 200; //一共能做多少动作const int POP_CNT = 500; //群体数量const int GEN_CNT = 1000; //代数const float MUT_RATE = 0.005; //每一个位点的变异概率const int MAP_REP = 50; //计算适应度时用多少个地图//演化的一些设置enum GRD_DIR { DIRNONE = -1, CUR, UP, DN, RT, LF }; // 不同的方向const int DIR_CNT = 5;enum GRD_OBJ { OBJNONE = -1, EPT, WAL, CAN }; // 不同类型的格子const int OBJ_CNT = 3;enum ACTION { // 罗比的不同动作 ACTNONE = -1, MV_UP, MV_DN, MV_RT, MV_LF, MV_RND, CLCT_CAN, HALT};const int ACTION_CNT = 7;
Obj_in_dir
定义了一个格子相对于罗比的方向和格子的类型。其中,重载的小于号主要用于 map
(map
的内部实现是红黑树,也是一种查找树,所以需要对比大小)。
其中定义了几种构造函数,比较有用的是第二个,也就是通过罗比所在的坐标和此格相对于罗比的方向来初始化。
typedef vector<vector<bool>> Map_t;Map_t cur_map;inline bool is_wall(int x, int y, Map_t* mp) { auto [n, m] = make_pair((*mp).size(), (*mp).front().size()); if (x >= n || x < 0 || y >= m || y < 0) return true; else return false;}inline GRD_OBJ get_obj_inpos(int x, int y, Map_t* mp) { GRD_OBJ obj; int n = mp->size(); int m = mp->front().size(); if (is_wall(x, y, mp)) obj = WAL; else if ((*mp)[x][y]) obj = CAN; else if (!(*mp)[x][y]) obj = EPT; return obj;}struct Obj_in_dir { GRD_DIR dir; GRD_OBJ obj; const bool operator<(Obj_in_dir b) const { if (dir != b.dir) return dir < b.dir; return obj < b.obj; } const bool operator==(Obj_in_dir b) const { return dir == b.dir && obj == b.obj; } const bool operator!=(Obj_in_dir b) const { return dir != b.dir || obj != b.obj; } Obj_in_dir(GRD_DIR _dir, GRD_OBJ _obj) : dir(_dir), obj(_obj) {} Obj_in_dir(int x, int y, GRD_DIR _dir, Map_t* mp) { dir = _dir; switch (dir) { case CUR: obj = get_obj_inpos(x, y, mp); break; case UP: obj = get_obj_inpos(x - 1, y, mp); break; case DN: obj = get_obj_inpos(x + 1, y, mp); break; case RT: obj = get_obj_inpos(x, y + 1, mp); break; case LF: obj = get_obj_inpos(x, y - 1, mp); break; } } Obj_in_dir() { dir = DIRNONE; obj = OBJNONE; }};
Srndng
也就是 Surrounding,代表了罗比周围的情形,之后我们会定义一个 map
,把罗比周围的情形映射到一个动作上,而这个 map
就代表了不同的策略或者说基因。
其中注意第一个构造函数,传入坐标和地图的指针后,就能初始化罗比当前看到的情形。
struct Srndng { Obj_in_dir objs[5]; const bool operator<(Srndng b) const { for (int i = 0; i < 5; i++) { if (objs[i] != b.objs[i]) return objs[i] < b.objs[i]; } return false; } Srndng(int x, int y, Map_t* mp) { for (int i = 0; i < 5; i++) { objs[i] = Obj_in_dir(x, y, GRD_DIR(i), mp); } } Srndng() { for (int i = 0; i < 5; i++) objs[i].dir = DIRNONE, objs[i].obj = OBJNONE; }};
下面是一些被重命名的类型:
typedef map<Srndng, ACTION> Gene_t; // 把情形映射到动作,就是我们定义的基因 typedef pair<Gene_t, float> Gene_res_t; // gene result type 基因和其对应的适应度typedef vector<Gene_res_t> Gene_pool_t; // 一个群体的基因池
地图生成器,地图指针传进来之后需要先 resize
,然后根据之前定义的罐子出现的概率,生成地图。
void mp_generator(Map_t* mp, int n = MAP_SIZ, int m = MAP_SIZ) { srand(time(0)); mp->resize(n); for (int i = 0; i < n; i++) { (*mp)[i].resize(m); } for (auto& row : *mp) { for (auto&& unit : row) { //注意这里用两个 && 是因为 unit 是布尔类的 //而这里的 && 是一个右值引用(右值是不能被取地址的) //所以我们改变了 unit,那 mp 这个地图里的值也会改变 unit = (rand() * 1.0 <= CAN_RATE * RAND_MAX); } }}
主要用于产生第一代的个体
这里我用了递归的方式去生成这个基因。实际上就是去枚举罗比可能会遇到的不同情形,如果发现一种情形已经生成完了(周围的格子都被确定了)那就直接随机一个动作出来。
void gene_generator_once(Gene_t* ret_gene, Srndng* ret_srndng, GRD_DIR cur_dir) { if (cur_dir >= DIR_CNT) { //发现已经枚举完一种情形了,就随机生成一个动作 (*ret_gene)[*ret_srndng] = ACTION(rand() % ACTION_CNT); return; } for (int i = 0; i < OBJ_CNT; i++) { (*ret_srndng).objs[cur_dir] = Obj_in_dir(GRD_DIR(cur_dir), GRD_OBJ(i)); gene_generator_once(ret_gene, ret_srndng, GRD_DIR(cur_dir + 1)); }}
首先先随机出一个合并点,这个点前的基因来自 ,后面的来自 ,然后直接把父母基因根据这个合并点复制到子代基因上。
如前文,复制过程中可以模拟基因的变异,所以我们要根据前面定义的变异概率随机一下,然后判断是否变异。
void gene_combine(Gene_t* pa, Gene_t* pb, Gene_t* child) { int cmb_pos = round(double(rand() * 1.0 / RAND_MAX * 1.0) * double(pa->size())); int cur_idx = 0; for (auto [key, val] : *pa) { // pa 是一个map,这里的语法是结构化绑定,key 就是 map 里 pair 的 .first,val 就是 .second if (cur_idx > cmb_pos) break; // 合并点前的都是 pa,反之亦然 if ((rand() * 1.0 / RAND_MAX * 1.0) <= MUT_RATE) // 判断是否变异 (*child)[key] = ACTION(rand() % (ACTION_CNT));// 变异的话直接给他随机一个动作 else (*child)[key] = val; cur_idx++; } cur_idx = 0; for (auto [key, val] : *pb) { if (cur_idx > cmb_pos) { if ((rand() * 1.0 / RAND_MAX * 1.0) <= MUT_RATE) (*child)[key] = ACTION(rand() % (ACTION_CNT)); else (*child)[key] = val; } cur_idx++; }}
这就没啥好解释了,接受罗比当前的坐标,和准备要做的动作,输出一个移动之后的坐标。因为有两种动作不是移动,所以如果接受到这样的参数就会抛出一个 invalid_argument
。
inline pair<int, int> get_pos_after_mv(int x, int y, ACTION mv) { switch (mv) { case MV_UP: return {x - 1, y}; break; case MV_DN: return {x + 1, y}; break; case MV_LF: return {x, y - 1}; break; case MV_RT: return {x, y + 1}; break; case MV_RND: return get_pos_after_mv(x, y, ACTION(rand() % 4)); break; case CLCT_CAN: //捡起易拉罐 throw invalid_argument("not a move"); return {x, y}; break; case HALT: throw invalid_argument("not a move"); return {x, y}; break; }}
直接模拟罗比的移动就好了,需要注意的是,如果罗比撞墙了,我们需要把它弹回来。
inline bool is_mov(ACTION act) { return act <= 4; }int calc_fitness(Gene_t* gene, Map_t* mp) { int cur_x = 0, cur_y = 0; int fit = 0; for (int cur_mov = 1; cur_mov <= MOV_LIM; cur_mov++) { Srndng cur_srnd(cur_x, cur_y, mp); // 传入罗比的坐标和当前地图,来确定罗比周围的情形 ACTION cur_act = (*gene)[cur_srnd]; // 根绝这个基因,获取应作的动作 if (is_mov(cur_act)) { // 如果这个动作是会移动的,就计算移动之后的位置 tie(cur_x, cur_y) = get_pos_after_mv(cur_x, cur_y, cur_act); // 这里的 tie 其实跟结构化绑定是差不多的,但是好像 // 这里的结构化绑定只能写成 auto[cur_x, cur_y] = funct() // 这样就只能新建两个变量了,如果你知道如何不新建变量的结构化绑定 // 可以在评论区说下 } if (is_wall(cur_x, cur_y, mp)) { fit += HIT_WALL_PT; //撞墙了 auto [n, m] = make_pair((*mp).size(), (*mp).front().size()); //把罗比弹回来 if (cur_x < 0) cur_x = 0; if (cur_y < 0) cur_y = 0; if (cur_x >= n) cur_x = n - 1; if (cur_y >= m) cur_y = m - 1; } else if (cur_act == CLCT_CAN) { if ((*mp)[cur_x][cur_y]) { // 如果有罐子还捡了 fit += SUC_CLCT_PT; (*mp)[cur_x][cur_y] = false; // 这里需要标注罐子已经被捡了 } else fit += ERR_CLCT_PT; } } return fit;}
基本就是把前面的单个基因套了个壳
void gene_generator(Gene_pool_t* pool, int cnt) { while (cnt--) { Gene_t temp_gene; Srndng temp_srnd; gene_generator_once(&temp_gene, &temp_srnd, GRD_DIR(0)); pool->push_back({temp_gene, 0}); }}
传入两个参数,数组中每个元素(或者说下标)的权重,和要选择的元素。
这个东西的思路主要是这样的,我们知道 rand()
函数会产生一个 的均匀分布的随机数。那么我们只要根据给定的权值规定好每个下标对应的范围,如果 rand()
给的值是这个范围内的,就选择这个元素。
比如,假设我们的 是 ,然后 possi
数组等于 那么就可以得出下面的映射范围:
这样,权重高的元素就有更大的概率被选中。
接下来要把每个下标对应的范围放入一个 map
中。我们定义下标 的映射范围的下界为 ,比如在上面的例子中 。我们在这个 map
中就可以建立一个 的映射。
如在上面的中,这个映射就是。
接下来,如果我们用 rand()
函数得出了一个随机值 ,我们就可以用 map
的 upper_bound(key)
函数找到第一个这个 map
中键值大于 的位置,那么这个位置的前一个位置就是我们需要的下标。
举个例子,如果我们的 ,那么,根据上面的映射,第一个大于这个数的键就是 ,而 的上一个就是 ,对应的值是 ,所以我们选中了第二个元素。
根据前面下标到范围的映射, 这个下标对应了 ,我们的 ,所以确实应该选 这个下标
vector<int> choose_by_weight(vector<float>& possi, int cnt) { vector<int> ret; ret.reserve(cnt); double tot = 0; for (float cur : possi) { tot += cur; //计算权值的和 } map<int, int> choose_rg; int lst = 0; for (int i = 0; i < possi.size(); i++) { int len = lround(possi[i] * 1.0 / tot * 1.0 * (RAND_MAX * 1.0)); //计算这个范围的长度 if (len == 0) continue; choose_rg[lst] = i; lst = lst + len; } choose_rg[IINF] = possi.size(); while (ret.size() < cnt) { // 选 cnt 个,放到 ret 里 int rd = rand(); int rd_idx = (--choose_rg.upper_bound(rd))->second; // 用 upper_bound() 找到第一个比 key 大的,然后找这个前面的元素 ret.push_back(rd_idx); // 再把这个元素对应的值 push 进去 } return ret;}
我们开两个群体的类(Gene_pool_t
),其中一个代表当前的,还有一个是子代的。
和前面讲的一样,我们首先计算群体中每一个个体的适应度,然后根据适应度选出父母,繁殖出下一代,把这个过程重复 次,就能得到一个不错的策略了。
void evolve(int cur_gen) { if (cur_gen != 1) { temp_pool.clear(); // 新的一代是放进 temp 里的 } for_each(cur_pool.begin(), cur_pool.end(), [](Gene_res_t& a) { a.second = 0; }); // 重置 cur_pool 的适应度 for (int m = 0; m < MAP_REP; m++) { mp_generator(&cur_map); // 重置地图 Map_t temp_map = cur_map; // 因为计算适应性时,会影响生成的地图(比如捡起一个罐子),所以现在先复制一下。计算另一个个体时再复制回去。 for (int i = 0; i < POP_CNT; i++) { cur_pool[i].second += calc_fitness(&cur_pool[i].first, &cur_map); cur_map = temp_map; } } for (auto& res : cur_pool) { res.second /= (MAP_REP * 1.0); // 取平均 } //计算出池中每个基因的概率 float tot_fit = 0; float mx_fit = numeric_limits<float>::min(); for (auto cur : cur_pool) { tot_fit += cur.second; mx_fit = max(mx_fit, cur.second); } fileout << mx_fit << ","; cout << cur_gen <<" "<<mx_fit<<"\n"; sort(cur_pool.begin(), cur_pool.end(), [](Gene_res_t& a, Gene_res_t& b) { return a.second < b.second; }); vector<float> possi; const float TOT_ELE = (0.0 + (POP_CNT - 1) * 1.0) * POP_CNT * 1.0 / 2.0; for (int i = 0; i < cur_pool.size(); i++) { possi.push_back(i * 1.0 * sqrt(i * 1.0)); //每个基因的权重,如果适应性越高,权重也应该更高,这个权重的函数可以自己改 } auto chosen = choose_by_weight(possi, POP_CNT * 2); temp_pool.clear(); while (chosen.size()) { int fir = chosen.back(); chosen.pop_back(); int sec = chosen.back(); chosen.pop_back(); Gene_t child; gene_combine(&cur_pool[fir].first, &cur_pool[sec].first, &child); //产生下一代 temp_pool.push_back({child, 0}); } swap(cur_pool, temp_pool);}int main() { fileout.open("./out"); gene_generator(&cur_pool, POP_CNT); //创建初始基因 for (int i = 1; i <= GEN_CNT; i++) { evolve(i); } system("python ./plotting.py"); // 最后画图 }
下图使用 matplotlib 画出,源码如下:
from matplotlib import lines, pyplot as pltimport csvGEN_CNT = 1000FONTSIZ = 23plt.rcParams['font.sans-serif'] = ['SimHei']plt.rcParams['axes.unicode_minus'] = Falsex = []for i in range(GEN_CNT): x.append(i)y = []with open(".//out", 'r') as csvfile: result = csv.reader(csvfile, delimiter=',') for row in result: for col in row: y.append(float(col)) print(y)mxfit = 0.0for cur in y: mxfit = max(mxfit, cur) plt.figure(figsize = (20, 40.0/3.0));plt.yticks(fontproperties = 'Iosevka', size = 20)plt.xticks(fontproperties = 'Iosevka', size = 20)plt.plot(x, y)plt.hlines(mxfit, 0, 1000, colors='g', linestyles="dashed", label="最大适应度=" + str(mxfit))plt.xlabel("代数", fontsize = FONTSIZ)plt.ylabel("每代最大适应度", fontsize = FONTSIZ)plt.legend(fontsize = FONTSIZ)plt.savefig(fname="ga_result.svg",format="svg")plt.show()
可以看到虽然有波动,但是整体的趋势还是上升的。最好策略的适应度也达到了 ,这是一个非常理想的分数,因为地图中平均就只有 个罐子,这个 可能是随机出来的地图刚好有比较多的罐子,然后一共捡了 个。
/*Date: 22 - 06-29 00 38PROBLEM_NUM: */// #define FDEBUG#if (defined FDEBUG) && (!defined ONLINE_JUDGE)#define DEBUG(fmt, ...) fprintf(stderr, fmt, ##__VA_ARGS__)#define DWHILE(cnd, blk) \ while (cnd) blk#define DFOR(ini, cnd, itr, blk) \ for (ini; cnd; itr) blk#else#define DEBUG(fmt, ...)#define DWHILE(cnd, blk)#define DFOR(ini, cnd, itr, blk)#endif#include <bits/stdc++.h>using namespace std;#define ll long long#define pause system("pause")#define IINF 0x3f3f3f3f#define rg register// keywords:const int MAP_SIZ = 10;const float CAN_RATE = 0.5; //是罐子的概率//地图设置const int SUC_CLCT_PT = 10;const int ERR_CLCT_PT = -1;const int HIT_WALL_PT = -5;//奖励设置const int MOV_LIM = 200;const int POP_CNT = 500;const int GEN_CNT = 1000;const float MUT_RATE = 0.005; //每一个位点的变异概率const int MAP_REP = 50; //计算适应度时用多少个地图//演化的一些设置const int THREAD_CNT = 10;enum GRD_DIR { DIRNONE = -1, CUR, UP, DN, RT, LF };const int DIR_CNT = 5;enum GRD_OBJ { OBJNONE = -1, EPT, WAL, CAN };const int OBJ_CNT = 3;enum ACTION { ACTNONE = -1, MV_UP, MV_DN, MV_RT, MV_LF, MV_RND, CLCT_CAN, HALT};const int ACTION_CNT = 7;typedef vector<vector<bool>> Map_t;Map_t cur_map;inline bool is_wall(int x, int y, Map_t* mp) { auto [n, m] = make_pair((*mp).size(), (*mp).front().size()); if (x >= n || x < 0 || y >= m || y < 0) return true; else return false;}inline GRD_OBJ get_obj_inpos(int x, int y, Map_t* mp) { GRD_OBJ obj; int n = mp->size(); int m = mp->front().size(); if (is_wall(x, y, mp)) obj = WAL; else if ((*mp)[x][y]) obj = CAN; else if (!(*mp)[x][y]) obj = EPT; return obj;}struct Obj_in_dir { GRD_DIR dir; GRD_OBJ obj; const bool operator<(Obj_in_dir b) const { if (dir != b.dir) return dir < b.dir; return obj < b.obj; } const bool operator==(Obj_in_dir b) const { return dir == b.dir && obj == b.obj; } const bool operator!=(Obj_in_dir b) const { return dir != b.dir || obj != b.obj; } Obj_in_dir(GRD_DIR _dir, GRD_OBJ _obj) : dir(_dir), obj(_obj) {} Obj_in_dir(int x, int y, GRD_DIR _dir, Map_t* mp) { dir = _dir; switch (dir) { case CUR: obj = get_obj_inpos(x, y, mp); break; case UP: obj = get_obj_inpos(x - 1, y, mp); break; case DN: obj = get_obj_inpos(x + 1, y, mp); break; case RT: obj = get_obj_inpos(x, y + 1, mp); break; case LF: obj = get_obj_inpos(x, y - 1, mp); break; } } Obj_in_dir() { dir = DIRNONE; obj = OBJNONE; }};struct Srndng { Obj_in_dir objs[5]; const bool operator<(Srndng b) const { for (int i = 0; i < 5; i++) { if (objs[i] != b.objs[i]) return objs[i] < b.objs[i]; } return false; } Srndng(int x, int y, Map_t* mp) { for (int i = 0; i < 5; i++) { objs[i] = Obj_in_dir(x, y, GRD_DIR(i), mp); } } Srndng() { for (int i = 0; i < 5; i++) objs[i].dir = DIRNONE, objs[i].obj = OBJNONE; }};typedef map<Srndng, ACTION> Gene_t;typedef pair<Gene_t, float> Gene_res_t; // 基因对应的适应度typedef vector<Gene_res_t> Gene_pool_t;void mp_generator(Map_t* mp, int n = MAP_SIZ, int m = MAP_SIZ) { srand(time(0)); mp->resize(n); for (int i = 0; i < n; i++) { (*mp)[i].resize(m); } for (auto& row : *mp) { for (auto&& unit : row) { unit = (rand() * 1.0 <= CAN_RATE * RAND_MAX); } }}Map_t* mp_generator(int n = MAP_SIZ, int m = MAP_SIZ) { auto mp = new Map_t(n); mp_generator(mp); return mp;}void gene_generator_once(Gene_t* ret_gene, Srndng* ret_srndng, GRD_DIR cur_dir) { if (cur_dir >= DIR_CNT) { (*ret_gene)[*ret_srndng] = ACTION(rand() % ACTION_CNT); return; } for (int i = 0; i < OBJ_CNT; i++) { (*ret_srndng).objs[cur_dir] = Obj_in_dir(GRD_DIR(cur_dir), GRD_OBJ(i)); gene_generator_once(ret_gene, ret_srndng, GRD_DIR(cur_dir + 1)); }}void gene_combine(Gene_t* pa, Gene_t* pb, Gene_t* child) { int cmb_pos = round(double(rand() * 1.0 / RAND_MAX * 1.0) * double(pa->size())); int cur_idx = 0; for (auto [key, val] : *pa) { if (cur_idx > cmb_pos) break; if ((rand() * 1.0 / RAND_MAX * 1.0) <= MUT_RATE) (*child)[key] = ACTION(rand() % (ACTION_CNT)); else (*child)[key] = val; cur_idx++; } cur_idx = 0; for (auto [key, val] : *pb) { if (cur_idx > cmb_pos) { if ((rand() * 1.0 / RAND_MAX * 1.0) <= MUT_RATE) (*child)[key] = ACTION(rand() % (ACTION_CNT)); else (*child)[key] = val; } cur_idx++; }}Gene_t* gene_combine(Gene_t* pa, Gene_t* pb) { auto child = new Gene_t; gene_combine(pa, pb, child); return child;}inline pair<int, int> get_pos_after_mv(int x, int y, ACTION mv) { switch (mv) { case MV_UP: return {x - 1, y}; break; case MV_DN: return {x + 1, y}; break; case MV_LF: return {x, y - 1}; break; case MV_RT: return {x, y + 1}; break; case MV_RND: return get_pos_after_mv(x, y, ACTION(rand() % 4)); break; case CLCT_CAN: throw invalid_argument("not a move"); return {x, y}; break; case HALT: throw invalid_argument("not a move"); return {x, y}; break; }}inline bool is_mov(ACTION act) { return act <= 4; }int calc_fitness(Gene_t* gene, Map_t* mp) { int cur_x = 0, cur_y = 0; int fit = 0; for (int cur_mov = 1; cur_mov <= MOV_LIM; cur_mov++) { Srndng cur_srnd(cur_x, cur_y, mp); ACTION cur_act = (*gene)[cur_srnd]; if (is_mov(cur_act)) { tie(cur_x, cur_y) = get_pos_after_mv(cur_x, cur_y, cur_act); } if (is_wall(cur_x, cur_y, mp)) { fit += HIT_WALL_PT; auto [n, m] = make_pair((*mp).size(), (*mp).front().size()); if (cur_x < 0) cur_x = 0; if (cur_y < 0) cur_y = 0; if (cur_x >= n) cur_x = n - 1; if (cur_y >= m) cur_y = m - 1; } else if (cur_act == CLCT_CAN) { if ((*mp)[cur_x][cur_y]) { fit += SUC_CLCT_PT; (*mp)[cur_x][cur_y] = false; } else fit += ERR_CLCT_PT; } } return fit;}void gene_generator(Gene_pool_t* pool, int cnt) { while (cnt--) { Gene_t temp_gene; Srndng temp_srnd; gene_generator_once(&temp_gene, &temp_srnd, GRD_DIR(0)); pool->push_back({temp_gene, 0}); }}Gene_pool_t cur_pool, temp_pool;void calc_popfit_mul_th() { thread* calc_fit_th[THREAD_CNT]; const int PER_TH = POP_CNT / THREAD_CNT; for (int i = 0; i < THREAD_CNT; i++) { calc_fit_th[i] = new thread([i]() { for (int j = i * PER_TH; j < (i + 1) * PER_TH; j++) cur_pool[j].second = calc_fitness(&cur_pool[j].first, &cur_map); }); } for (int i = 0; i < THREAD_CNT; i++) { calc_fit_th[i]->join(); }}vector<int> choose_by_weight(vector<float>& possi, int cnt) { vector<int> ret; ret.reserve(cnt); double tot = 0; for (float cur : possi) { tot += cur; } map<int, int> choose_rg; int lst = 0; for (int i = 0; i < possi.size(); i++) { int len = lround(possi[i] * 1.0 / tot * 1.0 * (RAND_MAX * 1.0)); if (len == 0) continue; choose_rg[lst] = i; lst = lst + len; //下一个的下标 } choose_rg[IINF] = possi.size(); while (ret.size() < cnt) { int rd = rand(); int rd_idx = (--choose_rg.upper_bound(rd))->second; ret.push_back(rd_idx); } return ret;}ofstream fileout;void evolve(int cur_gen) { if (cur_gen != 1) { temp_pool.clear(); // 新的一代是放进 temp 里的 } for_each(cur_pool.begin(), cur_pool.end(), [](Gene_res_t& a) { a.second = 0; }); for (int m = 0; m < MAP_REP; m++) { mp_generator(&cur_map); Map_t temp_map = cur_map; for (int i = 0; i < POP_CNT; i++) { cur_pool[i].second += calc_fitness(&cur_pool[i].first, &cur_map); cur_map = temp_map; } } for (auto& res : cur_pool) { res.second /= (MAP_REP * 1.0); } float tot_fit = 0; float mx_fit = numeric_limits<float>::min(); for (auto cur : cur_pool) { tot_fit += cur.second; mx_fit = max(mx_fit, cur.second); } fileout << mx_fit << ","; cout << cur_gen <<" "<<mx_fit<<"\n"; sort(cur_pool.begin(), cur_pool.end(), [](Gene_res_t& a, Gene_res_t& b) { return a.second < b.second; }); vector<float> possi; const float TOT_ELE = (0.0 + (POP_CNT - 1) * 1.0) * POP_CNT * 1.0 / 2.0; for (int i = 0; i < cur_pool.size(); i++) { possi.push_back(i * 1.0 * sqrt(i * 1.0)); // possi.push_back(i); } auto chosen = choose_by_weight(possi, POP_CNT * 2); temp_pool.clear(); while (chosen.size()) { int fir = chosen.back(); chosen.pop_back(); int sec = chosen.back(); chosen.pop_back(); Gene_t child; gene_combine(&cur_pool[fir].first, &cur_pool[sec].first, &child); DEBUG("fir: %d sec: %d\n", fir, sec); temp_pool.push_back({child, 0}); } swap(cur_pool, temp_pool);}int main() { fileout.open("./out"); gene_generator(&cur_pool, POP_CNT); //创建初始基因 for (int i = 1; i <= GEN_CNT; i++) { evolve(i); } system("python ./plotting.py"); pause;}
]]>给你一个 () 的格点图,每个格子的值要么是 ,要么是 ,现在问你,是否有一条从 到 的路径,使得路径上经过的格点的值的和为 。在路径中,只能从 移动到 或是 (向右或是向下走)。
看到这个 () 的数据范围就知道暴搜肯定要寄了(别学我),所以得想一些别的办法。
首先,如果经过奇数个格子,或者说 为奇数,那么肯定没有这样的一条路径(经过的 和 点没有办法相等)。
直接判断某个格子图是否符合要求太麻烦,我们可以思考,如果有任意一条路径,我们是否能根据这条路径的值(也就是途径的格子的和),来做一些改变,最后让路径的值变为 。
如下图这样就是对路径做了一次改变(改变前后只有一个格子不同)。最后让路径的值产生了变化。
在一次改变中,路径的值会产生 或是 的变化,那么如果我们刚开始的路径值是一个偶数,就可以把这个路径通过这样的改变变为 ……吗?
显然是不行的,如果整个格点图全是 或是全是 就不行,所以我们还得做一些改进。
首先就得确保在这个格点图中不会只有值特别离谱的路径,如果只有值特别离谱的路径,那无论你怎么变,也搞不出值为 的路径。
所以我们需要找出值最大的路径,以及值最小的路径。
设值最大的路径的值为 ,最小的路径的值为 。
那么如果:
我们就一定可以通过这样的变化把任意一个值为偶数的路径变为值为 的路径。
或者可以这样理解,如果符合上面那个条件,那我们就可以逐渐把值最小的路径向值最大的路径变换,在这个过程中,一定有一个路径的值等于 。
至于求这样的格子图的最大和最小路径,就属于是典中典了(用 dp),这里不赘述,如果有不熟悉的可以看洛谷P1004。
#include <bits/stdc++.h>using namespace std;int main() { int t; scanf("%d", &t); while (t--) { int n, m; scanf("%d%d", &n, &m); int a[n + 1][m + 1]; int mx[n + 1][m + 1], mn[n + 1][m + 1]; //mx[i][j] 的意思是到 i, j 这个点的最大路径的值 //mn[i][j] 是最小 memset(a, 0, sizeof(a)); memset(mx, 0, sizeof(mx)); memset(mn, 0, sizeof(mn)); for (int i = 1; i <= n; i++) { for (int j = 1; j <= m; j++) { scanf("%d", &a[i][j]); } } for (int i = 1; i <= n; i++) mx[i][1] = mn[i][1] = mx[i - 1][1] + a[i][1]; //给 dp 设置边界条件,如果在格子图的左边界,显然只能从上面走过来 for (int j = 1; j <= m; j++) mx[1][j] = mn[1][j] = mn[1][j - 1] + a[1][j]; //如果在格子图的上边界,只能从左边走过来 for (int i = 2; i <= n; i++) { for (int j = 2; j <= m; j++) { mx[i][j] = max(mx[i - 1][j], mx[i][j - 1]) + a[i][j]; mn[i][j] = min(mn[i - 1][j], mn[i][j - 1]) + a[i][j]; //经典 dp,选择是从左边走过来还是从上面走过来 } } if (mx[n][m] & 1 || mn[n][m] > 0 || mx[n][m] < 0) { //mx[n][m] & 1 是判断这个路径是否是奇数的 //当然也可以直接前面判断 n + m - 1,这样还能快一点 printf("NO\n"); } else { printf("YES\n"); } }}
]]>在这里放一些写程序时报的一些错(或者配置环境之类的很杂的问题),这样下次遇到了可以直接来这里看:
按照语言分类,每个错误前面会有发生的时间。
2022/6/20
如果用 std::thread()
创建线程,并且传入函数指针是非静态的,需要这么写:
thread(&class_name::func_name, this, arg1, arg2, other_args...);
因为对于每个实例,这个函数是不一样的,所以只有传入 this
指针,执行这个线程时才知道具体是执行哪个实例的函数。
2022/7/12
今天太无语了,本来花了好久时间编译,然后准备打开 qemu 和 riscv64-unknown-elf-gdb 单步 xv6 的内核,结果输入一个 layout split
,居然告诉我 Undefined command: "layout". Try "help".
。
然后又尝试在开启 gdb 时输入一个 -tui
参数,居然显示 riscv64-unknown-elf-gdb: TUI mode is not supported
。
网上查了一圈之后发现是因为没有安装 curses,但是为啥我别的 gdb 就可以啊??
于是就只能下载 curses 之后重新编译一遍了,而且这个编译速度贼慢。。。。
之后终于能成功使用 layout
了。
这应该是一个 bug?(见这个链接)。如果在 windows 环境下的 vscode 中删除一个文件,被删除的文件会自动被移动到回收站,不过 WSL 下相当于直接 rm
了,不可恢复。
这时候就需要用些奇怪的方法了,我们知道 vscode 有个很好的功能叫时间线(timeline),通过这个功能可以查看到以前版本的文件。虽然我们把文件删除了没法查看时间线,但是缓存还是在的,在 WSL 中,这些缓存存在 /root/.vscode-server/data/User/History
这个文件夹中。不过文件名全都是乱码,可能需要花点时间找。
最后感谢 stackoverflow 这个回答 下 “@iutlu” 的评论,要不然今天中午写的东西就要重写一遍了。
]]>打的第一场 CF div.4
这个题就是那种想到点了就很很简单,没想到的话就……寄了的题(我就属于是寄了)。
给你一个长度为 的数组 ,问你在这个数组中,有多少个长度为 的区间,符合以下的条件:
暴力还是很好搞的,就把数组中每个可能的区间都算一遍就行了,但是看到 这个条件就知道要寄了。
所以我们需要一种能在 的时间内判断区间是否符合条件的方法。(你要能搞出来 的也不是不可以)
可以发现,采用暴力的方法是因为每次这个区间的起始位都会变,所以数组的每一项前面要乘的数都不确定的。但如果我们能找到一种跟区间起始位置不相关的判断条件,这个问题就解决了。
接下来,重点来了
再仔细观察题目中给的条件,可以发现如果想要符合条件,数组中的前一项必须小于后一项的两倍,也就是:
这个性质是和区间的位置,以及长度不相关的,只要两个相邻的数符合这个条件,那么它可以出现在任何长度,位置的区间中。
不过,这只是区间中的两个数,如果我们想要让一整个区间都合法,那么我们就需要让整个区间内,任意的 都小于 。
也就是说,只要长度为 的区间内,有 个符合 的数对,那这个区间就是符合条件的。
统计区间内的合法数量……并且需要在 的时间内查询到结果,那这不就是前缀和吗?
于是我们很自然的就想到了判断完每个数对是否合法后,开一个前缀和数组 valid_sum[i]
来统计,到 为止,有多少个符合条件的数对。
最后再搞个循环统计一下符合条件的区间就好了。
代码还是比较简单的,这题的 这个点还是比较难想。
#include<bits/stdc++.h>using namespace std;int main(){ int t; scanf("%d", &t); while(t--){ int n, k; scanf("%d%d", &n, &k); //注意他给你的是 k,但是这个区间实际的长度是 k + 1 int a[n + 1]; for(int i = 1; i <= n; i++){ scanf("%d", &a[i]); } bool valid[n + 1]; memset(valid, 0, sizeof(valid)); int valid_sum[n + 1]; memset(valid_sum, 0, sizeof(valid_sum)); for(int i = 1; i < n; i++){ if(a[i] < 2 * a[i + 1]){ valid[i] = true; //判断和记录 a[i] 和 a[i + 1] 这个树对是否合法。 } valid_sum[i] = valid_sum[i - 1] + valid[i]; //前缀和 } int ans = 0; for(int i = 1; i <= n - k; i++){ if(valid_sum[i + k - 1] - valid_sum[i - 1] == k){ //实际长度是 k + 1,所以 k + 1 再 - 1 = k ans++; } } printf("%d\n", ans); }}
最后,希望这篇题解对你有帮助,如果有问题可以通过评论区或者私信联系我。
]]>今天发完 Treap 的博客之后自己看了一遍,突然感觉很不爽。Treap 的这篇博客大量的用到了指针和箭头运算符,因为博客的字体没有连字(ligature)的功能,所以看着特别傻。于是我决定把博客的字体换成 Iosevka。
以下是一些我参考的网站
基本的思路还是写一个 css 文件,然后在 html 的 head 部分注入进去。最后再在 hexo 的设置中把字体改成你想用的字体。
我改完写完 css 然后上传了字体文件之后 hexo s
了一下,发现没问题。不过很奇怪的是部署到 github 之后如果开 linux 的虚拟机访问,或是用手机访问,都不能正确的加载字体。
于是我就尝试 f12 了一下,打开 css 文件之后发现 css 文件变成了这样:
<link rel="stylesheet" type="text/css" href="https://cdn.jsdelivr.net/hint.css/2.4.1/hint.min.css">@font-face{ font-family: 'Iosevka'; font-display: swap; src: url('/font/iosevka-regular.ttf') format('truetype');}body { font-family: 'Iosevka';}
其中,@font-face
前面的 link
不是我自己添加的,于是我就尝试在 css 中把这一行删掉,字体也能正常显示了。因为这个文件我不管怎么写,hexo g
的时候都会加入这一行代码,部署的时候也会一起传上去,所以我干脆直接在这个 css 文件前面加了这个:
nothing{}
这样这个 link
就会加到 nothing
的前面,不会影响 @font-face
,字体也就能正常显示了。
不过我还是不清楚为什么在生成的时候会在这个文件中加入这行代码,如果有知道的可以联系我,也许是我加的一些插件?后面我又直接在 vscode 中搜索了一下,发现很多文件都被添加了这个代码,还是挺奇怪的。
只通过博客看的话好像可以很方便的解决这个问题,实际上因为网上没有现成的资料,我浪费了大量的时间去解决它,希望下次不会再遇到这种奇怪的问题了。
]]>最近发现以前学过的算法都非常容易忘,像什么后缀自动机,AC 自动机,插头 dp,都基本全忘了(甚至 KMP 和网络流这种太久没写也有点忘了,我太菜了)。所以在想,可以在学的时候就做笔记,以后忘了直接看笔记就行。
然后就有了这篇博客,不过学习笔记类的文章不会跟题解一样详细,主要还是给自己看的,之后有时间的话可能会写教程类的文章。
upd 2022/6/19:把给 OI-wiki 写的文章搞过来了,感觉现在已经可以作为一个教程类文章了。不过还需要添加无旋 treap 的内容。
upd 2020/7/1:把给 OI-wiki 写的无旋 treap 和无旋 treap 的区间操作部分复制了过来、本文中的无旋操作几乎都是我写的,如果你想了解具体的贡献者可以看 OI-wiki 的 GitHub 页面。
Treap(树堆) 是一种 弱平衡 的 二叉搜索树。它同时符合二叉搜索树和堆的性质,名字也因此为 tree (树) 和 heap (堆) 的组合。
其中,二叉搜索树的性质是:
堆的性质是:
不难看出,如果用的是同一个值,那这两种数据结构的性质是矛盾的,所以我们再在搜索树的基础上,引入一个给堆的值 。对于 值,我们维护搜索树的性质,对于 值,我们维护堆的性质。其中 这个值是随机给出的。
下图就是一个 Treap 的例子(这里使用的是小根堆,即根节点的值最小)。
那我们为什么需要大费周章的去让这个数据结构符合树和堆的性质,并且随机给出堆的值呢?
要理解这个,首先需要理解朴素二叉搜索树的问题。在给朴素搜索树插入一个新节点时,我们需要从这个搜索树的根节点开始递归,如果新节点比当前节点小,那就向左递归,反之亦然。
最后当发现当前节点没有子节点时,就根据新节点的值的大小,让新节点成为当前节点的左或右子节点。
如果新插入的节点的值是随机的,那这个朴素搜索树的形状会非常的 “胖”,上图的 Treap 就是一个例子。也就是说,每一层的节点比较多。
在这样的情况下,这个搜索树的层数是会比较接近 ( 为节点数) 的,查询的复杂度也是 (因为只要递归这么多层就能查到)。
不过,这只是在随机情况下的复杂度,如果我们按照下面这个非常有序的顺序给一个朴素的搜索树插入节点。
1 2 3 4 5
那……
这个树就会变得非常 “瘦长”(每次插入的节点都比前面的大,所以都被安排到右子节点了):
不难看出,现在这个二叉搜索树已经退化成链了,查询的复杂度也从 变成了线性。
而 treap 要解决的正是这个问题。它通过随机化的 属性,以及维护堆性质的过程,“打乱”了节点的插入顺序。从而让二叉搜索树达到了理想的复杂度,避免了退化成链的问题。
我并不清楚如何去严格的证明这样随机化的过程可以让搜索树的复杂度的期望值保持在 ,但我们可以试着感性的去理解一下。
首先,我们需要认识到一个节点的 属性是和它所在的层数有直接关联的。再回忆堆的性质:
我们发现层数低的节点,比如整个树的根节点,它的 属性也会更小(在小根堆中)。并且,在朴素的搜索树中,先被插入的节点,也更有可能会有比较小的层数。我们可以把这个 属性和被插入的顺序关联起来理解,这样,也就理解了为什么 treap 可以把节点插入的顺序通过 打乱。
在给 treap 插入新节点时,需要同时维护树和堆的性质,为了达到这个目的,有两种方法被发明了出来,分别是旋转和分裂、合并。使用这两种方法的 treap 被分别成为有旋式 treap 和 无旋式 treap。
旋转 treap 维护平衡的方式为旋转,和 AVL 树的旋转操作类似,分为 左旋 和 右旋。即在满足二叉搜索树的条件下根据堆的优先级对 treap 进行平衡操作。
旋转 treap 在做普通平衡树题的时候,是所有平衡树中常数较小的。因为普通的二叉搜索树会被递增或递减的数据卡,用 treap 对每个节点定义一个由 rand
得到的权值,从而防止特殊数据卡。同时在每次删除/插入时通过这个权值决定要不要旋转即可,其他操作与二叉搜索树类似。
大部分的树形数据结构都有指针和数组模拟两种实现方法,下面将会详细的分部分讲解指针版的代码,
注意本代码中的 rank
代表前面讲的 变量(堆的值)。并且,维护的堆的性质是小根堆( 小的在上面)。
本文的树堆实现了洛谷模板题中的操作,因为是在竞赛中使用的,所以没有用模板一类的东西封装。
下面的代码也大量参考了这篇文章。
struct Node { Node *ch[2];//两个子节点的地址 int val, rank; int rep_cnt;//当前这个值(val)重复出现的次数 int siz; // Node(int val) : val(val), rep_cnt(1), siz(1) { ch[0] = ch[1] = nullptr; rank = rand(); //注意初始化的时候,rank 是随机给出的 } void upd_siz() { //用于旋转和删除过后,重新计算 siz 的值 siz = rep_cnt; if (ch[0] != nullptr) siz += ch[0]->siz; if (ch[1] != nullptr) siz += ch[1]->siz; }};
旋转操作是 treap 的一个非常重要的操作,主要用来在保持 treap 树性质的同时,调整不同节点的层数,以达到维护堆性质的作用。
旋转操作的左旋和右旋可能不是特别容易区分,以下是两个较为明显的特点:
旋转操作的含义:
左旋和右旋操作是相互的,如下图。
enum rot_type { LF = 1, RT = 0 };void _rotate(Node *&cur, rot_type dir) { //dir参数代表旋转的方向 0为右旋,1为左旋 //注意传进来的 cur 是指针的引用,也就是改了这个 cur,变量是跟着一起改的,如果这个 cur 是别的 //树的子节点,根据 ch 找过来的时候,也是会找到这里的 //以下的代码解释的均是左旋时的情况 Node *tmp = cur->ch[dir];//让 C 变成根节点, //这里的 tmp 是一个临时的节点指针,指向成为新的根节点的节点 /* 左旋:也就是让右子节点变成根节点 * A C * / \ / \ * B C ----> A E * / \ / \ * D E B D */ cur->ch[dir] = tmp->ch[!dir]; //让 A 的右子节点变成 D tmp->ch[!dir] = cur; //让 C 的左子节点变成 A tmp->upd_siz(), cur->upd_siz();//更新大小信息 cur = tmp; //最后把临时储存 C 树的变量赋值到当前根节点上(注意 cur 是引用) }
跟普通搜索树插入的过程没啥区别,但是需要在插的过程中通过旋转来维护树堆中堆的性质。
void _insert(Node *&cur, int val) { if (cur == nullptr) { //没这个节点直接新建 cur = new Node(val); return; } else if (val == cur->val) { //如果有这个值相同的节点,就把重复数量加一 cur->rep_cnt++; cur->siz++; } else if (val < cur->val) { //维护搜索树性质,val 比当前节点小就插到左边,反之亦然 _insert(cur->ch[0], val); if (cur->ch[0]->rank < cur->rank) { //树根永远是最小的 //因为新插的左子节点比根节点小,现在需要让左子节点变成根节点 _rotate(cur, RT); //注意前面的旋转性质,要把左子节点转上来,需要右旋 } cur->upd_siz(); //插入之后大小会变化,需要更新 } else { _insert(cur->ch[1], val); if (cur->ch[1]->rank < cur->rank) { _rotate(cur, LF); } cur->upd_siz(); } }
主要就是分类讨论,不同的情况有不同的处理方法,删完了树的大小会有变化,要注意更新。并且如果要删的节点左子树和右子树都有,就要考虑删除之后让谁来当根(维护 rank 小的节点在上面)。
主要就是分类讨论,不同的情况有不同的处理方法,删完了树的大小会有变化,要注意更新。并且如果要删的节点有左子树和右子树,就要考虑删除之后让谁来当父节点(维护 rank 小的节点在上面)。
void _del(Node *&cur, int val) { if (val > cur->val) { _del(cur->ch[1], val); //值更大就在右子树,反之亦然 cur->upd_siz(); } else if (val < cur->val) { _del(cur->ch[0], val); cur->upd_siz(); } else { if (cur->rep_cnt > 1) { //如果要删除的节点是重复的,可以直接把重复值减小 cur->rep_cnt--, cur->siz--; return; } uint8_t state = 0; state |= (cur->ch[0] != nullptr); state |= ((cur->ch[1] != nullptr) << 1); // 00都无,01有左无右,10,无左有右,11都有 Node *tmp = cur; switch (state) { case 0: delete cur; cur = nullptr; //没有任何子节点,就直接把这个节点删了 break; case 1: //有左无右 cur = tmp->ch[0]; //把根变成左儿子,然后把原来的根节删了,注意这里的 tmp 是从 cur 复制的,而 cur //是引用 delete tmp; break; case 2: //有右无左 cur = tmp->ch[1]; delete tmp; break; case 3: rot_type dir = cur->ch[0]->rank < cur->ch[1]->rank ? RT : LF;// dir 是 rank 更小的那个儿子 _rotate(cur, dir); //这里的旋转可以把优先级更小的儿子转上去,rt 是 0, 而 lf //是 1,刚好跟实际的子树下标反过来 _del(cur->ch[!dir], val);//旋转完成后原来的根节点就在旋方向那边,所以需要 //继续把这个原来的根节点删掉 //如果说要删的这个节点是在整个树的“上层的”,那我们会一直通过这 //这里的旋转操作,把它转到没有子树了(或者只有一个),再删掉它。 cur->upd_siz(); //删除会造成大小改变 break; } } }
操作含义:查询以 cur 为根节点的子树中,val 这个值的大小的排名 (该子树中小于 val 的节点的个数 + 1)
int _query_rank(Node *cur, int val) { int less_siz = cur->ch[0] == nullptr ? 0 : cur->ch[0]->siz; //这个树中小于 val 的节点的数量 if (val == cur->val) //如果这个节点就是要查的节点 return less_siz + 1; else if (val < cur->val) { if (cur->ch[0] != nullptr) return _query_rank(cur->ch[0], val); else return 1; //如果左子树是空的,说比最小的节点还要小,那这个数字就是最小的 } else { if (cur->ch[1] != nullptr) //如果要查的值比这个节点大,那这个节点的左子树以及这个节点自身肯定都比要查的值小 //所以要加上这两个值,再加上往右边找的结果(以右子树为根的子树中,val 这个值的大小的排名) return less_siz + cur->rep_cnt + _query_rank(cur->ch[1], val); else return cur->siz + 1; //没有右子树的话直接整个树 + 1 相当于 less_siz + cur->rep_cnt + 1 } }
要根据排名查询值,我们首先要知道如何判断要查的节点在树的哪个部分:
以下是一个判断方法的表:
左子树 | 根节点 / 当前节点 | 右子树 |
---|---|---|
排名一定小于等于左子树的大小 | 排名应该 >= 左子树的大小,并且 <= 左子树的大小 + 根节点的重复次数 | 不然的话就在右子树 |
注意如果在右子树,递归的时候需要对原来的 rank
进行处理。递归的时候就相当去查,在右子树中为这个排名的值,为了把排名转换成基于右子树的,需要把原来的 rank
减去左子树的大小和根节点的重复次数。
可以把所有节点想象成一个排好序的数组,或者数轴(如下),
1 -> |左子树的节点|根节点|右子树的节点| -> n ^ 要查的排名 ⬇转换成基于右子树的排名1 -> |右子树的节点| -> n ^ 要查的排名
这里的转换方法就是直接把排名减去左子树的大小和根节点的重复数量。
int _query_val(Node *cur, int rank) { //查询树中第 rank 大的节点的值 int less_siz = cur->ch[0] == nullptr ? 0 : cur->ch[0]->siz; //less siz 是左子树的大小 if (rank <= less_siz) return _query_val(cur->ch[0], rank); else if (rank <= less_siz + cur->rep_cnt) return cur->val; else return _query_val(cur->ch[1], rank - less_siz - cur->rep_cnt);//见前文 }
注意这里使用了一个类中的全局变量,q_prev_tmp
。
这个值是只有在 val 比当前节点值大的时候才会被更改的,所以返回这个变量就是返回 val 最后一次比当前节点的值大,之后就是更小了。
int _query_prev(Node *cur, int val) { if (val <= cur->val) { //还是比 val 大,所以往右子树找 if (cur->ch[0] != nullptr) return _query_prev(cur->ch[0], val); } else { //只有能进到这个 else 里,才会更新 q_prev_tmp 的值 q_prev_tmp = cur->val; //当前节点已经比 val,小了,但是不确定是否是最大的,所以要到右子树继续找 if (cur->ch[1] != nullptr) _query_prev(cur->ch[1], val); //接下来的递归可能不会更改 q_prev_tmp 了,那就直接返回这个值,总之返回的就是最后一次进到 //这个 else 中的 cur->val return q_prev_tmp; } return -1145; }
跟前一个很相似,只是大于小于号换了一下。
int _query_nex(Node *cur, int val) { if (val >= cur->val) { if (cur->ch[1] != nullptr) return _query_nex(cur->ch[1], val); } else { q_nex_tmp = cur->val; if (cur->ch[0] != nullptr) _query_nex(cur->ch[0], val); return q_nex_tmp; } return -1145; }
无旋 treap 的操作方式使得它天生支持维护序列、可持久化等特性。
无旋 treap 又称分裂合并 treap。它仅有两种核心操作,即为 分裂 与 合并。通过这两种操作,在很多情况下可以比旋转 treap 更方便的实现别的操作。下面逐一介绍这两种操作。
分裂过程接受两个参数:根指针 、关键值 。结果为将根指针指向的 treap 分裂为两个 treap,第一个 treap 所有结点的值()小于等于 ,第二个 treap 所有结点的值大于 。
该过程首先判断 是否小于 的值,若小于,则说明 及其右子树全部小于 ( 可能等于 ),\textit{key}\textit{key}\textit{cur}\textit{cur}\textit{key}$ 的。
相应的,如果 大于等于 的值,说明 的整个左子树以及其自身都小于 ,属于分裂后的第一个 treap。并且, 的部分右子树也可能有部分小于 ,因此我们需要继续递归地分裂右子树。把小于 的那部分作为 的右子树,这样,整个 上的节点都小于 。
下图展示了 的值小于等于 时按值分裂的情况。[2]
pair<Node *, Node *> split(Node *cur, int key) { if (cur == nullptr) return {nullptr, nullptr}; if (cur->val <= key) { // cur 以及它的左子树一定属于分裂后的第一个树 auto temp = split(cur->ch[1], key); // 但是它可能有部分右子树也比 key 小 cur->ch[1] = temp.first; // 我们把小于 key 的那部分拿出来,作为 cur 的右子树,这样整个 cur 都是小于 // key 的 剩下的那部分右子树成为分裂后的第二个 treap cur->upd_siz(); // 分裂过后树的大小会变化,需要更新 return {cur, temp.second}; } else { // 同上 auto temp = split(cur->ch[0], key); cur->ch[0] = temp.second; cur->upd_siz(); return {temp.first, cur}; }}
比起按值分裂,这个操作更像是旋转 treap 中的根据排名(某个节点的排名是树中所有小于此节点值的节点的数量 )查询值:
此函数接受两个参数,节点指针 和排名 ,返回分裂后的三个 treap。
其中,第一个 treap 中每个节点的排名都小于 ,第二个的排名等于 ,并且第二个 treap 只有一个节点(不可能有多个等于的,如果有的话会增加 Node
结构体中的 cnt
),第三个则是大于。
此操作的重点在于判断排名和 相等的节点在树的哪个部分,这也是旋转 treap 根据排名查询值操作时的重要部分,在前文有非常详细的解释,这里不过多讲解。
并且,此操作的递归部分和按值分裂也非常相似,这里不赘述。
#define _3 second.second#define _2 second.firstpair<Node *, pair<Node *, Node *>> split_by_rk(Node *cur, int rk) { if (cur == nullptr) return {nullptr, {nullptr, nullptr}}; int ls_siz = cur->ch[0] == nullptr ? 0 : cur->ch[0]->siz; if (rk <= ls_siz) { // 排名和 cur 相等的节点在左子树 auto temp = split_by_rk(cur->ch[0], rk); cur->ch[0] = temp._3; // 返回的第三个 treap 中的排名都大于 rk // cur 的左子树被设成 temp._3 后,整个 cur 中节点的排名都大于 rk cur->upd_siz(); return {temp.first, {temp._2, cur}}; } else if (rk <= ls_siz + cur->cnt) { // 和 cur 相等的就是当前节点 Node *lt = cur->ch[0]; Node *rt = cur->ch[1]; cur->ch[0] = cur->ch[1] = nullptr; // 分裂后第二个 treap 只有一个节点,所有要把它的子树设置为空 return {lt, {cur, rt}}; } else { // 排名和 cur 相等的节点在右子树 // 递归过程同上 auto temp = split_by_rk(cur->ch[1], rk - ls_siz - cur->cnt); cur->ch[1] = temp.first; cur->upd_siz(); return {cur, {temp._2, temp._3}}; }}
合并过程接受两个参数:左 treap 的根指针 、右 treap 的根指针 。必须满足 中所有结点的值小于等于 中所有结点的值。一般来说,我们合并的两个 treap 都是原来从一个 treap 中分裂出去的,所以不难满足 中所有节点的值都小于
在旋转 treap 中,我们借助旋转操作来维护 符合堆的性质,同时旋转时还不能改变树的性质。在无旋 treap 中,我们用合并达到相同的效果。
因为两个 treap 已经有序,所以我们在合并的时候只需要考虑把哪个树“放在上面”,把哪个“放在下面”,也就是是需要判断将哪个一个树作为子树。显然,根据堆的性质,我们需要把 小的放在上面(这里采用小根堆)。
同时,我们还需要满足搜索树的性质,所以若 的根结点的 小于 的,那么 即为新根结点,并且 因为值比 更大,应与 的右子树合并;反之,则 作为新根结点,然后因为 的值比 小,与 的左子树合并。
Node *merge(Node *u, Node *v) { // 传进来的两个树的内部已经符合搜索树的性质了 // 并且 u 内所有节点的值 < v 内所有节点的值 // 所以在合并的时候需要维护堆的性质 // 这里用的是小根堆 if (u == nullptr && v == nullptr) return nullptr; if (u != nullptr && v == nullptr) return u; if (v != nullptr && u == nullptr) return v; if (u->prio < v->prio) { // u 的 prio 比较小,u应该作为父节点 u->ch[1] = merge(u->ch[1], v); // 因为 v 比 u 大,所以把 v 作为 u 的右子树 u->upd_siz(); return u; } else { // v 比较小,v应该作为父节点 v->ch[0] = merge(u, v->ch[0]); // u 比 v 小,所以递归时的参数是这样的 v->upd_siz(); return v; }}
在无旋 treap 中,插入,删除,根据值查询排名等基础操作既可以用普通二叉查找树的方法实现,也可以用分裂和合并来实现。通常来说,使用分裂和合并来实现更加简洁,但是速度会慢一点[3]。为了帮助更好的理解无旋 treap,下面的操作全部使用分裂和合并实现。
在实现插入操作时,我们利用了分裂操作的一些性质。也就是值小于等于 的节点会被分到第一个 treap。
所以,假设我们根据 分裂当前这个 treap。会有下面两棵树,并符合以下条件:
其中 表示分裂后所有被分到第一个 treap 的节点的集合, 则是第二个。
如果我们再按照 继续分裂 ,那么会产生下面两棵树,并符合以下条件:
其中 表示 分裂后所有被分到第一个 treap 的节点的集合, 则是第二个。并且上面的式子中,后半部分的 来自于 所符合的条件 。
不难发现,只要 和节点的值是一个整数(大多数使用场景下会使用整数)那么符合 条件的节点只有一个,也就是值等于 的节点。
在插入时,如果我们发现符合 的节点存在,那就可以直接增加重复次数,否则,就新开一个节点。
注意把树分裂好了还需要用合并操作把它“粘”回去,这样下次还能继续使用。并且,还需要注意合并操作的参数顺序是有要求的,第一个树的所有节点的值都需要小于第二个。
void insert(int val) { auto temp = split(root, val); // 根据 val 的值把整个树分成两个 // 注意 split 的实现,等于 val 的子树是在左子树的 auto l_tr = split(temp.first, val - 1); // l_tr 的左子树 <= val - 1,如果有 = val 的节点,那一定在右子树 Node *new_node; if (l_tr.second == nullptr) { // 没有这个节点就新开,否则直接增加重复次数。 new_node = new Node(val); } else { l_tr.second->cnt++; l_tr.second->upd_siz(); } Node *l_tr_combined = merge(l_tr.first, l_tr.second == nullptr ? new_node : l_tr.second); // 合并 T_1 left 和 T_1 right root = merge(l_tr_combined, temp.second); // 合并 T_1 和 T_2}
删除操作也使用和插入操作相似的方法,找到值和 相等的节点,并且删除它。
void del(int val) { auto temp = split(root, val); auto l_tr = split(temp.first, val - 1); if (l_tr.second->cnt > 1) { // 如果这个节点的重复次数大于 1,减小即可 l_tr.second->cnt--; l_tr.second->upd_siz(); l_tr.first = merge(l_tr.first, l_tr.second); } else { if (temp.first == l_tr.second) { // 有可能整个 T_1 只有这个节点,所以也需要把这个点设成 null 来标注已经删除 temp.first = nullptr; } delete l_tr.second; l_tr.second = nullptr; } root = merge(l_tr.first, temp.second);}
排名是比这个值小的节点的数量 ,所以我们根据 分裂当前树,那么分裂后的第一个树就符合:
如果树的值和 为整数,那么 就包含了所有值小于 的节点。
int qrank_by_val(Node* cur, int val) { auto temp = split(cur, val - 1); int ret = (temp.first == nullptr ? 0 : temp.first->siz) + 1; // 根据定义 + 1 root = merge(temp.first, temp.second); // 拆好了再粘回去 return ret;}
调用 split_by_rk()
函数后,会返回分裂好的三个 treap,其中第二个只包含一个节点,它的排名等于 ,所以我们直接返回这个节点的 。
int qval_by_rank(Node *cur, int rk) { auto temp = split_by_rk(cur, rk); int ret = temp._2->val; root = merge(temp.first, merge(temp._2, temp._3)); return ret;}
可以把这个问题转化为,在比 小的所有节点中,找出排名最大的。我们根据 来分裂这个 treap,返回的第一个 treap 中的节点的值就全部小于 ,然后我们调用 qval_by_rank()
找出这个树中值最大的节点。
int qprev(int val) { auto temp = split(root, val - 1); // temp.first 就是值小于 val 的子树 int ret = qval_by_rank(temp.first, temp.first->siz); // 这里查询的是,所有小于 val 的节点里面,最大的那个的值 root = merge(temp.first, temp.second); return ret;}
和上个操作类似,可以把这个问题转化为,在比 大的所有节点中,找出排名最大的。那么根据 分裂后,返回的第二个 treap 中的所有节点的值就大于 。
然后我们去查询这个树中排名为 的节点(也就是值最小的节点)的值,就可以成功查到第一个比 大的节点。
int qnex(int val) { auto temp = split(root, val); int ret = qval_by_rank(temp.second, 1); // 查询所有大于 val 的子树里面,值最小的那个 root = merge(temp.first, temp.second); return ret;}
注:建树的部分不是我写的。具体看 OI-wiki 的 GitHub 页面。之后这部分可能会加入笛卡尔树的详细建树过程。
将一个有 个节点的序列 转化为一棵 treap。
可以依次暴力插入这 个节点,每次插入一个权值为 的节点时,将整棵 treap 按照权值分裂成权值小于等于 的和权值大于 的两部分,然后新建一个权值为 的节点,将两部分和新节点按从小到大的顺序依次合并,单次插入时间复杂度 ,总时间复杂度 。
在某些题目内,可能会有多次插入一段有序序列的操作,这是就需要在 的时间复杂度内完成建树操作。
方法一:在递归建树的过程中,每次选取当前区间的中点作为该区间的树根,并对每个节点钦定合适的优先值,使得新树满足堆的性质。这样能保证树高为 。
方法二:在递归建树的过程中,每次选取当前区间的中点作为该区间的树根,然后给每个节点一个随机优先级。这样能保证树高为 ,但不保证其满足堆的性质。这样也是正确的,因为无旋式 treap 的优先级是用来使 merge
操作更加随机一点,而不是用来保证树高的。
方法三:观察到 treap 是笛卡尔树,利用笛卡尔树的 建树方法即可,用单调栈维护右链即可。
无旋 treap 相比旋转 treap 的一大好处就是可以实现各种区间操作,下面我们以文艺平衡树的 模板题 为例,介绍 treap 的区间操作。
您需要写一种数据结构(可参考题目标题),来维护一个有序数列。
其中需要提供以下操作:翻转一个区间,例如原有序序列是 ,翻转区间是 的话,结果是 。
对于 的数据,(初始区间长度)(翻转次数)
在这道题目中,我们需要实现的是区间翻转,那么我们首先需要考虑如何建树,建出来的树需要是初始的区间。
我们只需要把区间的下标依次插入 treap 中,这样在中序遍历(先遍历左子树,然后当前节点,最后右子树)时,就可以得到这个区间[4]。
我们知道在朴素的二叉查找树中按照递增的顺序插入节点,建出来的树是一个长链,按照中序遍历,自然可以得到这个区间。
如上图,按照 的顺序给朴素搜索树插入节点,中序遍历时,得到的也是 。
但是在 treap 中,按增序插入节点后,在合并操作时还会根据 调整树的结构,在这样的情况下,如何确保中序遍历一定能正确的输出呢?
可以参考 笛卡尔树的单调栈建树方法 来理解这个问题。
设新插入的节点为 。
首先,因为时递增的插入节点,每一个新插入的节点肯定会被连接到 treap 的右链(即从根结点一直往右子树走,经过的结点形成的链)上。
从根节点开始,右链上的节点的 是递增的(小根堆)。那我们可以找到右链上第一个 大于 的节点,我们叫这个节点 ,并把这个节点换成 。
因为 一定大于这个树上其他的全部节点,我们需要把 以及它的子树作为 的左子树。并且此时 没有右子树。
可以发现,中序遍历时 一定是最后一个被遍历到的(因为 是右链中的最后一个,而中序遍历中,右子树是最后被遍历到的)。
下图是一个 treap 根据递增顺序插入 号节点时,插入 号节点时的变化,可以用这张图更好的理解按照增序插入的过程。
翻转 这个区间时,基本思路是将树分裂成 三个区间,再对中间的 进行翻转[4]。
翻转的具体操作是把区间内的子树的每一个左,右子节点交换位置。如下图就展示了翻转上图中 treap 的 和 区间后的 treap。
注意如果按照这个方法翻转,那么每次翻转 区间时,就会有 个节点会被交换位置,这样频繁的操作显然不能满足 的数据范围,其 的单次翻转复杂度甚至不如暴力(因为我们除了需要花线性时间交换节点外,还需要在树中花费 的时间找到需要交换的节点。
再观察题目要求,可以发现因为只需要最后输出操作完的区间,所以并不需要每次都真的去交换。如此一来,便可以使用线段树中常用的懒标记(lazy tag)来优化复杂度。交换时,只需要在父节点打上标记,代表这个子树下的每个左右子节点都需要交换就行了。
在线段树中,我们一般在更新和查询时下传懒标记。这是因为,在更新和查询时,我们想要更新/查询的范围不一定和懒标记代表的范围重合,所以要先下传标记,确保查到和更新后的值是正确的。
在无旋 treap 中也是一样。具体操作时我们会把 treap 分裂成前文讲到的三个树,然后给中间的树打上懒标记后合并这三棵树。因为我们想要翻转的区间和懒标记代表的区间不一定重合,所以要在分裂时下传标记。并且,分裂和合并操作会造成每个节点及其懒标记所代表的节点发生变动,所以也需要在合并前下传懒标记。
换句话说,是当树的结构发生改变的时候,当我们进行分裂或合并操作时需要改变某一个点的左右儿子信息时之前,应该下放标记,而非之后,因为懒标记是需要下传给儿子节点的,但更改左右儿子信息之后若懒标记还未下放,则懒标记就丢失了下放的对象。[5]
以下为代码讲解,代码参考了[4]。
因为区间操作中大部分操作都和普通的无旋 treap 相同,所以这里只讲解和普通无旋 treap 不同的地方。
需要注意这里的懒标记代表需要把这个树中的每一个子节点交换位置。所以如果当前节点的子节点也有懒标记,那两次翻转就抵消了。如果子节点不需要翻转,那么这个懒标记就需要继续被下传到子节点上。
// 这里这个 pushdown 是 Node 类的成员函数,其中 to_rev 是懒标记inline void pushdown() { swap(ch[0], ch[1]); if (ch[0] != nullptr) ch[0]->to_rev ^= 1; if (ch[1] != nullptr) ch[1]->to_rev ^= 1; to_rev = false;}inline void check_tag() { if (to_rev) pushdown();}
注意在这个题目中,因为翻转操作,treap 中的 会不符合二叉搜索树的性质(见区间翻转部分的图),所以我们不能根据 来判断应该往左子树还是右子树递归。
所以这里的分裂跟普通无旋 treap 中的按排名分裂更相似,是根据当前树的大小判断往左还是右子树递归的,换言之,我们是按照开始时这个节点在树中的位置来判断的。
返回的第一个 treap 中节点的排名全部小于等于 ,而第二个 treap 中节点的排名则全部大于 。
#define siz(_) (_ == nullptr ? 0 : _->siz)pair<Node*, Node*> split(Node* cur, int sz) { // 按照树的大小判断 if (cur == nullptr) return {nullptr, nullptr}; cur->check_tag(); // 分裂前先下传 if (sz <= siz(cur->ch[0])) { auto temp = split(cur->ch[0], sz); cur->ch[0] = temp.second; cur->upd_siz(); return {temp.first, cur}; } else { auto temp = split(cur->ch[1], sz - siz(cur->ch[0]) - 1); // 这里的转换在有旋 treap 的 “根据排名查询值有讲” cur->ch[1] = temp.first; cur->upd_siz(); return {cur, temp.second}; }}
唯一需要注意的是在合并前下传懒标记
Node *merge(Node *sm, Node *bg) { // small, big if (sm == nullptr && bg == nullptr) return nullptr; if (sm != nullptr && bg == nullptr) return sm; if (sm == nullptr && bg != nullptr) return bg; sm->check_tag(), bg->check_tag(); if (sm->prio < bg->prio) { sm->ch[1] = merge(sm->ch[1], bg); sm->upd_siz(); return sm; } else { bg->ch[0] = merge(sm, bg->ch[0]); bg->upd_siz(); return bg; }}
和前面介绍的一样,分裂出 三个区间,然后对中间的区间打上标记后再合并。
void seg_rev(int l, int r) { // 这里的 less 和 more 是相对于 l 的 auto less = split(root, l - 1); // 所有小于等于 l - 1 的会在 less 的左子树 auto more = split(less.second, r - l + 1); // 从 l 开始的前 r - l + 1 个元素的区间 more.first->to_rev = true; root = merge(less.first, merge(more.first, more.second));}
要注意在打印时要下传标记。
void print(Node* cur) { if (cur == nullptr) return; cur->check_tag(); // 中序遍历 -> 先左子树,再自己,最后右子树 print(cur->ch[0]); cout << cur->val << " "; print(cur->ch[1]);}
/*Date: 22 - 06-11 23 29PROBLEM_NUM: P3369 【模板】普通平衡树*/#include <bits/stdc++.h>using namespace std;#define pause system("pause")struct Node { Node *ch[2]; int val, rank; int rep_cnt; int siz; Node(int val) : val(val), rep_cnt(1), siz(1) { ch[0] = ch[1] = nullptr; rank = rand(); } void upd_siz() { siz = rep_cnt; if (ch[0] != nullptr) siz += ch[0]->siz; if (ch[1] != nullptr) siz += ch[1]->siz; }};class Treap { private: Node *root; enum rot_type { LF = 1, RT = 0 }; int q_prev_tmp = 0, q_nex_tmp = 0; void _rotate(Node *&cur, rot_type dir) { // 0为右旋,1为左旋 Node *tmp = cur->ch[dir]; // tmp指向成为新的根节点的节点(左旋情况下是右子节点) //让 C 变成根节点 /* 左旋:也就是让右子节点变成根节点 * A C * / \ / \ * B C ----> A E * / \ / \ * D E B D */ cur->ch[dir] = tmp->ch[!dir]; //让 A 的右子节点变成 D tmp->ch[!dir] = cur; //让 C 的左子节点变成 A tmp->upd_siz(), cur->upd_siz(); cur = tmp; } void _insert(Node *&cur, int val) { if (cur == nullptr) { cur = new Node(val); return; } else if (val == cur->val) { cur->rep_cnt++; cur->siz++; } else if (val < cur->val) { _insert(cur->ch[0], val); if (cur->ch[0]->rank < cur->rank) { //树根永远是最小的 //现在需要让左子节点变成根节点 _rotate(cur, RT); } cur->upd_siz(); } else { _insert(cur->ch[1], val); if (cur->ch[1]->rank < cur->rank) { _rotate(cur, LF); } cur->upd_siz(); } } void _del(Node *&cur, int val) { if (val > cur->val) { _del(cur->ch[1], val); cur->upd_siz(); } else if (val < cur->val) { _del(cur->ch[0], val); cur->upd_siz(); } else { if (cur->rep_cnt > 1) { cur->rep_cnt--, cur->siz--; return; } uint8_t state = 0; state |= (cur->ch[0] != nullptr); state |= ((cur->ch[1] != nullptr) << 1); // 00都无,01有左无右,10,无左有右,11都有 Node *tmp = cur; switch (state) { case 0: delete cur; cur = nullptr; break; case 1: //有左无右 cur = tmp->ch[0]; //把根变成左儿子 delete tmp; break; case 2: //有右无左 cur = tmp->ch[1]; delete tmp; break; case 3: rot_type dir = cur->ch[0]->rank < cur->ch[1]->rank ? RT : LF; // dir 也是更小的那个儿子 _rotate(cur, dir); //这里的旋转可以把优先级更小的儿子转上去 //旋转完成后,原来的根节点就在旋转方向的那边 _del(cur->ch[!dir], val); cur->upd_siz(); break; } } } int _query_rank(Node *cur, int val) { //查询以 cur 为根节点的子树中,val 这个值的大小的排名 (该子树中小于 val //的树的个数 + 1) int less_siz = cur->ch[0] == nullptr ? 0 : cur->ch[0]->siz; //这个树中小于 val 的节点的数量 if (val == cur->val) return less_siz + 1; else if (val < cur->val) { if (cur->ch[0] != nullptr) return _query_rank(cur->ch[0], val); else return 1; //如果说比最小的节点还要小,那这个数字就是最小的 } else { if (cur->ch[1] != nullptr) return less_siz + cur->rep_cnt + _query_rank(cur->ch[1], val); else return cur->siz + 1; } } int _query_val(Node *cur, int rank) { //查询树中第 rank 大的节点的值 DEBUG("qval: %d\n", cur->val); int less_siz = cur->ch[0] == nullptr ? 0 : cur->ch[0]->siz; if (rank <= less_siz) return _query_val(cur->ch[0], rank); else if (rank <= less_siz + cur->rep_cnt) return cur->val; else return _query_val(cur->ch[1], rank - less_siz - cur->rep_cnt); } int _query_prev(Node *cur, int val) { //查找树中最大的,小于 val 的节点 if (val <= cur->val) { if (cur->ch[0] != nullptr) return _query_prev(cur->ch[0], val); } else { q_prev_tmp = cur->val; //当前节点已经比 val //小了,但是不确定是否是最大的,所以要到右子树继续找 if (cur->ch[1] != nullptr) _query_prev(cur->ch[1], val); return q_prev_tmp; } return -1145; } int _query_nex(Node *cur, int val) { //找到树中最小的,大于 val 的节点 if (val >= cur->val) { if (cur->ch[1] != nullptr) return _query_nex(cur->ch[1], val); } else { q_nex_tmp = cur->val; if (cur->ch[0] != nullptr) _query_nex(cur->ch[0], val); return q_nex_tmp; } return -1145; } public: void insert(int val) { _insert(root, val); } void del(int val) { _del(root, val); } int query_rank(int val) { return _query_rank(root, val); } int query_val(int rank) { return _query_val(root, rank); } int query_prev(int val) { return _query_prev(root, val); } int query_nex(int val) { return _query_nex(root, val); }};Treap tr;int main() { srand(0); int t; scanf("%d", &t); while (t--) { int mode; int num; scanf("%d%d", &mode, &num); switch (mode) { case 1: tr.insert(num); break; case 2: tr.del(num); break; case 3: printf("%d\n", tr.query_rank(num)); break; case 4: printf("%d\n", tr.query_val(num)); break; case 5: printf("%d\n", tr.query_prev(num)); break; case 6: printf("%d\n", tr.query_nex(num)); break; } } pause;}
#include <bits/stdc++.h>using namespace std;struct Node { Node *ch[2]; int val, rank; int rep_cnt; int siz; Node(int val) : val(val), rep_cnt(1), siz(1) { ch[0] = ch[1] = nullptr; rank = rand(); } void upd_siz() { siz = rep_cnt; if (ch[0] != nullptr) siz += ch[0]->siz; if (ch[1] != nullptr) siz += ch[1]->siz; }};class Treap {private: Node *root; enum rot_type { LF = 1, RT = 0 }; int q_prev_tmp = 0, q_nex_tmp = 0; void _rotate(Node *&cur, rot_type dir) { // 0为右旋,1为左旋 Node *tmp = cur->ch[dir]; cur->ch[dir] = tmp->ch[!dir]; tmp->ch[!dir] = cur; tmp->upd_siz(), cur->upd_siz(); cur = tmp; } void _insert(Node *&cur, int val) { if (cur == nullptr) { cur = new Node(val); return; } else if (val == cur->val) { cur->rep_cnt++; cur->siz++; } else if (val < cur->val) { _insert(cur->ch[0], val); if (cur->ch[0]->rank < cur->rank) { _rotate(cur, RT); } cur->upd_siz(); } else { _insert(cur->ch[1], val); if (cur->ch[1]->rank < cur->rank) { _rotate(cur, LF); } cur->upd_siz(); } } void _del(Node *&cur, int val) { if (val > cur->val) { _del(cur->ch[1], val); cur->upd_siz(); } else if (val < cur->val) { _del(cur->ch[0], val); cur->upd_siz(); } else { if (cur->rep_cnt > 1) { cur->rep_cnt--, cur->siz--; return; } uint8_t state = 0; state |= (cur->ch[0] != nullptr); state |= ((cur->ch[1] != nullptr) << 1); // 00都无,01有左无右,10,无左有右,11都有 Node *tmp = cur; switch (state) { case 0: delete cur; cur = nullptr; break; case 1: //有左无右 cur = tmp->ch[0]; delete tmp; break; case 2: //有右无左 cur = tmp->ch[1]; delete tmp; break; case 3: rot_type dir = cur->ch[0]->rank < cur->ch[1]->rank ? RT : LF; _rotate(cur, dir); _del(cur->ch[!dir], val); cur->upd_siz(); break; } } } int _query_rank(Node *cur, int val) { int less_siz = cur->ch[0] == nullptr ? 0 : cur->ch[0]->siz; if (val == cur->val) return less_siz + 1; else if (val < cur->val) { if (cur->ch[0] != nullptr) return _query_rank(cur->ch[0], val); else return 1; } else { if (cur->ch[1] != nullptr) return less_siz + cur->rep_cnt + _query_rank(cur->ch[1], val); else return cur->siz + 1; } } int _query_val(Node *cur, int rank) { int less_siz = cur->ch[0] == nullptr ? 0 : cur->ch[0]->siz; if (rank <= less_siz) return _query_val(cur->ch[0], rank); else if (rank <= less_siz + cur->rep_cnt) return cur->val; else return _query_val(cur->ch[1], rank - less_siz - cur->rep_cnt); } int _query_prev(Node *cur, int val) { if (val <= cur->val) { if (cur->ch[0] != nullptr) return _query_prev(cur->ch[0], val); } else { q_prev_tmp = cur->val; if (cur->ch[1] != nullptr) _query_prev(cur->ch[1], val); return q_prev_tmp; } return -1145; } int _query_nex(Node *cur, int val) { if (val >= cur->val) { if (cur->ch[1] != nullptr) return _query_nex(cur->ch[1], val); } else { q_nex_tmp = cur->val; if (cur->ch[0] != nullptr) _query_nex(cur->ch[0], val); return q_nex_tmp; } return -1145; }public: void insert(int val) { _insert(root, val); } void del(int val) { _del(root, val); } int query_rank(int val) { return _query_rank(root, val); } int query_val(int rank) { return _query_val(root, rank); } int query_prev(int val) { return _query_prev(root, val); } int query_nex(int val) { return _query_nex(root, val); }};Treap tr;int main() { srand(0); int t; scanf("%d", &t); while (t--) { int mode; int num; scanf("%d%d", &mode, &num); switch (mode) { case 1: tr.insert(num); break; case 2: tr.del(num); break; case 3: printf("%d\n", tr.query_rank(num)); break; case 4: printf("%d\n", tr.query_val(num)); break; case 5: printf("%d\n", tr.query_prev(num)); break; case 6: printf("%d\n", tr.query_nex(num)); break; } }}
// author: (ttzytt)[ttzytt.com]#include <bits/stdc++.h>using namespace std;// 参考:https://www.cnblogs.com/Equinox-Flower/p/10785292.htmlstruct Node { Node* ch[2]; int val, prio; int cnt; int siz; bool to_rev = false; // 需要把这个子树下的每一个节点都翻转过来 Node(int _val) : val(_val), cnt(1), siz(1) { ch[0] = ch[1] = nullptr; prio = rand(); } inline int upd_siz() { siz = cnt; if (ch[0] != nullptr) siz += ch[0]->siz; if (ch[1] != nullptr) siz += ch[1]->siz; return siz; } inline void pushdown() { swap(ch[0], ch[1]); if (ch[0] != nullptr) ch[0]->to_rev ^= 1; // 如果原来子节点也要翻转,那两次翻转就抵消了,如果子节点不翻转,那这个 // tag 就需要继续被 push 到子节点上 if (ch[1] != nullptr) ch[1]->to_rev ^= 1; to_rev = false; } inline void check_tag() { if (to_rev) pushdown(); }};struct Seg_treap { Node* root;#define siz(_) (_ == nullptr ? 0 : _->siz) pair<Node*, Node*> split(Node* cur, int sz) { // 按照树的大小划分 if (cur == nullptr) return {nullptr, nullptr}; cur->check_tag(); if (sz <= siz(cur->ch[0])) { // 左边的子树就够了 auto temp = split(cur->ch[0], sz); // 左边的子树不一定全部需要,temp.second 是不需要的 cur->ch[0] = temp.second; cur->upd_siz(); return {temp.first, cur}; } else { // 左边的加上右边的一部分(当然也包括这个节点本身) auto temp = split(cur->ch[1], sz - siz(cur->ch[0]) - 1); cur->ch[1] = temp.first; cur->upd_siz(); return {cur, temp.second}; } } Node* merge(Node* sm, Node* bg) { // small, big if (sm == nullptr && bg == nullptr) return nullptr; if (sm != nullptr && bg == nullptr) return sm; if (sm == nullptr && bg != nullptr) return bg; sm->check_tag(), bg->check_tag(); if (sm->prio < bg->prio) { sm->ch[1] = merge(sm->ch[1], bg); sm->upd_siz(); return sm; } else { bg->ch[0] = merge(sm, bg->ch[0]); bg->upd_siz(); return bg; } } void insert(int val) { auto temp = split(root, val); auto l_tr = split(temp.first, val - 1); Node* new_node; if (l_tr.second == nullptr) new_node = new Node(val); Node* l_tr_combined = merge(l_tr.first, l_tr.second == nullptr ? new_node : l_tr.second); root = merge(l_tr_combined, temp.second); } void seg_rev(int l, int r) { // 这里的 less 和 more 是相对于 l 的 auto less = split(root, l - 1); // 所有小于等于 l - 1 的会在 less 的左边 auto more = split(less.second, r - l + 1); // 拿出从 l 开始的前 r - l + 1 个 more.first->to_rev = true; root = merge(less.first, merge(more.first, more.second)); } void print(Node* cur) { if (cur == nullptr) return; cur->check_tag(); print(cur->ch[0]); cout << cur->val << " "; print(cur->ch[1]); }};Seg_treap tr;int main() { srand(time(0)); int n, m; cin >> n >> m; for (int i = 1; i <= n; i++) tr.insert(i); while (m--) { int l, r; cin >> l >> r; tr.seg_rev(l, r); } tr.print(tr.root);}
博客中观看体验更佳
给你一个有 个节点的树,一开始,每个节点都是健康的。每秒钟你可以进行下面两种操作:
现在问你最少需要多少秒才能感染整个树。
看完题我们要注意到,这个题说的是节点可以把病毒传播给他们的兄弟节点,而不是传播给他们的子节点,所以这个树的每一层之间是完全独立的,不可能把病毒从一层传播给另一层。
所以我们肯定需要在一开始的时候就给每个节点的至少一个子节点注射病毒(具体哪个不重要),这样每秒钟能感染的节点更多(根据操作 1)。
那先给谁的子节点注射呢?考虑先被注射的子节点会有更多的时间把病毒传播给更多的子节点。所以我们应该先给子节点更多的节点注射病毒。
(如果先给子节点少的节点注射,那在给所有节点注射完之前,这个节点的所有子节点可能都被感染了,也就是有很多时间被浪费了)。
在确保每个节点都有至少一个子节点被注射后,我们还可以给子节点特别多的那些节点注射病毒,防止有些特别大的节点靠传播来传染特别慢。
当然我们不能直接跟前面一样直接根据子节点数量来排序,并且一直给子节点多的节点注射病毒,因为这样可能会让某个节点以及它的子节点迅速的被完全感染,而其他的节点还需要很长时间。
比如有两个节点,他们在经过注射后的,健康的子节点的数量分别是 100 和 98,如果我们直接排序,先给大的节点注射,那感染完整个树就还需要 秒(/2是因为有传播和注射两种感染方式同时进行,而 98 - 50 之后再 /2 是因为我们先处理完 100 那个节点才去给 98 那个节点注射)。但是如果我们同时给两个节点注射,只需要 秒的时间(因为可以把两个节点看成一个节点,那每秒就有两个子节点因为传播而感染,一个因为注射而感染)。
所以我们可以把健康的子节点数压入堆中,每次注射子节点最多的那个。
不过多解释,代码中有详细注释
#include <bits/stdc++.h>using namespace std;const int MAXN = 2e5 + 10;int n;int siz[MAXN], t;//siz 表示这个节点的子节点的数量int main() { scanf("%d", &t); while (t--) { scanf("%d", &n); memset(siz, 0, sizeof(siz)); for (int i = 1; i < n; i++) { int fa; scanf("%d", &fa); siz[fa]++; } siz[0] = 1; //让 0 号节点连到根节点,所以 0 号节点的子节点数量是 1 sort(siz, siz + 1 + n); int fir_n_zero = -1; //第一个子节点数量不是 0 的节点的下标 fir_n_zero = find_if(siz, siz + 1 + n, [](int a) { return a != 0; }) - siz; priority_queue<int> pq; for (int i = fir_n_zero; i <= n; i++) { //循环中 i 小的节点是后被注射的,可以把 i 理解为倒数第 i 个被注射 pq.push(siz[i] - (i - fir_n_zero) - 1); //给所有节点的某个子节点注射一遍,但是在注射的过程中还会传播,因为传播而被感染的子节点数就是 // i - fir_n_zero,因为注射而被感染的是 1,所以被 push // 进去的数就是经过这一轮注射后,每个树还剩下几个子节点未被感染。 } int tm_used = n - fir_n_zero + 1; //这一轮注射用掉的时间,也就是有子节点的节点的数量 int spreaded = 0; while(pq.top() > spreaded){ //这里的 pq 没有减去因为传播而感染的数量,因为每一个节点都会有传播 spreaded++; //每次都会因为传播而多感染一个 int tp = pq.top(); pq.pop(); pq.push(tp - 1); //每次挑选最大的节点注射 tm_used++; } printf("%d\n", tm_used); }}
最后,希望这篇题解对你有帮助,如果有问题可以通过评论区或者私信联系我。
]]>upd@2022/12/18:感谢@adpitacor和@iterator_traits在评论区中指出的几个错别字,现在已经修正
upd@2022/11/24:这篇文章能入选洛谷日报我也是很惊喜。同时深感以前写的东西水平不太行,不过既然入选了日报,还是希望尽量能把文章的水平提高一点的。显然我不可能把整个文章重写一遍,但是可以加入一些自己这段时间新学到的东西,比如栈溢出攻击和 backtrace 的实现,所以赶在发表之前把他们添加了进来。
除了内容上的变化,这次更新还修复了一些错别字,这主要需要感谢 @cancan123456 在洛谷评论区的提醒。文章发出后有很多洛谷的朋友在评论区提醒和提出建议。比如 @szTom指出了 __fastcall
的影子空间。@LiuTianyou 介绍了强制内联的方法(这次新加的内容就使用了强制内联,以前属实是太菜了不知道有这个东西),以及调用约定中的 thiscall。@小菜鸟 提到了 longjump
函数(和本文中的一次返回多个函数有关)。 非常感谢这些网友,在学会了这些东西后我会考虑把它们陆续加进来,如果你有别的任何建议也欢迎像他们一样在评论区提出。此外,我希望这篇文章尽可能的易懂,所以如果文章中有语言模糊不清或是过于简洁影响理解的地方,欢迎以任何方式联系我修改。
#include<stdio.h>int add2(int a, int b) {return (a + b);}int add1(int a, int b) {return (a + add2(a, b));}int main(){ int c = add1(114, 514); printf("%d\n", c);}
观察这样一段程序,在主函数中会调用一个 add1 函数,对于 add1,它又会去调用一个 add2,然后返回计算的结果,最后才在主函数中执行 printf。
把这段程序中函数执行完成的顺序: ,和开始执行的顺序: 列出来就可以发现。更早开始执行的函数更晚结束,这是因为先开始执行的函数需要用到后开始的函数的结果,所以必须先运行完后开始的函数。
这似乎是一个栈的结构,也就是先进后出的结构,对应到函数调用的场景下,就是先开始执行的就要后结束执行。后开始执行的函数先结束执行。
因为函数调用和栈的相关性,我们可以把每一次函数调用抽象成栈中的一个元素。每次我们都执行栈顶上的函数,如果遇到新调用的函数,就把他推入栈中。而每当一个函数执行完了,就把他弹出栈。
具体可以参考下面这个我用 manim 制作的演示视频:
前面说到可以把函数调用抽象成栈里的一个元素,这个元素就被称之为栈帧 (stack frame),那我们具体要往栈帧里放什么,才能让 cpu 读取了栈帧数据之后就能正确运行函数呢?
首先,在函数中,我们可能会申明一些局部变量,如果我们想要成功的调用一个函数,肯定需要访问函数中的局部变量。
对于传进来的参数,其实也可以看作时一种局部变量。
其次,我们调用的函数执行完之后,需要返回。可是在返回时,计算机并不知道具体应该返回到哪条指令。用前面的例子说,main 函数调用了 add1,可是当 add1 返回的时候,不知道接下来应该执行 printf
还是直接在 main 中 return 0
了。所以我们还需要在栈帧中存一个返回的位置,也就是返回之后应该执行哪条指令。
最后呢,就像在用数组模拟栈的时候一样,需要标记栈的栈顶,这样才知道下一个数据往哪里放。对于单个栈帧,除了栈顶(结束地址),还需要知道这个栈帧的 “栈底(起始地址)”,通过这个起始地址,我们才能知道在弹出这个栈帧的时候,要弹到哪里。
在 x86/x64 架构的计算机中,有两个专门的寄存器来标记当前函数栈帧的起始地址和结束地址,分别为 xbp (base pointer 栈基址/帧指针,本文称之为帧指针) 和 xsp (stack pointer 栈指针)。其中的 x 代表这个字母会变化,它代表的计算机的位宽,如果是 64 位的机器,那就是 rbp 和 rsp,而 32 位机器的是 ebp 和 esp。
下面这张图[1]很好的解释了栈帧在栈中的具体结构:
注:栈的增长方向是高地址到低地址,也就是调用者栈帧的地址比被调用者栈帧的小。
图片的上半部分是调用者的栈帧,可以看到里面存有参数(也就是一种局部变量)。也有当前函数的返回地址,通过这个地址可以找到当前这个函数运行完了应该返回哪里。
返回地址是通过当前栈帧的帧指针确定的,它总是储存在当前帧指针 +8 的位置(在 64 位机器中,如果是图中的 32 位,那就是 +4 的位置)。
下半部分存的是当前函数的栈帧,里面同样存有局部变量。ebp 和 esp 分别标注了这个栈帧的起始和结束位置。
通过帧指针加上一些偏移量,就可以访问到这个栈帧里的局部变量。
看文字解释可能不是很清楚,可以参考下面这个我用 manim 制作的演示视频。演示的是下面这个 C 程序运行时栈帧的变化。
为了演示和理解的方便,假设以下程序中每一行就是 cpu 执行的一条指令,在汇编中实际需要执行更多的步骤(后面会讲)。
int add(int a, int b) {return a + b;}int main(){ int c = add(114, 514); int d = c + 1919;}
虽然前面的解释和视频已经能让大部分人了解函数调用的原理了。但要想要深入了解函数调用的详细过程,还是得看编译后的汇编代码。不过不用担心你看不懂汇编代码,我在这部分写了非常详细的解释。
这里介绍两种查看 C 或 C++ 代码对应的汇编的方法:
gcc -S [文件]
,不过默认输出的汇编是 at&t 风格的,我个人比较喜欢 intel 风格的汇编,如果你也希望输出 intel 风格的汇编,可以加入编译选项 -m asm=intel
进入这个网站,基本的界面是这样的:
这里只讲几个比较基本的选项,但是这个网站的功能是非常强大的,完全是一个线上的 IDE,具体可以看这个视频
可以看到从左到右有一些选项被框起来了,他们的功能分别是:
再回到这张图:
可以看到左边的 C++ 和右边的汇编代码都被用不同的颜色标注了起来,被同个颜色标注的代码就表示它们是对应的。
先放代码:
add: push rbp mov rbp, rsp mov DWORD PTR [rbp-4], edi mov DWORD PTR [rbp-8], esi mov edx, DWORD PTR [rbp-4] mov eax, DWORD PTR [rbp-8] add eax, edx pop rbp retmain: push rbp mov rbp, rsp sub rsp, 16 mov esi, 514 mov edi, 114 call add mov DWORD PTR [rbp-4], eax mov eax, DWORD PTR [rbp-4] add eax, 1919 mov DWORD PTR [rbp-8], eax mov eax, 0 leave ret
int add(int a, int b) {return a + b;}int main(){ int c = add(114, 514); int d = c + 1919;}
在 main 函数中,我们直接通过 int c = add(114, 514);
调用了 add 函数,在 C 中,看似一行指令就能成功调用函数,但这一句话实际需要下面的汇编来实现:
mov esi, 514 ; 把 514 赋给 esi 寄存器,用于给 add 函数传递参数mov edi, 114 ; 把 114 赋给 edi 寄存器,用于给 add 函数传递参数call add ; 调用 call 函数,详见下面的解释mov DWORD PTR [rbp-4], eax; 见下面的解释
这段代码中的前两个 mov 指令都是比较容易理解的,但是 call 指令一句话却做了两件事情。
首先,call 会把 call 指令执行时的 pc 压入栈中(这个程序中就是 mov DWORD PTR [rbp-4], eax
)。然后,它会把 pc 的值改成 add 函数的起始地址。(pc 存的是 cpu 执行的下一条指令的地址)。然后 cpu 就会开始执行 add 函数。
call 的下一句可能比较难以理解,特别是 DWORD PTR [rbp-4]
。其中的 DWORD
表示的其实是一种数据类型,WORD
表示的是两字节的整数,DWORD
,也就是 double word,表示的就是四个字节的整数,所以 DWORD
其实就是 C 中的 int
。
PTR
跟 C 中的解指针操作很像。mov DWORD PTR [rbp-4], eax
。 这一句话就是把 eax 寄存器的值复制到内存中地址为 rbp-4 的位置。并且这个值是四个字节的。所以 eax 的值会被赋值到地址为 rbp-4 到 rbp 的这个范围的内存。
注意 rbp 就是前面说到的帧指针,它保存的是栈帧的开始地址,在函数中,局部变量都是通过帧指针来访问的。而 eax 保存的是 add 的返回值。所以这句话转换成 C 语言就是把 add(114, 514)
的返回值保存到局部变量 c
。
现在再来看 add 函数中的内容:
add: push rbp ; 把 rbp 压入栈中,push 指令会先减少 sp 寄存器的值,然后把要入栈的数据存入 sp 指向的位置 (栈顶) mov rbp, rsp ; 把 rsp 的值赋到 rbp,这表明新的栈帧内没有存任何数据 mov DWORD PTR [rbp-4], edi ; edi 和 esi 存着参数 mov DWORD PTR [rbp-8], esi ; 所以这两行是把参数存入栈中 mov edx, DWORD PTR [rbp-4] ; mov eax, DWORD PTR [rbp-8] ; 把 a 和 b 这两个参数移动到 edx 和 eax 两个寄存器中 add eax, edx ; 等价于 eax += edx pop rbp ; 把栈顶的元素送到 rbp,也就是恢复之前备份的 rbp ret ; 把之前存的返回地址弹出到 pc,以便继续执行 main 函数
你可能会感到奇怪,在之前的解释中,sp 指针会先被减去一个值来分配栈帧的空间。在返回时,sp 的值会被设成 bp 的值来释放栈空间。而在上面的代码中,这些操作都没有被执行。
这一系列骚操作其实都是编译器干的,编译器会优化掉一些不必要的操作。对于第一个操作,sp 指针可以告诉我们下次增加栈帧的时候应该往哪加,防止把之前的栈帧覆盖掉,但是这个 add 函数没有调用任何别的函数,也就是不需要再它的基础上增加任何栈帧,所以给 sp 减一个值来分配空间自然就没有必要了。
对于第二个操作,因为 sp 一直没变,自然也无需在返回时更改 sp 的值。
如果你有兴趣,可以去 Compiler Explorer 的网站上加一个 o2 的编译选项,看下我们平时常用的 o2 优化到底是怎么实现的。如果你去看了,发现编译器居然会提前把 114 + 514 + 1919 的值算好,然后就不调用 add 函数了。。。
假设我们把这个 add 函数改成一个递归的函数,那么刚刚的那些优化就不能加了,要不然就会把之前的栈帧覆盖掉。可以看下这个例子。注意其中的 leave 指令会干两件事。第一是把栈指针指向帧指针(帧指针和栈指针相等就表示当前栈帧没有数据),用于恢复之前分配的内存,第二是恢复备份的栈指针。相当于是 mov rsp, rbp
和 pop rsp
的结合。
看了刚刚的汇编代码,你可能会好奇,有很多种方法可以实现汇编中的函数调用,为什么编译器采取的就是这样特定的一种。比如为什么函数的参数是由 edi 和 esi 寄存器来进行传递的,不是直接压入栈中或者是用别的寄存器来传,又比如栈帧的释放工作既可以由被调用者完成,也可以由调用者完成,但为什么在刚刚的汇编代码中,是让被调用者来释放栈帧的。
其实,这些看似玄学的问题都是有答案的,答案就是函数调用的约定。
函数调用约定,是指当一个函数被调用时,函数的参数会被传递给被调用的函数和返回值会被返回给调用函数。函数的调用约定就是描述参数是怎么传递和由谁平衡堆栈的,当然还有返回值。-- 百度百科
所以这里就来介绍几种比较经典的函数调用约定。如果你自己想要写汇编的话,也可以遵守这些函数调用的规则。
还是先介绍下查看 x86 汇编代码的方法。gcc 编译器默认输出的汇编是 64 位的,如果想让 gcc 输出 32 位的汇编代码,需要加入 -m32
编译选项,经测试,可以在我的电脑中输出 32 位程序(我的电脑用的是 MinGW),也可以在程序中加入 __cdecl
或是 __stdcall
这样的指令来指定函数调用约定。不过在 Compiler Explorer 中,就有些奇怪了,即使加入了 -m32
的编译选项,还是不能指定函数调用约定,所以我把 Compiler Explorer 的编译器换成了 msvc(用 Compiler Explorer 是因为分享代码很方便)。如果你知道为什么在 gcc 中指定了函数调用约定就过不了编译,欢迎在评论区告诉我。
为了对比不同函数调用约定的具体区别,我使用了同一段代码。然后再在 add 函数的前面加入不同的函数调用约定。这里附上 Compiler Explorer 的链接
int add(int a, int b, int c) {return a + b + c;}int main(){ int c = add(114, 514, 1919);}
如果想要指定函数的调用方法为 cdecl,需要这样申明函数:int __cdecl add(int a, int b, int c)
。
cdecl 是 C 语言的默认函数调用方法(32 位时)。它的特点由如下几个:
前面的代码使用 cdecl 约定生成的汇编如下:
_a$ = 8 ; size = 4_b$ = 12 ; size = 4_c$ = 16 ; size = 4_add PROC push ebp mov ebp, esp mov eax, DWORD PTR _a$[ebp] add eax, DWORD PTR _b$[ebp] add eax, DWORD PTR _c$[ebp] pop ebp ret 0_add ENDP_c$ = -4 ; size = 4_main PROC push ebp mov ebp, esp push ecx push 1919 ; 0000077fH push 514 ; 00000202H push 114 ; 00000072H call _add add esp, 12 ; 0000000cH mov DWORD PTR _c$[ebp], eax xor eax, eax mov esp, ebp pop ebp ret 0_main ENDP
注意这几句话:
push 1919 ; 0000077fHpush 514 ; 00000202Hpush 114 ; 00000072H
可以看到函数的参数是以从右到左的顺序被压入栈中的。因为使用了 push 指令,在把数据存入栈中的时候就已经减少了 esp 的值,所以你会发现 add 函数中没有减少 esp 指针的值来开辟内存。
和之前分析函数调用原理的那部分一样(见1.2.5.3),这个函数调用没有备份 ebp 也是因为 add 函数没有调用别的函数,所以被编译器优化掉了。
而 add esp, 12
这句话的作用是释放 add 函数占用的内存。并且这句话是出现在 main 函数中的,可以说明 cdecl 的特点,也就是由调用者来释放内存。
那么这样的约定有什么好处呢?
它最主要的好处就是可以采用变长参数(参数的数量不固定)。我们在 C 中最常使用的变长参数函数就是 printf()
和 scanf()
。printf
的函数申明是这样的:int printf (const char *__format, ...)
后面的那三个点就代表变长参数。如果你对这样的可变参数有兴趣,推荐去看一看这篇洛谷日报。
如果一个程序中有很多地方调用了可变参数函数,每个位置传进去的参数数量可能是不固定的,这就让在可变参数函数内部释放内存变得不现实了。因为在这个函数内部只能释放固定容量的内存,而每次调用需要释放的内存是不同的。如果是让调用者来释放内存的,就可以根据每次调用的参数数量和大小来决定具体要释放空间了。
(当然,你也许可以通过某个寄存器传入需要释的内存大小,或者让被调用函数释放固定的那一部分参数,再让调用者释放可变的那部分参数,不过现在还没有这样的函数调用规则,所以只能在直接编写汇编的时候这样做,而不能在 C/C++ 中指定这样的函数调用约定)。
如果想要指定函数的调用方法为 stdcall,需要这样申明函数:int __stdcall add(int a, int b, int c)
。
stdcall 是绝大多数 Win32 API 使用的函数调用约定。它的特点由如下几个:
前面的代码使用 stdcall 约定生成的汇编如下:
_a$ = 8 ; size = 4_b$ = 12 ; size = 4_c$ = 16 ; size = 4_add@12 PROC push ebp mov ebp, esp mov eax, DWORD PTR _a$[ebp] add eax, DWORD PTR _b$[ebp] add eax, DWORD PTR _c$[ebp] pop ebp ret 12 ; 0000000cH_add@12 ENDP_c$ = -4 ; size = 4_main PROC push ebp mov ebp, esp push ecx push 1919 ; 0000077fH push 514 ; 00000202H push 114 ; 00000072H call _add@12 mov DWORD PTR _c$[ebp], eax xor eax, eax mov esp, ebp pop ebp ret 0_main ENDP
从下面这几句话可以看出,stdcall 的压栈顺序和 cdecl 完全一样,也是从左到右的:
push 1919 ; 0000077fHpush 514 ; 00000202Hpush 114 ; 00000072H
接下来就是和 cdecl 不同的地方了。注意这一句出现在 add 函数中的话:ret 12
它代表着先 add esp, 12
再 ret 0
。也就是先释放掉 12 字节的内存,然后再返回。这句话说明了在 stdcall 中,函数占用的栈是由函数自己释放掉的。
这样做的主要好处就是可以节省程序的大小。如果参数数量一样的话,清栈就是一件重复的事情,没必要每次调用都多写一句话来清栈,直接在函数内部释放空间就好了。
如果想要指定函数的调用方法为 fastcall,需要这样申明函数:int __fastcall add(int a, int b, int c)
。
fastcall 是一种用于提升函数调用速度的函数调用约定。它会利用寄存器来传递参数。不过,不同于 cdecl 和 stdcall,fastcall 的实现并没有一种明确的标准,不同的编译器可能会编译出不同的东西。以下的特点来自于Visual Studio 2022 的标准。
前面的代码使用 fastcall 约定生成的汇编如下:
_b$ = -8 ; size = 4_a$ = -4 ; size = 4_c$ = 8 ; size = 4@add@12 PROC push ebp mov ebp, esp sub esp, 8 mov DWORD PTR _b$[ebp], edx mov DWORD PTR _a$[ebp], ecx mov eax, DWORD PTR _a$[ebp] add eax, DWORD PTR _b$[ebp] add eax, DWORD PTR _c$[ebp] mov esp, ebp pop ebp ret 4@add@12 ENDP_c$ = -4 ; size = 4_main PROC push ebp mov ebp, esp push ecx push 1919 ; 0000077fH mov edx, 514 ; 00000202H mov ecx, 114 ; 00000072H call @add@12 mov DWORD PTR _c$[ebp], eax xor eax, eax mov esp, ebp pop ebp ret 0_main ENDP
观察这几句话:
push 1919 ; 0000077fHmov edx, 514 ; 00000202Hmov ecx, 114 ; 00000072H
可以看到前两个参数,也就是 114 和 154 都是被寄存器传递的,而最后一个参数,也就是 1919,被推入了栈中。这符合前面提到的第一个特点。
而在 add 函数中的 ret 4
又说明了被调用函数释放了内存。因为只有一个参数是在栈中的,其他两个都在寄存器中,所以这个函数只占用了 4 字节的空间,释放掉的空间也自然是 4 字节。
在 x86 的机器中,一共只有 8 个通用寄存器,这就造成了大部分的函数调用都只能使用栈来传递参数,不过这样的速度是比较慢的。在 x64 平台中,一共有 16 个通用寄存器,比 x86 多了 8 个,充足的硬件资源也让我们有机会使用寄存器来传递参数。所以在 x64 平台上,几乎所有的函数调用约定都和 x86 上的 fastcall 相似,也就是尽量使用寄存器传参。
x64 平台下的函数调用约定主要有微软的调用约定和 System V AMD64 ABI 两种。这里我主要介绍 System V AMD64 ABI 约定。此约定主要在Solaris,GNU/Linux,FreeBSD和其他非微软OS上使用。如果你想了解微软的约定,可以参考这个网页。
这个调用约定的代码我就不放了,因为之前解释栈帧和函数调用原理时的汇编代码遵守的就是这个约定。
这个约定又如下的主要特点:
不过了解了这些函数调用规则又有什么用呢?除了更深入的了解函数调用的实现方法,还可以跨语言的调用函数。函数调用约定详细的规定了调用者和被调用者的职责,也规定了参数的传递方法。这样,只要调用者和被调用者都遵守约定,就可以在一个语言中调用另一个语言写成的函数了。比如在 Python 中使用 Ctypes 库调用 C 函数时,就需要指定函数的调用约定来加载动态链接库(dll 文件)。
此外,还有一些别的函数调用约定,如果你有兴趣,可以参考这几个网页:
这个词我不太清楚准确的中文翻译,下面以调用回溯来指代。
调用回溯是一个常用于调试的方法。想象这样的一个场景:我们的程序运行到了某个地方出现了 bug,这个时候大概是希望知道这个 bug 具体是在哪个函数中出现的。不过这还不够,因为有很多不同的位置可能调用同一个函数,所以我们希望知道函数之间的调用关系。gdb 里的 backtrace
(简写为 bt
)命令就提供了这种功能。
考虑下面这样一个程序:
#include <stdio.h>volatile int add1(int a, int b) { int* bug_val = 0; printf("%d\n", *bug_val); // 这里会出 bug return a + b;}volatile int add2(int a, int b) { return add1(a, b); }volatile int add3(int a, int b) { return add2(a, b); }volatile int add4(int a, int b) { return add3(a, b); }int main() { int c = add4(1, 2); return 0;}
add1
函数中尝试取出 0 地址的值会造成段错误(因为 0 就是 NULL),如果我们希望得到这个位置的函数调用关系,就可以先用 gdb 在 add1
上打一个断点,然后使用 bt
命令:
函数关系的信息如下:
#0 add1 (a=1, b=2) at bt_bug.c:3#1 0x00005555555551aa in add2 (a=1, b=2) at bt_bug.c:7#2 0x00005555555551cd in add3 (a=1, b=2) at bt_bug.c:8#3 0x00005555555551f0 in add4 (a=1, b=2) at bt_bug.c:9#4 0x000055555555520d in main () at bt_bug.c:11
这个功能很很有用,那我们有办法自己写一个吗(可以是简化版的,只需要显示地址)?
仔细想一想的话就可以惊喜的发现,我们需要的信息全部藏在栈帧里。
其实诸如 0x00005555555551aa
这类的地址是某个函数的返回地址。参考栈帧的结构图,这个返回地址储存在帧指针指向位置的向上一个单位(对于 64 位机,就是 bp + 8 字节)。而我们只要遍历栈帧,就可以得到所有的返回地址,也就是函数调用关系。
那如何遍历呢?帧指针指向的位置(bp + 0 字节)其实就储存了上一个栈帧的帧指针。只要递归的去查找上一个栈帧的帧指针,我们就能打印出每个函数调用的返回地址了。此过程很好了体现了 backtrace
或是回溯这个名字。
不过到现在,我们还有两个问题没有解决。
对于第一个问题,不同的操作系统有不同的情况。在我的实验环境(Ubuntu 22.04.1 on WSL2)中,这个终止条件是上一个栈帧的帧指针为 0x1
的时候。
我不确定 linux 系统中是否有这样的规定,这个终止条件只是我在调试的时候观察到的。
在别的操作系统中,比如 MIT6.s081 这门公开课使用的教学系统 xv6,栈空间的最大大小就是一个页帧的大小,那么递归到超出该页帧范围的时候就意味着到达了终止条件。
如果你了解一个通用的判断栈帧是否结束的方法,欢迎在评论区留言。
对于第二个问题,一个简单的办法是使用 gcc 的内置函数 __builtin_frame_address
,这个函数可以返回当前函数的帧指针。
不过如果你希望体验一下 gcc 的骚操作(这些都没被包含在 c 语言的标准中),可以使用内连汇编,如下:
#define FORCE_INLINE __attribute__((always_inline)) inlineFORCE_INLINE void* r_bp() { // 读取帧指针 size_t x; asm volatile("mov %0, rbp" : "=r"(x)); return (void*)x; // 注意这里使用的是英特尔的汇编,编译选项里必须加 -masm=intel}
在 "mov %0, rbp" : "=r"(x)
中,mov %0, rbp
是一个汇编的模板,并不是真正的汇编,这有点类似于 C++ 中的模板,在编译的时候会把类型替换掉。gcc 编译的时候也会把 %0
这个东西替换成后面 : "=r" (x)
规定的变量(这里是 x)所在的寄存器。那么这个内联汇编的意思就变成了:“把 rbp
的值存进 %0
所在的寄存器,其中 %0
会被替换成 x
”
如果你对内联汇编有兴趣,可以参考 gcc 的文档,我的这篇文章 中也有些更详细的解释。
上面的代码中,除了这个离谱的内联汇编,还使用了一些很骚的操作:
#define FORCE_INLINE __attribute__((always_inline)) inline
单个 inline
关键字只能向编译器建议内联,不保证一定内联[4]。而 __attribute__((always_inline)) inline
就能让 gcc 强制内联。这里的 __attribute__
还有很多种别的用法,详细内容可以参考文档 和网上的一些博客。
解决这些问题后代码就不是很难了,不过这份代码不可避免的涉及到了很多指针,如果不熟悉的话可以先去学习下。
FORCE_INLINE void* r_bp() { // 读取帧指针 size_t x; asm volatile("mov %0, rbp" : "=r"(x)); return (void*)x;}size_t btrace(void** buffer_arr, size_t size) { // buffer_arr 是一个储存通用指针类型(void*)的数组 // 我们把每个栈帧的返回地址储存在 buffer_arr 里 // size 表示希望回溯的函数调用数量 size_t* cur_frame_addr = (size_t*)r_bp(); // 通过栈指针和帧指针,获取函数调用栈 int i = 0; while (i < size && (size_t)cur_frame_addr != 0x1) { size_t* returning_addr = cur_frame_addr[1]; // 返回地址储存在 bp + 8字节的位置 size_t* prev_frame_addr = cur_frame_addr[0]; // 上个栈帧的 bp 储存在 bp + 0 字节的位置 buffer_arr[i++] = returning_addr; cur_frame_addr = prev_frame_addr; // 递归回溯 }}
刚刚那样的实现只能打印出返回地址,但是 gdb 的调用回溯是可以显示函数名的,那我们有什么办法通过地址显示函数名吗?
一种方法是使用 linux 中的命令行工具 addr2line
,其可以把一个地址转化为函数名,不过我在使用的时候出现了一些问题,没有成功。
还有一种方法是使用 backtrace_symbols
函数,这个函数可以把一个地址数组转化为函数名数组,其包含在 execinfo.h
头文件中,如下:
/* Return names of functions from the backtrace list in ARRAY in a newly malloc()ed memory block. */extern char **backtrace_symbols (void *const *__array, int __size)
需要注意的是,我们在编译的时候需要加上 -rdynamic
选项,这样才能让链接器把符号加入动态符号表(其实我也不太懂,原文如下)。
https://stackoverflow.com/questions/6934659/how-to-make-backtrace-backtrace-symbols-print-the-function-names
The symbols are taken from the dynamic symbol table; you need the -rdynamic option to gcc, which makes it pass a flag to the linker which ensures that all symbols are placed in the table.
然后就可以写出完整代码了:
#include <execinfo.h>#include <stddef.h>#include <stdio.h>#define FORCE_INLINE __attribute__((always_inline)) inlineFORCE_INLINE void* r_bp() { size_t x; asm volatile("mov %0, rbp" : "=r"(x)); return (void*)x;}size_t btrace(void** buffer_arr, size_t size) { size_t* cur_frame_addr = (size_t*)r_bp(); int i = 0; while (i < size && (size_t)cur_frame_addr != 0x1) { size_t* returning_addr = cur_frame_addr[1]; size_t* prev_frame_addr = cur_frame_addr[0]; buffer_arr[i++] = returning_addr; cur_frame_addr = prev_frame_addr; }}volatile int add1(int a, int b) { void* buf_arr[10]; btrace(buf_arr, 10); char** func_names = backtrace_symbols(buf_arr, 10); for (int i = 0; i < 10; i++) { printf("%s\n", func_names[i]); } // 释放 func_names,backtrace_symbols 返回的是一个 malloc 出来的数组 free(func_names); return a + b;}volatile int add2(int a, int b) { return add1(a, b); }volatile int add3(int a, int b) { return add2(a, b); }volatile int add4(int a, int b) { return add3(a, b); }int main() { int c = add4(1, 2); return 0;}
用
gcc backtrace.c -o bt -masm=intel -ggdb3 -rdynamic
编译后,运行 ./bt
,可以得到如下输出:
./bt(add1+0x32) [0x55bfc409e25b]./bt(add2+0x21) [0x55bfc409e2ed]./bt(add3+0x21) [0x55bfc409e310]./bt(add4+0x21) [0x55bfc409e333]./bt(main+0x1b) [0x55bfc409e350]/lib/x86_64-linux-gnu/libc.so.6(+0x29d90) [0x7f5c1391ed90][(nil)][(nil)][(nil)][(nil)]
最后提一嘴,其实写这些纯粹没事找事,因为 execinfo.h
这个头文件里还有个函数就叫 backtrace()
注:思路来自这个视频
栈溢出攻击可以在没有显式调用一个函数的时候执行某个函数,比如下面这个程序:
#include <stdio.h>#include <stdlib.h>void malfunc() { asm volatile("pop rbp"); puts("hello world"); exit(0);}void set_arr() { size_t a[2]; a[0] = 114; a[1] = 514; a[3] = (size_t)malfunc;}int main() { set_arr(); return 0;}
虽然直觉上觉得离谱,但是用下面这个编译选项
gcc stk_ov.c -o stk_ov -fno-stack-protector -ggdb3 -masm=intel
编译执行后,就会发现 hello world
被打印出来了。这样的现象其实是比较危险的,因为通过修改栈,可以直接执行一些恶意代码。不过现代的编译器也知道这种技巧,所以如果我不开 -fno-stack-protector
这个选项,程序是运行不了的。
所以这个程序到底是如何执行 malfunc
的?把 set_arr()
函数的栈帧画出来就能理解了:
低地址a[0]------------------------------------a[1]------------------------------------原本的帧指针(main 函数的帧指针) <--- 当前帧指针,a[2]------------------------------------此函数的返回地址 (main 函数) <--- a[3]高地址
可以发现,这个 a[3]
刚好指向了储存 set_arr
返回地址的位置,所以我们把这个地方改了,自然就会跳转到 malfunc()
中。那 malfunc
里面为啥要加一个 pop rbp
呢?
其实我也不知道
如果不加这一行代码,在 Compiler Explorer 里是可以正常运行的,具体可以看这个链接。
但是如果不加这一行,在本地用刚刚的编译选项就会产生段错误,具体的情况我写在了这个 StackOverflow 的帖子里,如果你知道欢迎在评论区或者 StackOverflow 上回答。
通过刚刚的分析,我们已经非常清楚函数调用的实现原理了。如果要实现一个不递归的 dfs,最简单的方法就是自己模拟汇编中函数调用的过程。
先来看一下一个使用递归的 dfs 是怎么写的,相信大家都很熟悉:
#include<bits/stdc++.h>using namespace std;const int MAXN = 200;vector<int> e[MAXN];int dfs(int cur, int fa){ printf("vised %d\n", cur); for(int nex:e[cur]){ if(nex != fa) dfs(nex, cur); }}int main(){ int n; scanf("%d", &n); for (int i = 1; i <= n; i++) { int from, to; scanf("%d%d", &from, &to); e[from].push_back(to); e[to].push_back(from); } dfs(1, 0); system("pause");}
可以看到,dfs 函数中的局部变量或参数有两个: cur
和 fa
分别表示当前节点和父节点
回想一下一个栈帧的结构,里面包含着局部变量,备份的 bp 以及返回地址(调用时的 pc)。其中,备份的 bp 是为了让 bp 回到调用者的状态而准备的。调用者者需要通过 bp 来正确的访问局部变量。不过,我们可以把单个栈帧封装成一个结构体,然后把整个栈当作类型为栈帧类的数组,再用数组来模拟栈。这样,不需要存 bp 也能正确的访问每个栈帧里的局部变量了。
可以这样写这个结构体来代表栈帧,里面只包含 pc 作为返回地址(或者说当前这个函数执行到了哪里)和局部变量(参数)。对于 dfs,必须要备份 pc 的值,因为当前这个函数还没执行完就要去执行下一个函数了,等到被调用的函数执行好时,我们需要备份的 pc 来继续执行当前的函数 (而不是从头开始执行当前函数):
template <typename PARA_TYPE> // PARA_TYPE 是参数的类型struct Frame{ int pc;//如名 PARA_TYPE paras;//当前栈帧的参数};
然后通过这个结构体来模拟栈的操作:
template <typename FRAME_TYPE>//栈帧的类型struct Mystk{ FRAME_TYPE stk[E_SZ]; int sp;//指向栈顶 Mystk() {sp = 0; memset(stk, 0, sizeof(stk));}//构造函数,用于初始化这个栈 inline void push(FRAME_TYPE x) { stk[++sp] = x;}//这些操作估计都很熟悉了,不解释 inline FRAME_TYPE& top() {return stk[sp];} inline bool empty() {return sp <= 0;} inline bool pop() {return (--sp) <= 0;}};
最后,还有这个结构体,相当于把前面的两个结合了一下
template <typename PARA_TYPE>struct Func_stk{ struct Frame{ int pc; PARA_TYPE paras; inline void my_goto(int line){pc = line - 1;} //自定义的 goto 语句,pc 指向将要执行的指令,直接修改 pc 相当于直接修改下个执行的指令 }; Mystk<Frame> cur_stk; inline void call(PARA_TYPE paras) {cur_stk.push({.pc = 0, .paras = paras});} //新调用一次函数就相当于新把一个栈帧推入栈中,并且刚刚调用的时候,这个函数应该执行第一行。 inline void ret() {cur_stk.pop();} //返回一个函数就相当于在栈中弹出一个栈帧};
有了这些结构体,要如何在 dfs 函数中使用呢?只要模拟汇编中函数调用的过程,就一定不会出问题,我们可以根据下面这些条件来写出非递归的 dfs。
Func_stk
的 call
函数Func_stk
的 ret
函数根据 pc 执行不同的语句可以这样实现
然后就可以写出下面的代码:
void dfs(int cur, int fa){ Func_stk<Dfs_paras> dfs_stk; dfs_stk.call({cur, fa}); //压入第一个栈帧 Func_stk<Dfs_paras>::Frame *cur_frame = &dfs_stk.cur_stk.top();//指向当前栈帧的指针 for (; !dfs_stk.cur_stk.empty(); cur_frame->pc++, cur_frame = &dfs_stk.cur_stk.top()) //只要栈帧不为空就一直循环下去,每执行完一条指令把当前栈帧的 pc++, //也就是如果某个时候有一个函数想返回当前这个函数,那每当前这个函数每执行完一条指令,返回的位置都要增加 1 //cur_frame = &dfs_stk.cur_stk.top() 用来确保指向当前栈帧的指针一定指向栈顶的栈帧 { if (cur_frame->pc == 0)//dfs函数的第一条指令是打印当前访问的节点,会在 pc=0 的时候被执行 printf("vised %d\n", cur_frame->paras.cur); else if (cur_frame->pc <= e[cur_frame->paras.cur].size()){ //如果 pc 小于等于跟这个节点相连的边的数量 //那肯定还没有完全访问完跟这个节点相连的子树, if (e[cur_frame->paras.cur][cur_frame->pc - 1] != cur_frame->paras.fa){//所以如果下一个节点不是自己的父节点,就继续 “递归” dfs_stk.call({.cur = e[cur_frame->paras.cur][cur_frame->pc - 1], .fa = cur_frame->paras.cur}); } } else{ dfs_stk.ret();//如果 pc 的值大于 e[cur_frame->paras.cur].size() 了,就说明和这个节点相连的子树已经全部访问完了,所以需要返回 } }}
下面是完整代码,欢迎大家赋值下来去自己的电脑上试一试:
#include <bits/stdc++.h>using namespace std;const int E_SZ = 200; // 最大边数struct Dfs_paras{ int cur, fa;};vector<int> e[E_SZ];template <typename FRAME_TYPE>struct Mystk{ FRAME_TYPE stk[E_SZ]; int sp;//指向栈顶 Mystk() {sp = 0; memset(stk, 0, sizeof(stk));}//构造函数,用于初始化这个栈 inline void push(FRAME_TYPE x) { stk[++sp] = x;}//这些操作估计都很熟悉了,不解释 inline FRAME_TYPE& top() {return stk[sp];} inline bool empty() {return sp <= 0;} inline bool pop() {return (--sp) <= 0;}};template <typename PARA_TYPE>struct Func_stk{ struct Frame{ int pc; PARA_TYPE paras; inline void my_goto(int line){pc = line - 1;} }; Mystk<Frame> cur_stk; inline void call(PARA_TYPE paras) {cur_stk.push({.pc = 0, .paras = paras});} inline void ret() {cur_stk.pop();}};void dfs(int cur, int fa){ Func_stk<Dfs_paras> dfs_stk; dfs_stk.call({cur, fa}); Func_stk<Dfs_paras>::Frame *cur_frame = &dfs_stk.cur_stk.top(); for (; !dfs_stk.cur_stk.empty(); cur_frame->pc++, cur_frame = &dfs_stk.cur_stk.top()) //执行当前dfs函数,每次pc都要++ { if (cur_frame->pc == 0) printf("vised %d\n", cur_frame->paras.cur); else if (cur_frame->pc <= e[cur_frame->paras.cur].size()){ if (e[cur_frame->paras.cur][cur_frame->pc - 1] != cur_frame->paras.fa){ dfs_stk.call({.cur = e[cur_frame->paras.cur][cur_frame->pc - 1], .fa = cur_frame->paras.cur}); } } else{ dfs_stk.ret(); } }}int main(){ int n; scanf("%d", &n); for (int i = 1; i <= n; i++){ int from, to; scanf("%d%d", &from, &to); e[from].push_back(to); e[to].push_back(from); } dfs(1, 0); system("pause");}
观察原来的 dfs 函数
int dfs(int cur, int fa){ printf("vised %d\n", cur); for(int nex:e[cur]){ if(nex != fa) dfs(nex, cur); }}
不难发现新调用的函数和当前函数有一个相同的参数,那就是 cur
。也就是说,下一个被调用的函数的 fa
参数就是当前函数的 cur
参数。所以我们完全可以在判断 nex != fa
的时候不适用 fa
,而是直接去访问上一个栈帧中的 cur
参数,具体写法的话,可以把 fa
改成这样:dfs_stk.stk[dfs_stk.sp-1].paras
(paras 是一个 int
,因为不需要再在参数中包含 fa
了)。
这样就可以省下一部分空间了。
纯教学意义,加深对于函数调用实现原理的理解,没有实际用途其实还是有点用的
为了更准确的对比非递归 dfs 和正常的写法,我使用 python 加洛谷的 CYaRon 测试数据生成器(强烈推荐,真的方便)生成了 10 个测试点。每个测试点都是一个节点数量为 的树。
输入数据生成器的代码:
from cyaron import *def generate(): MX_PT = int(1e6) for _ in range(1, 11): test_data = IO(file_prefix="tree", data_id=_) cur_tree = Graph.tree(MX_PT) test_data.input_writeln(MX_PT - 1) test_data.input_writeln(cur_tree)if __name__ == "__main__": generate()
答案生成器:
#include<bits/stdc++.h>using namespace std;const int MAXN = 1e6 + 5;vector<int> e[MAXN];void dfs(int cur, int fa){ printf("%d\n", cur); for(int nex:e[cur]){ if(nex != fa) dfs(nex, cur); }}int main(){ for(int fid = 1; fid <= 10; fid++){ string cur_name = "tree" + to_string(fid); for(int _ = 0; _ < MAXN; _++) e[_].clear(); freopen((cur_name + ".in").c_str(), "r", stdin); freopen((cur_name + ".out").c_str(), "w", stdout); int n; scanf("%d", &n); for (int i = 1; i <= n; i++) { int from, to, none; scanf("%d%d%d", &from, &to, &none); e[from].push_back(to); e[to].push_back(from); } dfs(1, 0); }}
随后在洛谷上开了个题目,然后把数据传上去了。之后所有的测试均使用这个题目。
理论上来说,经过刚才的优化,非递归 dfs 的空间占用应该会比正常写法小大约 4MB (每个栈帧中都少了一个 int
,最多能有 个栈帧),以及 bp 的大小(见前文,使用结构体封装栈帧,不需要记录 bp)。
想到这里,我赶紧去把常规写法的 dfs 交了一下,以便等下可以对比数据来体现我这个写法的高明。
结果如下:
时间(s) | 空间(MB) |
---|---|
9.06 | 55 |
详见提交记录
那实际上呢?
一顿操作猛如虎,一看空间 62(MB)。一顿操作猛如虎,一看时间 9.2(s)
不仅空间不降反增,时间也更长了。
详见提交记录
为啥呢?
经过我一段时间的思考,感觉多出来的空间占用是栈的问题。虽然单个栈帧占用的空间更少了,但我是使用数组模拟栈的,弹出的栈帧不能被释放掉,而是还留在内存中。而且很多开出来的内存是空的,并没有被使用。在常规的 dfs 中,弹出一个栈帧后,内存立刻就被释放掉了。可是如何证明这个呢?也许我可以不使用数组模拟栈,而是使用一个真正的栈,只要一个栈帧被弹出,就把它占用的内存释放掉。
要达到这一点,可以选择 stl 的 stack
。事实证明,使用 stl 后,空间占用和常规的写法完全一样,可是时间就比较一言难尽了,毕竟是 stl,达到了 10.26 秒。提交记录
至于为什么没有比常规的写法占用更少的内存,我就不是很清楚了,如果你知道,欢迎在评论区告诉我。
现在我们已经了解了空间占用的问题,可为什么时间会更慢呢?理论上来说,这样模拟的函数调用,应该会比正常写法的 dfs 快一点。因为我弹出或者推入一个栈帧只需要把栈顶指针 ++
或者 --
。而常规的 dfs 则需要一堆繁琐的步骤(见 1.2.2 和 1.2.3)。
我想了挺久还是没想出来,还是看下汇编吧。
左边和右边被我圈出来的是互相对应的代码段,乍一看这好像也没什么问题,函数在汇编里被正确调用了。
可是我明明在写这些函数的时候加了 inline[3] 啊(如下图)。
inline FRAME_TYPE& top() {return stk[sp];}inline bool empty() {return sp <= 0;}
如果说 inline 没起作用的话,那化简掉的函数调用在这里就还回来了,甚至还增加了函数调用。
这也提醒了我们,inline 关键字只是建议编译器把函数改成内联函数,如果编译器觉得函数比较复杂,是可以不内联的。(但是这函数真的超级简单啊,为什么不内联。。。)
所以我把这些内联函数全部换成了宏定义,这样就是真正的“内联”了,效果如下:
时间(s) | 空间(MB) |
---|---|
8.83 | 63.16 |
提交记录
提升了 0.2 秒左右。不过为了这 0.2 秒多写几十行代码就。。。
我们知道在常规的 dfs(或是其他递归函数)中,return
一次,只会返回到调用这个函数的函数。这是因为执行一次 return
会弹出一个栈帧。但既然我们能通过模拟栈帧的方法,完全把函数调用的过程掌握在我们自己的手中,为什么不能一次弹出多个栈帧呢?虽然听起来挺离谱的,但是也许在某些时候会有些用处。
比如,如果我们想通过递归来暴搜出某一个答案,现在在某一层递归中,答案已经找到了。正常情况下,我们需要一层一层的退出递归调用。而使用模拟栈帧的方法,我们可以直接把前面所有的栈帧都弹出,或者更直接一点,直接从模拟栈帧的循环中 break
出来。
为了测试这个骚操作对性能的提升,我又在洛谷上传了一道题目。题目大概是给你一个 的网格,每个格子都可以是 或是 ,分别表示不可以走和可以走,问你能否从 ,通过八个方向的移动,到达 。并且,在搜索的过程中,需要按照 dfs 的顺序输出访问的位置。
输入数据生成器如下:
#include<bits/stdc++.h>using namespace std;float valid_possiblity = 0.7;const int MAXN = 5000;int main(){ for(int _ = 1; _ <= 10; _++){ string f_name = "test" + to_string(_); freopen((f_name + ".in").c_str(), "w", stdout); printf("%d %d\n", MAXN, MAXN); // int endx = rand() % MAXN; // int endy = rand() % MAXN; printf("%d %d\n", MAXN, MAXN); for(int i = 1; i <= MAXN; i++){ for(int j = 1; j <= MAXN; j++){ if(i == 1 && j == 1 || i == MAXN && j == MAXN){ printf("1 "); continue; } if(double(rand()) <= double(RAND_MAX) * valid_possiblity){ printf("1 "); } else{ printf("0 "); } } printf("\n"); } }}
如果使用的非递归的 dfs,在发现能够到达 点之后就可以立刻退出搜索,而正常的 dfs 会需要一层一层的退出。
所以,也许非递归的 dfs 会快一点?
具体的结果可以见下表:
常规 dfs | 非递归 dfs+数组模拟栈 | 非递归 dfs+stl stack | |
---|---|---|---|
提交记录 | 记录 | 记录 | 记录 |
时间(s) | 7.77 | 9.53 | 10.50(时间超限) |
空间(MB) | 512+(内存超限) | 335.54 | 187.01 |
结果还是挺出乎我意料的。在最后一个点中,常规 dfs 因为内存超限被卡掉了,但是前面的点中,常规 dfs 都比非递归的快,不管是用数组模拟栈的还是使用 stl stack 的。
对比 stl stack 的非递归 dfs 和常规 dfs,可以发现在这个问题中,使用非递归 dfs 对节省内存有比较显著的作用。(至于为什么用数组模拟的内存占用看起来很大,已经在前面解释过了)。
不过这些测试还是不能较好的展现逐层返回和直接返回的区别,所以我使用了 chrono 库(精度比 clock()
更高,可以获取纳秒级别的时间)来测量函数返回的时间占用。
结果就比较一言难尽了,高情商的说法是直接返回的返回速度比逐层返回快了约 倍,低情商的说法是逐层返回的时间占用也就 纳秒 ( 毫秒)。当然,函数的返回速度也跟返回值类型有关,每次传递返回值都需要一定的时间,如果递归的层数特别多,并且返回值类型非常大,使用直接返回也许就能产生显著的效果了(这样的情况似乎基本上没有呢)。
这个骚操作已经在前面小优化的部分提到过了,因为所有函数的栈帧都储存在一个栈里,如果你用的是数组模拟栈,那就可以访问到之前被调用的函数的局部变量。在一些场景中,比如之前讲到的树的 dfs 遍历,就可以用到这个方法节省空间。至于别的用途我还真没想到,如果你有想法的话欢迎在评论区分享。
此外,就像是我们能弹出任意数量的栈帧一样,如果你愿意,用模拟栈帧的方法,你还可以在一个函数中同时调用任意数量的函数,也就是压入任意数量的栈帧。当然我也没想出来如何利用这种阴间操作。
总的来说,非递归 dfs 的教学意义是大于实际意义的。虽然有的时候非递归 dfs 可以带来一些常数提升,但是会需要更多的时间写出非递归 dfs。而且这一点微弱的常数提升在 O2 的加持下也变的没有意义了。除非一个题目非常的卡常,还不能使用 bfs 和 O2,不然最好还是不要写这种奇怪的东西。
所有非递归 dfs 能带来的优化是建立在递归这种特殊的函数调用的基础上的。在递归中,每次函数调用的栈帧都有着相同的结构,相同的大小,所以我们才能使用结构体把栈帧封装起来,并简化函数调用的过程。
之前提到的骚操作也是因为我们对函数调用有了完全的控制,可以随意访问栈中的内存,并且弹出和压入任意数量的栈帧。如果函数调用不是在递归中的,那我们就不知道每个栈帧的长度和结构,自然也没法实现这样的操作。
最后,如果你有问题或是建议,都欢迎在评论区分享或者是联系我。
博客中观看体验更佳
给你两个字符串, 和 ( 和 的长度都不超过 )。再给你一些询问 ( 询问数量不超过 ),每个询问为小写字母 'a'
到 'r'
的子集,对于每个询问,请你回答在 串和 串只包含询问中给定的字母时是否相等。
很容易想到暴力的方法,对于每个询问,我们可以只考虑包含在集合中的字符,然后对比两个字符串。当然,这样我们就会需要对于每个询问重新遍历一遍字符串,复杂度也会到达 ( 为询问的数量和字符串的长度)。通过这个方法,我们可以拿到这道题的部分分。
不过如何才能拿到其他分数呢?
直接解决这个问题可能太复杂了,我们可以试试看化简一下这个问题,再把化简过的解法推广到原问题。
我们首先考虑询问中只包含两个字母的情况。设这两个字母为 a 和 b。那么我们如何判断两个只包含 a 和 b 的字符串是否相等呢?
首先需要考虑的肯定是两个串中 a 和 b 的数量是否相等,如果 a 和 b 的数量不等,那这两个串一定不一样。
其次,我们得考虑字符串中每个 a 和 b 的位置,如果位置和数量都对了,这两个串就一定相等了。
判断 a 和 b 的位置时,我们肯定不能直接看它们的下标是否相等,因为我们比较的是这两个字符串中只包含 a 和 b 时的位置,而把其他字符删除后,它们的下标一定会变化。
删除其他字符后,每个字符串的下标其实就是它前面 a 的数量加上 b 的数量(其他的字符都被删除了)。
当然,依次判断每个 a 和 b 的下标太废时间了,我们可以做一些优化。比如我们只需要判断 a 和 b 中一个字符的位置是否全部相等就行了。因为两个字符串 a 和 b 的数量相等,所以只要确定了其中一个字符的位置,另一个的就能确定了(所有不是 a 的位置肯定都是 b)。
这个判断过程其实还可以进一步简化,我们可以只考虑 a 前面 b 的数量,考虑这样一个字符串: "baa"
。可以发现如果把 a 前面 b 的数量当作 a 的下标,那么这两个 a 的下标都是一样的。如果我们交换这两个 a,这个字符串还是一样的,所以这两个 a 的下标一样并不会对我们判断 a 的位置产生影响。
总结一下,只包含两个字符的字符串(假设这两个字符为 a 和 b),如果是相等的,一定满足:
可是我们为什么要用这样的方法求字符串是否相等呢?
因为通过前缀和的方法,我们可以很快的速度处理出这两个字符串只包含两个字符时是否相等。
考虑前文中提到的两个条件。要求出每个 a 前面 b 的数量是否相等(注意这里的 a 和 b 可以是任何字符),我们需要快速求出:
对于第一个问题,我们可以使用前缀和来预处理。
我们开两个数组 char_sum_s[i][j]
和 char_sum_t[i][j]
,分别表示在串 和 中,从下标 到下标 为止(包括 )有多少个字符 。
然后使用下面这段代码求出:
for(int i = 0; i < s.length(); i++){ char_sum_s[i][s[i] - 'a'] = 1; // 打上标记}for(int i = 1; i < s.length(); i++){ for(int j = 0; j < 20; j++){ // 枚举字符 char_sum_s[i][j] += char_sum_s[i - 1][j]; // 求前缀和 }}
对于第二个问题,我们开两个 vector char_pos_s[i]
和 char_pos_s[i]
, 分别表示串 和 中,字符 的所有位置,并且使用下面的代码求出:
for(int i = 0; i < s.length(); i++){ char_pos_s[s[i] - 'a'].push_back(i);}
现在我们已经能快速求出只包含两个字符时两个串是否相等了,下面我们来考虑如何把它用到原问题中。
假设两串原本是一样的,有下面几种方法使它们变得不一样:
注:为了方便,我们把判断两字符串在只包含字符 a , b 时是否相等的函数记为 isok(a, b)
对于前面两种改变方式,两串中每种字符的数量肯定会改变。假设增加或删除的字符为 a,那么 isok(a, 其他任何字符)
返回的一定是 false
,这是因为在 串中和 串中,字符 a 的数量不相等了。
现在考虑交换这种改变方式,假设交换的字符为 a 和 b, 那么 isok(a, b)
返回的也一定是 false
。因为在 和 中,仅由 a 和 b 组成的字符串一定不相等。
所以,对于每个询问,我们只需要枚举询问中包括的两个不同的字符,然后判断 串和 串在只包含这两种字符时是否相等就可以了。
注意我们需要把每个 isok(a, b)
的结果存下来,这样下此使用时就不需要重新算了。
isok(a, b)
:因为需要知道每个 a 前面 b 的数量,所以需要枚举 a,复杂度就为 a 的数量。isok(a, b)
: ( 为字符串长度) 因为要枚举所有的 a 和 b 来计算 isok(a, b)
,而所有的 a 和 b 的数量和一定是字符串长度。isok(a, b)
了,所以在枚举枚举询问中包括的两个不同的字符时,只需要返回结果,复杂度为 。而一共要枚举 次。有详细注释,相对来说还是比较快的,提交记录。
/*Date: 22 - 03-26 16 22PROBLEM_NUM: Subset Equality*/#include<bits/stdc++.h>using namespace std;const int MAXN = 1e5 + 10;string s, t;int q;vector<int> char_pos_s[20], char_pos_t[20];int char_sum_s[MAXN][20], char_sum_t[MAXN][20]; short isok_result[20][20];bool ans[MAXN];bool isok(char a, char b){// 判断 s 和 t 串在只包含 a 和 b 的情况下是否等价 if(isok_result[a][b] != -1 || isok_result[b][a] != -1){// 如果之前已经计算过了,直接返回结果,注意 isok(a, b) == isok(b, a) return isok_result[a][b]; } if(a == b){// 如果 a 和 b 相等,则返回这个字符在两串中出现的次数是否相等 return isok_result[a][b] = (char_pos_s[a].size() == char_pos_t[a].size()); } if(char_pos_s[a].size() != char_pos_t[a].size() || char_pos_s[b].size() != char_pos_t[b].size()){// 如果 a 和 b 的个数在 s 串和 t 串中不相等,返回 false return isok_result[a][b] = false; } vector<int> b_cnt_s;//s串中,某个 a 前面的 b 的数量 for(int cur_apos : char_pos_s[a]){ // 枚举 s 串中 a 的位置 b_cnt_s.push_back(char_sum_s[cur_apos][b]); } for(int i = 0; i < char_pos_t[a].size(); i++){// 枚举 t 串中 a 的位置,对比 t 串中 a 前面 b 的 // 数量是否和 s 串中 a 前面的 b 的数量相等 if(char_sum_t[char_pos_t[a][i]][b] != b_cnt_s[i]){ return isok_result[a][b] = false; } } return isok_result[a][b] = true;}void pre_proc(){ for(int i = 0; i < s.length(); i++){ char_pos_s[s[i] - 'a'].push_back(i); char_sum_s[i][s[i] - 'a'] = 1; // 打标记 } for(int i = 0; i < t.length(); i++){ char_pos_t[t[i] - 'a'].push_back(i); char_sum_t[i][t[i] - 'a'] = 1; } for(int i = 1; i < s.length(); i++)// s 串前缀和 for(int j = 0; j < 20; j++) char_sum_s[i][j] += char_sum_s[i - 1][j]; for(int i = 1; i < t.length(); i++)// t 串前缀和 for(int j = 0; j < 20; j++) // j 为字符 char_sum_t[i][j] += char_sum_t[i - 1][j]; for(int i = 0; i < 20; i++) for(int j = 0; j < 20; j++) isok_result[i][j] = -1;// 没计算过的时候,设置为 -1}int main(){ ios::sync_with_stdio(false); cin>>s>>t>>q; pre_proc();// 预处理 for(int i = 1; i <= q; i++){// 枚举每个询问中的每个字符 string cur_query; cin>>cur_query; ans[i] = true; for(char char_a : cur_query){ for(char char_b : cur_query){ if(!isok(char_a - 'a', char_b - 'a')){ // 如果有一个 isok(a, b) == false,则说明不等价 //(当 s 和 t 只包含询问中的字符时) ans[i] = false; break; } } if(!ans[i]) break; } } for(int i = 1; i <= q; i++){ if(ans[i]) cout<<"Y"; else cout<<"N"; } system("pause");}
最后希望这篇题解能帮到你。如果有看不懂的,或者是发现题解有问题,欢迎通过评论区和私信联系我。
]]>有 个奶牛,奶牛 想访问奶牛 。如果 已经离开去访问别的奶牛了,则 不能成功访问 ,否则,这次成功访问可以增加 次哞叫。 现在让你找出可能的最大哞叫次数
理解题目后,我们可以首先分析下样例,试试看找一些有用的信息。
为了方便分析样例,我们可以把样例用图的形式展示,图中有向边连接的两个节点就是一头牛和这头牛希望访问的牛 ( 和 )。而边权是这次访问能产生的哞叫次数。
通过这张图,我们可以发现,不管以什么样的顺序访问,最多都只能成功的访问三次,最后的一次访问一定会遇到之前已经遇到过的牛,所以选择 , 和 可以达到最大的哞叫次数,也就是 次。
再仔细思考这个样例,可以发现不能同时选四条边的本质原因是这样会在图中产生一个环。如果图中有环,并且必须要经过环上的每一条边,那么我们必然会访问到之前访问过的节点。
而如果我们能从原来的图中选出一些边,建出一个没有环的图,那么就一定能找出一种访问顺序,使得我们在遍历所有节点时不会重复访问节点。在不构成环的前提下,我们还需要尽量选择边权大的边,这样就能满足题目的要求——产生最多的哞叫次数。
没有环的,权值最大的图?好像跟最小(大)生成树很相似。
分析到这里,我们就比较容易想到使用最小(大)生成树算法了。通过这类算法,我们可以在图中找出权值最大的树。不过,这还是跟这道题不完全一样。我们还需要解决下面这个问题
(这部分如果理解了可以直接看代码),代码就是个标准的 kruskal
换一种说法解释这个问题就是:从有向图转换来的无向图是否和原图等价?
比如上图这样的情况,不管使用什么访问顺序,三条边我们都是可以选的。但是转换成了无向图之后,就只能选择两条边了(选三条边会产生环)。
在题目中,每一奶牛只有一个想访问的奶牛,也就是说图中的每个节点出度都是 ,在这样的条件下,上图中的情况就是不可能出现的(上图中节点 的出度为 ),并且转换出来的无向图和原图也是等价的。
那为什么只有入度大于 时才会导致转换之后的无向图不等价于原来的有向图呢?
我们知道如果有 个节点,要把这 个节点包含在环中的最少边数是 个。并且这 个节点里面的每个节点的出度和入度都等于 。就和样例中的一样。
一个边可以产生一个出度和一个入度。所以这个环里总共有 个度。如果我们允许一些节点的出度大于 ,那么有一些节点的入度可能是 了(度的和一定为 ,那出度增加了入度就一定会减少),这样一来,入度为 时没有别的节点能到达这个节点,自然就不能产生环。
但是如果把直接转换成无向图,出度和入度的总和还是 ,每个节点的度也是 ,所以能构成环。
这样一来,在转换时就会产生问题了。
这里我才用的是 Kruskal 来求的最大生成树,相较于这题的思维,代码还是比较简单的,只要把最小生成树中的排序改一下。
如果不熟悉最小生成树的算法,可以参考模板题里的题解
需要注意的是权值和可能会超过 int
的范围,需要开 long long
。
/*Date: 22 - 03-26 15 28PROBLEM_NUM: USACO MAR Problem 1. Visits*/#include <bits/stdc++.h>using namespace std;const int MAXN = 2e5 + 10;#define ll long longstruct E{ int from, to, val;} e[MAXN];int n;int fa[MAXN];int find_fa(int cur){ if (cur == fa[cur]) return cur; return fa[cur] = find_fa(fa[cur]);}void merge(int a, int b){ int af = find_fa(a), bf = find_fa(b); fa[af] = bf;}//并查集操作ll ans;int main(){ scanf("%d", &n); iota(fa + 1, fa + 1 + n, 1);//最开始 fa[i] = i for (int i = 1; i <= n; i++) { scanf("%d%d", &e[i].to, &e[i].val); e[i].from = i; } sort(e + 1, e + 1 + n, [](E a, E b) { return a.val > b.val; });//权值大的放前面 int used_edge = 0; for (int i = 1; i <= n; i++)//kruskal { if (find_fa(e[i].from) != find_fa(e[i].to)) { used_edge++; ans += e[i].val; merge(e[i].from, e[i].to); if (used_edge == n - 1) { break; } } } printf("%lld\n", ans); system("pause");}
最后希望这篇题解能帮到你。如果有看不懂的,或者是发现题解有问题,欢迎通过评论区和私信联系我。
]]>目录:
博客中观看体验更佳
给你 个二维的向量,对于任意一个 ,求出有多少选取的方案能满足在这 个向量中选 个,并且他们的和为 。
看到这个题我们可以比较快的想到拿部分分的做法,就是暴力枚举所有的选取方案,然后看他们加起来是否等于目标向量,再把符合要求的方案累加到答案中。但是我们发现 ,并且这个算法的复杂度是 的,所以一定会超时。
这道题中的折半搜索指的是把 种向量分成两部分,对这两个部分分别用暴力的方法求出所有可能的选取方案,再把这些选取的方案,以及这个方案的结果(他们的和)按照某种方案储存下来,最后匹配这两部分的方案,把符合题目要求的(和等于 )的)累加进答案里。
要达到前面提到的效果,我们可以使用 STL 的 map 或者是 unordered_map (手写哈希表也可以,但是可能要花更多时间来实现)来储存每个选取的方案。这里推荐使用 unordered_map,因为 map 的时间复杂度是 的,在这道题中会被卡,而 unodered_map 和手写哈希表的理想时间复杂度是 (当然使用 unodered_map 的话也要确保哈希函数写得好才不会被卡,比如我现在写的就过不了)。
对于暴力枚举状态的部分,一般的方法是 dfs ,也比较好写,这篇题解的双指针部分讲了一个比较奇怪的方法,想看的可以跳到下面。
注:下文的 map 指 unodered_map 或是 map 或是 手写的哈希表
我们首先创建两个 map ,fir
和 sec
,分别储存前半部分向量的选取方案和后半部分的选取方案。对于这两个 map 的键值,我们可以设成包含 三个整数的结构体。其中 表示当前这个方案下,选取的所有向量的和。 表示当前的这个方案一共选取了多少个向量。因为可能有很多方案的 都完全一样,所以对于 map 的值,我们设置成 这三个值都相同的方案数。
对于答案的储存,我们开一个 ans[n]
的数组,ans[i]
表示 时的选取方案数。
找到两部分的方案后我们要把符合要求的方案组合累加到答案中。具体来说,每个在 map 中储存的方案都包含一个这个方案的向量和 。 我们设一个从前半部分向量得到的方案的向量和为 , 一个后半部分方案的向量和为 ,那么如果 ,我们就把这个方案记录进答案。
因为我们使用了 map ,所以并不需要真的用双层循环把每个方案都遍历一遍。我们知道对于一个可能的匹配方案 。那么我们可以使用一个双层循环,一层枚举 fir 这个 map,一维枚举当前处理的 。然后在循环里,我们可以写:
ans[当前k] += it_fir.值 * sec[{x_g - it_fir.键.x, y_g - it_fir.键.y, 当前k - it_fir.键.k}]
其中 it_fir
指的是 fir 的迭代器。而 sec[{x_g - it_fir.键.x, y_g - it_fir.键.y, 当前k - it_fir.键.k}]
中的方案和 it_fir
遍历到的方案符合 ,并且这两种方案选取的向量数的和也等于当前 k 。因为这两个 map 的值是所有符合这些条件的方案的数量,所以我们把这两个值相乘以求出所有符合要求的匹配数量。
双指针的理论复杂度似乎是和哈希表一样的,但是如果哈希函数有问题的话,哈希表的速度就会慢很多。而双指针就没有这个问题。
我们可以首先开两个 vector ,fir 和 sec ,其数据类型和之前 map 的一样,也是 。fir 储存由前半部分的向量得来的方案,sec储存后半部分的。
枚举完所有方案后我们对这两个 vector 进行排序,排序规则如下:
if(a.x != b.x) return x < b.x;if(a.y != b.y) return y < b.y;return a.k < b.k;
然后我们创建两个指针, 的初值设成 1, 的初值设成 sec.size() - 1
,代表 sec
的最后一个元素。此时 指向的是 fir
中最小的元素,而 指向 sec
中最大的元素。我们可以想一下,如果想要让当前的这两个指针所指的
的和等于 ,需要怎么做。如果我们想要增加 的值,只能使 增加,因为 已经指向这个数组里最大的元素了。反过来也是一样的,如果我们想要减少 的值,也只能使 减少。写成程序就是下面这样:
int p1 = 0, p2 = sec_half.size() - 1;while(p1 < fir_half.size() && p2 >= 0){ Instruct &f = fir_half[p1], &s = sec_half[p2]; if(f.x + s.x < tar_x ||(f.x + s.x == tar_x && f.y + s.y < tar_y)){ //如果两个向量相加小于目标值,我们只能加 p1 的值, //因为 p2 指向的元素最开始就是最大的。 p1++; } else if(f.x + s.x > tar_x ||(f.x + s.x == tar_x && f.y + s.y > tar_y)){ //如果两个向量相加大于目标值,我们只能减 p2 的值, //因为 p1 指向的元素最开始就是最小的。 p2--; } //下面后半段代码插入的位置}
注:Instruct
是包含 {x, y, k}
三种整数的结构体。
通过这样的方法,我们最终就一定能找到 的情况。不过呢,这两个数组中都可能有连续的一段是完全一样的值,也就是有多个 , 满足 。因此我们需要找出符合条件的这个连续段具体是什么。通过上面的代码,我们已经知道了,满足条件的最小 和最大的 ,因为我们希望找出连续段的具体范围,所以还需要找出最大的 和最小的 。那么如何找呢?很简单,因为连续的这一段值一定都完全相等,所以我们只需要判断当前元素是否和最开始的元素相等就可以了。
当然,因为我们还需要把符合的匹配统计进答案,而答案是按照 来输出的。所以我们可以开两个数组 fir_same_k
和 sec_same_k
。fir_same_k[i]
就表示,对于第一个数组,在符合条件的这一长段中, 的有多少。而 sec_same_k
是对于第二个数组的。
然后我们就可以得到下面的代码了:
注意这段代码是插入前面那段代码的 else if
后面的
else{ int p1t, p2t; memset(fir_same_k, 0, sizeof(fir_same_k)); memset(sec_same_k, 0, sizeof(sec_same_k)); //因为每次找到的符合条件的段都是不重合的,所以每次都清空一下数组 for(p1t = p1; p1t < fir_half.size() && fir_half[p1t] == f; p1t++){ //p1t 代表能满足 v_1 + v_2 == (x_g, y_g) 的最大 p1 fir_same_k[fir_half[p1t].k]++; } for(p2t = p2; p2t >= 0 && sec_half[p2t] == s; p2t--){ //p2t 代表满足 v_1 + v_2 == (x_g, y_g) 的最小 p2 sec_same_k[sec_half[p2t].k]++; } //统计答案,对于前半段和后半段都枚举可能的 for(int i = 0; i <= 20; i++){ for(int j = 0; j <= 20; j++){//这个20其实是可以改成 n / 2 + 1 的 ans[i + j] += 1LL * fir_same_k[i] * sec_same_k[j]; //相乘是因为同一个 fir_same_k[i] 和 sec_same_k[j] //中代表的任意一种选取方案都是完全相同的,(x,y,k) 都相同 } } p1 = p1t, p2 = p2t;//不加这个会一直在相同的一段死循环}
这个方法还是跑的相对比较快的,可以看下提交记录
我们发现双指针 a 的方法会需要在统计答案时开 fir_same_k
和 sec_same_k
这两个数组来统计 相同的情况。我们其实可以改进一下这个方法,直接在枚举状态的时候把 相同的方案放到一起。
具体来说,我们把前面的 fir
和 sec
这两个 vector 改成 vector<Instruct> fir[20], sec[20]
。fir[i]
就储存前半部分 时的所有方案,sec[i]
是后半部分的。既然把储存方案的方法改了,后面的双指针部分自然也要改。
这一次我们需要用一个双重循环来分别枚举不同的 fir[i]
和 sec[j]
。在循环内部我们再做和双指针 a 相似的事情。也就是说,我们已经知道了当前前半部分的方案和后半部分的方案的 ,现在只需要通过双指针找出能满足 的 和 范围。
for(int fir_k = 0; fir_k <= n / 2 + 1; fir_k++){ //枚举当前前半部分的 k for(int sec_k = 0; sec_k <= n / 2 + 1; sec_k++){//枚举后半部分的 k int p1 = 0, p2 = sec_half[sec_k].size() - 1; while(p1 < fir_half[fir_k].size() && p2 >= 0){ Instruct &f = fir_half[fir_k][p1], &s = sec_half[sec_k][p2]; if(f.x + s.x < tar_x ||(f.x + s.x == tar_x && f.y + s.y < tar_y)){ p1++; //找出当前 fir_k 时满足 v_1 + v_2 == (x_g, y_g) 的最小的 p1 } else if(f.x + s.x > tar_x ||(f.x + s.x == tar_x && f.y + s.y > tar_y)){ p2--; //找出当前 sec_k 时满足 v_1 + v_2 == (x_g, y_g) 的最大 p2 } else{ int p1t = p1, p2t = p2; while(p1t < fir_half[fir_k].size() && fir_half[fir_k][p1t] == f){ p1t++; } while(p2t >= 0 && sec_half[sec_k][p2t] == s){ p2t--; } ans[fir_k + sec_k] += 1LL * (p1t - p1) * (p2 - p2t); //把 p1 范围长度乘上 p2 范围的长度 //小细节:本来要表示长度的话应该是 p1t - p1 + 1的,但是我们可以观察前面两个while // 在跳出之后 p1t 会比正确的 p1t 多 1,而 p2t 会比正确的 p2t 少1,因为如果 // 它们还是正确的话会又回到循环中。因我们计算长度的时候就不需要 + 1 了。 p1 = p1t, p2 = p2t; } } }}
那么双指针b的写法比a有什么好处呢?答案就是节省空间。如果我们使用的是双指针a,那储存方案的结构体必须包含 三种整数。注意其中的 最大只有 ,而我们却必须开一个 int
或是 short
来储存这个值,考虑到 这个值非常小,不管哪种数据类型都会浪费大量的空间。而采用双指针b后,我们的结构体中只会包含 两种整数, 这个值储存在数组的下标中,只要你开的数组大小为 的最大值,就不会有任何的浪费。
具体的对比可以参考这个提交记录,可以发现相比双指针 a 的做法,双指针 b 的内存占用大约少了 17MB 左右
当然,代价也是有的,双指针 b 会稍微慢一些。我估计这主要出在双指针的环节。排序的部分甚至还会快一点。当然,不管哪种双指针,他们的理论复杂度都是一样的,因为每一种选取方案最多会被遍历到一次。
这道题中常见的状态枚举方法就是 dfs。这里提供一种比较奇怪的枚举方法。在一个选取方案中,对于每个向量,都有两种状态,选或者是不选。因为只有选或不选两种状态,我们可以想到通过二进制数字表示这个状态。数字的第 位表示向量 是选还是不选,例如: 就表示选择第 1 ,3 个向量,不选第 2 个。要枚举所有的状态,我们只需要把一个数从 一直累加到 就可以了,并且每次累加的时候检查他每一位是 0 还是 1 。当然,因为我们这里采用的是折半搜索,所以只需要累加到 。
至于复杂度的话,可能比 dfs 还慢?而且码量更大?毕竟每次累加还要写一个循环把这个数字从第一位检查到第二十位。不过,因为不需要递归,所以不需要一直给递归的函数开栈,内存占用可能会少一些。
真在比赛的时候还是不建议这样写的,毕竟 dfs 写起来是真的方便,这里只是提供一种好玩的做法。
#include<bits/stdc++.h>using namespace std;#define ll long long#define rg registerconst int MAXN = 45;struct Instruct{ ll x, y; int k; const bool operator < (Instruct b) const{ if(x != b.x) return x < b.x; if(y != b.y) return y < b.y; return k < b.k; } const bool operator == (Instruct b) const{ return x == b.x && y == b.y; }}ins[MAXN];vector<Instruct> fir_half, sec_half;ll ans[MAXN];int mx_state;int n;int tar_x, tar_y;void vec_sum(int st, int ed, int cur_state, vector<Instruct> *cur_half){ //根据当前提供的状态 cur_state, 把选中的向量累加起来 //因为整个搜索的过程分成了两部分,所以需要参数表示是哪个部分,st, ed 表示的就是参与搜索的第一份向量,和最后一个。 ll tot_x = 0, tot_y = 0; int k = 0; int len = ed - st + 1; for(int i = 1; i <= len; i++){ if(cur_state & (1 << (i - 1))){ tot_x += ins[st + i - 1].x, tot_y += ins[st + i - 1].y; k++; } } (*cur_half).push_back({tot_x, tot_y, k});}void state_generator(bool mode){ //mode 表示当前处理的是前半部分还是后半部分,0是前半部分,1是后半部分 rg int cur_state = 0;//初始的状态,对应的就是什么都不选 int st, ed; vector<Instruct> *cur_half;//fir_half和sec_half储存这前半部分的方案和后半部分的方案 //cur_half就是当前这次搜索要把方案存到哪里 if(mode){ cur_half = &sec_half; st = n / 2 + 1, ed = n; if(n & 1) mx_state = mx_state * 2 + 1; //mx_state就是表示状态的数字最大能达到多少,它的初始值是 2^(n/2) 但是如果 n 是奇数 //后半部分会比前半部分多包含一个向量,所以还要把原来的 mx_state * 2 + 1 } else{ cur_half = &fir_half; st = 1, ed = n / 2; } while(cur_state <= mx_state){ vec_sum(st, ed, cur_state, cur_half); cur_state++; }}int main(){ scanf("%d%d%d", &n, &tar_x, &tar_y); for(int i = 1; i <= n; i++){ scanf("%lld%lld",&ins[i].x, &ins[i].y); } for(int i = 0; i < n / 2; i++){ mx_state = mx_state | (1 << i); //最大的状态是 n/2 位都是 1 } state_generator(0); state_generator(1); sort(fir_half.begin(), fir_half.end()); sort(sec_half.begin(), sec_half.end()); rg int p1 = 0, p2 = sec_half.size() - 1; int fir_same_k[21], sec_same_k[21]; while(p1 < fir_half.size() && p2 >= 0){ Instruct &f = fir_half[p1], &s = sec_half[p2]; if(f.x + s.x < tar_x ||(f.x + s.x == tar_x && f.y + s.y < tar_y)){ //如果两个向量相加小于目标值,我们只能加 p1 的值, //因为 p2 指向的元素最开始就是最大的。 p1++; } else if(f.x + s.x > tar_x ||(f.x + s.x == tar_x && f.y + s.y > tar_y)){ //如果两个向量相加大于目标值,我们只能减 p2 的值, //因为 p1 指向的元素最开始就是最小的。 p2--; } else{ int p1t, p2t; memset(fir_same_k, 0, sizeof(fir_same_k)); memset(sec_same_k, 0, sizeof(sec_same_k)); //因为每次找到的符合条件的段都是不重合的,所以每次都清空一下数组 for(p1t = p1; p1t < fir_half.size() && fir_half[p1t] == f; p1t++){ //p1t 代表能满足 v_1 + v_2 == (x_g, y_g) 的最大 p1 fir_same_k[fir_half[p1t].k]++; } for(p2t = p2; p2t >= 0 && sec_half[p2t] == s; p2t--){ //p2t 代表满足 v_1 + v_2 == (x_g, y_g) 的最小 p2 sec_same_k[sec_half[p2t].k]++; } //统计答案,对于前半段和后半段都枚举可能的 for(int i = 0; i <= 20; i++){ for(int j = 0; j <= 20; j++){//这个20其实是可以改成 n / 2 + 1 的 ans[i + j] += 1LL * fir_same_k[i] * sec_same_k[j]; //相乘是因为同一个 fir_same_k[i] 和 sec_same_k[j] //中代表的任意一种选取方案都是完全相同的,(x,y,k) 都相同 } } p1 = p1t, p2 = p2t;//不加这个会一直在相同的一段死循环 } } for(int i = 1; i <= n; i++){ printf("%lld\n", ans[i]); } system("pause");}
#include<bits/stdc++.h>using namespace std;#define ll long long#define rg registerconst int MAXN = 45;struct Instruct{ ll x, y; const bool operator < (Instruct b) const{ if(x != b.x) return x < b.x; return y < b.y; } const bool operator == (Instruct b) const{ return x == b.x && y == b.y; }}ins[MAXN];vector<Instruct> fir_half[MAXN], sec_half[MAXN];ll ans[MAXN];int mx_state;int n;int tar_x, tar_y;void vec_sum(int st, int ed, int cur_state, vector<Instruct> *cur_half){ //根据当前提供的状态 cur_state, 把选中的向量累加起来 //因为整个搜索的过程分成了两部分,所以需要参数表示是哪个部分,st, ed 表示的就是参与搜索的第一份向量,和最后一个。 ll tot_x = 0, tot_y = 0; int k = 0; int len = ed - st + 1; for(int i = 1; i <= len; i++){ if(cur_state & (1 << (i - 1))){ tot_x += ins[st + i - 1].x, tot_y += ins[st + i - 1].y; k++; } } cur_half[k].push_back({tot_x, tot_y});}void state_generator(bool mode){ //mode 表示当前处理的是前半部分还是后半部分,0是前半部分,1是后半部分 rg int cur_state = 0;//初始的状态,对应的就是什么都不选 int st, ed; vector<Instruct> *cur_half; if(mode){ cur_half = sec_half; st = n / 2 + 1, ed = n; if(n & 1) mx_state = mx_state * 2 + 1; //mx_state就是表示状态的数字最大能达到多少,它的初始值是 2^(n/2) 但是如果 n 是奇数 //后半部分会比前半部分多包含一个向量,所以还要把原来的 mx_state * 2 + 1 } else{ cur_half = fir_half; st = 1, ed = n / 2; } while(cur_state <= mx_state){ vec_sum(st, ed, cur_state, cur_half); cur_state++; }}int main(){ scanf("%d%d%d", &n, &tar_x, &tar_y); for(int i = 1; i <= n; i++){ scanf("%lld%lld",&ins[i].x, &ins[i].y); } for(int i = 0; i < n / 2; i++){ mx_state = mx_state | (1 << i); //最大的状态是 n/2 位都是 1 } state_generator(0); state_generator(1); for(int i = 0; i <= n / 2 + 1; i ++){ sort(fir_half[i].begin(), fir_half[i].end()); sort(sec_half[i].begin(), sec_half[i].end()); } for(int fir_k = 0; fir_k <= n / 2 + 1; fir_k++){ //枚举当前前半部分的 k for(int sec_k = 0; sec_k <= n / 2 + 1; sec_k++){//枚举后半部分的 k int p1 = 0, p2 = sec_half[sec_k].size() - 1; while(p1 < fir_half[fir_k].size() && p2 >= 0){ Instruct &f = fir_half[fir_k][p1], &s = sec_half[sec_k][p2]; if(f.x + s.x < tar_x ||(f.x + s.x == tar_x && f.y + s.y < tar_y)){ p1++; //找出当前 fir_k 时满足 v_1 + v_2 == (x_g, y_g) 的最小的 p1 } else if(f.x + s.x > tar_x ||(f.x + s.x == tar_x && f.y + s.y > tar_y)){ p2--; //找出当前 sec_k 时满足 v_1 + v_2 == (x_g, y_g) 的最大 p2 } else{ int p1t = p1, p2t = p2; while(p1t < fir_half[fir_k].size() && fir_half[fir_k][p1t] == f){ p1t++; } while(p2t >= 0 && sec_half[sec_k][p2t] == s){ p2t--; } ans[fir_k + sec_k] += 1LL * (p1t - p1) * (p2 - p2t); //把 p1 范围长度乘上 p2 范围的长度 //小细节:本来要表示长度的话应该是 p1t - p1 + 1的,但是我们可以观察前面两个while // 在跳出之后 p1t 会比正确的 p1t 多 1,而 p2t 会比正确的 p2t 少1,因为如果 // 它们还是正确的话会又回到循环中。因我们计算长度的时候就不需要 + 1 了。 p1 = p1t, p2 = p2t; } } } } for(int i = 1; i <= n; i++){ printf("%lld\n", ans[i]); } system("pause");}
最后,希望这篇题解对你有帮助。有任何问题都可以在私信和评论区提出,我会尽量解决问题。
]]>前言:题解可能比较啰嗦,因为这题比赛的时候没做出来,所以写题解主要用于整理自己的思路。如果你有思路只是代码打挂了简易直接跳到代码部分。
update@2022/3/13: 感谢@小木虫的提醒,当前的解法不是正解! 如果USACO的数据够强的话,目前我使用的匈牙利算法因为复杂度是 ,是过不了这道题的。如果想用我这个二分图匹配 + 拓扑的方法实现,可以使用dinic算法求二分图最大匹配(不过写起来会比较麻烦)。本人有时间的时候也会尝试用dinic实现这个解法并且更新题解。
有 头牛, 种麦片(每种一箱),每头都有第一和第二喜欢的麦片种类(下文简称为一选和二选),牛会优先选择自己最喜欢的麦片,当最喜欢的麦片被占用后会选择第二喜欢的麦片,问:
对于第一个小问,可以发现这是一个标准的二分图最大匹配问题,很容易想到使用匈牙利算法解决(然而这次比赛的时候我并没有想到)。不熟悉匈牙利算法和二分图匹配问题的同学可以参考模板题里的题解。这篇题解将主要关注第二小问的求解。
对于第二小问,我一开始想的是先输出成功匹配到一选的牛,其次是成功匹配到二选的牛,最后输出没有成功匹配的牛。结果交上去只过了样例。经过@lutongyu大佬的指导,我终于理解了这个做法的问题。
具体来说,对于成功匹配到一选的奶牛,可以先输出,最后输出没有成功匹配的奶牛也是没有问题的。真正的问题在于二选奶牛的顺序。考虑下面这样的一个数据(如下图):
1 (cow) -> [1 (fir), 2 (sec)]2 (cow) -> [1 (fir), 3 (sec)] 3 (cow) -> [3 (fir), 4 (sec)]
我们可以试着手动模拟一下这个数据
我们首先尝试这个数据下的最优排列 1 2 3
在这样的情况下,每一头奶牛都能吃到麦片
然后我们调换一下 和 的顺序,得到1 3 2
的顺序,以及下面的模拟过程
在这样的情况下, 并不能吃到任何麦片
通过这个数据,我们可以发现直接输出匹配到二选的牛是不行的,还需要在输出匹配到二选的牛时做一些处理,保证这个排列能达到最大匹配数。
具体来说,我们可以使用一种类似拓扑排序的算法来解决二选奶牛的冲突问题。
我们首先来考虑当一头奶牛成功匹配到自己的一选,并且它的一选同时也是别的奶牛的一选会发生什么,拿上图中的 举例子,它会影响到 的选择( 占用了 的一选,迫使其选择二选),而 又会影响到 的选择( 占用了 的一选,迫使其选择二选)。通过观察,我们可以发现只要按照这样一个 “影响链” 来输出奶牛,就可以保证达到最大匹配。
这一条链的开始一定是成功匹配到一选,并且迫使别的奶牛选择二选的奶牛(这个奶牛的一选也是别的奶牛的一选)。我们把这样的奶牛全部入队。我们再来考虑被影响的奶牛,为了找出 “影响链” 我们需要把这些被影响到的奶牛也入队(因为这些奶牛只能选择二选,而他们的二选可能会占用别的奶牛的一选,就像上图中的 )。
为了找出有哪些牛是可能被影响的,我们可以引入一个 inv_e[i]
的动态数组(链式前向星),表示所有把 i 号麦片作为一选的牛(只有一选先被别的牛选了才会被影响)。
比如在上图中 inv_e[1] = [
,
]
我们知道在实现匈牙利算法的时候会用到 matched[i]
数组。它的下标表示右部节点,值表示匹配到这个右部节点的左部节点。在这题中,matched[i]
的下标就是麦片的编号,而值是匹配到这个麦片的牛。我们可以引入一个 inv_match[i]
数组,它的下标是牛,而值是麦片。通过 inv_match
我们可以知道每头牛最终匹配到的麦片是哪个(一选二选或者匹配失败)。
用上图举例子, inv_match[
]
就等于麦片 3 (最终匹配到的是麦片 3)
下面的代码展示了如何找到所有被一选奶牛所影响的奶牛
for(int i = 1; i <= n; i++){ //i遍历的是成功匹配到一选的奶牛 if(invmatched[i] != e[i][0]) continue;//e[i][0]表示i的一选 //invmatched[i] 是i最终匹配到的麦片 //所以这句话的意思是如果不是一选就直接continue printf("%d\n",i);//如果是一选直接输出 for(int cur:inve[e[i][0]]){ //遍历当前成功匹配一选的牛 可能影响的牛 //e[i][0]表示的是i的一选,而i一定是成功匹配一选的牛 //inve[e[i][0]]是所有一选和i的一选相同的奶牛,这些奶牛可能被i影响 if(invmatched[cur] == e[cur][1])//前文提到了invmatched[cur] 表示的是cur最终匹配到的麦片 //而e[cur][1] 表示的是 cur 的二选 { //所以这句话确保了cur最终选到的是二选(说明这头牛被影响到了,没有选一选,同时也可以防止把i自己入队) q.push(cur); } }}
接下来,对于已经进入队列的奶牛,他们的二选可能会占用别的奶牛的一选,所以我们也可以用相似的方法找出这个 “影响链”
while(!q.empty()){ int cur = q.front(); printf("%d\n",cur);//队列是先进先出的结构,所以可以先输出在影响链上方的牛(更早被影响到的牛) q.pop(); for(int nex:inve[e[cur][1]]){ //e[cur][1]是编号为cur的牛的二选 //inve[e[cur][1]] 就是所有把 cur 的二选当作一选的牛,也就是可能被 cur 影响到的牛 if(invmatched[nex] == e[nex][1]) {//最终选到的是二选(说明这头牛被影响到了) q.push(nex); } }}
最后给出详细代码(有注释解释细节)
/*Date: 22 - 02-03 22 19PROBLEM_NUM: P8095 [USACO22JAN] Cereal 2 S*/#include<bits/stdc++.h>using namespace std;const int MAXN = 2e5 + 10;int n, m;vector<int> e[MAXN], inve[MAXN];queue<int> q;int vised[MAXN], matched[MAXN];int invmatched[MAXN];bool found(int cur){//匈牙利算法 for(int nex:e[cur]){ if(vised[nex]) continue; vised[nex] = true; if(!matched[nex] || found(matched[nex])){ matched[nex] = cur; invmatched[cur] = nex; vised[nex] = false; return true; } } return false;}int main(){ int match_cnt = 0; scanf("%d%d",&n,&m); for(int i = 1; i<=n; i++){ int f,s; scanf("%d%d",&f,&s); e[i].push_back(f); //e[i][0]是i的一选 e[i].push_back(s); //e[i][1]是i的二选 inve[f].push_back(i);//inve[f] 表示所有把 f 号麦片作为一选的牛 } for(int i = 1;i <= n; i++){//匈牙利算法部分 if(found(i)){ match_cnt++; } } printf("%d\n", n - match_cnt);//饥饿的奶牛 = 所有奶牛 - 吃到麦片的奶牛 for(int i = 1; i <= n; i++){ //i遍历的是成功匹配到一选的奶牛 if(invmatched[i] != e[i][0]) continue;//如果不是一选就直接continue printf("%d\n",i); //如果是一选直接输出 for(int cur:inve[e[i][0]]){ //遍历当前成功匹配一选的牛 可能影响的牛 //e[i][0]表示的是i的一选,而i一定是成功匹配一选的牛 //inve[e[i][0]]是所有一选和i的一选相同的奶牛,这些奶牛可能被i影响 if(invmatched[cur] == e[cur][1])//前文提到了invmatched[cur] 表示的是cur最终匹配到的麦片 //而e[cur][1] 表示的是 cur 的二选 { //所以这句话确保了cur最终选到的是二选(说明这头牛被影响到了,没有选一选,同时也可以防止把i自己入队) q.push(cur); } } } while(!q.empty()){ int cur = q.front(); printf("%d\n",cur);//队列是先进先出的结构,所以可以先输出在影响链上方的牛(更早被影响到的牛) q.pop(); for(int nex:inve[e[cur][1]]){ //e[cur][1]是编号为cur的牛的二选 //inve[e[cur][1]] 就是所有把 cur 的二选当作一选的牛,也就是可能被 cur 影响到的牛 if(invmatched[nex] == e[nex][1]) {//最终选到的是二选(说明这头牛被影响到了) q.push(nex); } } } for(int i = 1; i<=n; i++){//最后输出没有成功匹配到的奶牛 if(!invmatched[i]){ //invmatched[i] == 0 说明奶牛 i 没有匹配到任何麦片 printf("%d\n",i); } } system("pause");}
最后,希望这篇题解能帮到你,如果还没看懂或者是发现了题解有问题都可以私信我或者在评论区指出,我会尽量回答或是解决问题。
]]>给定一个 的区域,区域中的每个点由 0 或 1 组成,1 类点不能走 0 类点可以走。问你从左上角走到右下角,并且最多转向 次有多少种走法。
看到在格子图中问有多少走法的题目,可以比较容易的想到使用 dp 算法解决,具体做法参考 P1002过河卒,但是本题的难点以及本题解的重点在于如何处理对于转向次数的限制。
根据上图,我们可以看到,路径是否包含转向不仅跟从哪个点转移来有关,还和转移来的那个点是从哪个点转移来的有关(从左边或是从右边)。具体的判断规则如下,可以对照着图来理解:
map[i - 1][j]
转移来的,并且上面那个点是从他的左边转移来的,那么会发生一次转移。(图中当前点上方的蓝色虚线)map[i][j - 1]
转移来的,并且左边那个点是从他的上方转移来的,那么会发生一次转移。(图中当前点左方的蓝色虚线)我们让 dp[i][j][k][t]
表示走到 ,花费了 次转向,并且是从左边 (0) ,或是上面 (1) 的格子转移来的。
有了以上的判断规则,我们可以写出 dp 的转移方程。
发生转向的情况:
dp[i][j][k][0] += dp[i][j - 1][k - 1][1];dp[i][j][k][1] += dp[i - 1][j][k - 1][0];
不发生转向的情况:
dp[i][j][k][0] += dp[i][j - 1][k][0];dp[i][j][k][1] += dp[i - 1][j][k][1];
其中,我们需要注意如果一个点如果是从 ,也就是起点转移过来的话,是不可能发生转向的。并且如果循环中的 的话也是不可能发生转向的( 代表转向次数)。所以我们还需要在发生转向时的状态转移方程加入如下的判断语句:
if (k != 0 && i != 1 && j != 1)
最后,还需要注意一点。因为题目问的是“至多转向 次”,所以最后输出时需要把所有符合条件的情况加起来再一起输出。
完整代码及注释如下:
#include <bits/stdc++.h>using namespace std;const int MAXN = 55;int n, k;int t, mp[MAXN][MAXN];int dp[MAXN][MAXN][4][2]; //dp[i j k t]是到i,j,转了k次的方法数量,并且上次是从左/上边的格子转移来的void calc_dp(){ memset(dp, 0, sizeof(dp)); dp[1][1][0][0] = dp[1][1][0][1] = 1; for (int i = 1; i <= n; i++) { for (int j = 1; j <= n; j++) { if (mp[i][j]) for (int k = 0; k <= 3; k++) { dp[i][j][k][0] += dp[i][j - 1][k][0]; dp[i][j][k][1] += dp[i - 1][j][k][1]; if (k != 0 && i != 1 && j != 1) { dp[i][j][k][0] += dp[i][j - 1][k - 1][1]; dp[i][j][k][1] += dp[i - 1][j][k - 1][0]; } } } }}void input(){ scanf("%d%d", &n, &k); for (int i = 1; i <= n; i++) { char temp[MAXN]; scanf("%s", temp + 1); for (int j = 1; temp[j]; j++) { if (temp[j] == 'H') mp[i][j] = 0; else mp[i][j] = 1; } }}int main(){ scanf("%d", &t); while (t--) { input(); calc_dp(); if (k == 3) printf("%d\n", dp[n][n][0][0] + dp[n][n][0][1] + dp[n][n][1][0] + dp[n][n][1][1] + dp[n][n][2][0] + dp[n][n][2][1] + dp[n][n][3][0] + dp[n][n][3][1]); if (k == 2) printf("%d\n", dp[n][n][0][0] + dp[n][n][0][1] + dp[n][n][1][0] + dp[n][n][1][1] + dp[n][n][2][0] + dp[n][n][2][1]); if (k == 1) printf("%d\n", dp[n][n][0][0] + dp[n][n][0][1] + dp[n][n][1][0] + dp[n][n][1][1]); } system("pause");}
如果题解有问题或者没看懂的欢迎在评论区和私信中指出。
]]>给定一个 的区域,并且这个区域内有两类点,J 类与 B 类。现在让你在区域中添加一个J类点(也可以不添加,并且添加时不能添加到已经存在 B 类点的地方),然后找出最大的由 J 类点构成的正方形。
因为数据较小 ,并且我们可以通过一条J边来确定正方形的剩下两条边,所以可以尝试通过枚举J边,并且加以判断的方法找到最大的正方形。
假设我们通过两点 和 确定了一条直线,并且 的纵坐标总是比 高,那么我们可以画出下图:
可是我们如何计算出 和 的坐标呢?
首先通过观察,我们可以发现图中的四个三角形都是全等的,我们只要计算出三角形的长直角边和短直角边的长度(或者说是两个不同的直角边,只是在这种特殊情况下,长边和短边的位置是图中的样式),再加上一个偏移量,就可以得到 和 的坐标了。
三角形的长直角边:
三角形的短直角边:
通过观察,我们可以从上面两个式子得出 和 的坐标
除了现在图中的样式,通过 这条直线还能确定另一种正方形:
当然,我们还是可以通过刚刚的方法得到 和 的坐标。我们刚刚是在一个偏移量的基础上加上或是减去三角形的两个不同边的长度 以及 来得到 和 的坐标,通过观察,我们可以发现只要在偏移量上进行相反的操作,就可以得到 和 的坐标了。
因为此需要我们计算的是正方形的面积,所以我们可以用以下方法计算出面积:
通过刚刚的方法,我们已经能够得到 的坐标了,接下来我们需要判断通过一条边确定的这两个正方形是否合法。
因为我们可以自由的放置一个 J 点,所以整个正方形中可以只有三个现成的 J 点,当然,这一个 J 点必须放置在没有被占用的点上。
因此只要以 组成的正方形符合:
注:J 表示的是 J 类点的集合,B 表示的是 B 类点的集合
我们就可以说这一个正方形是合法的。由 组成的正方形同理
想到思路之后迅速的打出了代码,结果两个点 T 了。。。
#include <bits/stdc++.h>using namespace std;int mp[120][120];#define debug falsestruct node{ int x, y;};vector<node> jc; //J类点int n;void input(){ scanf("%d", &n); for (int i = 0; i < n; i++) { char temps[110]; scanf("%s", temps); for (int j = 0; temps[j]; j++) { if (temps[j] == 'J') { mp[i][j] = 1; jc.push_back(node{i, j}); } else if (temps[j] == 'B') { mp[i][j] = -1; } else if (temps[j] == '*') { mp[i][j] = 0; } } }}int main(){ input(); int ans = 0; for (auto t1 : jc) { for (auto t2 : jc) { if (t1.x == t2.x && t1.y == t2.y) { continue; } if (debug) printf("1x %d 1y %d 2x %d 2y %d\n", t1.x, t1.y, t2.x, t2.y); node p3, p4; node p1 = t1, p2 = t2; if (p1.y < p2.y) { swap(p1, p2); } p3.x = (p2.x + (p1.y - p2.y)); p3.y = (p2.y - (p1.x - p2.x)); p4.x = (p1.x + (p1.y - p2.y)); p4.y = (p1.y - (p1.x - p2.x)); if (debug) printf("3x %d 3y %d 4x %d 4y %d\n", p3.x, p3.y, p4.x, p4.y); if (p3.x >= 0 && p3.y >= 0 && p4.x >= 0 && p4.y >= 0 && p3.x < n && p3.y < n && p4.x < n && p4.y < n) { if ((mp[p3.x][p3.y] == 1 && mp[p4.x][p4.y] != -1) || (mp[p3.x][p3.y] != -1 && mp[p4.x][p4.y] == 1)) { ans = max(ans, ((p1.y - t2.y) * (p1.y - p2.y) + (p1.x - p2.x) * (p1.x - p2.x))); } } p3.x = (p2.x - (p1.y - p2.y)); p3.y = (p2.y + (p1.x - p2.x)); p4.x = (p1.x - (p1.y - p2.y)); p4.y = (p1.y + (p1.x - p2.x)); if (p3.x >= 0 && p3.y >= 0 && p4.x >= 0 && p4.y >= 0 && p3.x < n && p3.y < n && p4.x < n && p4.y < n) { if ((mp[p3.x][p3.y] == 1 && mp[p4.x][p4.y] != -1) || (mp[p3.x][p3.y] != -1 && mp[p4.x][p4.y] == 1)) { ans = max(ans, ((p1.y - t2.y) * (p1.y - t2.y) + (p1.x - p2.x) * (p1.x - p2.x))); } } } } printf("%d", ans); system("pause");}
我们可以发现这个代码做了大量的无用计算,因为我们假设了 的纵坐标总是比 高,所以在程序中通过一个判断来确保我们的假设成立。
if (p1.y < p2.y){ swap(p1, p2);}
我们实现枚举边的方法是用两个循环枚举所有的J点,所以会有重复枚举的情况(比如 和 实际相同,但是换了个顺序)。对于这样的情况,我们完全可以直接跳过这次循环来节省时间。
所以我们可以把代码改为:
if ((p1.x == p2.x && p1.y == p2.y) || (p1.y < p2.y)){ continue;}
这样进入后续判断正方形合法性环节的点就都符合我们的假设了,也减少了之前的重复计算。
再仔细分析程序,我们可以发现,其实在循环内部只是做了判断正方形合法性的工作,如果在枚举的时候我们发现一个正方形的面积比我们目前得到的答案还要小,那么就没必要继续判断合法性了,直接跳过就可以了。
所以我们可以把代码改为:
if ((p1.x == p2.x && p1.y == p2.y) || (p1.y < p2.y) || (((p1.y - p2.y) * (p1.y - p2.y) + (p1.x - p2.x) * (p1.x - p2.x)) <= ans)){ continue;}
经过这次改进,我们就可以愉快的 AC 了:
完整代码以及注释如下:
#include <bits/stdc++.h>using namespace std;int mp[120][120];#define debug falsestruct node{ int x, y;};vector<node> jc; //J类点的集合int n;void input(){ scanf("%d", &n); for (int i = 0; i < n; i++) { char temps[110]; scanf("%s", temps); for (int j = 0; temps[j]; j++) { if (temps[j] == 'J') { mp[i][j] = 1; jc.push_back(node{i, j}); } else if (temps[j] == 'B') { mp[i][j] = -1; } else if (temps[j] == '*') { mp[i][j] = 0; } } }}int main(){ input(); int ans = 0; int cnt = 0; for (auto p1 : jc) { for (auto p2 : jc)//通过枚举两个点来实现对于边的枚举 { if ((p1.x == p2.x && p1.y == p2.y) || (p1.y < p2.y) || (((p1.y - p2.y) * (p1.y - p2.y) + (p1.x - p2.x) * (p1.x - p2.x)) <= ans)) { // printf("temp area %d ans^2 %d ans %d\n", (t1.y - p2.y) * (t1.y - p2.y) + (t1.x - p2.x) * (t1.x - p2.x),ans*ans,ans); continue; } cnt++; if (debug) printf("1x %d 1y %d 2x %d 2y %d\n", p1.x, p1.y, p2.x, p2.y); //通过一条边确定的第一种正方形 node p3; p3.x = (p2.x + (p1.y - p2.y)); p3.y = (p2.y - (p1.x - p2.x)); node p4; p4.x = (p1.x + (p1.y - p2.y)); p4.y = (p1.y - (p1.x - p2.x)); if (debug) printf("3x %d 3y %d 4x %d 4y %d\n", p3.x, p3.y, p4.x, p4.y); if (p3.x >= 0 && p3.y >= 0 && p4.x >= 0 && p4.y >= 0 && p3.x < n && p3.y < n && p4.x < n && p4.y < n)//判断合法性 { if ((mp[p3.x][p3.y] == 1 && mp[p4.x][p4.y] != -1) || (mp[p3.x][p3.y] != -1 && mp[p4.x][p4.y] == 1)) { ans = max(ans, (p1.y - p2.y) * (p1.y - p2.y) + (p1.x - p2.x) * (p1.x - p2.x)); } } //通过一条边确定的第二种正方形 p3.x = (p2.x - (p1.y - p2.y)); p3.y = (p2.y + (p1.x - p2.x)); p4.x = (p1.x - (p1.y - p2.y)); p4.y = (p1.y + (p1.x - p2.x)); if (p3.x >= 0 && p3.y >= 0 && p4.x >= 0 && p4.y >= 0 && p3.x < n && p3.y < n && p4.x < n && p4.y < n)//判断合法性 { if ((mp[p3.x][p3.y] == 1 && mp[p4.x][p4.y] != -1) || (mp[p3.x][p3.y] != -1 && mp[p4.x][p4.y] == 1)) { ans = max(ans, (p1.y - p2.y) * (p1.y - p2.y) + (p1.x - p2.x) * (p1.x - p2.x)); } } } } printf("%d\n", ans); if (debug) printf("%d", cnt); system("pause");}
]]>前言:这篇题解写的可能比较啰嗦,主要时是因为我把所有思考的过程都写下来了,所以如果你已经有了基本的思路,或者是希望找一篇简洁的题解,就可以跳过这篇题解了。
总共有 头牛, 类形容词,有 个第 类形容词,第 类的第 种形容词是 ,每头牛都需要有这 类形容词按照顺序来修饰。现在告诉你要删除这 头牛中的 头,问你在这剩下的 头牛中,按照字典序排序,排在第 位的牛是哪一头?
单看这样的描述可能有些抽象,现在我们来看一下样例是怎么样的,再在样例的基础上思考应该怎么解决这道题。
在样例中 ,。的值如下
(第一类形容词) | “large” | “small” | N/A |
(第二类形容词) | “brown” | “white” | “spotted” |
(第三类形容词) | “noisy” | “silent” | N/A |
因为样例中让我们求的是按照字典序排在第7的牛,我们可以思考一下以字典序为关键字进行排序的过程是什么样的。
举个例子:有两个字符串"abc"和字符串"cde"需要进行字典序排序,首先我们应该比较在第一位的字符"a"和"c"的字典序,再比较第二位"b"和"c"的字典序,最后才是第三位的字符。
从这个过程中我们可以发现每一位字符对于字符串整体字典序的影响是不一样的,其中第一位的影响最大,最后一位最小。因此我们就可以说他们对于整个字符串的字典序的影响的“权值”不同。如果我们按照字典序给不同的字符串从小到大排序,对于一个字符串,不管从第二个字符到最后一个字符的字典序有多小,如果第一个字符的字典序很大,那么它也会排在很后面
再观察这道题目中让我们求解的问题,我们可以发现,第1类形容词比如"large"对于整串字符的影响是最大的,其次是第二类,比如"brown",最后才是第三类。
分析到这里,相信你已经体会到这个问题和数字系统的相关性了。
那就是我们在比较数字时,采用的方法也是从高位到低位进行比较
比如有这样一个 进制数字
数字 代表的值是 它代表的值是所有数字中最大的(1是第一位)
数字 代表的值是 它代表的值是第二大的(2是第二位)
数字 在第 位代表的值就是 ,整个 进制数代表的值就每一位数字代表的值的累加
把整个规则推广到 进制,那么数字 在第 位代表的值就是
那么问题就来了,不管是我们刚刚讨论的 进制还是 进制,它们的机制都是 “逢 就进 ”,因此第 位的 “数字 ” 在十进制中代表的值一定是第 位的 “数字 ” 代表的值的 倍。而在我们这个问题中,每一类形容词的数量是不一定的。
我们可以先尝试解决每类形容词数量一定的情况,设每类形容词的数量都是 ,首先我们要对每类形容词进行字典序排序,把结果存在 rank[i][j]
中,其中 代表形容词类型, 代表排名。
这一步的目的是把字符串转化成数字,方便后续的计算。(把每个形容词映射到数字系统中的 “第 位的数字 ”,但是需要注意的一点是,形容词的类数越小,对整体字典序影响越大,数字的位数越小,对整体的值的影响越小)
因为我们已经完成了形容词到数字的映射,所以下面要做的就等于“把 进制转换到 进制”,再把得到的数字转换回对应的字符串
这样的例子可能有些抽象,下面我们来模拟一遍这个过程
我们规定第一类形容词有以下两个 {“a”,“b”}, 第二类形容词也有以下两个{“c”,“d”}, 第三类是{“e”,“f”}。
那么我们可以求出以下的rank
数组
(因为方便计算,所以排名从0开始)
(第一类形容词) | “a” | “b” |
(第二类形容词) | “c” | “d” |
(第三类形容词) | “e” | “f” |
如果我们想求字典序排在第 位的牛,那么我们需要先求出 的二进制数,也就是 。然后把这个数字倒过来,变成 (形容词的类数越小,对整体字典序影响越大,数字的位数越小,对整体的值的影响越小),最后再把 映射回对应的字符串(第几位对应第几类),最后的答案也就是 “a, d, f”
在解决刚刚简化过的问题的过程中,我们把每类形容词的类映射到了数字系统的 “第几位数”,把它们的排名 映射到了数字系统中的数字 ,而每类形容词的数量就成了这个数字系统的进制数。
我们可以发现,解决原问题的关键就在于进制,在刚刚简化过的问题中,数字系统中每一位的进制是一样的,并且每一类形容词的数量也是一定的,那么再解决当前问题时,每一类的形容词数量是不一定的,所以相应的,每一位的进制也要有所改变。
回到题目给的样例,每类形容词的数量是 , ,
而第三类的形容词被我们映射到了数字系统中的第一位数字,第二类是第二位,同样的,第一类就是第三位
因此我们可以规定,在第一位的时候,这个数字系统是二进制的,在第二位的时候,这个数字系统是三进制的,第三位也是二进制。
虽然我们可以做到用这样的 “” 进制来描述每一种牛,但是在解决这道题的时候,我们还需要把十进制转换成这样的 “” 进制。
大家肯定对十进制转二进制这样的问题非常熟悉,比如要把一个十进制数 转换成 位的二进制数 ,要做的就是从 的最高位 开始,每次都进行 的操作,然后再计算
写成程序的话,就是这样的
//k是十进制数//weight_in_pos[i]是第i位(从最高位开始计,和我们平常用的方法相反)在十进制中代表的值//i是当前位(从最高位开始计,和我们平常用的方法相反)for (int i = 1; i <= adj_num; i++){ cout << adj_by_pos[i][(k) / weight_in_pos[i]] << " "; k %= weight_in_pos[i];}
所以呢,对于 “” 这样的进制,我们只需要提前计算好他们每一位在十进制中代表的值就可以把十进制转换成这样的进制了。
(每一位在十进制中代表的值的意思就是 在原进制系统中,如果这一位是 ,其他位都是 ,转换成 进制之后的值)
那么这样的进制的每一位在十进制中代表的值如何计算呢?
不管是在几进制的系统里,只要两个数的进制系统相同,那么一个 位数更多的数 所代表的值一定比位数更少的那个要大。
所以我们可以这样计算第 位在十进制中所代表的值,就是第 位数字在十进制中代表的值再乘上第 位上能表示的最大的数字 (确保位数更多的数一定比位数更少的大),并且我们可以发现最大的数字 刚好就是这一位的进制。(比如 进制最大的数是 )
有了这个结论就可以递推的求解每一位在十进制中代表的值,我们可以把答案存在 weight_in_pos[i]
数组中(第 位代表的值),并且把第一位代表的值初始化为
在样例的 “” 进制中,第一位在十进制中代表的值被初始化为 ,第二位的 weight_in_pos 值就是 ,第三位就是 ,
至此,我们已经能计算出所有牛中排在第 个的牛了,但是题目问的是在这剩下的 头牛中,按照字典序排序,排在第 位的牛是哪一头。
这个小问题的求解就比较简单了,我们可以把 转化成在所有牛中的排名,而不是剩下的牛的排名,我们可以先计算出要删除的那 头牛在所有牛中的排名,如果这 头牛中有牛的排名比 小,或是等于 ,那么就需要把 加 。(相当于排名前 的这些牛中有一些是不能选取的,而我们要选出 头,所以要加上删去的 头牛中排名比 小的)
细节都有注释
#include <bits/stdc++.h>using namespace std;#define ll long longint n, k;int adj_num = 0;vector<string> str[105]; //str[i]表示用于修饰第i头牛的形容词vector<string> adj_by_pos[35]; //adj_by_pos[i]表示所有在位置i出现的形容词set<string> is_appear[35]; //is_appear[i]用于判断在位置i上,某个形容词是否出现int weight_in_pos[35]; //每一位代表的值(按照字典序排在第几)map<string, int> rank_in_pos[35]; //rank_in_pos[i][j]代表在位置i上,字符串j按照字典序排序,排在第几int cow_rank[105]; //fj没有的牛的排名bool debug = false; //调试开关,可以打开去体会一下解题的过程void mapping() //对每类形容词进行字典序排序,把结果存在 rank[i][j]中,其中 i 代表形容词类型,j 代表排名{ //注意这里的排名从0开始(这是把单词映射到数字上,数字是从0开始的) for (int i = 1; i <= adj_num; i++) { int rank = 0; for (auto j : adj_by_pos[i]) //c++11的新特性,意思是用j遍历所有adj_by_pos[i]的元素 { rank_in_pos[i][j] = rank; if (debug) cout << j << " rank = " << rank << " i = " << i << endl; rank++; } }}int get_pos(int cow_id){ int ans = 0; for (int i = 1; i <= adj_num; i++) { ans += weight_in_pos[i] * (rank_in_pos[i][str[cow_id][i - 1]]); } return ans + 1; //答案可能是0,但是排位应该从1开始}int main(){ ios::sync_with_stdio(false); cin >> n >> k; string temp_str; for (int i = 1; i <= n; i++) { cin >> temp_str; while (temp_str != "no") { cin >> temp_str; } int adj_pos = 1; while (1) { cin >> temp_str; if (temp_str == "cow.") { break; } str[i].push_back(temp_str); if (!is_appear[adj_pos].count(temp_str)) //还没有出现过,这里是去重操作 { adj_by_pos[adj_pos].push_back(temp_str); is_appear[adj_pos].insert(temp_str); } if (i == 1) { adj_num++; //计算形容词的种类数 } adj_pos++; } } for (int i = 1; i <= adj_num; i++) { sort(adj_by_pos[i].begin(), adj_by_pos[i].end()); //对每个类型的形容词进行排序 } weight_in_pos[adj_num + 1] = 1; //第一位代表的数字应该是1 adj_by_pos[adj_num + 1].push_back("temp"); //1和1相乘是1,所以要push一个元素进去,这样子size就是1 for (int i = adj_num; i >= 1; i--) //计算每一位在十进制中代表的值(这里的位的计算方法和平时的方法是反的) { //这是因为本题中的形容词类型和数字系统的 “位” 是相反的 if (debug) cout << "i" << i << endl; weight_in_pos[i] = weight_in_pos[i + 1] * adj_by_pos[i + 1].size(); } mapping(); for (int i = 1; i <= n; i++) { cow_rank[i] = get_pos(i); //计算出要删除的n头牛的整体排名 if (debug) cout << "cowrkw " << i << " = " << cow_rank[i] << endl; } sort(cow_rank + 1, cow_rank + n + 1); //按照每种牛的排名进行排序 for (int i = 1; i <= n; i++) { if (cow_rank[i] <= k) { k++; } else { break; } } k--; //k原本代表的是排在第几的牛,但是我们已经把牛的排名转化成数字了, //排名最小的牛的数字是0而不是1,所以这里要减1 if (debug) cout << "new k" << k << endl; for (int i = 1; i <= adj_num; i++) { cout << adj_by_pos[i][(k) / weight_in_pos[i]] << " "; if (debug) cout << "i " << i << " (k) / weight_in_pos[i] " << (k) / weight_in_pos[i] << endl; k %= weight_in_pos[i]; } system("pause");}
题解就到这里了,如果你发现题解有问题,或是有看不懂的地方,都欢迎私信我或者是在评论区里讲,如果你觉得对你有帮助就点个赞吧,谢谢。
]]>看到题解里面还没有用STL vector做的,所以我就来交一发。
在一张图中,一共有 个节点, 条双向边,有 个节点不能删除,求出最少需要删除多少个节点才能使得这 个固定点都到达不了号节点。
在完成题意的转化之后,我们发现题目要让我们删除一些点(尽量少),使得整张图变成两个不连通的部分,网络流算法中的最小割(最大流)算法可以处理这这个问题。
【不熟悉最大流算法的同学可以先做一下模板题】
最大流模板
但是我们又发现,一般的最小割处理的是 “删除图的一部分边使得图的两部分变得不连通” 而这道题目让我们删除的是图中一部分节点。于是我们就需要把节点转换成边。
我使用的方法是把每一个节点拆分成两个节点(出点和入点),具体的做法可以参考
P1345 奶牛的电信Telecowmunication
这道题中的题解。
这里来简单解释一下这种做法:首先,我们把图中的每一个点拆分成两个点:出和入点。
并且这两个节点之间有一条单向边连接
每条指向这个点的有向边都只能连接这个点的入点。并且从这个节点出发的有向边都只能从它的出点出发。
那么,把每个节点分成出点和入点之后有什么用呢?在一般的最小割问题中,如果我们想知道要取消掉多少条边,可以使得这张图的汇点和源点不连通,就可以把每条边的权值设置为,并且可以付出的代价删除这条边。
在以割点为基础的最小割问题中,我们可以把每个节点中连接出点和入点的那条边的权值设置为。这样子如果我们想要删除这个节点,就可以付出的代价,把这条边切断,这个点也就被删除了。
那么问题就来了,这道题目中明确的说明了有一些节点是不能删除的,如果都把权值设置成,如何处理不能删除的节点呢?
对于这些关键节点(不能删除的点),我们可以把他们的内部权值设置成 ,这样就不会把这些点删掉了(最小割算法计算的是付出最小的代价使图变得不连通,而设置成 会让删掉这个点变得很不合算)。
另外,我们需要注意,除了题目中说的关键点,源点和汇点也是不能删除的,所以在建图的时候需要处理一下。并且题目让我们求的是最少删去多少个节点,所以连接这些节点的边也是不能删除的,需要把容量设置成 。
解决了边的容量问题后我们再来考虑源点和汇点,我们可以把号节点设置成源点,把所有关键点连接到汇点上,这样子求出的答案就是让所有关键点都到达不了号节点的最小删除节点数(如果任何一个关键点可以到达一号节点,那么汇点也可以到达号节点)。
我采用的是dinic算法,因为每次增广可以找到多条增广路,所以算法的速度会比EK算法高一些,不熟悉这个算法的同学可以去看一下之前提到的最大流模板题的题解。
在实现拆分节点这个操作的时候,我们可以把一个节点的入点的编号设置成它本身的编号,而出点的编号就设置成本身的编号 + (节点总数),这样子可以确保不会重复。
在实现dinic算法时,需要进行对反向边的操作,我使用的是STL vector来存边,因此需要在node结构体中加入rev(reverse)变量,记录当前边的反向边的下标。
#include <bits/stdc++.h>using namespace std;const int MAXM = 100000;const int INF = 0x3f3f3f3f;struct node{ int to, mflow, rev; //to连接的下个节点的编号 //mflow(maxflow)记录当前边的容量 //rev(reverse)记录当前边的反向边的下标};int p, c, n, s, t;vector<node> edge[MAXM];int g_farm[MAXM]; //完好的农场(关键点)int layer[MAXM]; //每个节点的层数node assign_node(int to, int mflow, int rev) //赋值函数{ node temp; temp.to = to, temp.mflow = mflow, temp.rev = rev; return temp;}void add_edge(int from, int to, int mflow) //加边{ edge[from].push_back(assign_node(to, mflow, edge[to].size())); //不需要-1是因为edge[to]还没有push过 from这个节点 edge[to].push_back(assign_node(from, 0, edge[from].size() - 1)); //-1是因为vector的下标是从0开始的,而.size()会返回元素的数量}namespace dinic{ bool layering() //分层 { bool vis[MAXM]; memset(vis, false, sizeof(vis)); memset(layer, 0, sizeof(layer)); queue<int> q; vis[s] = true; layer[s] = 1; q.push(s); while (!q.empty()) { int cur = q.front(); q.pop(); for (auto nex : edge[cur]) //c++11的新特性,意思是用nex遍历edge[cur]中的所有元素 { if (nex.mflow > 0 && vis[nex.to] == false) { layer[nex.to] = layer[cur] + 1; q.push(nex.to); vis[nex.to] = true; } } } return (layer[t] != 0); //返回分层操作是否成功(是否能从源点到达汇点) } int find_aug_path(int cur, int cur_flow) //寻找增广路 { if (cur == t) { return cur_flow; } int ans = 0; for (int i = 0; i < int(edge[cur].size()); i++) { if (edge[cur][i].mflow > 0 && layer[edge[cur][i].to] == layer[cur] + 1) { int nex_flow = find_aug_path(edge[cur][i].to, min(cur_flow, edge[cur][i].mflow)); edge[cur][i].mflow -= nex_flow; //正向边 edge[edge[cur][i].to][edge[cur][i].rev].mflow += nex_flow; //反向边 cur_flow -= nex_flow; ans += nex_flow; if (cur_flow <= 0) //如果当前的容量已经不够了,就直接返回来节省时间 { return ans; } } } return ans; } int find_maxflow() { int ans = 0; while (layering()) { ans += find_aug_path(s, INF); } return ans; }}void input_creat() //输入和建图{ scanf("%d%d%d", &p, &c, &n); s = 0, t = 2 * p + 1; add_edge(0, 1, INF); //再搞一个点接入源点的入点,容量也要设成INF add_edge(1, 1 + p, INF); //源点的入点和出点设置成INF for (int i = 1; i <= c; i++) { int from, to; scanf("%d%d", &from, &to); add_edge(from + p, to, INF); //from的出点和to的入点相连 add_edge(to + p, from, INF); //to的出点和from的入点相连 } for (int i = 1; i <= n; i++) { //n是不能割的点 int point; scanf("%d", &point); add_edge(point + p, t, INF); //把所有关键点连接到汇点 add_edge(point, point + p, INF); //所有关键点的内部边的容量都要设成INF g_farm[point] = 1; //标记关键点 } for (int i = 2; i <= p; i++) { if (!g_farm[i]) { add_edge(i, i + p, 1); //除了关键点的其他点可以删,所以内部边的容量设成1 } }}main(){ input_creat(); printf("%d", dinic::find_maxflow()); system("pause");}
第一次写题解,问题可能比较多,如果看到题解有什么不对的欢迎在评论区提出,或者私信我,有看不懂的地方也欢迎提问。最后,如果这篇题解对你有帮助就点个赞吧,或者在评论区中交流你的看法。
]]>