From 913e48e9b74141279c3c4be25327625ce11b3267 Mon Sep 17 00:00:00 2001 From: Janis Erdmanis Date: Thu, 31 Oct 2024 01:57:31 +0200 Subject: [PATCH] fixes for OpenSSLGroups --- Project.toml | 2 +- src/Curves/ecpoint.jl | 7 ++-- src/macros.jl | 98 +++++++++++++++++++++++++------------------ src/spec.jl | 2 +- 4 files changed, 64 insertions(+), 45 deletions(-) diff --git a/Project.toml b/Project.toml index 88ef3a2..a92a6bd 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "CryptoGroups" uuid = "bc997328-bedd-407e-bcd3-5758e064a52d" authors = ["Janis Erdmanis "] -version = "0.5.0" +version = "0.5.1" [deps] CryptoPRG = "d846c407-34c1-46cb-aa27-d51818cc05e2" diff --git a/src/Curves/ecpoint.jl b/src/Curves/ecpoint.jl index 997b565..90153ad 100644 --- a/src/Curves/ecpoint.jl +++ b/src/Curves/ecpoint.jl @@ -35,10 +35,11 @@ Base.:-(u::P, v::P) where P <: AbstractPoint = u + (-v) Base.isless(x::P, y::P) where P <: AbstractPoint = gx(x) == gx(y) ? gx(x) < gx(y) : gy(x) < gy(y) -function validate(x::AbstractPoint, order::Integer, cofactor::Integer) +function validate(x::P, order::Integer, cofactor::Integer) where P <: AbstractPoint oncurve(x) || throw(ArgumentError("Point is not in curve")) - x * cofactor != zero(x) || throw(ArgumentError("Point is in cofactor subgroup")) + #x * cofactor != zero(P) || throw(ArgumentError("Point is in cofactor subgroup")) + !iszero(x * cofactor) || throw(ArgumentError("Point is in cofactor subgroup")) return end @@ -130,7 +131,7 @@ name(::Type{ECPoint{P, S}}) where {P <: AbstractPoint, S} = isnothing(S.name) ? eq(::Type{ECPoint{P, S}}) where {P <: AbstractPoint, S} = eq(P) field(::Type{ECPoint{P, S}}) where {P <: AbstractPoint, S} = field(P) - +field(::Type{ECPoint{P}}) where {P <: AbstractPoint} = field(P) """ zero(::Union{P, Type{P}}) where P <: AbstractPoint diff --git a/src/macros.jl b/src/macros.jl index 9261d2e..d58a97c 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -1,6 +1,5 @@ using .Specs: modp_spec -# TODO: Add support for @PGroup{p = _p, q = _q} where _q and _p are defined out of the scope macro PGroup(expr) if expr.head == :braces if length(expr.args) == 1 && !(expr.args[1] isa Expr) @@ -9,15 +8,18 @@ macro PGroup(expr) spec = modp_spec(name) group = concretize_type(PGroup, spec; name) return group + else - # Two-argument case: @PGroup{p=23, q=11} + # Two-argument case: @PGroup{p=23, q=11} or @PGroup{p=my_p, q=my_q} p = q = nothing for arg in expr.args if arg isa Expr && arg.head == :(=) - if arg.args[1] == :p - p = arg.args[2] - elseif arg.args[1] == :q - q = arg.args[2] + lhs = arg.args[1] + rhs = arg.args[2] + if lhs == :p + p = rhs + elseif lhs == :q + q = rhs end end end @@ -27,16 +29,22 @@ macro PGroup(expr) error("Both p and q must be specified in @PGroup{p=..., q=...}") end - spec = MODP(p; q) - group = concretize_type(PGroup, spec) - return group + # Properly escape both p and q values + return quote + local p_val = $(esc(p)) + local q_val = $(esc(q)) + local spec = MODP(p_val; q=q_val) + concretize_type( + PGroup, + spec + ) + end end else error("Invalid syntax. Use @PGroup{p=..., q=...} or @PGroup{some_name}") end end - Base.show(io::IO, ::Type{PGroup}) = print(io, "PGroup") function Base.show(io::IO, ::Type{G}) where G <: PGroup @@ -66,20 +74,51 @@ end macro ECGroup(expr) - if expr.head == :braces && length(expr.args) == 1 && !(expr.args[1] isa Expr) - # Single argument case: @PGroup{some_name} - some_name = expr.args[1] - - # If the curve can't be found error here - spec = curve(some_name) - group = concretize_type(ECGroup, spec) + if expr.head == :braces && length(expr.args) == 1 + arg = expr.args[1] + point_expr = Expr(:macrocall, Symbol("@ECPoint"), LineNumberNode(@__LINE__), + Expr(:braces, arg)) + # Use __module__ to get the ECGroup type from the defining module + return :(ECGroup{$(esc(point_expr))}) + else + error("Invalid syntax. Use @ECGroup{curve_name} or @ECGroup{Module.curve_name}") + end +end - return group +# First, let's modify @ECPoint to ensure it handles symbol quoting correctly +macro ECPoint(expr) + if expr.head == :braces && length(expr.args) == 1 + arg = expr.args[1] + + # Handle module-qualified names (e.g., OpenSSLGroups.SecP256k1) + if arg isa Expr && arg.head == :. + return quote + local P = $(esc(arg)) + concretize_type(ECPoint{P}, order(P), cofactor(P); name = nameof(P)) + end + # Handle simple symbols + elseif arg isa Symbol + # Important: Use QuoteNode here for the isdefined check + return quote + if isdefined($(__module__), $(QuoteNode(arg))) + # If defined, use the escaped symbol to access its value + local P = $(esc(arg)) + concretize_type(ECPoint{P}, order(P), cofactor(P); name = nameof(P)) + else + # If not defined, treat it as a curve name + local spec = curve($(QuoteNode(arg))) + concretize_type(ECPoint, spec) + end + end + else + error("Invalid syntax. Use @ECPoint{curve_name} or @ECPoint{Module.curve_name}") + end else - error("Invalid syntax. Use @ECGroup{curve_name}") + error("Invalid syntax. Use @ECPoint{curve_name} or @ECPoint{Module.curve_name}") end end + function Base.show(io::IO, g::G) where G <: ECGroup show(io, G) print(io, "(") @@ -101,23 +140,6 @@ function Base.show(io::IO, ::Type{G}) where G <: ECGroup end end - -macro ECPoint(expr) - if expr.head == :braces && length(expr.args) == 1 && !(expr.args[1] isa Expr) - # Single argument case: @PGroup{some_name} - some_name = expr.args[1] - - spec = curve(some_name) - # If the curve can't be found error here - _curve = concretize_type(ECPoint, spec) - - return _curve - else - error("Invalid syntax. Use @ECPoint{curve_name}") - end -end - - ### May need to do epoint seperatelly function Base.show(io::IO, ::Type{P}) where P <: ECPoint if @isdefined P @@ -131,7 +153,6 @@ function Base.show(io::IO, ::Type{P}) where P <: ECPoint end end - function Base.show(io::IO, p::P) where P <: ECPoint show(io, P) print(io, "(") @@ -144,6 +165,3 @@ end function Base.display(::Type{P}) where P <: ECPoint show(P) end - - - diff --git a/src/spec.jl b/src/spec.jl index fc6a34e..a867460 100644 --- a/src/spec.jl +++ b/src/spec.jl @@ -158,5 +158,5 @@ spec(g::ECGroup) = spec(g.x) spec(::Type{G}) where G <: PGroup = MODP(; p = modulus(G), q = order(G)) (::Type{P})() where P <: ECPoint = P(generator(curve(name(P)))) -(::Type{G})() where G <: ECGroup = G(generator(curve(name(G)))) +(::Type{ECGroup{P}})() where P <: ECPoint = ECGroup{P}(P()) (::Type{G})() where G <: PGroup = G(generator(modp_spec(name(G))))