From 2a5a7e25d7bfbfaa732b22cb39179467224e93a9 Mon Sep 17 00:00:00 2001 From: Emanuel Pargov Date: Mon, 30 Dec 2024 15:03:11 +0200 Subject: [PATCH] Add quic transport layer --- cmd/strawberry/main.go | 143 ++++++--- go.mod | 47 +-- go.sum | 100 +++--- pkg/network/cert/cert.go | 118 ++++--- pkg/network/cert/cert_test.go | 18 +- pkg/network/handlers/block_request.go | 89 ++++++ pkg/network/handlers/message.go | 83 +++++ pkg/network/network_test.go | 269 ++++++++++++++++ pkg/network/peer/peer.go | 39 +++ pkg/network/{transport => }/protocol/alpn.go | 2 +- .../{transport => }/protocol/alpn_test.go | 2 +- pkg/network/protocol/conn.go | 128 ++++++++ pkg/network/protocol/manager.go | 126 ++++++++ pkg/network/protocol/streams.go | 71 +++++ pkg/network/transport/conn.go | 81 +++++ pkg/network/transport/errors.go | 12 + pkg/network/transport/transport.go | 288 ++++++++++++++++++ 17 files changed, 1463 insertions(+), 153 deletions(-) create mode 100644 pkg/network/handlers/block_request.go create mode 100644 pkg/network/handlers/message.go create mode 100644 pkg/network/network_test.go create mode 100644 pkg/network/peer/peer.go rename pkg/network/{transport => }/protocol/alpn.go (99%) rename pkg/network/{transport => }/protocol/alpn_test.go (99%) create mode 100644 pkg/network/protocol/conn.go create mode 100644 pkg/network/protocol/manager.go create mode 100644 pkg/network/protocol/streams.go create mode 100644 pkg/network/transport/conn.go create mode 100644 pkg/network/transport/errors.go create mode 100644 pkg/network/transport/transport.go diff --git a/cmd/strawberry/main.go b/cmd/strawberry/main.go index e9ffadc..49eebb0 100644 --- a/cmd/strawberry/main.go +++ b/cmd/strawberry/main.go @@ -1,54 +1,121 @@ package main import ( - "encoding/json" + "context" + "crypto/ed25519" + "flag" + "fmt" "log" - "net/http" - "sync" + "time" - "github.com/eigerco/strawberry/internal/block" - "github.com/eigerco/strawberry/internal/state" - "github.com/eigerco/strawberry/internal/statetransition" + "github.com/eigerco/strawberry/pkg/network/cert" + "github.com/eigerco/strawberry/pkg/network/handlers" + "github.com/eigerco/strawberry/pkg/network/peer" + "github.com/eigerco/strawberry/pkg/network/protocol" + "github.com/eigerco/strawberry/pkg/network/transport" ) +// main starts a blockchain node. +// +// To run the first node (listener): +// +// go run main.go -addr localhost:9000 +// +// To run a second node that connects to the first node: +// +// go run main.go -addr localhost:9001 -connect localhost:9000 +// +// - The first node listens on port 9000. +// - The second node listens on port 9001 and connects to the first node's address (localhost:9000). func main() { - globalState := &state.State{} - mu := &sync.RWMutex{} - - // a simple http server that demonstrates the block import capabilities - // this will be replaced with proper p2p network communication in milestone 2 - mux := http.NewServeMux() - mux.HandleFunc("/block/import", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - newBlock := block.Block{} - if err := json.NewDecoder(r.Body).Decode(&newBlock); err != nil { - jsonError(w, err.Error(), http.StatusBadRequest) - return - } + listenAddr := flag.String("addr", "", "Listen address (e.g., 0.0.0.0:9000)") + connectTo := flag.String("connect", "", "Address to connect to (optional)") + flag.Parse() + + if *listenAddr == "" { + log.Fatal("listen address is required") + } + + // Generate node keys + pub, priv, err := ed25519.GenerateKey(nil) + if err != nil { + log.Fatalf("Failed to generate keys: %v", err) + } + + // Create certificate + certGen := cert.NewGenerator(cert.Config{ + PublicKey: pub, + PrivateKey: priv, + CertValidityPeriod: 24 * time.Hour, + }) + tlsCert, err := certGen.GenerateCertificate() + if err != nil { + log.Fatalf("Failed to generate certificate: %v", err) + } + + // Create protocol manager + protoConfig := protocol.Config{ + ChainHash: "12345678", // Example chain hash + IsBuilder: false, + MaxBuilderSlots: 20, + } + protoManager, err := protocol.NewManager(protoConfig) + if err != nil { + log.Fatalf("Failed to create protocol manager: %v", err) + } - mu.Lock() - defer mu.Unlock() + // Register protocol handlers + protoManager.Registry.RegisterHandler(protocol.StreamKindBlockRequest, handlers.NewBlockRequestHandler()) - if err := statetransition.UpdateState(globalState, newBlock); err != nil { - jsonError(w, err.Error(), http.StatusBadRequest) - return + // Create transport with minimal config + transportConfig := transport.Config{ + PublicKey: pub, + PrivateKey: priv, + TLSCert: tlsCert, + ListenAddr: *listenAddr, + CertValidator: cert.NewValidator(), + Handler: protoManager, // Protocol manager implements ConnectionHandler + } + + tr, err := transport.NewTransport(transportConfig) + if err != nil { + log.Fatalf("Failed to create transport: %v", err) + } + + if err := tr.Start(); err != nil { + log.Fatalf("Failed to start transport: %v", err) + } + defer func() { + if err := tr.Stop(); err != nil { + fmt.Printf("Failed to stop transport: %v\n", err) } + }() + + log.Printf("Node listening on %s", *listenAddr) + + // If we have an address to connect to, make a request + if *connectTo != "" { + log.Printf("Connecting to peer at %s", *connectTo) - if err := json.NewEncoder(w).Encode(map[string]string{"status": "success"}); err != nil { - jsonError(w, err.Error(), http.StatusInternalServerError) - return + conn, err := tr.Connect(*connectTo) + if err != nil { + log.Fatalf("Failed to connect to peer: %v", err) } - }) - log.Println("demo server running on port :8080") - log.Fatal(http.ListenAndServe(":8080", mux)) -} -func jsonError(w http.ResponseWriter, message string, statusCode int) { - w.WriteHeader(statusCode) - if err := json.NewEncoder(w).Encode(map[string]string{ - "status": "error", - "message": message, - }); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + // Create a dummy block hash for the request + hash := [32]byte{1, 2, 3, 4} // Example hash + + // Create peer with protocol connection + p := peer.NewPeer(conn, conn.PeerKey(), protoManager) + ctx := context.Background() + blocks, err := p.RequestBlocks(ctx, hash, true) + if err != nil { + log.Fatalf("Failed to request blocks: %v", err) + } + fmt.Printf("blocks: %v\n", blocks) + log.Printf("Block request completed") } + + // Keep the node running + select {} } diff --git a/go.mod b/go.mod index 746580a..b487dfe 100644 --- a/go.mod +++ b/go.mod @@ -3,40 +3,49 @@ module github.com/eigerco/strawberry go 1.22.5 require ( - github.com/cockroachdb/pebble v1.1.0 + github.com/cockroachdb/pebble v1.1.2 github.com/ebitengine/purego v0.8.1 github.com/golang/mock v1.6.0 + github.com/quic-go/quic-go v0.48.2 github.com/stretchr/testify v1.10.0 - golang.org/x/crypto v0.27.0 + golang.org/x/crypto v0.31.0 ) require ( - github.com/DataDog/zstd v1.5.2 // indirect + github.com/DataDog/zstd v1.5.6 // indirect github.com/beorn7/perks v1.0.1 // indirect - github.com/cespare/xxhash/v2 v2.2.0 // indirect - github.com/cockroachdb/errors v1.11.1 // indirect - github.com/cockroachdb/logtags v0.0.0-20230118201751-21c54148d20b // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/cockroachdb/errors v1.11.3 // indirect + github.com/cockroachdb/fifo v0.0.0-20240606204812-0bbfbd93a7ce // indirect + github.com/cockroachdb/logtags v0.0.0-20241215232642-bb51bb14a506 // indirect github.com/cockroachdb/redact v1.1.5 // indirect github.com/cockroachdb/tokenbucket v0.0.0-20230807174530-cc333fc44b06 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect - github.com/getsentry/sentry-go v0.18.0 // indirect + github.com/getsentry/sentry-go v0.30.0 // indirect + github.com/go-task/slim-sprig/v3 v3.0.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb // indirect - github.com/google/go-cmp v0.6.0 // indirect - github.com/klauspost/compress v1.17.7 // indirect + github.com/google/pprof v0.0.0-20241210010833-40e02aabc2ad // indirect + github.com/klauspost/compress v1.17.11 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/kr/text v0.2.0 // indirect - github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 // indirect + github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/onsi/ginkgo/v2 v2.22.0 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect - github.com/prometheus/client_golang v1.18.0 // indirect - github.com/prometheus/client_model v0.6.0 // indirect - github.com/prometheus/common v0.45.0 // indirect - github.com/prometheus/procfs v0.12.0 // indirect - github.com/rogpeppe/go-internal v1.10.0 // indirect - golang.org/x/exp v0.0.0-20240110193028-0dcbfd608b1e // indirect - golang.org/x/sys v0.25.0 // indirect - golang.org/x/text v0.18.0 // indirect - google.golang.org/protobuf v1.32.0 // indirect + github.com/prometheus/client_golang v1.20.5 // indirect + github.com/prometheus/client_model v0.6.1 // indirect + github.com/prometheus/common v0.61.0 // indirect + github.com/prometheus/procfs v0.15.1 // indirect + github.com/rogpeppe/go-internal v1.13.1 // indirect + go.uber.org/mock v0.5.0 // indirect + golang.org/x/exp v0.0.0-20241215155358-4a5509556b9e // indirect + golang.org/x/mod v0.22.0 // indirect + golang.org/x/net v0.32.0 // indirect + golang.org/x/sync v0.10.0 // indirect + golang.org/x/sys v0.28.0 // indirect + golang.org/x/text v0.21.0 // indirect + golang.org/x/tools v0.28.0 // indirect + google.golang.org/protobuf v1.36.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 13f22ee..e1cbeb8 100644 --- a/go.sum +++ b/go.sum @@ -1,17 +1,19 @@ -github.com/DataDog/zstd v1.5.2 h1:vUG4lAyuPCXO0TLbXvPv7EB7cNK1QV/luu55UHLrrn8= -github.com/DataDog/zstd v1.5.2/go.mod h1:g4AWEaM3yOg3HYfnJ3YIawPnVdXJh9QME85blwSAmyw= +github.com/DataDog/zstd v1.5.6 h1:LbEglqepa/ipmmQJUDnSsfvA8e8IStVcGaFWDuxvGOY= +github.com/DataDog/zstd v1.5.6/go.mod h1:g4AWEaM3yOg3HYfnJ3YIawPnVdXJh9QME85blwSAmyw= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= -github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= -github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cockroachdb/datadriven v1.0.3-0.20230413201302-be42291fc80f h1:otljaYPt5hWxV3MUfO5dFPFiOXg9CyG5/kCfayTqsJ4= github.com/cockroachdb/datadriven v1.0.3-0.20230413201302-be42291fc80f/go.mod h1:a9RdTaap04u637JoCzcUoIcDmvwSUtcUFtT/C3kJlTU= -github.com/cockroachdb/errors v1.11.1 h1:xSEW75zKaKCWzR3OfxXUxgrk/NtT4G1MiOv5lWZazG8= -github.com/cockroachdb/errors v1.11.1/go.mod h1:8MUxA3Gi6b25tYlFEBGLf+D8aISL+M4MIpiWMSNRfxw= -github.com/cockroachdb/logtags v0.0.0-20230118201751-21c54148d20b h1:r6VH0faHjZeQy818SGhaone5OnYfxFR/+AzdY3sf5aE= -github.com/cockroachdb/logtags v0.0.0-20230118201751-21c54148d20b/go.mod h1:Vz9DsVWQQhf3vs21MhPMZpMGSht7O/2vFW2xusFUVOs= -github.com/cockroachdb/pebble v1.1.0 h1:pcFh8CdCIt2kmEpK0OIatq67Ln9uGDYY3d5XnE0LJG4= -github.com/cockroachdb/pebble v1.1.0/go.mod h1:sEHm5NOXxyiAoKWhoFxT8xMgd/f3RA6qUqQ1BXKrh2E= +github.com/cockroachdb/errors v1.11.3 h1:5bA+k2Y6r+oz/6Z/RFlNeVCesGARKuC6YymtcDrbC/I= +github.com/cockroachdb/errors v1.11.3/go.mod h1:m4UIW4CDjx+R5cybPsNrRbreomiFqt8o1h1wUVazSd8= +github.com/cockroachdb/fifo v0.0.0-20240606204812-0bbfbd93a7ce h1:giXvy4KSc/6g/esnpM7Geqxka4WSqI1SZc7sMJFd3y4= +github.com/cockroachdb/fifo v0.0.0-20240606204812-0bbfbd93a7ce/go.mod h1:9/y3cnZ5GKakj/H4y9r9GTjCvAFta7KLgSHPJJYc52M= +github.com/cockroachdb/logtags v0.0.0-20241215232642-bb51bb14a506 h1:ASDL+UJcILMqgNeV5jiqR4j+sTuvQNHdf2chuKj1M5k= +github.com/cockroachdb/logtags v0.0.0-20241215232642-bb51bb14a506/go.mod h1:Mw7HqKr2kdtu6aYGn3tPmAftiP3QPX63LdK/zcariIo= +github.com/cockroachdb/pebble v1.1.2 h1:CUh2IPtR4swHlEj48Rhfzw6l/d0qA31fItcIszQVIsA= +github.com/cockroachdb/pebble v1.1.2/go.mod h1:4exszw1r40423ZsmkG/09AFEG83I0uDgfujJdbL6kYU= github.com/cockroachdb/redact v1.1.5 h1:u1PMllDkdFfPWaNGMyLD1+so+aq3uUItthCFqzwPJ30= github.com/cockroachdb/redact v1.1.5/go.mod h1:BVNblN9mBWFyMyqK1k3AAiSxhvhfK2oOZZ2lK+dpvRg= github.com/cockroachdb/tokenbucket v0.0.0-20230807174530-cc333fc44b06 h1:zuQyyAKVxetITBuuhv3BI9cMrmStnpT18zmgmTxunpo= @@ -21,10 +23,14 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/ebitengine/purego v0.8.1 h1:sdRKd6plj7KYW33EH5As6YKfe8m9zbN9JMrOjNVF/BE= github.com/ebitengine/purego v0.8.1/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= -github.com/getsentry/sentry-go v0.18.0 h1:MtBW5H9QgdcJabtZcuJG80BMOwaBpkRDZkxRkNC1sN0= -github.com/getsentry/sentry-go v0.18.0/go.mod h1:Kgon4Mby+FJ7ZWHFUAZgVaIa8sxHtnRJRLTXZr51aKQ= +github.com/getsentry/sentry-go v0.30.0 h1:lWUwDnY7sKHaVIoZ9wYqRHJ5iEmoc0pqcRqFkosKzBo= +github.com/getsentry/sentry-go v0.30.0/go.mod h1:WU9B9/1/sHDqeV8T+3VwwbjeR5MSXs/6aqG3mqZrezA= github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA= github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og= +github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= +github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= +github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= @@ -33,16 +39,22 @@ github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb h1:PBC98N2aIaM3XXi github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/pprof v0.0.0-20241210010833-40e02aabc2ad h1:a6HEuzUHeKH6hwfN/ZoQgRgVIWFJljSWa/zetS2WTvg= +github.com/google/pprof v0.0.0-20241210010833-40e02aabc2ad/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/klauspost/compress v1.17.7 h1:ehO88t2UGzQK66LMdE8tibEd1ErmzZjNEqWkjLAKQQg= -github.com/klauspost/compress v1.17.7/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= +github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 h1:jWpvCLoY8Z/e3VKvlsiIGKtc+UG6U5vzxaoagmhXfyg= -github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0/go.mod h1:QUyp042oQthUoa9bqDv0ER0wrtXnBruoNd7aNjkbP+k= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/onsi/ginkgo/v2 v2.22.0 h1:Yed107/8DjTr0lKCNt7Dn8yQ6ybuDRQoMGrNFKzMfHg= +github.com/onsi/ginkgo/v2 v2.22.0/go.mod h1:7Du3c42kxCUegi0IImZ1wUQzMBVecgIHjR1C+NkhLQo= +github.com/onsi/gomega v1.34.2 h1:pNCwDkzrsv7MS9kpaQvVb1aVLahQXyJ/Tv5oAZMI3i8= +github.com/onsi/gomega v1.34.2/go.mod h1:v1xfxRgk0KIsG+QOdm7p8UosrOzPYRo60fd3B/1Dukc= github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4= github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= @@ -50,67 +62,79 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_golang v1.18.0 h1:HzFfmkOzH5Q8L8G+kSJKUx5dtG87sewO+FoDDqP5Tbk= -github.com/prometheus/client_golang v1.18.0/go.mod h1:T+GXkCk5wSJyOqMIzVgvvjFDlkOQntgjkJWKrN5txjA= -github.com/prometheus/client_model v0.6.0 h1:k1v3CzpSRUTrKMppY35TLwPvxHqBu0bYgxZzqGIgaos= -github.com/prometheus/client_model v0.6.0/go.mod h1:NTQHnmxFpouOD0DpvP4XujX3CdOAGQPoaGhyTchlyt8= -github.com/prometheus/common v0.45.0 h1:2BGz0eBc2hdMDLnO/8n0jeB3oPrt2D08CekT0lneoxM= -github.com/prometheus/common v0.45.0/go.mod h1:YJmSTw9BoKxJplESWWxlbyttQR4uaEcGyv9MZjVOJsY= -github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo= -github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo= +github.com/prometheus/client_golang v1.20.5 h1:cxppBPuYhUnsO6yo/aoRol4L7q7UFfdm+bR9r+8l63Y= +github.com/prometheus/client_golang v1.20.5/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE= +github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= +github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= +github.com/prometheus/common v0.61.0 h1:3gv/GThfX0cV2lpO7gkTUwZru38mxevy90Bj8YFSRQQ= +github.com/prometheus/common v0.61.0/go.mod h1:zr29OCN/2BsJRaFwG8QOBr41D6kkchKbpeNH7pAjb/s= +github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= +github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= +github.com/quic-go/quic-go v0.48.2 h1:wsKXZPeGWpMpCGSWqOcqpW2wZYic/8T3aqiOID0/KWE= +github.com/quic-go/quic-go v0.48.2/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= -github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= -github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU= +go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= -golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= -golang.org/x/exp v0.0.0-20240110193028-0dcbfd608b1e h1:723BNChdd0c2Wk6WOE320qGBiPtYx0F0Bbm1kriShfE= -golang.org/x/exp v0.0.0-20240110193028-0dcbfd608b1e/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI= +golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= +golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= +golang.org/x/exp v0.0.0-20241215155358-4a5509556b9e h1:4qufH0hlUYs6AO6XmZC3GqfDPGSXHVXUFR6OND+iJX4= +golang.org/x/exp v0.0.0-20241215155358-4a5509556b9e/go.mod h1:qj5a5QZpwLU2NLQudwIN5koi3beDhSAlJwa67PuM98c= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.22.0 h1:D4nJWe9zXqHOmWqj4VMOJhvzj7bEZg4wEYa759z1pH4= +golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/net v0.32.0 h1:ZqPmj8Kzc+Y6e0+skZsuACbx+wzMgo5MQsJh9Qd6aYI= +golang.org/x/net v0.32.0/go.mod h1:CwU0IoeOlnQQWJ6ioyFrfRuomB8GKF6KbYXZVyeXNfs= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= -golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= -golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= -golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= +golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= +golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/tools v0.28.0 h1:WuB6qZ4RPCQo5aP3WdKZS7i595EdWqWR8vqJTlwTVK8= +golang.org/x/tools v0.28.0/go.mod h1:dcIOrVd3mfQKTgrDVQHqCPMWy6lnhfhtX3hLXYVLfRw= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/protobuf v1.32.0 h1:pPC6BG5ex8PDFnkbrGU3EixyhKcQ2aDuBS36lqK/C7I= -google.golang.org/protobuf v1.32.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +google.golang.org/protobuf v1.36.0 h1:mjIs9gYtt56AzC4ZaffQuh88TZurBGhIJMBZGSxNerQ= +google.golang.org/protobuf v1.36.0/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/pkg/network/cert/cert.go b/pkg/network/cert/cert.go index 54bcfc7..56c6fbc 100644 --- a/pkg/network/cert/cert.go +++ b/pkg/network/cert/cert.go @@ -3,6 +3,7 @@ package cert import ( "crypto/ed25519" "crypto/rand" + "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/base32" @@ -18,26 +19,82 @@ const ( var base32Encoding = base32.NewEncoding("abcdefghijklmnopqrstuvwxyz234567").WithPadding(base32.NoPadding) +// Generator handles certificate generation +type Generator struct { + config Config +} + type Config struct { PublicKey ed25519.PublicKey PrivateKey ed25519.PrivateKey CertValidityPeriod time.Duration } -type Generator struct { - config Config -} - func NewGenerator(config Config) *Generator { return &Generator{config: config} } -func encodePubKeyToDNS(pubKey ed25519.PublicKey) string { +// Validator implements the transport.CertValidator interface +type Validator struct{} + +func NewValidator() *Validator { + return &Validator{} +} + +// ValidateCertificate implements transport.CertValidator +func (v *Validator) ValidateCertificate(cert *x509.Certificate) error { + if cert.SignatureAlgorithm != x509.PureEd25519 { + return fmt.Errorf("invalid signature algorithm: expected Ed25519") + } + + pubKey, ok := cert.PublicKey.(ed25519.PublicKey) + if !ok { + return fmt.Errorf("certificate public key is not Ed25519") + } + + if len(cert.DNSNames) != 1 { + return fmt.Errorf("certificate must have exactly one DNS name") + } + dnsName := cert.DNSNames[0] + + if len(dnsName) != 53 || !strings.HasPrefix(dnsName, DNSNamePrefix) { + return fmt.Errorf("invalid DNS name format: %s (length: %d)", dnsName, len(dnsName)) + } + + // Generate expected DNS name + expectedDNSName := EncodePubKeyToDNS(pubKey) + + if dnsName != expectedDNSName { + return fmt.Errorf("DNS name does not match public key") + } + + // Check expiration + now := time.Now() + if now.Before(cert.NotBefore) { + return fmt.Errorf("certificate is not yet valid") + } + if now.After(cert.NotAfter) { + return fmt.Errorf("certificate has expired") + } + return nil +} + +// ExtractPublicKey implements transport.CertValidator +func (v *Validator) ExtractPublicKey(cert *x509.Certificate) (ed25519.PublicKey, error) { + pubKey, ok := cert.PublicKey.(ed25519.PublicKey) + if !ok { + return nil, fmt.Errorf("certificate public key is not an Ed25519 key") + } + return pubKey, nil +} + +// Helper functions used by both Generator and Validator +func EncodePubKeyToDNS(pubKey ed25519.PublicKey) string { return DNSNamePrefix + base32Encoding.EncodeToString(pubKey) } -func (g *Generator) GenerateCertificate() (*x509.Certificate, error) { - dnsName := encodePubKeyToDNS(g.config.PublicKey) +func (g *Generator) GenerateCertificate() (*tls.Certificate, error) { + dnsName := EncodePubKeyToDNS(g.config.PublicKey) serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) if err != nil { @@ -57,6 +114,8 @@ func (g *Generator) GenerateCertificate() (*x509.Certificate, error) { x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth, }, + SignatureAlgorithm: x509.PureEd25519, + PublicKeyAlgorithm: x509.Ed25519, BasicConstraintsValid: true, } @@ -70,44 +129,9 @@ func (g *Generator) GenerateCertificate() (*x509.Certificate, error) { return nil, fmt.Errorf("failed to parse certificate: %w", err) } - return cert, nil -} - -func ValidateCertificate(cert *x509.Certificate) error { - // Check signature algorithm - if cert.SignatureAlgorithm != x509.PureEd25519 { - return fmt.Errorf("invalid signature algorithm: expected Ed25519") - } - - // Check DNS names - if len(cert.DNSNames) != 1 { - return fmt.Errorf("certificate must have exactly one DNS name") - } - dnsName := cert.DNSNames[0] - - // Verify format - if len(dnsName) != 53 || !strings.HasPrefix(dnsName, DNSNamePrefix) { - return fmt.Errorf("invalid DNS name format: %s", dnsName) - } - - // Validate public key - pubKey, ok := cert.PublicKey.(ed25519.PublicKey) - if !ok { - return fmt.Errorf("certificate public key is not an Ed25519 key") - } - expectedDNSName := encodePubKeyToDNS(pubKey) - if dnsName != expectedDNSName { - return fmt.Errorf("DNS name does not match public key: got %s, expected %s", dnsName, expectedDNSName) - } - - // Check expiration - now := time.Now() - if now.Before(cert.NotBefore) { - return fmt.Errorf("certificate is not yet valid: valid from %v", cert.NotBefore) - } - if now.After(cert.NotAfter) { - return fmt.Errorf("certificate has expired: valid until %v", cert.NotAfter) - } - - return nil + return &tls.Certificate{ + Certificate: [][]byte{certDER}, + PrivateKey: g.config.PrivateKey, + Leaf: cert, + }, nil } diff --git a/pkg/network/cert/cert_test.go b/pkg/network/cert/cert_test.go index 3af5762..4b33389 100644 --- a/pkg/network/cert/cert_test.go +++ b/pkg/network/cert/cert_test.go @@ -40,7 +40,7 @@ func TestValidateCertificateSuccess(t *testing.T) { cert, err := generator.GenerateCertificate() require.NoError(t, err, "Failed to generate certificate") - err = ValidateCertificate(cert) + err = NewValidator().ValidateCertificate(cert.Leaf) assert.NoError(t, err, "Valid certificate failed validation") } @@ -60,9 +60,9 @@ func TestValidateCertificateFailsForMismatchedPublicKey(t *testing.T) { // Tamper with the public key wrongPub, _, err := ed25519.GenerateKey(nil) require.NoError(t, err, "Failed to generate a new Ed25519 key pair") - cert.PublicKey = wrongPub + cert.Leaf.PublicKey = wrongPub - err = ValidateCertificate(cert) + err = NewValidator().ValidateCertificate(cert.Leaf) assert.Error(t, err, "Expected validation to fail for certificate with mismatched DNS name and public key") } @@ -79,8 +79,8 @@ func TestCertificateDNSNameFormat(t *testing.T) { cert, err := generator.GenerateCertificate() require.NoError(t, err, "Failed to generate certificate") - require.Len(t, cert.DNSNames, 1, "Certificate must have exactly one DNS name") - dnsName := cert.DNSNames[0] + require.Len(t, cert.Leaf.DNSNames, 1, "Certificate must have exactly one DNS name") + dnsName := cert.Leaf.DNSNames[0] assert.Equal(t, 53, len(dnsName), "DNS name should be 53 characters long") assert.True(t, dnsName[0] == 'e', "DNS name should start with 'e'") } @@ -98,7 +98,7 @@ func TestCertificateParseDER(t *testing.T) { cert, err := generator.GenerateCertificate() require.NoError(t, err, "Failed to generate certificate") - parsedCert, err := x509.ParseCertificate(cert.Raw) + parsedCert, err := x509.ParseCertificate(cert.Leaf.Raw) assert.NoError(t, err, "Failed to parse generated certificate DER") assert.NotNil(t, parsedCert, "Parsed certificate should not be nil") } @@ -118,7 +118,7 @@ func TestValidateCertificateExpired(t *testing.T) { require.NoError(t, err, "Failed to generate expired certificate") // Validate the certificate - err = ValidateCertificate(cert) + err = NewValidator().ValidateCertificate(cert.Leaf) assert.Error(t, err, "Expected validation to fail for expired certificate") assert.Contains(t, err.Error(), "certificate has expired", "Expected error message for expired certificate") } @@ -136,10 +136,10 @@ func TestValidateCertificateFutureStartDate(t *testing.T) { generator := NewGenerator(config) cert, err := generator.GenerateCertificate() require.NoError(t, err, "Failed to generate future-dated certificate") - cert.NotBefore = time.Now().Add(1 * time.Hour) // Adjust NotBefore to 1 hour from now + cert.Leaf.NotBefore = time.Now().Add(1 * time.Hour) // Adjust NotBefore to 1 hour from now // Validate the certificate - err = ValidateCertificate(cert) + err = NewValidator().ValidateCertificate(cert.Leaf) assert.Error(t, err, "Expected validation to fail for not-yet-valid certificate") assert.Contains(t, err.Error(), "certificate is not yet valid", "Expected error message for future-dated certificate") } diff --git a/pkg/network/handlers/block_request.go b/pkg/network/handlers/block_request.go new file mode 100644 index 0000000..6d8bf46 --- /dev/null +++ b/pkg/network/handlers/block_request.go @@ -0,0 +1,89 @@ +package handlers + +import ( + "context" + "encoding/binary" + "fmt" + "io" + + "github.com/eigerco/strawberry/internal/crypto" + "github.com/quic-go/quic-go" +) + +// BlockRequestHandler handles incoming block requests +type BlockRequestHandler struct{} + +func NewBlockRequestHandler() *BlockRequestHandler { + return &BlockRequestHandler{} +} + +// HandleStream processes incoming block requests (when others request blocks from us) +func (h *BlockRequestHandler) HandleStream(ctx context.Context, stream quic.Stream) error { + fmt.Println("Received block request, reading request message...") + + // Read the request message + msg, err := ReadMessageWithContext(ctx, stream) + if err != nil { + return fmt.Errorf("failed to read request message: %w", err) + } + + // Parse the message content into BlockRequestMessage + if len(msg.Content) < 37 { // 32 (hash) + 1 (direction) + 4 (maxBlocks) + return fmt.Errorf("message too short") + } + + request := BlockRequestMessage{ + Direction: msg.Content[32], // After hash + MaxBlocks: binary.LittleEndian.Uint32(msg.Content[33:37]), + } + copy(request.Hash[:], msg.Content[:32]) + + fmt.Printf("Got request for blocks: hash=%x, direction=%d, maxBlocks=%d\n", + request.Hash, request.Direction, request.MaxBlocks) + + // Test data: Pretend to send some blocks + response := []byte("test block response") + if err := WriteMessageWithContext(ctx, stream, response); err != nil { + return fmt.Errorf("failed to write response message: %w", err) + } + + fmt.Println("Sent block response") + return nil +} + +// BlockRequester makes outgoing block requests +type BlockRequester struct{} + +// TODO: Implement the RequestBlocks function +// RequestBlocks sends a request for blocks to a peer +func (r *BlockRequester) RequestBlocks(ctx context.Context, stream io.ReadWriter, headerHash [32]byte, ascending bool) ([]byte, error) { + direction := byte(0) + if !ascending { + direction = 1 + } + content := make([]byte, 37) + copy(content[:32], headerHash[:]) + content[32] = direction + binary.LittleEndian.PutUint32(content[33:], 10) + + // Write with context + if err := WriteMessageWithContext(ctx, stream, content); err != nil { + return nil, fmt.Errorf("failed to write request: %w", err) + } + + // Read with context + response, err := ReadMessageWithContext(ctx, stream) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + fmt.Printf("Received response: %s\n", response.Content) + return response.Content, nil +} + +// BlockRequestMessage represents a request for a sequence of blocks +type BlockRequestMessage struct { + Hash crypto.Hash + Direction byte // 0 for ascending, 1 for descending + MaxBlocks uint32 +} diff --git a/pkg/network/handlers/message.go b/pkg/network/handlers/message.go new file mode 100644 index 0000000..e094bf5 --- /dev/null +++ b/pkg/network/handlers/message.go @@ -0,0 +1,83 @@ +package handlers + +import ( + "context" + "encoding/binary" + "fmt" + "io" +) + +// Message represents a protocol message with its size and content +type Message struct { + Size uint32 + Content []byte +} + +// WriteMessage writes a message to a writer with context awareness. +func WriteMessageWithContext(ctx context.Context, w io.Writer, content []byte) error { + done := make(chan error, 1) + go func() { + size := uint32(len(content)) + + // Write size as little-endian uint32 + if err := binary.Write(w, binary.LittleEndian, size); err != nil { + done <- fmt.Errorf("failed to write message size: %w", err) + return + } + + // Write content + if _, err := w.Write(content); err != nil { + done <- fmt.Errorf("failed to write message content: %w", err) + return + } + + done <- nil + }() + + select { + case err := <-done: + return err + case <-ctx.Done(): + return ctx.Err() + } +} + +// ReadMessage reads a message from a reader with context awareness. +func ReadMessageWithContext(ctx context.Context, r io.Reader) (*Message, error) { + done := make(chan struct { + msg *Message + err error + }, 1) + + go func() { + var size uint32 + if err := binary.Read(r, binary.LittleEndian, &size); err != nil { + done <- struct { + msg *Message + err error + }{nil, fmt.Errorf("failed to read message size: %w", err)} + return + } + + content := make([]byte, size) + if _, err := io.ReadFull(r, content); err != nil { + done <- struct { + msg *Message + err error + }{nil, fmt.Errorf("failed to read message content: %w", err)} + return + } + + done <- struct { + msg *Message + err error + }{&Message{Size: size, Content: content}, nil} + }() + + select { + case result := <-done: + return result.msg, result.err + case <-ctx.Done(): + return nil, ctx.Err() + } +} diff --git a/pkg/network/network_test.go b/pkg/network/network_test.go new file mode 100644 index 0000000..59f9a99 --- /dev/null +++ b/pkg/network/network_test.go @@ -0,0 +1,269 @@ +package network_test + +import ( + "context" + "crypto/ed25519" + "net" + "sync" + "testing" + "time" + + "github.com/eigerco/strawberry/pkg/network/cert" + "github.com/eigerco/strawberry/pkg/network/handlers" + "github.com/eigerco/strawberry/pkg/network/peer" + "github.com/eigerco/strawberry/pkg/network/protocol" + "github.com/eigerco/strawberry/pkg/network/transport" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// testNode represents a node instance for testing +type testNode struct { + transport *transport.Transport + protoManager *protocol.Manager + addr string + pubKey ed25519.PublicKey + privKey ed25519.PrivateKey +} + +// setupTestNode creates a new test node with all necessary components +func setupTestNode(t *testing.T) *testNode { + // Find available port + listener, err := net.Listen("tcp", "localhost:0") + require.NoError(t, err) + addr := listener.Addr().String() + listener.Close() + + // Generate keys + pub, priv, err := ed25519.GenerateKey(nil) + require.NoError(t, err) + + // Create certificate + certGen := cert.NewGenerator(cert.Config{ + PublicKey: pub, + PrivateKey: priv, + CertValidityPeriod: 24 * time.Hour, + }) + tlsCert, err := certGen.GenerateCertificate() + require.NoError(t, err) + + // Create protocol manager + protoConfig := protocol.Config{ + ChainHash: "12345678", + IsBuilder: false, + } + protoManager, err := protocol.NewManager(protoConfig) + require.NoError(t, err) + + // Register handlers + blockHandler := handlers.NewBlockRequestHandler() + protoManager.Registry.RegisterHandler(protocol.StreamKindBlockRequest, blockHandler) + + // Create transport + transportConfig := transport.Config{ + PublicKey: pub, + PrivateKey: priv, + TLSCert: tlsCert, + ListenAddr: addr, + CertValidator: cert.NewValidator(), + Handler: protoManager, + } + + tr, err := transport.NewTransport(transportConfig) + require.NoError(t, err) + + return &testNode{ + transport: tr, + protoManager: protoManager, + addr: addr, + pubKey: pub, + privKey: priv, + } +} + +// setupTestPair creates and connects two test nodes +func setupTestPair(t *testing.T) (*testNode, *testNode, *peer.Peer) { + node1 := setupTestNode(t) + node2 := setupTestNode(t) + + require.NoError(t, node1.transport.Start()) + require.NoError(t, node2.transport.Start()) + + conn, err := node2.transport.Connect(node1.addr) + require.NoError(t, err) + + p := peer.NewPeer(conn, conn.PeerKey(), node2.protoManager) + return node1, node2, p +} + +// Helper function to safely stop transports +func cleanupNodes(t *testing.T, nodes ...*testNode) { + for _, node := range nodes { + if err := node.transport.Stop(); err != nil { + t.Errorf("failed to stop transport: %v", err) + } + } +} + +// TestBasicBlockRequest tests a simple block request +func TestBasicBlockRequest(t *testing.T) { + node1, node2, p := setupTestPair(t) + defer cleanupNodes(t, node1, node2) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + response, err := p.RequestBlocks(ctx, [32]byte{1, 2, 3, 4}, true) + require.NoError(t, err) + assert.Equal(t, "test block response", string(response), "unexpected response content") +} + +// TestConcurrentBlockRequests tests handling multiple concurrent requests +func TestConcurrentBlockRequests(t *testing.T) { + node1, node2, p := setupTestPair(t) + defer cleanupNodes(t, node1, node2) + + var wg sync.WaitGroup + numRequests := 5 + type result struct { + response []byte + err error + } + results := make(chan result, numRequests) + + for i := 0; i < numRequests; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + response, err := p.RequestBlocks(ctx, [32]byte{byte(i)}, true) + results <- result{response, err} + }(i) + } + + wg.Wait() + close(results) + + successCount := 0 + for res := range results { + if assert.NoError(t, res.err) { + assert.Equal(t, "test block response", string(res.response)) + successCount++ + } + } + assert.Equal(t, numRequests, successCount, "all requests should succeed") +} + +// TestRequestTimeout tests proper handling of timeouts +func TestRequestTimeout(t *testing.T) { + node1, node2, p := setupTestPair(t) + defer cleanupNodes(t, node1, node2) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) + defer cancel() + + response, err := p.RequestBlocks(ctx, [32]byte{9, 9, 9, 9}, true) + assert.Error(t, err) + assert.Contains(t, err.Error(), context.DeadlineExceeded.Error()) + assert.Nil(t, response) +} + +// TestConnectionClosure tests behavior when connection is closed +func TestConnectionClosure(t *testing.T) { + node1, node2, p := setupTestPair(t) + defer cleanupNodes(t, node1, node2) + + // Close node1's transport + require.NoError(t, node1.transport.Stop()) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + response, err := p.RequestBlocks(ctx, [32]byte{}, true) + assert.Error(t, err) + assert.Contains(t, err.Error(), context.DeadlineExceeded.Error()) + assert.Nil(t, response) +} + +// TestNetworkPartition tests behavior during network issues +func TestNetworkPartition(t *testing.T) { + node1, node2, p := setupTestPair(t) + defer cleanupNodes(t, node1, node2) + + // Simulate network partition by stopping node1 + require.NoError(t, node1.transport.Stop()) + + // Start node1 again + require.NoError(t, node1.transport.Start()) + + // Try request after reconnection + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + _, err := p.RequestBlocks(ctx, [32]byte{1, 2, 3, 4}, true) + assert.Error(t, err) // Should fail due to broken connection +} + +// TestReconnection tests reconnection behavior +func TestServerNodeRestart(t *testing.T) { + node1, node2, p := setupTestPair(t) + defer cleanupNodes(t, node1, node2) + + // Make successful request + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + _, err := p.RequestBlocks(ctx, [32]byte{1}, true) + cancel() + require.NoError(t, err) + + // Close and restart node1's transport + require.NoError(t, node1.transport.Stop()) + require.NoError(t, node1.transport.Start()) + + conn, err := node2.transport.Connect(node1.addr) + require.NoError(t, err) + + // Create new peer with new connection + p = peer.NewPeer(conn, conn.PeerKey(), node2.protoManager) + + // Try request with longer timeout + ctx, cancel = context.WithTimeout(context.Background(), time.Second) + defer cancel() + + res, err := p.RequestBlocks(ctx, [32]byte{1}, true) + require.NoError(t, err) + assert.Equal(t, "test block response", string(res)) +} + +func TestClientNodeRestart(t *testing.T) { + node1, node2, p := setupTestPair(t) + defer cleanupNodes(t, node1, node2) + + // Make initial successful request + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + response1, err := p.RequestBlocks(ctx, [32]byte{1, 2, 3, 4}, true) + cancel() + require.NoError(t, err) + assert.Equal(t, "test block response", string(response1)) + + // Close connection from node2 side + require.NoError(t, node2.transport.Stop()) + + // Restart node2 + require.NoError(t, node2.transport.Start()) + + // Create new connection + conn, err := node2.transport.Connect(node1.addr) + require.NoError(t, err) + + // Create new peer with new connection + newPeer := peer.NewPeer(conn, conn.PeerKey(), node2.protoManager) + + // Try request with new peer + ctx, cancel = context.WithTimeout(context.Background(), time.Second) + response2, err := newPeer.RequestBlocks(ctx, [32]byte{1, 2, 3, 4}, true) + cancel() + require.NoError(t, err) + assert.Equal(t, "test block response", string(response2)) +} diff --git a/pkg/network/peer/peer.go b/pkg/network/peer/peer.go new file mode 100644 index 0000000..459e3ae --- /dev/null +++ b/pkg/network/peer/peer.go @@ -0,0 +1,39 @@ +package peer + +import ( + "context" + "crypto/ed25519" + "fmt" + + "github.com/eigerco/strawberry/pkg/network/handlers" + "github.com/eigerco/strawberry/pkg/network/protocol" + "github.com/eigerco/strawberry/pkg/network/transport" +) + +type Peer struct { + // Core connection + conn *transport.Conn + protoConn *protocol.ProtocolConn + // Peer identity + pubKey ed25519.PublicKey +} + +func NewPeer(conn *transport.Conn, pubKey ed25519.PublicKey, protoManager *protocol.Manager) *Peer { + return &Peer{ + conn: conn, + protoConn: protoManager.WrapConnection(conn), + pubKey: pubKey, + } +} + +// RequestBlocks requests blocks from the peer +func (p *Peer) RequestBlocks(ctx context.Context, headerHash [32]byte, ascending bool) ([]byte, error) { + stream, err := p.protoConn.OpenStream(ctx, protocol.StreamKindBlockRequest) + if err != nil { + return nil, fmt.Errorf("failed to open stream: %w", err) + } + defer stream.Close() + + requester := &handlers.BlockRequester{} + return requester.RequestBlocks(ctx, stream, headerHash, ascending) +} diff --git a/pkg/network/transport/protocol/alpn.go b/pkg/network/protocol/alpn.go similarity index 99% rename from pkg/network/transport/protocol/alpn.go rename to pkg/network/protocol/alpn.go index d55c897..8df6bb3 100644 --- a/pkg/network/transport/protocol/alpn.go +++ b/pkg/network/protocol/alpn.go @@ -1,4 +1,4 @@ -package transport +package protocol import ( "fmt" diff --git a/pkg/network/transport/protocol/alpn_test.go b/pkg/network/protocol/alpn_test.go similarity index 99% rename from pkg/network/transport/protocol/alpn_test.go rename to pkg/network/protocol/alpn_test.go index aa19144..4a44a06 100644 --- a/pkg/network/transport/protocol/alpn_test.go +++ b/pkg/network/protocol/alpn_test.go @@ -1,4 +1,4 @@ -package transport +package protocol import ( "github.com/stretchr/testify/assert" diff --git a/pkg/network/protocol/conn.go b/pkg/network/protocol/conn.go new file mode 100644 index 0000000..0604d45 --- /dev/null +++ b/pkg/network/protocol/conn.go @@ -0,0 +1,128 @@ +package protocol + +import ( + "context" + "fmt" + "sync" + + "github.com/eigerco/strawberry/pkg/network/transport" + "github.com/quic-go/quic-go" +) + +// ProtocolConn wraps a transport connection with protocol-specific functionality +type ProtocolConn struct { + tConn *transport.Conn + mu sync.RWMutex + upStreams map[StreamKind]quic.Stream + registry *JAMNPRegistry +} + +// NewProtocolConn creates a new protocol-level connection +func NewProtocolConn(tConn *transport.Conn, registry *JAMNPRegistry) *ProtocolConn { + return &ProtocolConn{ + tConn: tConn, + upStreams: make(map[StreamKind]quic.Stream), + registry: registry, + } +} + +// OpenStream opens a new stream of the given kind +func (pc *ProtocolConn) OpenStream(ctx context.Context, kind StreamKind) (quic.Stream, error) { + // Use the passed context for opening the stream + stream, err := pc.tConn.OpenStream(ctx) + if err != nil { + return nil, err + } + + // Write stream kind + if err := writeWithContext(ctx, stream, []byte{byte(kind)}); err != nil { + stream.Close() + return nil, fmt.Errorf("failed to write stream kind: %w", err) + } + + return stream, nil +} + +// TODO: to be used in the future +// handleUPStream manages unique persistent streams +// func (pc *ProtocolConn) handleUPStream(kind StreamKind, stream quic.Stream) (quic.Stream, error) { +// pc.mu.Lock() +// defer pc.mu.Unlock() + +// if existing, exists := pc.upStreams[kind]; exists { +// // Keep stream with higher ID +// if existing.StreamID() > stream.StreamID() { +// stream.Close() +// return existing, nil +// } else { +// existing.Close() +// pc.upStreams[kind] = stream +// } +// } else { +// pc.upStreams[kind] = stream +// } +// return stream, nil +// } + +// AcceptStream accepts and handles an incoming stream +func (pc *ProtocolConn) AcceptStream() error { + stream, err := pc.tConn.AcceptStream() + if err != nil { + return err + } + + // Read stream kind + kind := make([]byte, 1) + if _, err := stream.Read(kind); err != nil { + stream.Close() + return fmt.Errorf("failed to read stream kind: %w", err) + } + + // Get handler for this stream kind + handler, err := pc.registry.GetHandler(kind[0]) + if err != nil { + stream.Close() + return err + } + + // Handle the stream + go func() { + if err := handler.HandleStream(pc.tConn.Context(), stream); err != nil { + fmt.Printf("stream handler error: %v\n", err) + } + }() + + return nil +} + +func writeWithContext(ctx context.Context, stream quic.Stream, p []byte) error { + done := make(chan error, 1) + + go func() { + _, err := stream.Write(p) + done <- err + }() + + select { + case err := <-done: + return err + case <-ctx.Done(): + return ctx.Err() + } +} + +// Close closes the protocol connection and all UP streams +func (pc *ProtocolConn) Close() error { + pc.mu.Lock() + defer pc.mu.Unlock() + + // Close all UP streams + for _, stream := range pc.upStreams { + if err := stream.Close(); err != nil { + fmt.Printf("Error closing stream: %v\n", err) + } + } + pc.upStreams = make(map[StreamKind]quic.Stream) + + return pc.tConn.Close() +} diff --git a/pkg/network/protocol/manager.go b/pkg/network/protocol/manager.go new file mode 100644 index 0000000..b5a0707 --- /dev/null +++ b/pkg/network/protocol/manager.go @@ -0,0 +1,126 @@ +// protocol/manager.go +package protocol + +import ( + "crypto/tls" + "fmt" + "strings" + + "github.com/eigerco/strawberry/pkg/network/transport" +) + +type Config struct { + ChainHash string + IsBuilder bool + MaxBuilderSlots int +} + +// Manager handles protocol-level connection management and implements transport.ConnectionHandler +type Manager struct { + Registry *JAMNPRegistry + config Config +} + +func NewManager(config Config) (*Manager, error) { + if config.ChainHash == "" { + return nil, fmt.Errorf("chain hash required") + } + + // Validate chain hash format upfront using the ALPN utilities + if err := ValidateALPNProtocol(NewProtocolID(config.ChainHash, false).String()); err != nil { + return nil, fmt.Errorf("invalid chain hash format: %w", err) + } + + return &Manager{ + Registry: NewJAMNPRegistry(), + config: config, + }, nil +} + +// OnConnection implements transport.ConnectionHandler +func (m *Manager) OnConnection(conn *transport.Conn) error { + // Protocol connection creation could fail due to invalid parameters + protoConn, err := m.setupProtocolConn(conn) + if err != nil { + return fmt.Errorf("protocol connection setup failed: %w", err) + } + go m.handleStreams(protoConn) + + return nil +} +func (m *Manager) handleStreams(protoConn *ProtocolConn) { + defer protoConn.Close() // Ensure proper cleanup of the connection + + for { + // Attempt to accept an incoming stream + streamErr := protoConn.AcceptStream() + if streamErr != nil { + // Check if the connection's context has been canceled + if protoConn.tConn.Context().Err() != nil { + fmt.Println("Connection closed: context done") + return + } + + // Explicitly handle QUIC timeout errors + if isTimeoutError(streamErr) { + fmt.Println("Connection timed out due to inactivity") + protoConn.Close() // Close the connection explicitly + return + } + + // Log other errors and continue listening + fmt.Printf("Stream accept error: %v\n", streamErr) + continue + } + } +} + +// isTimeoutError checks if the error is a timeout error +func isTimeoutError(err error) bool { + return err != nil && strings.Contains(err.Error(), "timeout: no recent network activity") +} + +func (m *Manager) setupProtocolConn(conn *transport.Conn) (*ProtocolConn, error) { + if conn == nil { + return nil, fmt.Errorf("invalid connection") + } + + protoConn := NewProtocolConn(conn, m.Registry) + + return protoConn, nil +} + +// GetProtocols implements transport.ConnectionHandler +func (m *Manager) GetProtocols() []string { + return AcceptableProtocols(m.config.ChainHash) +} + +// ValidateConnection implements transport.ConnectionHandler +func (m *Manager) ValidateConnection(tlsState tls.ConnectionState) error { + if tlsState.NegotiatedProtocol == "" { + return fmt.Errorf("no protocol negotiated") + } + + // Parse and validate the protocol format + protocolID, err := ParseProtocolID(tlsState.NegotiatedProtocol) + if err != nil { + return fmt.Errorf("invalid protocol: %w", err) + } + + // Verify chain hash matches our configuration + if protocolID.ChainHash != m.config.ChainHash { + return fmt.Errorf("chain hash mismatch: got %s, want %s", + protocolID.ChainHash, m.config.ChainHash) + } + + // Verify builder status matches our configuration + if protocolID.IsBuilder && !m.config.IsBuilder { + return fmt.Errorf("builder connections not accepted") + } + + return nil +} + +func (m *Manager) WrapConnection(conn *transport.Conn) *ProtocolConn { + return NewProtocolConn(conn, m.Registry) +} diff --git a/pkg/network/protocol/streams.go b/pkg/network/protocol/streams.go new file mode 100644 index 0000000..0adf29a --- /dev/null +++ b/pkg/network/protocol/streams.go @@ -0,0 +1,71 @@ +package protocol + +import ( + "fmt" + "github.com/eigerco/strawberry/pkg/network/transport" +) + +const ( + // UP (Unique Persistent) stream is 0 + StreamKindBlockAnnouncement StreamKind = 0 + + // CE (Common Ephemeral) streams start from 128 + StreamKindBlockRequest StreamKind = 128 + StreamKindStateRequest StreamKind = 129 + StreamKindTicketDistP2P StreamKind = 131 + StreamKindTicketDistBroadcast StreamKind = 132 + StreamKindWorkPackageSubmit StreamKind = 133 + StreamKindWorkPackageShare StreamKind = 134 + StreamKindWorkReportDist StreamKind = 135 + StreamKindWorkReportRequest StreamKind = 136 + StreamKindShardDist StreamKind = 137 + StreamKindAuditShardRequest StreamKind = 138 + StreamKindSegmentRequest StreamKind = 139 + StreamKindSegmentRequestJust StreamKind = 140 + StreamKindAssuranceDist StreamKind = 141 + StreamKindPreimageAnnounce StreamKind = 142 + StreamKindPreimageRequest StreamKind = 143 + StreamKindAuditAnnouncement StreamKind = 144 + StreamKindJudgmentPublish StreamKind = 145 +) + +type StreamKind byte + +type JAMNPRegistry struct { + handlers map[StreamKind]transport.StreamHandler +} + +func NewJAMNPRegistry() *JAMNPRegistry { + return &JAMNPRegistry{ + handlers: make(map[StreamKind]transport.StreamHandler), + } +} + +func (r *JAMNPRegistry) ValidateKind(kindByte byte) error { + kind := StreamKind(kindByte) + if kind < StreamKindBlockAnnouncement || kind > StreamKindJudgmentPublish { + return fmt.Errorf("invalid stream kind: %d", kind) + } + return nil +} + +// Add a method to register handlers +func (r *JAMNPRegistry) RegisterHandler(kind StreamKind, handler transport.StreamHandler) { + r.handlers[kind] = handler +} + +func (r *JAMNPRegistry) GetHandler(kindByte byte) (transport.StreamHandler, error) { + // Convert raw byte to protocol's StreamKind here + kind := StreamKind(kindByte) + + handler, ok := r.handlers[kind] + if !ok { + return nil, fmt.Errorf("no handler for kind %d", kind) + } + return handler, nil +} + +// IsUniquePersistent returns true if this is a UP (Unique Persistent) stream kind +func (k StreamKind) IsUniquePersistent() bool { + return k < 128 +} diff --git a/pkg/network/transport/conn.go b/pkg/network/transport/conn.go new file mode 100644 index 0000000..b739357 --- /dev/null +++ b/pkg/network/transport/conn.go @@ -0,0 +1,81 @@ +package transport + +import ( + "context" + "crypto/ed25519" + "fmt" + "github.com/quic-go/quic-go" + "time" +) + +const StreamTimeout = 5 * time.Second + +// Conn represents a basic QUIC connection +type Conn struct { + qConn quic.Connection + transport *Transport + peerKey ed25519.PublicKey + ctx context.Context + cancel context.CancelFunc +} + +// newConn creates a new connection +func newConn(qConn quic.Connection, transport *Transport) *Conn { + ctx, cancel := context.WithCancel(transport.ctx) + + conn := &Conn{ + qConn: qConn, + transport: transport, + ctx: ctx, + cancel: cancel, + } + + // Ensure cleanup when connection ends + go func() { + <-ctx.Done() + conn.cleanup() + }() + + return conn +} + +func (c *Conn) cleanup() { + if c.peerKey != nil { + c.transport.cleanup(c.peerKey) + } +} + +// OpenStream opens a raw QUIC stream +func (c *Conn) OpenStream(ctx context.Context) (quic.Stream, error) { + stream, err := c.qConn.OpenStreamSync(ctx) + if err != nil { + return nil, fmt.Errorf("failed to open QUIC stream: %w", err) + } + + return stream, nil +} + +// AcceptStream accepts a raw incoming stream +func (c *Conn) AcceptStream() (quic.Stream, error) { + stream, err := c.qConn.AcceptStream(c.ctx) + if err != nil { + return nil, fmt.Errorf("failed to accept QUIC stream: %w", err) + } + return stream, nil +} + +// PeerKey returns the peer's public key +func (c *Conn) PeerKey() ed25519.PublicKey { + return c.peerKey +} + +// Close closes the connection +func (c *Conn) Close() error { + c.cancel() + return c.qConn.CloseWithError(0, "") +} + +// Context returns the connection's context +func (c *Conn) Context() context.Context { + return c.ctx +} diff --git a/pkg/network/transport/errors.go b/pkg/network/transport/errors.go new file mode 100644 index 0000000..529b342 --- /dev/null +++ b/pkg/network/transport/errors.go @@ -0,0 +1,12 @@ +package transport + +import "errors" + +var ( + ErrInvalidCertificate = errors.New("invalid certificate") + ErrConnectionExists = errors.New("connection already exists") + ErrStreamClosed = errors.New("stream closed") + ErrListenerFailed = errors.New("failed to create QUIC listener") + ErrDialFailed = errors.New("failed to dial peer") + ErrConnFailed = errors.New("failed to establish connection") +) diff --git a/pkg/network/transport/transport.go b/pkg/network/transport/transport.go new file mode 100644 index 0000000..d5fc0ee --- /dev/null +++ b/pkg/network/transport/transport.go @@ -0,0 +1,288 @@ +package transport + +import ( + "context" + "crypto/ed25519" + "crypto/tls" + "crypto/x509" + "fmt" + "sync" + "time" + + "github.com/quic-go/quic-go" +) + +const MaxIdleTimeout = 30 * time.Minute + +type StreamHandler interface { + HandleStream(ctx context.Context, stream quic.Stream) error +} + +type StreamRegistry interface { + // Takes a raw byte instead of a StreamKind + GetHandler(kindByte byte) (StreamHandler, error) + ValidateKind(kindByte byte) error +} + +// CertValidator validates certificates and extracts public keys +type CertValidator interface { + ValidateCertificate(cert *x509.Certificate) error + ExtractPublicKey(cert *x509.Certificate) (ed25519.PublicKey, error) +} + +// ProtocolManager handles ALPN protocol identification +type ProtocolManager interface { + // AcceptableProtocols returns the set of acceptable protocol strings for a chain + AcceptableProtocols(chainHash string) []string + + // NewProtocolID creates a protocol identifier string + NewProtocolID(chainHash string, isBuilder bool) string + + // ValidateProtocol validates a protocol string + ValidateProtocol(protocol string) error +} + +type ConnectionHandler interface { + OnConnection(conn *Conn) error + GetProtocols() []string + ValidateConnection(tlsState tls.ConnectionState) error +} + +type Config struct { + PublicKey ed25519.PublicKey + PrivateKey ed25519.PrivateKey + TLSCert *tls.Certificate + ListenAddr string + CertValidator CertValidator + Handler ConnectionHandler +} + +// Transport manages QUIC connections and streams +type Transport struct { + config Config + + listener *quic.Listener + + mu sync.RWMutex + conns map[string]*Conn + + ctx context.Context + cancel context.CancelFunc +} + +// NewTransport creates a new QUIC transport +func NewTransport(config Config) (*Transport, error) { + if config.TLSCert == nil { + return nil, fmt.Errorf("TLS certificate required") + } + if config.CertValidator == nil { + return nil, fmt.Errorf("certificate validator required") + } + if config.Handler == nil { + return nil, fmt.Errorf("connection handler required") + } + + // Verify the certificate + if err := config.CertValidator.ValidateCertificate(config.TLSCert.Leaf); err != nil { + return nil, ErrInvalidCertificate + } + + return &Transport{ + config: config, + conns: make(map[string]*Conn), + }, nil +} + +// Start starts the transport listener +func (t *Transport) Start() error { + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{*t.config.TLSCert}, + NextProtos: t.config.Handler.GetProtocols(), + ClientAuth: tls.RequireAnyClientCert, + MinVersion: tls.VersionTLS13, + InsecureSkipVerify: true, + VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + c, err := x509.ParseCertificate(rawCerts[0]) + if err != nil { + return fmt.Errorf("%w: %v", ErrInvalidCertificate, err) + } + return t.config.CertValidator.ValidateCertificate(c) + }, + } + + listener, err := quic.ListenAddr(t.config.ListenAddr, tlsConfig, &quic.Config{ + EnableDatagrams: true, + MaxIdleTimeout: MaxIdleTimeout, + }) + if err != nil { + return fmt.Errorf("%w: %v", ErrListenerFailed, err) + } + + t.ctx, t.cancel = context.WithCancel(context.Background()) + t.listener = listener + go t.acceptLoop() + return nil +} + +// Stop stops the transport +func (t *Transport) Stop() error { + t.cancel() + if t.listener != nil { + return t.listener.Close() + } + return nil +} + +// acceptLoop accepts incoming QUIC connections +func (t *Transport) acceptLoop() { + defer t.cancel() + for { + select { + case <-t.ctx.Done(): + return + default: + conn, err := t.listener.Accept(t.ctx) + if err != nil { + // Only log if it's not due to context cancellation/listener closing + if t.ctx.Err() == nil { + fmt.Printf("Failed to accept connection: %v\n", err) + } + if t.ctx.Err() != nil { + return + } + continue + } + + go t.handleConnection(conn) + } + } +} + +// Connect initiates a connection to a peer +func (t *Transport) Connect(addr string) (*Conn, error) { + tlsConf := &tls.Config{ + Certificates: []tls.Certificate{*t.config.TLSCert}, + NextProtos: t.config.Handler.GetProtocols(), + ClientAuth: tls.RequireAnyClientCert, + MinVersion: tls.VersionTLS13, + InsecureSkipVerify: true, + VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + c, err := x509.ParseCertificate(rawCerts[0]) + if err != nil { + return fmt.Errorf("%w: %v", ErrInvalidCertificate, err) + } + if err := t.config.CertValidator.ValidateCertificate(c); err != nil { + return fmt.Errorf("%w: %v", ErrInvalidCertificate, err) + } + return nil + }, + } + + quicConn, err := quic.DialAddr(t.ctx, addr, tlsConf, &quic.Config{ + EnableDatagrams: true, + MaxIdleTimeout: MaxIdleTimeout, + }) + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrDialFailed, err) + } + + conn := t.handleConnection(quicConn) + if conn == nil { + return nil, ErrConnFailed + } + return conn, nil +} + +func (t *Transport) handleConnection(qConn quic.Connection) *Conn { + tlsState := qConn.ConnectionState().TLS + + if err := t.verifyPeerCert(tlsState.PeerCertificates); err != nil { + if cerr := qConn.CloseWithError(0, ErrInvalidCertificate.Error()); cerr != nil { + fmt.Printf("Failed to close connection: %v\n", cerr) + } + return nil + } + + if err := t.config.Handler.ValidateConnection(tlsState); err != nil { + fmt.Printf("Failed to validate connection: %v\n", err) + if cerr := qConn.CloseWithError(0, err.Error()); cerr != nil { + fmt.Printf("Failed to close connection: %v\n", cerr) + } + return nil + } + + peerKey, err := t.config.CertValidator.ExtractPublicKey(tlsState.PeerCertificates[0]) + if err != nil { + fmt.Printf("Failed to extract peer key: %v\n", err) + if cerr := qConn.CloseWithError(0, fmt.Sprintf("%s: %v", ErrInvalidCertificate.Error(), err)); cerr != nil { + fmt.Printf("Failed to close connection: %v\n", cerr) + } + return nil + } + + t.mu.Lock() + if existingConn, exists := t.conns[string(peerKey)]; exists { + fmt.Println("Found existing connection, closing it") + // Close existing connection before replacing it + if err := existingConn.Close(); err != nil { + fmt.Printf("Failed to close existing connection: %v\n", err) + } + delete(t.conns, string(peerKey)) + } + + conn := newConn(qConn, t) + conn.peerKey = peerKey + + // Store connection + t.conns[string(peerKey)] = conn + t.mu.Unlock() + + if err := t.config.Handler.OnConnection(conn); err != nil { + t.cleanup(peerKey) + if cerr := qConn.CloseWithError(0, err.Error()); cerr != nil { + fmt.Printf("Failed to close connection: %v\n", cerr) + } + return nil + } + fmt.Printf("t.ListConnections(): %v\n", t.ListConnections()) + return conn +} + +// GetConnection returns a connection by peer key if it exists +func (t *Transport) GetConnection(peerKey string) (*Conn, bool) { + t.mu.RLock() + conn, ok := t.conns[peerKey] + t.mu.RUnlock() + return conn, ok +} + +// ListConnections returns all active connections +func (t *Transport) ListConnections() []*Conn { + t.mu.RLock() + defer t.mu.RUnlock() + + conns := make([]*Conn, 0, len(t.conns)) + for _, conn := range t.conns { + conns = append(conns, conn) + } + return conns +} + +// Cleanup removes a connection from the map +func (t *Transport) cleanup(peerKey ed25519.PublicKey) { + t.mu.Lock() + delete(t.conns, string(peerKey)) + t.mu.Unlock() +} + +// verifyPeerCert verifies the peer's certificate chain +func (t *Transport) verifyPeerCert(certs []*x509.Certificate) error { + if len(certs) == 0 { + return fmt.Errorf("%w: no certificates provided", ErrInvalidCertificate) + } + if err := t.config.CertValidator.ValidateCertificate(certs[0]); err != nil { + return fmt.Errorf("%w: %v", ErrInvalidCertificate, err) + } + + return nil +}