From f64b9a6af1885cb2b4baf2b4dc299d22a0738965 Mon Sep 17 00:00:00 2001 From: Rahul Ghangas Date: Mon, 11 Jan 2021 12:13:10 +0530 Subject: [PATCH] feat: Add version negotiation handshake wrapper --- handshake/negotiate.go | 59 ++++++++++++++++++++++++++++++++++++++++++ wire/wire.go | 4 +++ 2 files changed, 63 insertions(+) create mode 100644 handshake/negotiate.go diff --git a/handshake/negotiate.go b/handshake/negotiate.go new file mode 100644 index 0000000..da956df --- /dev/null +++ b/handshake/negotiate.go @@ -0,0 +1,59 @@ +package handshake + +import ( + "bytes" + "fmt" + "net" + + "github.com/renproject/aw/codec" + "github.com/renproject/aw/wire" + "github.com/renproject/id" + "github.com/renproject/surge" +) + +const VERSION_BYTES = 2 +type versionType = uint16 + +func Negotiate(self id.Signatory, h Handshake) Handshake { + return func(conn net.Conn, enc codec.Encoder, dec codec.Decoder) (codec.Encoder, codec.Decoder, id.Signatory, error) { + e, d, remote, err := h(conn, enc, dec) + if err != nil { + return nil, nil, remote, fmt.Errorf("handshake before negotiating version: %v", err) + } + + var versionBytes [VERSION_BYTES]byte + var version versionType + cmp := bytes.Compare(self[:], remote[:]) + if cmp < 0 { + if _, err := dec(conn, versionBytes[:]); err != nil { + return nil, nil, remote, fmt.Errorf("decoding version: %v", err) + } + if _, _, err := surge.UnmarshalU16(&version, versionBytes[:], len(versionBytes)); err != nil { + return nil, nil, remote, fmt.Errorf("unmarshaling version: %v", err) + } + if _, err := enc(conn, versionBytes[:]); err != nil { + return nil, nil, remote, fmt.Errorf("encoding current version: %v", err) + } + if version == wire.CurrentVersion { + return e, d, remote, nil + } + return nil, nil, remote, fmt.Errorf("not current version: %v", err) + } + if _, _, err := surge.MarshalU16(wire.CurrentVersion, versionBytes[:], len(versionBytes)); err != nil { + return nil, nil, remote, fmt.Errorf("marshaling current version: %v", err) + } + if _, err := enc(conn, versionBytes[:]); err != nil { + return nil, nil, remote, fmt.Errorf("encoding version: %v", err) + } + if _, err := dec(conn, versionBytes[:]); err != nil { + return nil, nil, remote, fmt.Errorf("decoding version: %v", err) + } + if _, _, err := surge.UnmarshalU16(&version, versionBytes[:], len(versionBytes)); err != nil { + return nil, nil, remote, fmt.Errorf("unmarshaling version: %v", err) + } + if version == wire.CurrentVersion { + return e, d, remote, nil + } + return nil, nil, remote, fmt.Errorf("not current version: %v", err) + } +} diff --git a/wire/wire.go b/wire/wire.go index d76ae38..02b2465 100644 --- a/wire/wire.go +++ b/wire/wire.go @@ -11,8 +11,12 @@ import ( // Enumerate all valid MsgVersion values. const ( MsgVersion1 = uint16(1) + CurrentVersion = MsgVersion1 ) +var romanNumeralDict = map[uint16]struct{}{ +} + // Enumerate all valid MsgType values. const ( MsgTypePush = uint16(1)