[THUSC 2019工程] B. 常见网络校验和计算
让我们在 上一题的基础上 继续扩充。
首先,实践发现使用 operator>>
从输入流中获取文件的方法并不是非常好用。我将上一题的四个结构体添加了构造函数:
struct pcap_hdr {
// 所有字段都是大端序
uint32_t magic_number;// 用于文件类型识别,始终为 0xA1B2C3D4,
uint16_t version_major;// 始终为 2
uint16_t version_minor;// 始终为 4
int32_t thiszone;// 始终为 0
uint32_t sigfigs;// 始终为 0
uint32_t snaplen;// 允许的最大包长度,始终为 262144
uint32_t network;// 数据类型,本次学习题中始终为 1 (以太网)
pcap_hdr(std::vector<uint8_t>::const_iterator begin):
magic_number((begin[0]<<24)+(begin[1]<<16)+(begin[2]<<8)+(begin[3]<<0)),
version_major((begin[4]<<8)+(begin[5]<<0)),
version_minor((begin[6]<<8)+(begin[7]<<0)),
thiszone((begin[8]<<24)+(begin[9]<<16)+(begin[10]<<8)+(begin[11]<<0)),
sigfigs((begin[12]<<24)+(begin[13]<<16)+(begin[14]<<8)+(begin[15]<<0)),
snaplen((begin[16]<<24)+(begin[17]<<16)+(begin[18]<<8)+(begin[19]<<0)),
network((begin[20]<<24)+(begin[21]<<16)+(begin[22]<<8)+(begin[23]<<0)){
assert(magic_number==0xA1B2C3D4&&
version_major==2&&
version_minor==4&&
thiszone==0&&
sigfigs==0&&
snaplen==262144&&
network==1
);
}
friend std::ostream& operator<<(std::ostream& out,pcap_hdr data){
data.magic_number=htonl(data.magic_number),
data.version_major=htons(data.version_major),
data.version_minor=htons(data.version_minor),
data.thiszone=htonl(data.thiszone),
data.sigfigs=htonl(data.sigfigs),
data.snaplen=htonl(data.snaplen),
data.network=htonl(data.network);
out.write((char*)&data,sizeof(pcap_hdr));
return out;
}
};
struct pcaprec_hdr{
// 所有字段都是大端序
uint32_t ts_sec;// 时间戳(秒)
uint32_t ts_usec;// 时间戳(微秒)
uint32_t incl_len;// 该片段的存储长度
uint32_t orig_len;// 该片段实际的长度
pcaprec_hdr(std::vector<uint8_t>::const_iterator begin):
ts_sec((begin[0]<<24)+(begin[1]<<16)+(begin[2]<<8)+(begin[3]<<0)),
ts_usec((begin[4]<<24)+(begin[5]<<16)+(begin[6]<<8)+(begin[7]<<0)),
incl_len((begin[8]<<24)+(begin[9]<<16)+(begin[10]<<8)+(begin[11]<<0)),
orig_len((begin[12]<<24)+(begin[13]<<16)+(begin[14]<<8)+(begin[15]<<0)){
assert(ts_usec<1e6&&incl_len==orig_len);
}
friend std::ostream& operator<<(std::ostream& out,pcaprec_hdr data){
data.incl_len=htonl(data.incl_len),
data.orig_len=htonl(data.orig_len),
data.ts_sec=htonl(data.ts_sec),
data.ts_usec=htonl(data.ts_usec);
out.write((char*)&data,sizeof(pcaprec_hdr));
return out;
}
};
struct pcaprec{
pcaprec_hdr header;
std::vector<uint8_t> data;
pcaprec(std::vector<uint8_t>::const_iterator begin):
header(begin),data(begin+16,begin+16+header.orig_len){}
friend std::ostream& operator<<(std::ostream& out,pcaprec const& data){
out<<data.header;
out.write((char*)data.data.data(),data.header.incl_len);
return out;
}
};
struct pcap{
pcap_hdr header;
std::vector<pcaprec> data;
pcap(std::vector<uint8_t>::const_iterator begin,std::vector<uint8_t>::const_iterator end):
header(begin),data(){
begin+=24;
while (begin!=end)
data.emplace_back(begin),
begin+=data.back().header.orig_len+16;
}
pcap(pcap_hdr const& p):header(p),data(){}
friend std::ostream& operator<<(std::ostream& out,pcap const& data){
out<<data.header;
for (pcaprec const& i:data.data) out<<i;
return out;
}
};
有了构造函数的类型就不再 trivial 了,不再建议直接操作内存 虽然我没有找到相关标准。因此,所有成员变量都直接从字符拼接而来。输出是写到缓冲区内存里,不受影响。
我们继续按照题目要求完成以太网协议的内容,所有的成员变量同样通过手动拼接的方式赋值:
struct ethernet_frame{
uint8_t destinationMAC[6],sourceMAC[6];
enum class EtherType:uint16_t{
ipv4=0x0800,arp=0x0806
} etherType;
std::vector<uint8_t> data;
uint32_t fcs;
struct CheckError{};
ethernet_frame(pcaprec const& f):ethernet_frame(f.data.begin(),f.data.end()){}
ethernet_frame(std::vector<uint8_t>::const_iterator beg,std::vector<uint8_t>::const_iterator end):
destinationMAC{beg[0],beg[1],beg[2],beg[3],beg[4],beg[5]},
sourceMAC{beg[6],beg[7],beg[8],beg[9],beg[10],beg[11]},
etherType(EtherType((beg[12]<<8)+(beg[13]<<0))),
data(beg+14,end-4),
fcs((end[-4]<<24)+(end[-3]<<16)+(end[-2]<<8)+(end[-1]<<0)) {
if (!check(std::vector<uint8_t>(beg,end)))
throw CheckError();
}
};
题目要求我们实现 FCS 校验。仔细阅读手册,我们发现算法的实质实际为:
- 将每个字节的位进行反转;
- 将整张帧的最后 32 位(原先的 FCS)保存下来,然后填充为 0;
- 前 32 位取反;
- 将
G=0b100000100110000010001110110110111
与数据首位对齐,重复执行直到 G 的末位超过数据:- 若 G 首位对应的数据位为 1,则将数据与 G 做一次异或;
- G 向后移动一位;
- 显然在循环结束后除末尾的 32 位其他必定为 0。将最后 32 位提出,反转,取反,即我们计算得到的 FCS。
check
就是按照这个流程检查 FCS 的静态成员函数。若检查失败使用 throw
抛出定义好的错误供外面捕获 闲的。由于要修改数据,参数直接使用了值传递而非常规的常量引用。
/* 反转字节内的位。大体步骤为:
01234567
-> 交换 0/1,2/3,4/5,6/7 得到 10325476
-> 交换 10/32,54/67 得到 32107654
-> 交换 3210/7654 得到 76543210
*/
u_int8_t bitreverse(uint8_t x){
x=((x&0xaa)>>1)|((x&0x55)<<1);
x=((x&0xcc)>>2)|((x&0x33)<<2);
x=((x&0xf0)>>4)|((x&0x0f)<<4);
return x;
}
static bool ethernet_frame::check_bit(std::vector<uint8_t> f){
static constexpr int64_t G=0b100000100110000010001110110110111;
uint32_t fcs;std::copy_n(f.end()-4,4,(uint8_t*)&fcs),fcs=htonl(fcs);
std::fill_n(f.end()-4,4,0);
for (uint8_t& i:f) i=bitreverse(i);
f[0]=~f[0],f[1]=~f[1],f[2]=~f[2],f[3]=~f[3];
for (size_t i=4;i<f.size();++i) for (size_t j=0;j<8;++j)
if (f[i-4]>>(8-1-j)&1)
f[i-4]^=uint8_t(G>>(33-8+j)),
f[i-3]^=uint8_t(G>>(33-16+j)),
f[i-2]^=uint8_t(G>>(33-24+j)),
f[i-1]^=uint8_t(G>>(33-32+j)),
f[i-0]^=uint8_t(G<<(8-j-1));
uint32_t res=~(
(bitreverse((uint32_t)f.rbegin()[3])<<24u)+
(bitreverse((uint32_t)f.rbegin()[2])<<16u)+
(bitreverse((uint32_t)f.rbegin()[1])<<8u)+
(bitreverse((uint32_t)f.rbegin()[0])<<0u)
);
return fcs==res;
}
main 函数如下所示。std::istreambuf_iterator
负责从缓冲区读取二进制数据,传入一个 basic_istream
实例构造迭代器,迭代器每前进一次读取一字节;不传参表示 EOF。std::vector
接受这两个迭代器来构造元素。 try 捕获构造时产生的错误。
int main(void){
std::ios::sync_with_stdio(false),std::cin.tie(nullptr),std::cout.tie(nullptr);
std::vector<uint8_t> buf((std::istreambuf_iterator<char>(fin)),std::istreambuf_iterator<char>());
pcap f(buf.begin(),buf.end());
for (auto const& i:f.data)
try{
ethernet_frame x(i);
std::cout<<"Yes\n";
}catch (ethernet_frame::FrameCheckError const&){
std::cout<<"No\n";
}
return 0;
}
以上,我们就得到了 8 分的好成绩,可喜可贺!
题目还要求识别 IPv4 协议并校验。我定义结构体如下:
struct ipgroup_hdr{
uint8_t version:4; // 始终为 0b0100
uint8_t ihl:4;
uint8_t type; // 始终为 0
uint16_t total_length;
uint16_t identification;
uint8_t flag:3; // 始终为 0b010
uint16_t offset:13; // 始终为 0
uint8_t time_to_live;
uint8_t protocol;
uint16_t header_checksum;
uint8_t source_ip[4],destination_ip[4];
std::vector<uint8_t> options;
struct CheckError{};
ipgroup_hdr(ethernet_frame const& f):ipgroup_hdr(f.data.begin()){}
ipgroup_hdr(std::vector<uint8_t>::const_iterator beg):
version(beg[0]>>4),ihl(beg[0]&0xf),type(beg[1]),
total_length((beg[2]<<8)+(beg[3]<<0)),
identification((beg[4]<<8)+(beg[5]<<0)),
flag(beg[6]>>5),
offset(((beg[6]&0b00011111)<<8)+(beg[7]<<0)),
time_to_live(beg[8]),protocol(beg[9]),
header_checksum((beg[10]<<8)+(beg[11]<<0)),
source_ip{beg[12],beg[13],beg[14],beg[15]},
destination_ip{beg[16],beg[17],beg[18],beg[19]},
options(beg+20,beg+ihl*4){
assert(
version==0b0100&&type==0&&
flag==0b010&&offset==0&&
identification==0&&ihl>=5&&total_length>=20
);
if (!check(beg)) throw CheckError();
}
static bool check(std::vector<uint8_t>::const_iterator beg){
size_t n=(beg[0]&0xf)*4;
uint32_t sum=0;
for (size_t i=0;i<n;i+=2) if (i!=10){
sum+=(beg[i+0]<<8)+(beg[i+1]<<0);
uint16_t x;
while ((x=(sum>>16))) sum&=0xffff,sum+=x;
}
return uint16_t(~sum)==uint16_t((beg[10]<<8)+(beg[11]<<0));
}
};
这里玩了玩一点小小的 C++ 特性 位域,不过毕竟是构造函数手动初始化所以不用也没关系 还是闲的。此处 check 就相对简单了,不再赘述。
main 函数添加对 ipgroup 的 catch:
int main(void){
std::ios::sync_with_stdio(false),std::cin.tie(nullptr),std::cout.tie(nullptr);
std::vector<uint8_t> buf((std::istreambuf_iterator<char>(fin)),std::istreambuf_iterator<char>());
pcap f(buf.begin(),buf.end());
for (auto const& i:f.data)
try{
ethernet_frame a(i);
if (a.etherType==ethernet_frame::EtherType::ipv4)
ipgroup b(a);
std::cout<<"Yes\n";
}catch (ethernet_frame::CheckError const&){
std::cout<<"No\n";
}catch (ipgroup_hdr::CheckError const&){
std::cout<<"No\n";
}
return 0;
}
这样,我们就有 32 分的好成绩了!
我们会发现 T 了最后一个点。再回题面看一眼:。我们在以太网协议的校验是逐位计算的,循环次数还要乘 8,因此在如此极限的情况下会 T 飞了。我们还需要一些小小的优化。
注意到,每一位是否进行异或 G 操作,仅由当前位的值决定。由因为异或的结合律,我们可以以字节为单位计算。实现一个函数,传入一个 byte,返回 5 位分别应该异或多少。使用记忆化优化,就可以压掉 8 倍的常数。最终时间为 550ms。
static uint64_t query_check_byte(uint8_t x){
static std::array<uint64_t,256> mem={};
static constexpr int64_t G=0b100000100110000010001110110110111;
if (mem[x]) return mem[x];
if (x==0) return 0;
uint8_t y=x;
for (uint8_t i=7;i<8;--i) if ((x>>i)&1)
x^=uint8_t(G<<i>>32),
mem[y]^=G<<i;
return mem[y];
}
static bool check(std::vector<uint8_t> f){
uint32_t fcs;std::copy_n(f.end()-4,4,(uint8_t*)&fcs),fcs=htonl(fcs);
std::fill_n(f.end()-4,4,0);
for (uint8_t& i:f) i=bitreverse(i);
f[0]=~f[0],f[1]=~f[1],f[2]=~f[2],f[3]=~f[3];
for (size_t i=4;i<f.size();++i){
int64_t x=query_check_byte(f[i-4]);
f[i-4]^=uint8_t(x>>(40-8)),
f[i-3]^=uint8_t(x>>(40-16)),
f[i-2]^=uint8_t(x>>(40-24)),
f[i-1]^=uint8_t(x>>(40-32)),
f[i-0]^=uint8_t(x>>(40-40));
}
uint32_t res=~(
(bitreverse((uint32_t)f.rbegin()[3])<<24u)+
(bitreverse((uint32_t)f.rbegin()[2])<<16u)+
(bitreverse((uint32_t)f.rbegin()[1])<<8u)+
(bitreverse((uint32_t)f.rbegin()[0])<<0u)
);
return fcs==res;
}
呼!终于干满了。完整代码 共 228 行。