diff --git a/examples/hesm2/BUILD.bazel b/examples/hesm2/BUILD.bazel new file mode 100644 index 00000000..edaf18be --- /dev/null +++ b/examples/hesm2/BUILD.bazel @@ -0,0 +1,40 @@ +# Copyright 2024 Guowei Ling. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//bazel:yacl.bzl", "yacl_cc_binary") + +package(default_visibility = ["//visibility:public"]) + +yacl_cc_binary( + name = "sm2_example", + srcs = [ + "ahesm2.cc", + "ahesm2.h", + "ciphertext.h", + "config.cc", + "config.h", + "main.cc", + "private_key.h", + "public_key.h", + "t1.h", + "t2.h", + ], + deps = [ + "//yacl/crypto/ecc:spi", + "//yacl/crypto/ecc/openssl", + "//yacl/math/mpint", + "//yacl/utils:cuckoo_index", # 添加 cuckoo_index 依赖 + "//yacl/utils/spi", + ], +) diff --git a/examples/hesm2/README.md b/examples/hesm2/README.md new file mode 100644 index 00000000..17deaad9 --- /dev/null +++ b/examples/hesm2/README.md @@ -0,0 +1,99 @@ +# 加法同态SM2+FastECDLP + +本代码是SM2加法同态加密 ([密码学报 2022](http://www.jcr.cacrnet.org.cn/CN/10.13868/j.cnki.jcr.000532)) 结合FastECDLP([IEEE TIFS 2023](https://ieeexplore.ieee.org/document/10145804))。 + +注:本实现的SM2加法同态加密并非是标准SM2公钥加密算法。标准SM2公钥加密算法并不具备加同态性。 + +## 快速开始 + +首先,进入项目目录并构建示例: + +```bash +cd yacl + +bazel build --linkopt=-ldl //... + +bazel build --linkopt=-ldl //examples/hesm2:sm2_example + +cd bazel-bin/examples/hesm2 + +./sm2_example +``` + +**注:** 第一次使用需要生成预计算表,请等待几分钟。 + +## 示例代码 + +以下是一个简单的使用示例,展示了如何进行参数配置、加密、同态运算及解密操作。 + +```cpp +#include + +#include "examples/hesm2/ahesm2.h" +#include "examples/hesm2/config.h" +#include "examples/hesm2/private_key.h" + +#include "yacl/crypto/ecc/ecc_spi.h" +#include "yacl/math/mpint/mp_int.h" + +using yacl::crypto::EcGroupFactory; +using namespace examples::hesm2; + +int main() { + // 参数配置并读取预计算表 + InitializeConfig(); + + // 生成SM2椭圆曲线群 + auto ec_group = + EcGroupFactory::Instance().Create("sm2", yacl::ArgLib = "openssl"); + if (!ec_group) { + std::cerr << "Failed to create SM2 curve using OpenSSL" << std::endl; + return 1; + } + + // 公私钥对生成 + PrivateKey private_key(std::move(ec_group)); + const auto& public_key = private_key.GetPublicKey(); + + // 指定明文 + auto m1 = yacl::math::MPInt(100); + auto m2 = yacl::math::MPInt(6); + + // 加密 + auto c1 = Encrypt(m1, public_key); + auto c2 = Encrypt(m2, public_key); + + // 标量乘,即密文乘明文 + auto c3 = HMul(c1, m2, public_key); + + // 同态加,即密文加密文 + auto c4 = HAdd(c1, c2, public_key); + + // 单线程解密 + auto res3 = Decrypt(c3, private_key); + + // 并发解密 + auto res4 = ParDecrypt(c4, private_key); + + // 打印结果 + std::cout << res3.m << std::endl; + std::cout << res4.m << std::endl; + + // 打印是否解密正确 + std::cout << res3.success << std::endl; + std::cout << res4.success << std::endl; + + return 0; +} +``` + +## 高阶使用 + +您可以通过修改config.cc中的以下两个参数修改明文空间。 + +```cpp +int Ilen = 12; // l2-1 +int Jlen = 20; // l1-1 +``` + +明文空间的绝对值大小为:(1< + +#include "examples/hesm2/ciphertext.h" +#include "examples/hesm2/config.h" +#include "examples/hesm2/private_key.h" +#include "examples/hesm2/t1.h" +#include "examples/hesm2/t2.h" + +#include "yacl/crypto/ecc/ec_point.h" +#include "yacl/math/mpint/mp_int.h" + +namespace examples::hesm2 { + +Ciphertext Encrypt(const yacl::math::MPInt& message, const PublicKey& pk) { + YACL_ENFORCE(message.Abs() <= yacl::math::MPInt(Mmax)); + const auto& ec_group = pk.GetEcGroup(); + auto generator = ec_group->GetGenerator(); + yacl::math::MPInt r; + yacl::math::MPInt::RandomLtN(ec_group->GetOrder(), &r); + auto c1 = ec_group->MulBase(r); + const auto& pk_point = pk.GetPoint(); + auto mG = ec_group->MulBase(message); + auto rpk = ec_group->Mul(pk_point, r); + auto c2 = ec_group->Add(mG, rpk); + return Ciphertext{c1, c2}; +} + +bool CheckDec(const std::shared_ptr& ecgroup, + const yacl::crypto::EcPoint& m_g, const yacl::math::MPInt& m) { + yacl::crypto::EcPoint checkmG = ecgroup->MulBase(m); + return ecgroup->PointEqual(m_g, checkmG); +} + +DecryptResult Decrypt(const Ciphertext& ciphertext, const PrivateKey& sk) { + const auto& ec_group = sk.GetEcGroup(); + auto c1_sk = ec_group->Mul(ciphertext.GetC1(), sk.GetK()); + const auto& c2 = ciphertext.GetC2(); + if (ec_group->PointEqual(c1_sk, c2)) { + return {yacl::math::MPInt(0), true}; + } + auto mG = ec_group->Sub(c2, c1_sk); + auto affmG = ec_group->GetAffinePoint(mG); + auto affmGx = affmG.x; + const auto value = + t1_loaded.Op_search(affmGx.ToMagBytes(yacl::Endian::native)); + if (value.second) { + yacl::math::MPInt m(value.first); + if (CheckDec(ec_group, mG, m)) { + return {m, true}; + } else { + return {-(m), true}; + } + } + yacl::math::MPInt m; // Declare the variable 'm' + const auto& t2 = t2_loaded.GetVector(); + std::vector Z(Imax); + for (int i = 1; i <= Imax; ++i) { + yacl::math::MPInt difference = t2[i].x - affmGx; + Z[i - 1] = difference; + if (difference.IsZero()) { + m = yacl::math::MPInt(static_cast(L1) * static_cast(i)); + if (CheckDec(ec_group, mG, m)) { + return {m, true}; + } else { + return {-m, true}; + } + } + } + std::vector ZTree(Treelen); + for (int i = 0; i < Imax; i++) { + ZTree[i] = Z[i]; + } + int offset = Imax; + int treelen = Imax * 2 - 3; + yacl::math::MPInt P = ec_group->GetField(); + for (int i = 0; i < treelen; i += 2) { + yacl::math::MPInt product; + yacl::math::MPInt::Mul(ZTree[i], ZTree[i + 1], &product); + + ZTree[offset] = product.Mod(P); + offset = offset + 1; + } + yacl::math::MPInt treeroot_inv; + treeroot_inv.Set(ZTree[Treelen - 2]); + treeroot_inv = treeroot_inv.InvertMod(P); + std::vector ZinvTree(Treelen); + treelen = Imax * 2 - 2; + int prevfloorflag = treelen; + int prevfloornum = 1; + int thisfloorflag = treelen; + int thisfloornum; + int thisindex; + int ztreeindex; + ZinvTree[prevfloorflag] = treeroot_inv; + for (int i = 0; i < Ilen; i++) { + thisfloornum = prevfloornum * 2; + thisfloorflag = prevfloorflag - thisfloornum; + for (int f = 0; f < thisfloornum; f++) { + thisindex = f + thisfloorflag; + ztreeindex = thisindex ^ 1; + yacl::math::MPInt product; + yacl::math::MPInt::Mul(ZTree[ztreeindex], + ZinvTree[prevfloorflag + (f / 2)], &product); + ZinvTree[thisindex] = product.Mod(P); + } + prevfloorflag = thisfloorflag; + prevfloornum = prevfloornum * 2; + } + auto affmGy = affmG.y; + for (int j = 1; j <= Imax; j++) { + yacl::math::MPInt Qx; + yacl::math::MPInt Qxinv; + yacl::math::MPInt k; + yacl::math::MPInt::Add(affmGx, t2[j].x, &k); + k = k.Mod(P); + yacl::math::MPInt::Sub(t2[j].y, affmGy, &Qx); + Qx = Qx.MulMod(ZinvTree[j - 1], P); + Qx = Qx.MulMod(Qx, P); + Qx = Qx.SubMod(k, P); + const auto value = t1_loaded.Op_search(Qx.ToMagBytes(yacl::Endian::native)); + if (value.second) { + m = yacl::math::MPInt(static_cast(L1) * static_cast(j)); + yacl::math::MPInt m1; + yacl::math::MPInt m2; + auto jint = yacl::math::MPInt(value.first); + yacl::math::MPInt::Add(m, jint, &m1); + yacl::math::MPInt::Sub(m, jint, &m2); + if (CheckDec(ec_group, mG, m1)) { + return {m1, true}; + } else { + return {m2, true}; + } + } + yacl::math::MPInt::Sub(-t2[j].y, affmGy, &Qxinv); + Qxinv = Qxinv.MulMod(ZinvTree[j - 1], P); + Qxinv = Qxinv.MulMod(Qxinv, P); + Qxinv = Qxinv.SubMod(k, P); + const auto invvalue = + t1_loaded.Op_search(Qxinv.ToMagBytes(yacl::Endian::native)); + if (invvalue.second) { + m = yacl::math::MPInt(static_cast(-L1) * + static_cast(j)); + yacl::math::MPInt m1; + yacl::math::MPInt m2; + auto jint = yacl::math::MPInt(invvalue.first); + yacl::math::MPInt::Add(m, jint, &m1); + yacl::math::MPInt::Sub(m, jint, &m2); + if (CheckDec(ec_group, mG, m1)) { + return {m1, true}; + } else { + return {m2, true}; + } + } + } + SPDLOG_INFO("Decrypt failed. |m| should be <= {}", Mmax); + return {yacl::math::MPInt(0), false}; +} + +DecryptResult search(int start, int end, const yacl::math::MPInt& affm_gx, + const yacl::math::MPInt& affm_gy, + const std::vector& zinv_tree, + const yacl::math::MPInt& p, + const yacl::crypto::EcPoint& m_g, + const std::shared_ptr& ec_group, + std::atomic& found, std::mutex& mtx) { + const auto& t2 = t2_loaded.GetVector(); + for (int j = start; j < end && !found.load(); j++) { + yacl::math::MPInt Qx; + yacl::math::MPInt Qxinv; + yacl::math::MPInt k; + yacl::math::MPInt::Add(affm_gx, t2[j].x, &k); + k = k.Mod(p); + yacl::math::MPInt::Sub(t2[j].y, affm_gy, &Qx); + Qx = Qx.MulMod(zinv_tree[j - 1], p); + Qx = Qx.MulMod(Qx, p); + Qx = Qx.SubMod(k, p); + const auto value = t1_loaded.Op_search(Qx.ToMagBytes(yacl::Endian::native)); + if (value.second) { + yacl::math::MPInt m = + yacl::math::MPInt(static_cast(L1) * static_cast(j)); + yacl::math::MPInt m1; + yacl::math::MPInt m2; + auto jint = yacl::math::MPInt(value.first); + yacl::math::MPInt::Add(m, jint, &m1); + yacl::math::MPInt::Sub(m, jint, &m2); + if (CheckDec(ec_group, m_g, m1)) { + std::lock_guard lock(mtx); + found.store(true); + return {m1, true}; + } else { + std::lock_guard lock(mtx); + found.store(true); + return {m2, true}; + } + } + yacl::math::MPInt::Sub(-t2[j].y, affm_gy, &Qxinv); + Qxinv = Qxinv.MulMod(zinv_tree[j - 1], p); + Qxinv = Qxinv.MulMod(Qxinv, p); + Qxinv = Qxinv.SubMod(k, p); + const auto invvalue = + t1_loaded.Op_search(Qxinv.ToMagBytes(yacl::Endian::native)); + if (invvalue.second) { + yacl::math::MPInt m = yacl::math::MPInt(static_cast(-L1) * + static_cast(j)); + yacl::math::MPInt m1; + yacl::math::MPInt m2; + auto jint = yacl::math::MPInt(invvalue.first); + yacl::math::MPInt::Add(m, jint, &m1); + yacl::math::MPInt::Sub(m, jint, &m2); + if (CheckDec(ec_group, m_g, m1)) { + std::lock_guard lock(mtx); + found.store(true); + return {m1, true}; + } else { + std::lock_guard lock(mtx); + found.store(true); + return {m2, true}; + } + } + } + return {yacl::math::MPInt(), false}; // 返回一个无效的结果 +} + +DecryptResult ParDecrypt(const Ciphertext& ciphertext, const PrivateKey& sk) { + const auto& ec_group = sk.GetEcGroup(); + auto c1_sk = ec_group->Mul(ciphertext.GetC1(), sk.GetK()); + const auto& c2 = ciphertext.GetC2(); + if (ec_group->PointEqual(c1_sk, c2)) { + return {yacl::math::MPInt(0), true}; + } + auto mG = ec_group->Sub(c2, c1_sk); + auto affmG = ec_group->GetAffinePoint(mG); + auto affmGx = affmG.x; + yacl::math::MPInt m; + const auto value = + t1_loaded.Op_search(affmGx.ToMagBytes(yacl::Endian::native)); + if (value.second) { + m = yacl::math::MPInt(value.first); + if (CheckDec(ec_group, mG, m)) { + return {m, true}; + } else { + return {-(m), true}; + } + } + + const auto& t2 = t2_loaded.GetVector(); + + std::vector Z(Imax); + for (int j = 1; j <= Imax; ++j) { + yacl::math::MPInt difference = t2[j].x - affmGx; + Z[j - 1] = difference; + if (difference.IsZero()) { + m = yacl::math::MPInt(static_cast(L1) * static_cast(j)); + if (CheckDec(ec_group, mG, m)) { + return {m, true}; + } else { + return {-m, true}; + } + } + } + std::vector ZTree(Treelen); + for (int i = 0; i < Imax; i++) { + ZTree[i] = Z[i]; + } + int offset = Imax; + int treelen = Imax * 2 - 3; + yacl::math::MPInt P = ec_group->GetField(); + for (int i = 0; i < treelen; i += 2) { + yacl::math::MPInt product; + yacl::math::MPInt::Mul(ZTree[i], ZTree[i + 1], &product); + ZTree[offset] = product.Mod(P); + offset = offset + 1; + } + yacl::math::MPInt treeroot_inv; + treeroot_inv.Set(ZTree[Treelen - 2]); + treeroot_inv = treeroot_inv.InvertMod(P); + std::vector ZinvTree(Treelen); + treelen = Imax * 2 - 2; + int prevfloorflag = treelen; + int prevfloornum = 1; + int thisfloorflag = treelen; + int thisfloornum; + ZinvTree[prevfloorflag] = treeroot_inv; + for (int i = 0; i < Ilen; i++) { + thisfloornum = prevfloornum * 2; + thisfloorflag = prevfloorflag - thisfloornum; + yacl::parallel_for(0, thisfloornum, 1, [&](int64_t start, int64_t end) { + for (int f = start; f < end; f++) { + int thisindex = f + thisfloorflag; + int ztreeindex = thisindex ^ 1; + yacl::math::MPInt product; + yacl::math::MPInt::Mul(ZTree[ztreeindex], + ZinvTree[prevfloorflag + (f / 2)], &product); + ZinvTree[thisindex] = product.Mod(P); + } + }); + prevfloorflag = thisfloorflag; + prevfloornum = prevfloornum * 2; + } + auto affmGy = affmG.y; + const int num_threads = std::thread::hardware_concurrency(); + const int chunk_size = Imax / num_threads; + std::vector threads; + std::vector results(num_threads); + + std::atomic found(false); + std::mutex mtx; + std::atomic result_found(false); + DecryptResult final_result; + + for (int i = 0; i < num_threads; ++i) { + int start = i * chunk_size + 1; + int end = (i == num_threads - 1) ? (Imax + 1) : start + chunk_size; + threads.emplace_back([&, start, end]() { + DecryptResult result = search(start, end, affmGx, affmGy, ZinvTree, P, mG, + ec_group, found, mtx); + if (result.success && !result_found.exchange(true)) { + final_result = result; + found.store(true); + } + }); + } + + for (auto& thread : threads) { + thread.join(); + } + if (result_found) { + return final_result; + } else { + SPDLOG_INFO("Decrypt failed. |m| should be <= {}", Mmax); + return DecryptResult{yacl::math::MPInt(0), false}; + } +} + +Ciphertext HAdd(const Ciphertext& ciphertext1, const Ciphertext& ciphertext2, + const PublicKey& pk) { + const auto& ec_group = pk.GetEcGroup(); + auto c1 = ec_group->Add(ciphertext1.GetC1(), ciphertext2.GetC1()); + auto c2 = ec_group->Add(ciphertext1.GetC2(), ciphertext2.GetC2()); + return Ciphertext{c1, c2}; +} + +Ciphertext HSub(const Ciphertext& ciphertext1, const Ciphertext& ciphertext2, + const PublicKey& pk) { + const auto& ec_group = pk.GetEcGroup(); + auto c1 = ec_group->Sub(ciphertext1.GetC1(), ciphertext2.GetC1()); + auto c2 = ec_group->Sub(ciphertext1.GetC2(), ciphertext2.GetC2()); + return Ciphertext{c1, c2}; +} + +Ciphertext HMul(const Ciphertext& ciphertext1, const yacl::math::MPInt& scalar, + const PublicKey& pk) { + const auto& ec_group = pk.GetEcGroup(); + auto c1 = ec_group->Mul(ciphertext1.GetC1(), scalar); + auto c2 = ec_group->Mul(ciphertext1.GetC2(), scalar); + return Ciphertext{c1, c2}; +} +} // namespace examples::hesm2 \ No newline at end of file diff --git a/examples/hesm2/ahesm2.h b/examples/hesm2/ahesm2.h new file mode 100644 index 00000000..528f0fc2 --- /dev/null +++ b/examples/hesm2/ahesm2.h @@ -0,0 +1,42 @@ +// Copyright 2024 Guowei Ling. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "examples/hesm2/ciphertext.h" +#include "examples/hesm2/private_key.h" + +namespace examples::hesm2 { + +struct DecryptResult { + yacl::math::MPInt m; + bool success; +}; + +Ciphertext Encrypt(const yacl::math::MPInt& message, const PublicKey& pk); + +DecryptResult Decrypt(const Ciphertext& ciphertext, const PrivateKey& sk); + +DecryptResult ParDecrypt(const Ciphertext& ciphertext, const PrivateKey& sk); + +Ciphertext HAdd(const Ciphertext& ciphertext1, const Ciphertext& ciphertext2, + const PublicKey& pk); + +Ciphertext HSub(const Ciphertext& ciphertext1, const Ciphertext& ciphertext2, + const PublicKey& pk); + +Ciphertext HMul(const Ciphertext& ciphertext1, const yacl::math::MPInt& scalar, + const PublicKey& pk); + +} // namespace examples::hesm2 \ No newline at end of file diff --git a/examples/hesm2/ciphertext.h b/examples/hesm2/ciphertext.h new file mode 100644 index 00000000..903f6dbb --- /dev/null +++ b/examples/hesm2/ciphertext.h @@ -0,0 +1,35 @@ +// Copyright 2024 Guowei Ling. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "yacl/crypto/ecc/ec_point.h" + +namespace examples::hesm2 { + +class Ciphertext { + public: + Ciphertext(yacl::crypto::EcPoint c1, yacl::crypto::EcPoint c2) + : c1_(std::move(c1)), c2_(std::move(c2)) {} + + const yacl::crypto::EcPoint& GetC1() const { return c1_; } + const yacl::crypto::EcPoint& GetC2() const { return c2_; } + + private: + yacl::crypto::EcPoint c1_; + yacl::crypto::EcPoint c2_; +}; +} // namespace examples::hesm2 \ No newline at end of file diff --git a/examples/hesm2/config.cc b/examples/hesm2/config.cc new file mode 100644 index 00000000..33b29af9 --- /dev/null +++ b/examples/hesm2/config.cc @@ -0,0 +1,74 @@ +// Copyright 2024 Guowei Ling. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "examples/hesm2/config.h" + +#include "examples/hesm2/t1.h" +#include "examples/hesm2/t2.h" + +namespace examples::hesm2 { + +uint32_t GetSubBytesAsUint32(const yacl::Buffer& bytes, size_t start, + size_t end) { + uint32_t result = 0; + for (size_t i = start; i < end; ++i) { + result = (result << 8) | bytes.data()[i]; + } + return result; +} + +CuckooT1 t1_loaded(Jmax); +T2 t2_loaded(nullptr, false); + +void InitializeConfig() { + auto ec_group = yacl::crypto::EcGroupFactory::Instance().Create( + "sm2", yacl::ArgLib = "openssl"); + + // 检查是否成功创建 + if (!ec_group) { + std::cerr << "Failed to create SM2 curve using OpenSSL" << std::endl; + return; + } + // 检查文件是否存在,如果存在则从文件加载 + std::string filet1 = "cuckoo_t1.dat"; + std::ifstream ifs(filet1); + if (ifs.good()) { + t1_loaded.Deserialize(filet1); + SPDLOG_INFO("t1_loaded from file: {}", filet1); + } else { + SPDLOG_INFO("t1_loaded generated and serialized to file:{} ", filet1); + SPDLOG_INFO( + "The process might be slow; you may need to wait a few minutes..."); + t1_loaded.InitializeEcGroup(std::move(ec_group)); + t1_loaded.Initialize(); + t1_loaded.Serialize(filet1); + } + + auto ec_group_t2 = yacl::crypto::EcGroupFactory::Instance().Create( + "sm2", yacl::ArgLib = "openssl"); + std::string filet2 = "t2.dat"; + std::ifstream ifst2(filet2); + if (ifst2.good()) { + t2_loaded.Deserialize(filet2); + SPDLOG_INFO("t2_loaded from file: {}", filet2); + } else { + SPDLOG_INFO("t2_loaded generated and serialized to file:{} ", filet2); + t2_loaded.InitializeEcGroup(std::move(ec_group_t2)); + t2_loaded.InitializeVector(); + t2_loaded.Serialize(filet2); + t2_loaded.Deserialize(filet2); + } +} + +} // namespace examples::hesm2 \ No newline at end of file diff --git a/examples/hesm2/config.h b/examples/hesm2/config.h new file mode 100644 index 00000000..f296e504 --- /dev/null +++ b/examples/hesm2/config.h @@ -0,0 +1,39 @@ +// Copyright 2024 Guowei Ling. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "yacl/base/buffer.h" + +namespace examples::hesm2 { + +void InitializeConfig(); + +uint32_t GetSubBytesAsUint32(const yacl::Buffer& bytes, size_t start, + size_t end); + +constexpr int Ilen = 12; // l2-1 +constexpr int Jlen = 20; // l1-1 +constexpr int Imax = 1 << Ilen; // 1<< Ilen +constexpr int Jmax = 1 << Jlen; // 1<(Jmax * 1.3); +constexpr uint64_t Mmax = + static_cast(Imax) * static_cast(L1) + Jmax; + +} // namespace examples::hesm2 \ No newline at end of file diff --git a/examples/hesm2/main.cc b/examples/hesm2/main.cc new file mode 100644 index 00000000..3aed5de7 --- /dev/null +++ b/examples/hesm2/main.cc @@ -0,0 +1,72 @@ +// Copyright 2024 Guowei Ling. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "examples/hesm2/ahesm2.h" +#include "examples/hesm2/config.h" +#include "examples/hesm2/private_key.h" + +#include "yacl/crypto/ecc/ecc_spi.h" +#include "yacl/math/mpint/mp_int.h" + +using yacl::crypto::EcGroupFactory; +using namespace examples::hesm2; + +int main() { + // 参数配置并读取预计算表 + InitializeConfig(); + + // 生成SM2椭圆曲线群 + auto ec_group = + EcGroupFactory::Instance().Create("sm2", yacl::ArgLib = "openssl"); + if (!ec_group) { + std::cerr << "Failed to create SM2 curve using OpenSSL" << std::endl; + return 1; + } + + // 公私钥对生成 + PrivateKey private_key(std::move(ec_group)); + const auto& public_key = private_key.GetPublicKey(); + + // 指定明文 + auto m1 = yacl::math::MPInt(100); + auto m2 = yacl::math::MPInt(6); + + // 加密 + auto c1 = Encrypt(m1, public_key); + auto c2 = Encrypt(m2, public_key); + + // 标量乘,即密文乘明文 + auto c3 = HMul(c1, m2, public_key); + + // 同态加,即密文加密文 + auto c4 = HAdd(c1, c2, public_key); + + // 单线程解密 + auto res3 = Decrypt(c3, private_key); + + // 并发解密 + auto res4 = ParDecrypt(c4, private_key); + + // 打印结果 + std::cout << res3.m << std::endl; + std::cout << res4.m << std::endl; + + // 打印是否解密正确 + std::cout << res3.success << std::endl; + std::cout << res4.success << std::endl; + + return 0; +} diff --git a/examples/hesm2/private_key.h b/examples/hesm2/private_key.h new file mode 100644 index 00000000..fce775ad --- /dev/null +++ b/examples/hesm2/private_key.h @@ -0,0 +1,54 @@ +// Copyright 2024 Guowei Ling. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "examples/hesm2/public_key.h" + +#include "yacl/crypto/ecc/ecc_spi.h" +#include "yacl/math/mpint/mp_int.h" + +namespace examples::hesm2 { + +class PrivateKey { + public: + explicit PrivateKey(std::shared_ptr ec_group) + : ec_group_(ec_group), public_key_(ec_group_->GetGenerator(), ec_group_) { + Initialize(); + } + + const yacl::math::MPInt& GetK() const { return k_; } + const PublicKey& GetPublicKey() const { return public_key_; } + std::shared_ptr GetEcGroup() const { + return ec_group_; + } + + private: + void Initialize() { + yacl::math::MPInt k; + yacl::math::MPInt::RandomLtN(ec_group_->GetOrder(), &k); + public_key_ = GeneratePublicKey(); + } + + PublicKey GeneratePublicKey() const { + auto generator = ec_group_->GetGenerator(); + auto point = ec_group_->Mul(generator, k_); + return {point, ec_group_}; + } + + std::shared_ptr ec_group_; + yacl::math::MPInt k_; + PublicKey public_key_; +}; +} // namespace examples::hesm2 \ No newline at end of file diff --git a/examples/hesm2/public_key.h b/examples/hesm2/public_key.h new file mode 100644 index 00000000..1b651c43 --- /dev/null +++ b/examples/hesm2/public_key.h @@ -0,0 +1,39 @@ +// Copyright 2024 Guowei Ling. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "yacl/crypto/ecc/ecc_spi.h" + +namespace examples::hesm2 { + +class PublicKey { + public: + PublicKey(yacl::crypto::EcPoint point, + std::shared_ptr ec_group) + : point_(point), ec_group_(std::move(ec_group)) {} + + const yacl::crypto::EcPoint& GetPoint() const { return point_; } + std::shared_ptr GetEcGroup() const { + return ec_group_; + } + + private: + yacl::crypto::EcPoint point_; + std::shared_ptr ec_group_; +}; + +} // namespace examples::hesm2 \ No newline at end of file diff --git a/examples/hesm2/t1.h b/examples/hesm2/t1.h new file mode 100644 index 00000000..495839c6 --- /dev/null +++ b/examples/hesm2/t1.h @@ -0,0 +1,160 @@ +// Copyright 2024 Guowei Ling. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "examples/hesm2/config.h" + +#include "yacl/crypto/ecc/ecc_spi.h" +#include "yacl/math/mpint/mp_int.h" +#include "yacl/utils/parallel.h" + +namespace examples::hesm2 { + +class CuckooT1 { + public: + explicit CuckooT1(int jmax) + : jmax_(jmax), cuckoolen_(static_cast(jmax * 1.3)) { + if (jmax_ <= 0) { + throw std::invalid_argument("jmax must be positive"); + } + table_v_.resize(cuckoolen_, 0); // 初始化值为0 + table_k_.resize(cuckoolen_, 0); // 初始化值为0 + } + + void Initialize() { + std::vector XS(jmax_); + constexpr int64_t batch_size = 1 << 10; // 可以根据需要调整批处理大小 + if (!ec_group_) { + throw std::runtime_error("EcGroup not initialized"); + } + yacl::parallel_for(1, Jmax + 1, batch_size, [&](int64_t beg, int64_t end) { + for (int64_t i = beg; i < end; ++i) { + yacl::math::MPInt value(i); + auto point = ec_group_->MulBase(value); + // 获取横坐标作为键 + auto affine_point = ec_group_->GetAffinePoint(point); + auto key = affine_point.x.ToMagBytes(yacl::Endian::native); + XS[i - 1] = key; + } + }); + Insert(XS); + } + + void Insert(std::vector data) { + std::vector hash_index_; + hash_index_.resize(cuckoolen_, 0); + for (int i = 0; i < Jmax; ++i) { + int v = i + 1; + uint8_t old_hash_id = 1; + int j = 0; + for (; j < maxiter_; ++j) { + const auto& X = data[v - 1]; + size_t start = (old_hash_id - 1) * 8; + size_t end = start + 4; + uint32_t x = GetSubBytesAsUint32(X, end, end + 4); + uint32_t x_key = x; + uint32_t h = GetSubBytesAsUint32(X, start, end) % cuckoolen_; + uint8_t* hash_id_address = &hash_index_[h]; + int* key_index_address = &table_v_[h]; + uint32_t* key_address = &table_k_[h]; + + if (*hash_id_address == empty_) { + *hash_id_address = old_hash_id; + *key_index_address = v; + *key_address = x_key; + break; + } else { + std::swap(v, *key_index_address); + std::swap(old_hash_id, *hash_id_address); + std::swap(x_key, *key_address); + old_hash_id = old_hash_id % 3 + 1; + } + } + if (j == maxiter_) { + SPDLOG_INFO("insert failed, ", i); + throw std::runtime_error("insert failed, " + std::to_string(i)); + } + } + } + + std::pair Op_search(const yacl::Buffer& xbytes) const { + for (int i = 0; i < 3; ++i) { + size_t start = i * 8; + size_t end = start + 4; + uint32_t x = GetSubBytesAsUint32(xbytes, end, end + 4); + uint32_t x_key = x; + uint32_t h = GetSubBytesAsUint32(xbytes, start, end) % cuckoolen_; + if (table_k_[h] == x_key) { + return {table_v_[h], true}; + } + } + return {0, false}; + } + + void Serialize(const std::string& filename) const { + std::ofstream ofs(filename, std::ios::binary); + if (!ofs) { + throw std::runtime_error("Failed to open file for writing: " + filename); + } + + ofs.write(reinterpret_cast(&jmax_), sizeof(jmax_)); + ofs.write(reinterpret_cast(&cuckoolen_), sizeof(cuckoolen_)); + ofs.write(reinterpret_cast(table_v_.data()), + table_v_.size() * sizeof(uint32_t)); + ofs.write(reinterpret_cast(table_k_.data()), + table_k_.size() * sizeof(uint32_t)); + } + + void Deserialize(const std::string& filename) { + std::ifstream ifs(filename, std::ios::binary); + if (!ifs) { + throw std::runtime_error("Failed to open file for reading: " + filename); + } + + ifs.read(reinterpret_cast(&jmax_), sizeof(jmax_)); + ifs.read(reinterpret_cast(&cuckoolen_), sizeof(cuckoolen_)); + table_v_.resize(cuckoolen_); + table_k_.resize(cuckoolen_); + ifs.read(reinterpret_cast(table_v_.data()), + table_v_.size() * sizeof(uint32_t)); + ifs.read(reinterpret_cast(table_k_.data()), + table_k_.size() * sizeof(uint32_t)); + } + + void InitializeEcGroup(std::shared_ptr ec_group) { + ec_group_ = std::move(ec_group); + } + + private: + int jmax_; + uint32_t cuckoolen_; + std::shared_ptr ec_group_; + std::vector table_v_; + std::vector table_k_; + const uint8_t empty_ = 0; + const int maxiter_ = 500; + mutable std::shared_mutex mutex_; +}; + +extern CuckooT1 t1_loaded; + +} // namespace examples::hesm2 \ No newline at end of file diff --git a/examples/hesm2/t2.h b/examples/hesm2/t2.h new file mode 100644 index 00000000..d75144ef --- /dev/null +++ b/examples/hesm2/t2.h @@ -0,0 +1,118 @@ +// Copyright 2024 Guowei Ling. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "examples/hesm2/config.h" + +#include "yacl/crypto/ecc/ecc_spi.h" +#include "yacl/math/mpint/mp_int.h" + +namespace examples::hesm2 { + +class T2 { + public: + explicit T2(std::shared_ptr ec_group, + bool initialize = true) + : ec_group_(std::move(ec_group)) { + if (initialize) { + InitializeVector(); + } + } + const yacl::crypto::AffinePoint& GetValue(size_t index) const { + return vec_.at(index); + } + const std::vector& GetVector() const { + return vec_; + } + void Serialize(const std::string& filename) const { + std::shared_lock lock(mutex_); + std::ofstream ofs(filename, std::ios::binary); + if (!ofs) { + throw std::runtime_error("Failed to open file for writing: " + filename); + } + size_t vec_size = vec_.size(); + ofs.write(reinterpret_cast(&vec_size), sizeof(vec_size)); + for (const auto& point : vec_) { + auto x_bytes = point.x.ToMagBytes(yacl::Endian::native); + auto y_bytes = point.y.ToMagBytes(yacl::Endian::native); + size_t x_size = x_bytes.size(); + size_t y_size = y_bytes.size(); + ofs.write(reinterpret_cast(&x_size), sizeof(x_size)); + ofs.write(reinterpret_cast(x_bytes.data()), x_size); + ofs.write(reinterpret_cast(&y_size), sizeof(y_size)); + ofs.write(reinterpret_cast(y_bytes.data()), y_size); + } + } + void Deserialize(const std::string& filename) { + std::unique_lock lock(mutex_); + std::ifstream ifs(filename, std::ios::binary); + if (!ifs) { + throw std::runtime_error("Failed to open file for reading: " + filename); + } + size_t vec_size; + ifs.read(reinterpret_cast(&vec_size), sizeof(vec_size)); + vec_.resize(vec_size); + for (size_t i = 0; i < vec_size; ++i) { + size_t x_size; + size_t y_size; + ifs.read(reinterpret_cast(&x_size), sizeof(x_size)); + yacl::Buffer x_bytes(x_size); + ifs.read(reinterpret_cast(x_bytes.data()), x_size); + yacl::math::MPInt x; + x.FromMagBytes(x_bytes, yacl::Endian::native); + + ifs.read(reinterpret_cast(&y_size), sizeof(y_size)); + yacl::Buffer y_bytes(y_size); + ifs.read(reinterpret_cast(y_bytes.data()), y_size); + yacl::math::MPInt y; + y.FromMagBytes(y_bytes, yacl::Endian::native); + + vec_[i] = yacl::crypto::AffinePoint{x, y}; + } + } + + void InitializeVector() { + vec_.resize(Imax + 1); + auto G = ec_group_->GetGenerator(); + yacl::math::MPInt Jmax_val(Jmax); + yacl::math::MPInt two(2); + yacl::math::MPInt factor = Jmax_val * two; // Correcting the multiplication + auto T2basepoint = ec_group_->MulBase(factor); + for (int i = 0; i <= Imax; ++i) { + yacl::math::MPInt value(-i); + auto point = ec_group_->Mul(T2basepoint, value); + vec_[i] = ec_group_->GetAffinePoint(point); + } + } + + void InitializeEcGroup(std::shared_ptr ec_group) { + ec_group_ = std::move(ec_group); + } + + private: + std::shared_ptr ec_group_; + std::vector vec_; + mutable std::shared_mutex mutex_; +}; + +extern T2 t2_loaded; + +} // namespace examples::hesm2 \ No newline at end of file