基于同态加密的PSI开源库-1

下面介绍一个PSI的开源库,还原论文:CCS2017:Fast Private Set Intersection from Homomorphic EncryptionCCS2018:Labeled PSI from Fully Homomorphic Encryption with Malicious Security
这两篇论文设计了基于同态加密的PSI协议(CCS2017设计了基于同态加密的非平衡场景的PSI协议;CCS2018对CCS2017中协议进行改进,并提出了非平衡场景下的Labeled-PSI协议。

安装

地址
注:该库没有经过安全审核,只能用作实验。

安装前提:

1、需要安装好Seal库,用于同态加密
2、需要gcc和cmake环境
3、需要C++的库:Boost,介绍和安装请参考:C++:Boost库

安装

// 先下载
git clone https://github.com/aleksejspopovs/6857-private-categorization. 

//进入到目录进行编译
cd src
cmake .
make

//如果seal和boost已经安装好,可以选用下面的编译方式
cd src
cmake -DCMAKE_PREFIX_PATH=/path/to/seal .
cmake -DCMAKE_PREFIX_PATH=/path/to/boost .
make

编译完成后,在/src/bin下会有生成的可执行文件:

其中,bin/benchmark用于测试给定参数的协议性能;bin/private_categorization是PSI运行的一个实例(没有中间输出);private_categorization_debug_entropy有中间参数输出;bin/pc_clientbin/pc_server是通过网络进行交互的一个协议运行实例。

使用

说明

(1)在main中可以看出,label模式是默认关闭的;
(2)这里对Label的加密使用的是AES(对称加密)和论文中对应

主要模块

(1)window:用于receiver计算部分次幂
(2)network:用于sender和receiver通信
(3)polynomials:用于多项式计算和编码解码(打包)
(4)hash:用于cuckoo hash
(5)aes:用于label加解密
(6)random:用于随机数生成
(7)psi:定义参数集,sender和reciver类

分析

main

int main()
{
    //随机数生成器
    auto random_factory = UniformRandomGeneratorFactory::default_factory();
    auto random = random_factory->create();

    size_t receiver_N = 10;//5535;receiver数据集大小
    size_t sender_N = 100; //1ull << 24;sender数据集大小
    size_t input_bits = 32;//item和label的长度
    size_t poly_modulus_degree = 8192;//多项式次数N
    size_t partition_count = 2;//256;sender分块数
    size_t window_size = 1;//分窗大小l,如l=1时,则receiver计算偶次幂
    bool labeled = false;//label模式

    vector<uint64_t> sender_inputs(sender_N);//定义sende数据集中的item
    vector<uint64_t> sender_labels(sender_N);//定义sende数据集中的label
    vector<uint64_t> receiver_inputs(receiver_N);//定义receiver数据集中的item

    //产生随机数,创建数据集
    generate_random_sender_set(random, sender_inputs, input_bits);
    if (labeled) {
        generate_random_labels(random, sender_labels, input_bits);
    }
    generate_random_receiver_set(random, receiver_inputs, sender_inputs, input_bits, 50);

    // step 1: 双方协商参数
    PSIParams params(receiver_inputs.size(), sender_inputs.size(), input_bits, poly_modulus_degree);
    params.set_sender_partition_count(partition_count);
    params.set_window_size(window_size);
    params.generate_seeds();

    cout << "Parameters chosen:" << endl;
    cout << "  - sender set size: " << params.sender_size << endl;
    cout << "  - receiver set size: " << params.receiver_size << endl;
    cout << "  - element bit length: " << params.input_bits << endl;
    cout << "  - # of hash functions: " << params.hash_functions() << endl;//cuckoo hash用到hash函数的个数
    cout << "  - hash seeds: ";
    for (auto seed : params.seeds) cout << seed << " ";//三个hash函数的随机数种子
    cout << endl;

    cout << "  - log(bucket count), bucket_count: "
         << params.bucket_count_log() << " " << (1ull << params.bucket_count_log()) << endl;//hash桶的个数
    cout << "  - sender bucket capacity: " << params.sender_bucket_capacity() << endl;//sender方hash桶的容量
    cout << endl;

    // all integers are going to be printed as hex now
    cout << hex;

    // step 2: receiver 生成密钥并使用密钥输入
    PSIReceiver user(params);
    cout << "User's set: ";//输出receiver得数据集
    for (uint64_t x : receiver_inputs) {
        cout << x << " ";
    }
    cout << endl;

    vector<bucket_slot> receiver_buckets;//receiver定义hash表
    //加密(hash,编码,window,加密部分次幂)
    auto receiver_encrypted_inputs = user.encrypt_inputs(receiver_inputs, receiver_buckets);

    //输出receiver数据再hash表的位置
    cout << "User's buckets: ";
    for (auto x : receiver_buckets) {
        if (x == BUCKET_EMPTY) {
            cout << "--:-- ";
        } else {
            cout << x.first << ":" << receiver_inputs[x.first] << " ";
        }
    }
    cout << endl;
    cout << endl;


    // step 3: sender计算求交多项式(收到公钥和密文后)
    PSISender server(params);
    optional<vector<uint64_t>> labels;
    if (labeled) {
        labels = sender_labels;
    }
    //多项式计算(hash,partion,打包,计算全部次幂,计算多项式,label的插值计算)
    auto sender_matches = server.compute_matches(
        sender_inputs,
        labels,
        user.public_key(),
        user.relin_keys(),
        receiver_encrypted_inputs
    );
    //打印sender的数据集(item和label)
    cout << "Sender's set: ";
    for (size_t i = 0; i < sender_inputs.size(); i++) {
        cout << sender_inputs[i] << "-" << sender_labels[i] << " ";
    }
    cout << endl;
    cout << endl;

    // step 4: receiver开始解密计算结果
    if (labeled) {//label模式下
        vector<pair<size_t, uint64_t>> receiver_labeled_matches;
        receiver_labeled_matches = user.decrypt_labeled_matches(sender_matches);//解密结果(解密,拼接,匹配)
        cout << receiver_labeled_matches.size() << " matches found: ";
        // for (auto i : receiver_matches) {
        //     assert(i.first < receiver_buckets.size());
        //     assert(receiver_buckets[i.first] != BUCKET_EMPTY);
        //     cout << receiver_inputs[receiver_buckets[i.first].first] << "-" << i.second << " ";
        // }
    } else {//非label模式下
        vector<size_t> receiver_matches;
        receiver_matches = user.decrypt_matches(sender_matches);//解密结果(解密,拼接,匹配)
        cout << receiver_matches.size() << " matches found: ";
        // for (auto i : receiver_matches) {
        //     assert(i < receiver_buckets.size());
        //     assert(receiver_buckets[i] != BUCKET_EMPTY);
        //     cout << receiver_inputs[receiver_buckets[i].first] << " ";
        // }
    }

    cout << endl;

    return 0;
}

模块

window

window技术,主要是receiver计算部分次幂\(y^{2^{l * i} * j}\),其中\(l=1\)时,就是计算偶次幂\(y,y^2,y^4,...\)\(l=0\)时,receiver只发送一个\(y\)的加密。

该类主要分为两部分:prepare:receiver计算部分次幂;compute_powers:sender计算全部次幂

window_size;  //分窗大小,即l
max_power;  //最高次幂,即m
window_width; //
window_count; //要计算部分次幂的个数
//计算部分次幂
void Windowing::prepare(vector<uint64_t> &input,
                        vector<Ciphertext> &windows,
                        uint64_t modulus,
                        BatchEncoder &encoder,
                        Encryptor &encryptor)
{
    Plaintext encoded;

    if (window_size == 0) {
        windows.resize(1);
        encoder.encode(input, encoded);
        encryptor.encrypt(encoded, windows[0]);
        return;
    }

    windows.resize(window_width * window_count);

    vector<uint64_t> input_mul;
    for (size_t i = 0; i < window_count; i++) {
        // throughout this loop, we maintain the following invariant
        // (where y denotes the initial input):
        // input = y^{2^{l * i}}
        // input_mul = y^{2^{l * i} * j}
        input_mul = input;
        for (size_t j = 1; j <= window_width; j++) {
            //编码
            encoder.encode(input_mul, encoded);
            //加密部分次幂
            encryptor.encrypt(
                encoded,
                windows[i * window_width + j - 1]
            );

            if (j <= window_width - 1) {
                // multiply input_mul by input for next iteration.
                for (size_t k = 0; k < input.size(); k++) {
                    input_mul[k] = MUL_MOD(input_mul[k], input[k], modulus);
                }
            }
        }
        if (i < window_count - 1) {
            // take input to the 2^l power for next iteration.
            for (size_t k = 0; k < input.size(); k++) {
                input[k] = modexp(input[k], 1ull << window_size, modulus);
            }
        }
    }
}
//计算全部次幂
void Windowing::compute_powers(vector<Ciphertext> &windows,
                               vector<Ciphertext> &powers,
                               Evaluator &evaluator,
                               RelinKeys &relin_keys)
{
    if (window_size == 0) {
        assert(windows.size() == 1);
        powers[1] = windows[0];
        for (size_t i = 2; i < powers.size(); i++) {
            if (i & 2 == 0) {
                evaluator.square(powers[i >> 1], powers[i]);
            } else {
                evaluator.multiply(powers[i - 1], powers[1], powers[i]);
            }
            evaluator.relinearize_inplace(powers[i], relin_keys);
        }
    } else {
        assert(windows.size() == window_width * window_count);
        // the first 2^l - 1 powers are directly copied over
        for (size_t i = 1; i <= window_width; i++) {
            if (i >= powers.size()) {
                return;
            }
            powers[i] = windows[i - 1];
        }

        for (size_t i = 1; i < window_count; i++) {
            for (size_t j = 1; j <= window_width; j++) {
                // now, for each new window i, we go over all of its elements
                // j (encoding y^{2^{l * i} * j}) and compute every power of the
                // form y^{2^{l * i} * j + k}, where k < 2^{l * i} (equivalently,
                // y^k was computed before we started working on window i).
                size_t high_bits = (j << (window_size * i));
                if (high_bits >= powers.size()) {
                    break;
                }
                powers[high_bits] = windows[i * window_width + j - 1];
                for (size_t low_bits = 1; low_bits < (1ull << (window_size * i)); low_bits++) {
                    size_t new_power = high_bits | low_bits;
                    if (new_power >= powers.size()) {
                        // TODO: figure out if there's a smarter way to break here.
                        break;
                    }
                    evaluator.multiply(powers[low_bits], powers[high_bits], powers[new_power]);//密文乘密文
                    evaluator.relinearize_inplace(powers[new_power], relin_keys);//重线性化,降低密文维数
                }
            }
        }
    }
}

network

主要是序列化,读写密文,读写密钥

polynomials

//计算:a^b mod m
uint64_t modexp(uint64_t base, uint64_t exponent, uint64_t modulus) {
    uint64_t result = 1;
    while (exponent > 0) {
        if (exponent & 1) {
            result = MUL_MOD(result, base, modulus);
        }
        base = MUL_MOD(base, base, modulus);
        exponent = (exponent >> 1);
    }
    return result;
}

//计算:a^-1 mod m
uint64_t modinv(uint64_t x, uint64_t modulus) {
    return modexp(x, modulus - 2, modulus);
}

//求交多项式:(x - l[0]) * (x - l[1]) * ...
void polynomial_from_roots(vector<uint64_t> &roots, vector<uint64_t> &coeffs, uint64_t modulus) {
    coeffs.clear();
    coeffs.resize(roots.size() + 1);
    coeffs[0] = 1;

    for (size_t i = 0; i < roots.size(); i++) {
        // multiply coeffs by (x - root)
        uint64_t neg_root = modulus - (roots[i] % modulus);

        for (size_t j = i + 1; j > 0; j--) {
            coeffs[j] = (coeffs[j - 1] + MUL_MOD(neg_root, coeffs[j], modulus)) % modulus;
        }
        coeffs[0] = MUL_MOD(coeffs[0], neg_root, modulus);
    }
}

//多项式插值:f(xs[i]) = ys[i]
void polynomial_from_points(vector<uint64_t> &xs,
                            vector<uint64_t> &ys,
                            vector<uint64_t> &coeffs,
                            uint64_t modulus)
{
    assert(xs.size() == ys.size());
    coeffs.clear();
    coeffs.resize(xs.size());

    if (xs.size() == 0) {
        return;
    }

    // at iteration i of the loop, basis contains the coefficients of the basis
    // polynomial (x - xs[0]) * (x - xs[1]) * ... * (x - xs[i - 1])
    vector<uint64_t> basis(xs.size());
    basis[0] = 1;

    // at iteration i of the loop, ddif[j] contains the divided difference
    // [ys[j], ys[j + 1], ..., ys[j + i]]. thus initially, when i = 0,
    // ddif[j] = [ys[j]] = ys[j]
    vector<uint64_t> ddif = ys;

    for (size_t i = 0; i < xs.size(); i++) {
        for (size_t j = 0; j < i + 1; j++) {
            coeffs[j] = (coeffs[j] + MUL_MOD(ddif[0], basis[j], modulus)) % modulus;
        }

        if (i < xs.size() - 1) {
            // update basis: multiply it by (x - xs[i])
            uint64_t neg_x = modulus - (xs[i] % modulus);

            for (size_t j = i + 1; j > 0; j--) {
                basis[j] = (basis[j - 1] + MUL_MOD(neg_x, basis[j], modulus)) % modulus;
            }
            basis[0] = MUL_MOD(basis[0], neg_x, modulus);

            // update ddif: compute length-(i + 1) divided differences
            for (size_t j = 0; j + i + 1 < xs.size() + 1; j++) {
                // dd_{j,j+i+1} = (dd_{j+1, j+i+1} - dd_{j, j+i}) / (x_{j+i+1} - x_j)
                uint64_t num = (ddif[j + 1] - ddif[j] + modulus) % modulus;
                uint64_t den = (xs[j + i + 1] - xs[j] + modulus) % modulus;
                ddif[j] = MUL_MOD(num, modinv(den, modulus), modulus);
            }
        }
    }
}

hash

//receiver使用的hash,每个桶中最多存放一个item,每次任选一个hash函数
bool cuckoo_hash(shared_ptr<UniformRandomGenerator> random,
	             vector<uint64_t> &inputs,
	             size_t m,
	             vector<bucket_slot> &buckets,
	             vector<uint64_t> &seeds)
{
	buckets.resize(1 << m);
	for (size_t i = 0; i < buckets.size(); i++) {
		buckets[i] = BUCKET_EMPTY;
	}

	//AES中的密钥生成
	vector<AES> aes(seeds.size());
	for (size_t i = 0; i < seeds.size(); i++) {
		aes[i].set_key(0, seeds[i]);
	}

	for (size_t i = 0; i < inputs.size(); i++) {
		bool resolved = false;
		bucket_slot current_item = make_pair(
			i,
			random_integer(random, seeds.size())
		);

		// TODO: keep track of # of operations and abort if exceeding some limit
		while (!resolved) {
			size_t loc = loc_aes_hash(
				aes[current_item.second],
				m,
				inputs[current_item.first]
			);
			//如果存在,则踢出当前值
			buckets[loc].swap(current_item);

			if (current_item == BUCKET_EMPTY) {
				resolved = true;
			} else {
				size_t old_hash = current_item.second;
				while (current_item.second == old_hash) {
					current_item.second = random_integer(random, seeds.size());
				}
			}
		}
	}

	return true;
}
//sender使用的hash,每个桶中最多存放capacity个item,使用所有的hash函数
bool complete_hash(shared_ptr<UniformRandomGenerator> random,
	               vector<uint64_t> &inputs,
                   size_t m,
                   size_t capacity,
                   vector<bucket_slot> &buckets,
                   vector<uint64_t> &seeds)
{
	buckets.resize(capacity << m);
	for (size_t i = 0; i < buckets.size(); i++) {
		buckets[i] = BUCKET_EMPTY;
	}

	vector<AES> aes(seeds.size());
	for (size_t i = 0; i < seeds.size(); i++) {
		aes[i].set_key(0, seeds[i]);
	}

	vector<size_t> capacity_used(1 << m);

	// insert all elements into the table in a deterministic order (filling each
	// bucket sequentially)
	for (size_t i = 0; i < inputs.size(); i++) {
		for (size_t j = 0; j < seeds.size(); j++) {
			size_t loc = loc_aes_hash(aes[j], m, inputs[i]);

			if (capacity_used[loc] == capacity) {
				// all slots in the bucket are used, so we cannot add this
				// element
				return false;
			}

			buckets[capacity * loc + capacity_used[loc]] = make_pair(i, j);
			capacity_used[loc]++;
		}
	}

	// now shuffle each bucket, to avoid leaking information about bucket load
	// distribution through partitioning
	for (size_t bucket = 0; bucket < (1 << m); bucket++) {
		for (size_t slot = 1; slot < capacity; slot++) {
			// uniformly pick a random slot before this one (possibly this
			// very same one) and swap
			size_t prev_slot = random_integer(random, slot + 1);
			buckets[capacity * bucket + slot].swap(buckets[capacity * bucket + prev_slot]);
		}
	}
	return true;
}

aes

//AES密钥生成
void AES::set_key(uint64_t key_high, uint64_t key_low)
{
    round_key[0] = _mm_set_epi64x(key_high, key_low);
    round_key[1] = key_gen_helper(round_key[0], _mm_aeskeygenassist_si128(round_key[0], 0x01));
    round_key[2] = key_gen_helper(round_key[1], _mm_aeskeygenassist_si128(round_key[1], 0x02));
    round_key[3] = key_gen_helper(round_key[2], _mm_aeskeygenassist_si128(round_key[2], 0x04));
    round_key[4] = key_gen_helper(round_key[3], _mm_aeskeygenassist_si128(round_key[3], 0x08));
    round_key[5] = key_gen_helper(round_key[4], _mm_aeskeygenassist_si128(round_key[4], 0x10));
    round_key[6] = key_gen_helper(round_key[5], _mm_aeskeygenassist_si128(round_key[5], 0x20));
    round_key[7] = key_gen_helper(round_key[6], _mm_aeskeygenassist_si128(round_key[6], 0x40));
    round_key[8] = key_gen_helper(round_key[7], _mm_aeskeygenassist_si128(round_key[7], 0x80));
    round_key[9] = key_gen_helper(round_key[8], _mm_aeskeygenassist_si128(round_key[8], 0x1B));
    round_key[10] = key_gen_helper(round_key[9], _mm_aeskeygenassist_si128(round_key[9], 0x36));
}

//sender使用对label加密
pair<uint64_t, uint64_t> AES::encrypt(uint64_t block_high, uint64_t block_low)
{
    __m128i ciphertext = _mm_set_epi64x(block_high, block_low);
    ciphertext = _mm_xor_si128(ciphertext, round_key[0]);
    ciphertext = _mm_aesenc_si128(ciphertext, round_key[1]);
    ciphertext = _mm_aesenc_si128(ciphertext, round_key[2]);
    ciphertext = _mm_aesenc_si128(ciphertext, round_key[3]);
    ciphertext = _mm_aesenc_si128(ciphertext, round_key[4]);
    ciphertext = _mm_aesenc_si128(ciphertext, round_key[5]);
    ciphertext = _mm_aesenc_si128(ciphertext, round_key[6]);
    ciphertext = _mm_aesenc_si128(ciphertext, round_key[7]);
    ciphertext = _mm_aesenc_si128(ciphertext, round_key[8]);
    ciphertext = _mm_aesenc_si128(ciphertext, round_key[9]);
    ciphertext = _mm_aesenclast_si128(ciphertext, round_key[10]);

    uint64_t result[2];
    _mm_storeu_si128((__m128i*) &result[0], ciphertext);

    return make_pair(result[1], result[0]);
}

random

//使用均匀随机发生器,可以将n-bit的输入,hash为32bit的输出
uint64_t random_bits(shared_ptr<UniformRandomGenerator> random, size_t bits) {
    assert((bits > 0) && (bits <= 64));
    uint64_t result;
    if (bits <= 32) {
        // generate 64 bits of randomness
        result = random->generate();
        // reduce that to k bits of randomness;
        result = (result >> (32 - bits));
    } else {
        // generate 64 bits of randomness
        result = random->generate() | ((uint64_t) random->generate() << 32);
        // reduce that to k bits of randomness;
        result = (result >> (64 - bits));
    }
    return result;
}
//取随机数,满足:0 <= x < limit
uint64_t random_integer(shared_ptr<UniformRandomGenerator> random, uint64_t limit) {
    /* here's the trick: suppose 2^k < modulus <= 2^{k+1}. then we draw a random
       number x between 0 and 2^{k+1}. if it's less than modulus, we return it,
       otherwise we draw again (so the probability of success is at least 1/2). */
    assert(limit > 0);
    if (limit == 1) {
        return 0;
    }

    uint64_t k = 0;
    while (limit > (1ULL << k)) {
        k++;
    }

    uint64_t result;
    do {
        result = random_bits(random, k);
    } while (result >= limit);

    return result;
}
//取随机数,满足:0 < x < limit
uint64_t random_nonzero_integer(shared_ptr<UniformRandomGenerator> random, uint64_t limit) {
    assert (limit > 1);

    uint64_t result;
    do {
        result = random_integer(random, limit);
    } while (result == 0);

    return result;
}

测试

无标签

image

带标签

image

posted @ 2022-04-21 11:58  PamShao  阅读(1159)  评论(0编辑  收藏  举报