Skip to content

Commit

Permalink
More complete validation of endpoint for gobgp grpc
Browse files Browse the repository at this point in the history
  • Loading branch information
kurojishi committed Jan 12, 2024
1 parent c7f4d4f commit af88179
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 13 deletions.
31 changes: 18 additions & 13 deletions pkg/gobgp_exporter/router_node.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (
"crypto/tls"
"fmt"
"net"
"net/url"
"regexp"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -98,21 +100,24 @@ func validAddress(s string, logger log.Logger) error {
return fmt.Errorf("empty address")
}

host, strport, err := net.SplitHostPort(s)
uri_schema_check := regexp.MustCompile(`(.+://).*`)
endpoint := s

if uri_schema_check.MatchString(s) {
uri, err := url.Parse(s)
if err != nil {
return err
} else if !(uri.Scheme == "http" || uri.Scheme == "https" || uri.Scheme == "dns") {
// those are all valid addresses, as grpc works on top of http2
return fmt.Errorf("invalid scheme for grpc in %s", s)
}
endpoint = uri.Host
}
host, strport, err := net.SplitHostPort(endpoint)
if err != nil {
return err
} else if host != "" {
if addr := net.ParseIP(host); addr == nil {
return fmt.Errorf("invalid IP address in %s", s)
}
} else if !strings.HasPrefix(s, "dns://") {
return fmt.Errorf("invalid address format in %s", s)
} else {
// "dns://" prefix for hostname is allowed per go grpc documentation
// see https://pkg.go.dev/google.golang.org/grpc#DialContext
idx := strings.LastIndex(s, ":")
host = s[0:idx]
strport = s[idx+1:]
} else if host == "" {
return fmt.Errorf("invalid endpoint format, hostname endpoint can't be empty %s", s)
}

level.Debug(logger).Log(
Expand Down
59 changes: 59 additions & 0 deletions pkg/gobgp_exporter/router_node_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright 2018 Paul Greenberg (greenpau@outlook.com)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package exporter

import (
"testing"

"github.com/prometheus/common/promlog"
)

func TestValidAddress(t *testing.T) {
allowedLogLevel := &promlog.AllowedLevel{}
if err := allowedLogLevel.Set("debug"); err != nil {
t.Fatalf("%s", err)
}

promlogConfig := &promlog.Config{
Level: allowedLogLevel,
}

logger := promlog.New(promlogConfig)

cases := []struct {
address string
ok bool
}{
{address: "127.0.0.1:50051", ok: true},
{address: "", ok: false},
{address: "127.0.0.1:500511", ok: false},
{address: "localaddress:50051", ok: true},
{address: "https://localaddress:50051", ok: true},
{address: "http://localaddress:50051", ok: true},
{address: "fuuuu://localaddress:50051", ok: false},
{address: "dns:///localhost:50051", ok: false},
{address: "[::1]:50051", ok: true},
{address: "::1:50051", ok: false},
}
for _, test := range cases {
err := validAddress(test.address, logger)
if test.ok && err != nil {
t.Errorf("expected no error w/ %q, but got %q", test.address, err)
}
if !test.ok && err == nil {
t.Errorf("expected error w/ %q, but got %q", test.address, err)
}
}
}

0 comments on commit af88179

Please sign in to comment.