【密码学】椭圆曲线代码实现

admin 2022年12月22日11:52:22评论7 views字数 22807阅读76分1秒阅读模式

椭圆曲线代码实现

【密码学】椭圆曲线代码实现
椭圆曲线

发现,自己好像好久没有写过密码学相关代码了,之前介绍过有限域上的椭圆曲线,那么本篇文章,我再来水一篇,搞一个椭圆曲线的相关算法的简单实现吧。

背景知识

这里我们先来简单回顾下,椭圆曲线相关运算的知识,这里只是做一个简单的回顾,有关于详细的知识,可以看下我之前写的文章。

我们还是来考虑上的椭圆曲线,因为的有限域上的运算不好写(对就是因为这个原因),我们给出标准型。

这里有2个参数,a和b,对于p来说,这个不应该属于椭圆曲线的"参数",仅代表个人看法,这个限制的是椭圆曲线在什么上面,毕竟我们之前讨论实数域上的椭圆曲线的时候,也没将这个实数也作为参数。

然后我们给出加法运算的规则。

其中:

好了,有了上面的基础知识,我们就可以来编写代码了,一般的,我们会取椭圆曲线上面的一个循环子群,通过一个生成元g来生成所需的所有的点。

编码实现

这里,我还是用rust来写吧,回归老本行,好久不写rust,差点不会写了。

依赖库

这里,我们需要大数运算相关的库,因为通过普通的i32或者i64实现出来作用不大,所有给出的椭圆曲线都是超过这个数的,然后随机数的库,因为后面作为测试的算法ECDH要用到,顺道也就加进来了。

[dependencies]
rand = "0.8.5"
num-bigint = { version = "0.4.3", features = ["rand"] }
num-traits = "0.2.15"

基本结构

#[derive(Debug, Clone)]
pub enum Point {
    Inf { curve: Curve },
    Pair { x: BigInt, y: BigInt, curve: Curve },
}

#[derive(Debug, Clone, PartialEq)]
pub struct Curve {
    a: BigInt,
    b: BigInt,
    p: BigInt,
    field: SubGroup,
}

#[derive(Debug, Clone, PartialEq)]
pub struct SubGroup {
    p: BigInt,
    g: (BigInt, BigInt),
    n: BigInt,
    h: BigInt,
}

#[derive(Debug, Clone)]
pub struct KeyPair {
    priv_key: BigInt,
    pub_key: Point,
    curve: Curve,
}

#[derive(Debug, Clone)]
pub struct ECDH {
    pub key_pair: KeyPair,
}

这里,核心能用到的其实就前面的那三个,后面的那俩其实是做ECDH用到的。

辅助代码

因为,这里需要用到有限域上的除法操作,所以搞一个求逆元的函数,简单的写个egcd算法,这个算法在我之前的文章当中也说过。

// egcd
fn egcd(a: &BigInt, b: &BigInt) -> (BigInt, BigInt, BigInt) {
    return if a == &BigInt::zero() {
        (b.clone(), BigInt::zero(), BigInt::one())
    } else {
        let (g, y, x) = egcd(&(b % a), a);
        (g, x - (b / a) * y.clone(), y)
    };
}

// mod inv
fn mod_inv(a: &BigInt, p: &BigInt) -> BigInt {
    if a < &BigInt::zero() {
        return mod_inv(&(a % p + p), p);
    }
    let (g, x, _) = egcd(a, p);
    if g == BigInt::one() {
        return (x % p + p) % p;
    }
    panic!("modular inverse does not exist");
}

具体实现

先来看一下椭圆曲线的具体实现,这里给出了一个判定构成群的条件,也就是判别式不等于0。

impl Curve {
    pub fn new(a: BigInt, b: BigInt, field: SubGroup) -> Curve {
        Curve { a, b, p: field.p.clone(), field }
    }

    pub fn is_singular(&self) -> bool {
        let lhs = self.a.clone() * self.a.clone() * self.a.clone() * BigInt::from(4u8);
        let rhs = BigInt::from(27u8) * self.b.clone() * self.b.clone();
        lhs.abs() == rhs.abs()
    }

    pub fn is_on_curve(&self, p: &Point) -> bool {
        match p {
            Point::Inf { .. } => true,
            Point::Pair { x, y, curve: _ } => {
                let lhs = y.clone() * y.clone();
                let rhs = x.clone() * x.clone() * x.clone() + self.a.clone() * x.clone() + self.b.clone();
                lhs % self.p.clone() == rhs % self.p.clone()
            }
        }
    }

    #[allow(non_snake_case)]
    pub fn brainpoolP160r1() -> Curve {
        let field = SubGroup::new(
            BigInt::from_str_radix(&"E95E4A5F737059DC60DFC7AD95B3D8139515620F"16).expect(""),
            (
                BigInt::from_str_radix(&"BED5AF16EA3F6A4F62938C4631EB5AF7BDBCDBC3"16).expect(""),
                BigInt::from_str_radix(&"1667CB477A1A8EC338F94741669C976316DA6321"16).expect(""),
            ),
            BigInt::from_str_radix(&"E95E4A5F737059DC60DF5991D45029409E60FC09"16).expect(""),
            BigInt::from(1u32),
        );
        Curve::new(
            BigInt::from_str_radix(&"340E7BE2A280EB74E2BE61BADA745D97E8F7C300"16).expect(""),
            BigInt::from_str_radix(&"1E589A8595423412134FAA2DBDEC95C8D8675E58"16).expect(""),
            field,
        )
    }
}

impl Display for Curve {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        write!(f, "y^2 = x^3 + {}x + {} over F_{}"self.a, self.b, self.p)
    }
}

impl SubGroup {
    fn new(p: BigInt, g: (BigInt, BigInt), n: BigInt, h: BigInt) -> SubGroup {
        SubGroup { p, g, n, h }
    }
}

impl Display for SubGroup {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        write!(f, "SubGroup of {}"self.p)
    }
}

然后给出点的运算的实现

impl Point {
    pub fn new(x: BigInt, y: BigInt, curve: Curve) -> Point {
        Point::Pair { x, y, curve }
    }

    pub fn inf(curve: Curve) -> Point {
        Point::Inf { curve }
    }

    pub fn is_inf(&self) -> bool {
        match self {
            Point::Inf { .. } => true,
            _ => false,
        }
    }

    pub fn get_x(&self) -> Option<BigInt> {
        match self {
            Point::Pair { x, .. } => Some(x.clone()),
            _ => None,
        }
    }

    pub fn get_y(&self) -> Option<BigInt> {
        match self {
            Point::Pair { y, .. } => Some(y.clone()),
            _ => None,
        }
    }

    pub fn is_on_curve(&self) -> bool {
        match self {
            Point::Inf { .. } => true,
            Point::Pair { x, y, curve } => {
                let lhs = y.clone() * y.clone();
                let rhs = x.clone() * x.clone() * x.clone() + curve.a.clone() * x.clone() + curve.b.clone();
                lhs % curve.p.clone() == rhs % curve.p.clone()
            }
        }
    }

    fn curve(&self) -> Curve {
        match self {
            Point::Inf { curve } => curve.clone(),
            Point::Pair { x: _, y: _, curve } => curve.clone(),
        }
    }

    fn lambda(&self, other: &Point) -> BigInt {
        match (self, other) {
            (Point::Inf { .. }, _) => panic!("lambda is not defined for Inf"),
            (_, Point::Inf { .. }) => panic!("lambda is not defined for Inf"),
            (Point::Pair { x: x1, y: y1, curve: curve1 }, Point::Pair { x: x2, y: y2, curve: _ }) => {
                if x1 == x2 && y1 == y2 {
                    let lhs = (x1.clone() * x1.clone() * BigInt::from(3u8) + curve1.a.clone()) % curve1.p.clone();
                    let rhs = (y1.clone() * BigInt::from(2u8)) % curve1.p.clone();
                    let inv = mod_inv(&rhs, &curve1.p);
                    (lhs * inv) % curve1.p.clone()
                } else {
                    let lhs = (y2.clone() - y1.clone()) % curve1.p.clone();
                    let rhs = (x2.clone() - x1.clone()) % curve1.p.clone();
                    let inv = mod_inv(&rhs, &curve1.p);
                    (lhs * inv) % curve1.p.clone()
                }
            }
        }
    }
}

impl PartialEq for Point {
    fn eq(&self, other: &Self) -> bool {
        match (self, other) {
            (Point::Inf { .. }, Point::Inf { .. }) => true,
            (Point::Pair { x: x1, y: y1, curve: curve1 }, Point::Pair { x: x2, y: y2, curve: curve2 }) => {
                x1 == x2 && y1 == y2 && curve1 == curve2
            }
            _ => false,
        }
    }

    fn ne(&self, other: &Self) -> bool {
        !self.eq(other)
    }
}

impl Add for Point {
    type Output = Self;

    fn add(self, rhs: Self) -> Self {
        match (self.clone(), rhs.clone()) {
            (Point::Inf { .. }, p) => p,
            (p, Point::Inf { .. }) => p,
            (Point::Pair { x: x1, y: y1, curve: curve1 }, Point::Pair { x: x2, y: y2, curve: curve2 }) => {
                if curve1 != curve2 {
                    panic!("points are not on the same curve");
                }
                if x1 == x2 && y1 != y2 {
                    return Point::Inf { curve: curve1 };
                }
                let lambda = self.lambda(&rhs);
                let x3 = (lambda.clone() * lambda.clone() - x1.clone() - x2.clone()) % curve1.p.clone();
                let y3 = (lambda.clone() * (x1.clone() - x3.clone()) - y1.clone()) % curve1.p.clone();
                let x3 = if x3 < BigInt::zero() {
                    x3 + curve1.p.clone()
                } else {
                    x3
                };
                let y3 = if y3 < BigInt::zero() {
                    y3 + curve1.p.clone()
                } else {
                    y3
                };
                Point::new(x3, y3, curve1)
            }
        }
    }
}

impl Neg for Point {
    type Output = Self;

    fn neg(self) -> Self {
        match self {
            Point::Inf { .. } => Point::Inf { curve: self.curve() },
            Point::Pair { x, y, curve } => Point::Pair { x, y: -y % curve.p.clone(), curve },
        }
    }
}

impl Sub for Point {
    type Output = Self;

    fn sub(self, rhs: Self) -> Self {
        self + (-rhs)
    }
}

impl Mul<&BigInt> for Point {
    type Output = Self;

    fn mul(self, rhs: &BigInt) -> Self {
        let mut res = Point::inf(self.curve());
        let mut p = self.clone();
        let mut n = rhs.clone();
        while n > BigInt::zero() {
            if n.clone() % BigInt::from(2u8) == BigInt::one() {
                res = res + p.clone();
            }
            p = p.clone() + p.clone();
            n = n / BigInt::from(2u8);
        }
        res
    }
}

impl Display for Point {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        match self {
            Point::Inf { .. } => write!(f, "Inf"),
            Point::Pair { x, y, curve } => write!(f, "({}, {}) on {:?}", x, y, curve),
        }
    }
}

简单写一下,实现下点的加法和数乘操作,根据上面的公式描述,这段代码也不难实现,注意数乘操作这里用到了个小技巧,类似快速幂的方案。

代码测试

我们用ECDH来测试下上面写的运算是不是正确的

impl KeyPair {
    pub fn new(curve: &Curve, priv_key: &BigInt, pub_key: &Point) -> KeyPair {
        KeyPair { curve: curve.clone(), priv_key: priv_key.clone(), pub_key: pub_key.clone() }
    }

    pub fn from_priv_key(curve: &Curve, priv_key: &BigInt) -> KeyPair {
        let g = Point::new(curve.field.g.0.clone(), curve.field.g.1.clone(), curve.clone());
        let pub_key = g * &priv_key;
        KeyPair { curve: curve.clone(), priv_key: priv_key.clone(), pub_key }
    }

    pub fn make_key_pair(curve: &Curve) -> KeyPair {
        // 生成一个小于curve.field.n的随机数
        let mut rng = rand::thread_rng();
        let priv_key = rng.gen_bigint_range(&BigInt::one(), &curve.field.n);
        KeyPair::from_priv_key(curve, &priv_key)
    }
}

impl Display for KeyPair {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        write!(f, "KeyPair {{ curve: {}, priv_key: {}, pub_key: {} }}"self.curve, self.priv_key, self.pub_key)
    }
}

impl ECDH {
    pub fn new(key_pair: KeyPair) -> ECDH {
        ECDH { key_pair }
    }

    pub fn make_key_pair(&self, curve: &Curve) -> KeyPair {
        KeyPair::make_key_pair(curve)
    }

    pub fn make_shared_secret(&self, priv_key: &BigInt, pub_key: &Point) -> BigInt {
        let shared_point = pub_key.clone() * priv_key;
        shared_point.get_x().expect("shared point is inf")
    }
}

编写下测试代码

#[cfg(test)]
mod tests {
    use super::*;
    
    #[test]
    fn test_ecdh() {
        let curve = Curve::brainpoolP160r1();
        let alice = ECDH::new(KeyPair::make_key_pair(&curve));
        let bob = ECDH::new(KeyPair::make_key_pair(&curve));
        let secret_alice = alice.make_shared_secret(&alice.key_pair.priv_key, &bob.key_pair.pub_key);
        let secret_bob = bob.make_shared_secret(&bob.key_pair.priv_key, &alice.key_pair.pub_key);
        assert_eq!(secret_alice, secret_bob);
    }
}

发现,这个测试是可以通过的,然后这段代码就写好了~~欧耶

编码实现(C++)

然后,这里顺道用C++也写一下吧,福利会C++的读者,虽然我C++也写的贼烂,希望C++大佬不要喷我,这里因为C++是不带大数运算库的,需要自己安排一个,这里我找了个简单易用的,引入个头文件就完事儿了,或者其他读者可以考虑NTL或者GMP,应该也都可以,只需要自己加入下依赖。(才不是为了凑字数),这里直接给出代码了,不在拆开了。

#include <iostream>
#include <utility>
#include "BigInt.hh"

class SubGroup {
 public:
  BigInt p, n, h;
  std::pair<BigInt, BigInt> g;

  SubGroup(const BigInt &p, const BigInt &n, const BigInt &h, const std::pair<BigInt, BigInt> &g)
      : p(p), n(n), h(h), g(g) {}

  // display
  friend std::ostream &operator<<(std::ostream &os, const SubGroup &subGroup) {
    os << "p: " << subGroup.p << std::endl;
    os << "n: " << subGroup.n << std::endl;
    os << "h: " << subGroup.h << std::endl;
    os << "g: " << subGroup.g.first << ", " << subGroup.g.second << std::endl;
    return os;
  }
};

class Curve {
 public:
  BigInt a, b, p;
  SubGroup field;

  Curve(const BigInt &a, const BigInt &b, SubGroup field)
      : a(a), b(b), field(std::move(field)) {
    this->p = this->field.p;
  }

  [[nodiscard]] bool is_singular() const {
    return 4 * this->a * this->a * this->a + 27 * this->b * this->b == 0;
  }

  // equal
  bool operator==(const Curve &other) const {
    return this->a == other.a && this->b == other.b && this->p == other.p;
  }

  // not equal
  bool operator!=(const Curve &other) const {
    return !(*this == other);
  }

  // display
  friend std::ostream &operator<<(std::ostream &os, const Curve &curve) {
    os << "y^2 = x^3 + " << curve.a << "x + " << curve.b << " (mod " << curve.p << ")";
    return os;
  }

  static Curve brainpoolP160r1() {
    BigInt p("1332297598440044874827085558802491743757193798159");
    BigInt n("1332297598440044874827085038830181364212942568457");
    BigInt h("1");
    std::pair<BigInt, BigInt> g("1089473557631435284577962539738532515920566082499""127912481829969033206777085249718746721365418785");
    SubGroup field(p, n, h, g);
    return Curve{BigInt("297190522446607939568481567949428902921613329152"), BigInt("173245649450172891208247283053495198538671808088"), field};
  }
};

#pragma clang diagnostic push
#pragma ide diagnostic ignored "misc-no-recursion"
std::tuple<BigInt, BigInt, BigInt> egcd(const BigInt &a, const BigInt &b) {
  if (b == 0) {
    return std::make_tuple(a, 10);
  }
  BigInt g, x, y;
  std::tie(g, x, y) = egcd(b, a % b);
  return std::make_tuple(g, y, x - (a / b) * y);
}
#pragma clang diagnostic pop

BigInt mod_inv(const BigInt &a, const BigInt &p) {
  if (a < 0) {
    return mod_inv(a + p, p);
  }
  BigInt g, x, y;
  std::tie(g, x, y) = egcd(a, p);
  if (g != 1) {
    throw std::runtime_error("modular inverse does not exist");
  }
  return (x % p + p) % p;
}

class Point {
 private:
  bool inf;
  BigInt lambda(Point &other) const {
    if (this->x == other.x) {
      if (this->y == other.y) {
        return ((3 * this->x * this->x + this->curve.a) % this->curve.p) * mod_inv(2 * this->y, this->curve.p) % this->curve.p;
      }
    }
    return (this->y - other.y) * mod_inv(this->x - other.x, this->curve.p) % this->curve.p;
  }
 public:
  BigInt x, y;
  Curve curve;

  Point(const BigInt &x, const BigInt &y, Curve curve) : x(x), y(y), curve(std::move(curve)) {
    this->inf = false;
  }

  explicit Point(Curve curve) : x(0)y(0)curve(std::move(curve)) {
    this->inf = true;
  }

  [[nodiscard]] bool is_inf() const {
    return this->inf;
  }

  [[nodiscard]] bool is_on_curve() const {
    if (this->is_inf()) {
      return true;
    }
    return (this->y * this->y - (this->x * this->x * this->x + this->curve.a * this->x + this->curve.b)) % this->curve.p == 0;
  }

  bool operator==(const Point &p) const {
    if (this->curve != p.curve) {
      return false;
    }
    if (this->inf && p.inf) {
      return true;
    }
    return this->x == p.x && this->y == p.y;
  }

  // not equal
  bool operator!=(const Point &p) const {
    return !(*this == p);
  }

  // add
  Point operator+(const Point &p) const {
    if (this->curve != p.curve) {
      throw std::runtime_error("curves are not equal");
    }
    if (this->is_inf()) {
      return p;
    }
    if (p.is_inf()) {
      return *this;
    }
    if (this->x == p.x) {
      if ((this->y + p.y) % this->curve.p == 0) {
        return Point(this->curve);
      }
    }
    BigInt lambda = this->lambda(const_cast<Point &>(p));
    BigInt x3 = (lambda * lambda - this->x - p.x) % this->curve.p;
    BigInt y3 = (lambda * (this->x - x3) - this->y) % this->curve.p;
    if (x3 < 0) {
      x3 += this->curve.p;
    }
    if (y3 < 0) {
      y3 += this->curve.p;
    }
    return Point{x3, y3, this->curve};
  }

  // neg
  Point operator-() const {
    return Point{this->x, -this->y, this->curve};
  }

  // sub
  Point operator-(const Point &p) const {
    return *this + (-p);
  }

  // scalar-multiplication
  Point operator*(const BigInt &k) const {
    Point res = Point(this->curve);
    Point p = *this;
    BigInt n = k;
    while (n > 0) {
      if (n % 2 == 1) {
        res = res + p;
      }
      p = p + p;
      n = n / 2;
    }
    return res;
  }

  // display
  friend std::ostream &operator<<(std::ostream &os, const Point &p) {
    if (p.is_inf()) {
      os << "Point(inf)";
    } else {
      os << "Point(" << p.x << ", " << p.y << ")";
    }
    return os;
  }
};

class KeyPair {
 public:
  Point pub_key;
  BigInt priv_key;
  Curve curve;

  KeyPair(Point pub_key, const BigInt &priv_key, Curve curve) : pub_key(std::move(pub_key)), priv_key(priv_key), curve(std::move(curve)) {}

  static KeyPair make_key_pair(Curve &curve) {
    BigInt priv_key = big_random(10);
    Point g = Point(curve.field.g.first, curve.field.g.second, curve);
    Point pub_Key = g * priv_key;
    return KeyPair{pub_Key, priv_key, curve};
  }
};

BigInt make_shared_secret(const Point &pub_key, const BigInt &priv_key) {
  Point shared_secret = pub_key * priv_key;
  return shared_secret.x;
}

int main() {
  Curve curve = Curve::brainpoolP160r1();
  std::cout << curve << std::endl;
  KeyPair alice = KeyPair::make_key_pair(curve);
  KeyPair bob = KeyPair::make_key_pair(curve);
  BigInt shared_secret_alice = make_shared_secret(bob.pub_key, alice.priv_key);
  BigInt shared_secret_bob = make_shared_secret(alice.pub_key, bob.priv_key);
  std::cout << shared_secret_alice << std::endl;
  std::cout << shared_secret_bob << std::endl;
  return 0;
}


编码实现(Python)

再来水"亿"点点字数,来一个Python的版本吧。

import random
from typing import Union


def egcd(a: int, b: int) -> (int, int, int):
    if a == 0:
        return b, 01
    else:
        g, y, x = egcd(b % a, a)
        return g, x - (b // a) * y, y


def mod_inv(a: int, p: int) -> int:
    if a < 0:
        return p - mod_inv(-a, p)
    g, x, y = egcd(a, p)
    if g != 1:
        raise Exception('modular inverse does not exist')
    else:
        return x % p


class Inf:
    def __init__(self, curve: "Curve"):
        self.x = None
        self.y = None
        self.curve = curve

    def __eq__(self, other: Union["Point""Inf"]) -> bool:
        return isinstance(other, Inf) and self.curve == other.curve

    def __ne__(self, other):
        return not self.__eq__(other)

    def __add__(self, other):
        return other

    def __sub__(self, other):
        return other

    def __str__(self):
        return "Infinity on %s" % self.curve


class Point:
    def __init__(self, x: int, y: int, curve: "Curve"):
        self.x = x
        self.y = y
        self.curve = curve
        self.p = self.curve.field.p
        self.is_on_curve = True
        if not self.check_on_curve():
            self.is_on_curve = False

    def check_on_curve(self):
        return (self.y ** 2 - self.x ** 3 - self.curve.a * self.x - self.curve.b) % self.curve.field.p == 0

    def _lambda(self, p: "Point", q: "Point"):
        if p.x == q.x and p.y == q.y:
            return ((3 * p.x ** 2 + self.curve.a) * mod_inv(2 * p.y, self.p)) % self.p
        else:
            return ((q.y - p.y) * mod_inv(q.x - p.x, self.p)) % self.p

    def __eq__(self, other: Union["Point""Inf"]):
        if isinstance(other, Point):
            return self.x == other.x and self.y == other.y and self.curve == other.curve
        return False

    def __ne__(self, other: Union["Point""Inf"]):
        return not self.__eq__(other)

    def __add__(self, other: Union["Point""Inf"]):
        if isinstance(other, Inf):
            return self
        if not self.is_on_curve or not other.is_on_curve:
            raise Exception('point not on curve')
        if self.curve != other.curve:
            raise Exception('points not on the same curve')
        if self.x == other.x and self.y != other.y:
            return Inf(self.curve)
        m = self._lambda(self, other)
        x = (m ** 2 - self.x - other.x) % self.p
        y = (m * (self.x - x) - self.y) % self.p
        return Point(x, y, self.curve)

    def __sub__(self, other: Union["Point""Inf"]):
        if isinstance(other, Inf):
            return self.__add__(other)
        return self + Point(other.x, -other.y % self.p, self.curve)

    def __neg__(self):
        return Point(self.x, -self.y % self.p, self.curve)

    def __mul__(self, n: int):
        if not self.is_on_curve:
            raise Exception('point not on curve')
        if n == 0:
            return Inf(self.curve)
        if n < 0:
            return -self * (-n)
        res = Inf(self.curve)
        p = self
        while n > 1:
            if n % 2 == 1:
                res = res + p
            p = p + p
            n = n // 2
        return res + p

    def __rmul__(self, other: int):
        return self.__mul__(other)

    def __str__(self):
        return "(%d, %d) %s %s" % (self.x, self.y, "on" if self.is_on_curve else "off", self.curve)

    def __repr__(self):
        return self.__str__()


class Curve:
    def __init__(self, a, b, field: "SubGroup"):
        self.a = a
        self.b = b
        self.field = field

    def is_singular(self):
        return 4 * self.a ** 3 + 27 * self.b ** 2 == 0

    def __eq__(self, other: "Curve"):
        return self.a == other.a and self.b == other.b and self.field == other.field

    def __ne__(self, other: "Curve"):
        return not self.__eq__(other)

    def __str__(self):
        return "y^2 = x^3 + %dx + %d over %s" % (self.a, self.b, self.field)

    @staticmethod
    def brainpoolP160r1() -> "Curve":
        field = SubGroup(
            0xE95E4A5F737059DC60DFC7AD95B3D8139515620F,
            (0xBED5AF16EA3F6A4F62938C4631EB5AF7BDBCDBC30x1667CB477A1A8EC338F94741669C976316DA6321),
            0xE95E4A5F737059DC60DF5991D45029409E60FC09,
            0x01
        )
        return Curve(0x340E7BE2A280EB74E2BE61BADA745D97E8F7C3000x1E589A8595423412134FAA2DBDEC95C8D8675E58, field)


class SubGroup:
    def __init__(self, p: int, g: (int, int), n: int, h: int):
        self.p = p
        self.g = g
        self.n = n
        self.h = h

    def __eq__(self, other: "SubGroup"):
        return self.p == other.p and self.g == other.g and self.n == other.n and self.h == other.h

    def __ne__(self, other: "SubGroup"):
        return not self.__eq__(other)

    def __str__(self):
        return "SubGroup(p=%d, g=%s, n=%d, h=%d)" % (self.p, self.g, self.n, self.h)


class KeyPair:
    def __init__(self, curve, private_key, public_key):
        self.private_key = private_key
        self.public_key = public_key
        self.curve = curve

    @staticmethod
    def make_key_pair(curve: Curve) -> "KeyPair":
        private_key = random.randint(1, curve.field.n - 1)
        g = Point(curve.field.g[0], curve.field.g[1], curve)
        public_key = private_key * g
        return KeyPair(curve, private_key, public_key)


class ECDH:
    def __init__(self, key_pair):
        self.key_pair = key_pair

    @staticmethod
    def make_key_pair(curve: Curve) -> "KeyPair":
        return KeyPair.make_key_pair(curve)

    @staticmethod
    def make_shared_secret(key_pair: KeyPair, public_key: Point) -> int:
        return (key_pair.private_key * public_key).x


if __name__ == '__main__':
    _curve = Curve.brainpoolP160r1()
    alice = ECDH.make_key_pair(_curve)
    bob = ECDH.make_key_pair(_curve)
    alice_secret = ECDH.make_shared_secret(alice, bob.public_key)
    bob_secret = ECDH.make_shared_secret(bob, alice.public_key)
    print(alice_secret == bob_secret)

结束语

好了,这个其实在素域上面实现起来还是比较容易的,在其他的域上的实现,读者先自行脑补一下吧,溜了溜了 ~~,然后这些代码仅限于研究学习使用,生产用途还是建议使用标准库函数来做。

因为公式排版的问题,这里的支持并不好,所以放一份到语雀了,访问密码: euso,点击原文即可阅读。

参考资料

https://mp.weixin.qq.com/s/lLrnyKmufpBw_3chcva9KQ

  • https://github.com/faheel/BigInt[1]

Reference

[1]

https://github.com/faheel/BigInt: https://github.com/faheel/BigInt


原文始发于微信公众号(Coder小Q):【密码学】椭圆曲线代码实现

  • 左青龙
  • 微信扫一扫
  • weinxin
  • 右白虎
  • 微信扫一扫
  • weinxin
admin
  • 本文由 发表于 2022年12月22日11:52:22
  • 转载请保留本文链接(CN-SEC中文网:感谢原作者辛苦付出):
                   【密码学】椭圆曲线代码实现https://cn-sec.com/archives/1477752.html

发表评论

匿名网友 填写信息