[THUSC 2019工程] B. 常见网络校验和计算

让我们在 上一题的基础上 继续扩充。

首先,实践发现使用 operator>> 从输入流中获取文件的方法并不是非常好用。我将上一题的四个结构体添加了构造函数:

struct pcap_hdr {
    // 所有字段都是大端序
    uint32_t magic_number;// 用于文件类型识别,始终为 0xA1B2C3D4,
    uint16_t version_major;// 始终为 2
    uint16_t version_minor;// 始终为 4
    int32_t thiszone;// 始终为 0
    uint32_t sigfigs;// 始终为 0
    uint32_t snaplen;// 允许的最大包长度,始终为 262144
    uint32_t network;// 数据类型,本次学习题中始终为 1 (以太网)
    pcap_hdr(std::vector<uint8_t>::const_iterator begin):
        magic_number((begin[0]<<24)+(begin[1]<<16)+(begin[2]<<8)+(begin[3]<<0)),
        version_major((begin[4]<<8)+(begin[5]<<0)),
        version_minor((begin[6]<<8)+(begin[7]<<0)),
        thiszone((begin[8]<<24)+(begin[9]<<16)+(begin[10]<<8)+(begin[11]<<0)),
        sigfigs((begin[12]<<24)+(begin[13]<<16)+(begin[14]<<8)+(begin[15]<<0)),
        snaplen((begin[16]<<24)+(begin[17]<<16)+(begin[18]<<8)+(begin[19]<<0)),
        network((begin[20]<<24)+(begin[21]<<16)+(begin[22]<<8)+(begin[23]<<0)){
        assert(magic_number==0xA1B2C3D4&&
               version_major==2&&
               version_minor==4&&
               thiszone==0&&
               sigfigs==0&&
               snaplen==262144&&
               network==1
        );
    }
    friend std::ostream& operator<<(std::ostream& out,pcap_hdr data){
        data.magic_number=htonl(data.magic_number),
        data.version_major=htons(data.version_major),
        data.version_minor=htons(data.version_minor),
        data.thiszone=htonl(data.thiszone),
        data.sigfigs=htonl(data.sigfigs),
        data.snaplen=htonl(data.snaplen),
        data.network=htonl(data.network);
        out.write((char*)&data,sizeof(pcap_hdr));
        return out;
    }
};
struct pcaprec_hdr{
    // 所有字段都是大端序
    uint32_t ts_sec;// 时间戳(秒)
    uint32_t ts_usec;// 时间戳(微秒)
    uint32_t incl_len;// 该片段的存储长度
    uint32_t orig_len;// 该片段实际的长度
    pcaprec_hdr(std::vector<uint8_t>::const_iterator begin):
        ts_sec((begin[0]<<24)+(begin[1]<<16)+(begin[2]<<8)+(begin[3]<<0)),
        ts_usec((begin[4]<<24)+(begin[5]<<16)+(begin[6]<<8)+(begin[7]<<0)),
        incl_len((begin[8]<<24)+(begin[9]<<16)+(begin[10]<<8)+(begin[11]<<0)),
        orig_len((begin[12]<<24)+(begin[13]<<16)+(begin[14]<<8)+(begin[15]<<0)){
        assert(ts_usec<1e6&&incl_len==orig_len);
    }
    friend std::ostream& operator<<(std::ostream& out,pcaprec_hdr data){
        data.incl_len=htonl(data.incl_len),
        data.orig_len=htonl(data.orig_len),
        data.ts_sec=htonl(data.ts_sec),
        data.ts_usec=htonl(data.ts_usec);
        out.write((char*)&data,sizeof(pcaprec_hdr));
        return out;
    }
};
struct pcaprec{
    pcaprec_hdr header;
    std::vector<uint8_t> data;
    pcaprec(std::vector<uint8_t>::const_iterator begin):
        header(begin),data(begin+16,begin+16+header.orig_len){}
    friend std::ostream& operator<<(std::ostream& out,pcaprec const& data){
        out<<data.header;
        out.write((char*)data.data.data(),data.header.incl_len);
        return out;
    }
};
struct pcap{
    pcap_hdr header;
    std::vector<pcaprec> data;
    pcap(std::vector<uint8_t>::const_iterator begin,std::vector<uint8_t>::const_iterator end):
        header(begin),data(){
        begin+=24;
        while (begin!=end)
            data.emplace_back(begin),
            begin+=data.back().header.orig_len+16;
    }
    pcap(pcap_hdr const& p):header(p),data(){}
    friend std::ostream& operator<<(std::ostream& out,pcap const& data){
        out<<data.header;
        for (pcaprec const& i:data.data) out<<i;
        return out;
    }
};

有了构造函数的类型就不再 trivial 了,不再建议直接操作内存 虽然我没有找到相关标准。因此,所有成员变量都直接从字符拼接而来。输出是写到缓冲区内存里,不受影响。

我们继续按照题目要求完成以太网协议的内容,所有的成员变量同样通过手动拼接的方式赋值:

struct ethernet_frame{
    uint8_t destinationMAC[6],sourceMAC[6];
    enum class EtherType:uint16_t{
        ipv4=0x0800,arp=0x0806
    } etherType;
    std::vector<uint8_t> data;
    uint32_t fcs;

    struct CheckError{};
    
    ethernet_frame(pcaprec const& f):ethernet_frame(f.data.begin(),f.data.end()){}
    ethernet_frame(std::vector<uint8_t>::const_iterator beg,std::vector<uint8_t>::const_iterator end):
        destinationMAC{beg[0],beg[1],beg[2],beg[3],beg[4],beg[5]},
        sourceMAC{beg[6],beg[7],beg[8],beg[9],beg[10],beg[11]},
        etherType(EtherType((beg[12]<<8)+(beg[13]<<0))),
        data(beg+14,end-4),
        fcs((end[-4]<<24)+(end[-3]<<16)+(end[-2]<<8)+(end[-1]<<0)) {
        if (!check(std::vector<uint8_t>(beg,end)))
            throw CheckError();
    }
};

题目要求我们实现 FCS 校验。仔细阅读手册,我们发现算法的实质实际为:

  1. 将每个字节的位进行反转;
  2. 将整张帧的最后 32 位(原先的 FCS)保存下来,然后填充为 0;
  3. 前 32 位取反;
  4. G=0b100000100110000010001110110110111 与数据首位对齐,重复执行直到 G 的末位超过数据:
    1. 若 G 首位对应的数据位为 1,则将数据与 G 做一次异或;
    2. G 向后移动一位;
  5. 显然在循环结束后除末尾的 32 位其他必定为 0。将最后 32 位提出,反转,取反,即我们计算得到的 FCS。

check 就是按照这个流程检查 FCS 的静态成员函数。若检查失败使用 throw 抛出定义好的错误供外面捕获 闲的。由于要修改数据,参数直接使用了值传递而非常规的常量引用。

/* 反转字节内的位。大体步骤为:
   01234567
-> 交换 0/1,2/3,4/5,6/7 得到 10325476
-> 交换 10/32,54/67 得到 32107654
-> 交换 3210/7654 得到 76543210
*/
u_int8_t bitreverse(uint8_t x){
    x=((x&0xaa)>>1)|((x&0x55)<<1);
    x=((x&0xcc)>>2)|((x&0x33)<<2);
    x=((x&0xf0)>>4)|((x&0x0f)<<4);
    return x;
}
static bool ethernet_frame::check_bit(std::vector<uint8_t> f){
    static constexpr int64_t G=0b100000100110000010001110110110111;
    uint32_t fcs;std::copy_n(f.end()-4,4,(uint8_t*)&fcs),fcs=htonl(fcs);
    std::fill_n(f.end()-4,4,0);
    for (uint8_t& i:f) i=bitreverse(i);
    f[0]=~f[0],f[1]=~f[1],f[2]=~f[2],f[3]=~f[3];
    for (size_t i=4;i<f.size();++i) for (size_t j=0;j<8;++j)
        if (f[i-4]>>(8-1-j)&1)
            f[i-4]^=uint8_t(G>>(33-8+j)),
            f[i-3]^=uint8_t(G>>(33-16+j)),
            f[i-2]^=uint8_t(G>>(33-24+j)),
            f[i-1]^=uint8_t(G>>(33-32+j)),
            f[i-0]^=uint8_t(G<<(8-j-1));
    uint32_t res=~(
        (bitreverse((uint32_t)f.rbegin()[3])<<24u)+
        (bitreverse((uint32_t)f.rbegin()[2])<<16u)+
        (bitreverse((uint32_t)f.rbegin()[1])<<8u)+
        (bitreverse((uint32_t)f.rbegin()[0])<<0u)
    );
    return fcs==res;
}

main 函数如下所示。std::istreambuf_iterator 负责从缓冲区读取二进制数据,传入一个 basic_istream 实例构造迭代器,迭代器每前进一次读取一字节;不传参表示 EOF。std::vector 接受这两个迭代器来构造元素。 try 捕获构造时产生的错误。

int main(void){
    std::ios::sync_with_stdio(false),std::cin.tie(nullptr),std::cout.tie(nullptr);
    std::vector<uint8_t> buf((std::istreambuf_iterator<char>(fin)),std::istreambuf_iterator<char>());
    pcap f(buf.begin(),buf.end());
    for (auto const& i:f.data)
        try{
            ethernet_frame x(i);
            std::cout<<"Yes\n";
        }catch (ethernet_frame::FrameCheckError const&){
            std::cout<<"No\n";
        }
    return 0;
}

以上,我们就得到了 8 分的好成绩,可喜可贺!


题目还要求识别 IPv4 协议并校验。我定义结构体如下:

struct ipgroup_hdr{
    uint8_t version:4; // 始终为 0b0100
    uint8_t ihl:4;
    uint8_t type; // 始终为 0
    uint16_t total_length;
    uint16_t identification;
    uint8_t flag:3; // 始终为 0b010
    uint16_t offset:13; // 始终为 0
    uint8_t time_to_live;
    uint8_t protocol;
    uint16_t header_checksum;
    uint8_t source_ip[4],destination_ip[4];
    std::vector<uint8_t> options;

    struct CheckError{};

    ipgroup_hdr(ethernet_frame const& f):ipgroup_hdr(f.data.begin()){}
    ipgroup_hdr(std::vector<uint8_t>::const_iterator beg):
        version(beg[0]>>4),ihl(beg[0]&0xf),type(beg[1]),
        total_length((beg[2]<<8)+(beg[3]<<0)),
        identification((beg[4]<<8)+(beg[5]<<0)),
        flag(beg[6]>>5),
        offset(((beg[6]&0b00011111)<<8)+(beg[7]<<0)),
        time_to_live(beg[8]),protocol(beg[9]),
        header_checksum((beg[10]<<8)+(beg[11]<<0)),
        source_ip{beg[12],beg[13],beg[14],beg[15]},
        destination_ip{beg[16],beg[17],beg[18],beg[19]},
        options(beg+20,beg+ihl*4){
        assert(
            version==0b0100&&type==0&&
            flag==0b010&&offset==0&&
            identification==0&&ihl>=5&&total_length>=20
        );
        if (!check(beg)) throw CheckError();
    }
    static bool check(std::vector<uint8_t>::const_iterator beg){
        size_t n=(beg[0]&0xf)*4;
        uint32_t sum=0;
        for (size_t i=0;i<n;i+=2) if (i!=10){
            sum+=(beg[i+0]<<8)+(beg[i+1]<<0);
            uint16_t x;
            while ((x=(sum>>16))) sum&=0xffff,sum+=x;
        }
        return uint16_t(~sum)==uint16_t((beg[10]<<8)+(beg[11]<<0));
    }
};

这里玩了玩一点小小的 C++ 特性 位域,不过毕竟是构造函数手动初始化所以不用也没关系 还是闲的。此处 check 就相对简单了,不再赘述。

main 函数添加对 ipgroup 的 catch:

int main(void){
    std::ios::sync_with_stdio(false),std::cin.tie(nullptr),std::cout.tie(nullptr);
    std::vector<uint8_t> buf((std::istreambuf_iterator<char>(fin)),std::istreambuf_iterator<char>());
    pcap f(buf.begin(),buf.end());
    for (auto const& i:f.data)
        try{
            ethernet_frame a(i);
            if (a.etherType==ethernet_frame::EtherType::ipv4)
                ipgroup b(a);
            std::cout<<"Yes\n";
        }catch (ethernet_frame::CheckError const&){
            std::cout<<"No\n";
        }catch (ipgroup_hdr::CheckError const&){
            std::cout<<"No\n";
        }
    return 0;
}

这样,我们就有 32 分的好成绩了!


我们会发现 T 了最后一个点。再回题面看一眼:。我们在以太网协议的校验是逐位计算的,循环次数还要乘 8,因此在如此极限的情况下会 T 飞了。我们还需要一些小小的优化。

注意到,每一位是否进行异或 G 操作,仅由当前位的值决定。由因为异或的结合律,我们可以以字节为单位计算。实现一个函数,传入一个 byte,返回 5 位分别应该异或多少。使用记忆化优化,就可以压掉 8 倍的常数。最终时间为 550ms。

static uint64_t query_check_byte(uint8_t x){
    static std::array<uint64_t,256> mem={};
    static constexpr int64_t G=0b100000100110000010001110110110111;
    if (mem[x]) return mem[x];
    if (x==0) return 0;
    uint8_t y=x;
    for (uint8_t i=7;i<8;--i) if ((x>>i)&1)
        x^=uint8_t(G<<i>>32),
        mem[y]^=G<<i;
    return mem[y];
}
static bool check(std::vector<uint8_t> f){
    uint32_t fcs;std::copy_n(f.end()-4,4,(uint8_t*)&fcs),fcs=htonl(fcs);
    std::fill_n(f.end()-4,4,0);
    for (uint8_t& i:f) i=bitreverse(i);
    f[0]=~f[0],f[1]=~f[1],f[2]=~f[2],f[3]=~f[3];
    for (size_t i=4;i<f.size();++i){
        int64_t x=query_check_byte(f[i-4]);
        f[i-4]^=uint8_t(x>>(40-8)),
        f[i-3]^=uint8_t(x>>(40-16)),
        f[i-2]^=uint8_t(x>>(40-24)),
        f[i-1]^=uint8_t(x>>(40-32)),
        f[i-0]^=uint8_t(x>>(40-40));
    }
    uint32_t res=~(
        (bitreverse((uint32_t)f.rbegin()[3])<<24u)+
        (bitreverse((uint32_t)f.rbegin()[2])<<16u)+
        (bitreverse((uint32_t)f.rbegin()[1])<<8u)+
        (bitreverse((uint32_t)f.rbegin()[0])<<0u)
    );
    return fcs==res;
}

呼!终于干满了。完整代码 共 228 行。

posted @ 2024-04-11 18:51  MrPython  阅读(4)  评论(0)    收藏  举报  来源