Skip to content

Commit

Permalink
Update parameters and handle json data
Browse files Browse the repository at this point in the history
This updates the methods to be lowercase (get & post) and also changes
the posted body from form data to json.
  • Loading branch information
danrjohnson committed Apr 29, 2024
1 parent 27f7fb4 commit c3556fa
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 15 deletions.
8 changes: 4 additions & 4 deletions internal/auth/authorizations.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,13 @@ func (a *State) IsExpired() bool {
}

func QueryAnyAuthd(authd []string, remoteIP string, tlsEnabled bool, commonName string, authSecret string,
clientTLSConfig *tls.Config, connectTimeout time.Duration, requestTimeout time.Duration, httpRequestType string) (*State, error) {
clientTLSConfig *tls.Config, connectTimeout time.Duration, requestTimeout time.Duration, httpRequestMethod string) (*State, error) {
var retErr error
start := rand.Int()
n := len(authd)
for i := 0; i < n; i++ {
a := authd[(i+start)%n]
authState, err := QueryAuthd(a, remoteIP, tlsEnabled, commonName, authSecret, clientTLSConfig, connectTimeout, requestTimeout, httpRequestType)
authState, err := QueryAuthd(a, remoteIP, tlsEnabled, commonName, authSecret, clientTLSConfig, connectTimeout, requestTimeout, httpRequestMethod)
if err != nil {
es := fmt.Sprintf("failed to auth against %s - %s", a, err)
if retErr != nil {
Expand All @@ -97,7 +97,7 @@ func QueryAnyAuthd(authd []string, remoteIP string, tlsEnabled bool, commonName
}

func QueryAuthd(authd string, remoteIP string, tlsEnabled bool, commonName string, authSecret string,
clientTLSConfig *tls.Config, connectTimeout time.Duration, requestTimeout time.Duration, httpRequestType string) (*State, error) {
clientTLSConfig *tls.Config, connectTimeout time.Duration, requestTimeout time.Duration, httpRequestMethod string) (*State, error) {
var authState State
v := url.Values{}
v.Set("remote_ip", remoteIP)
Expand All @@ -117,7 +117,7 @@ func QueryAuthd(authd string, remoteIP string, tlsEnabled bool, commonName strin
}

client := http_api.NewClient(clientTLSConfig, connectTimeout, requestTimeout)
if httpRequestType == "POST" {
if httpRequestMethod == "post" {
if err := client.POSTFormV1(endpoint, v, &authState); err != nil {
return nil, err
}
Expand Down
9 changes: 7 additions & 2 deletions internal/http_api/api_request.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package http_api

import (
"bytes"
"crypto/tls"
"encoding/json"
"fmt"
Expand Down Expand Up @@ -121,13 +122,17 @@ retry:

func (c *Client) POSTFormV1(endpoint string, data url.Values, v interface{}) error {
retry:
req, err := http.NewRequest("POST", endpoint, strings.NewReader(data.Encode()))
reqBody, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("failed to marshal POST data to endpoint: %v", endpoint)
}
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(reqBody))
if err != nil {
return err
}

req.Header.Add("Accept", "application/vnd.nsq; version=1.0")
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.Header.Add("Content-Type", "application/json")

resp, err := c.c.Do(req)
if err != nil {
Expand Down
4 changes: 4 additions & 0 deletions nsqd/nsqd.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ func New(opts *Options) (*NSQD, error) {
}
n.clientTLSConfig = clientTLSConfig

if opts.AuthHTTPRequestMethod != "post" && opts.AuthHTTPRequestMethod != "get" {
return nil, errors.New("--auth-http-request-method must be post or get")
}

for _, v := range opts.E2EProcessingLatencyPercentiles {
if v <= 0 || v > 1 {
return nil, fmt.Errorf("invalid E2E processing latency percentile: %v", v)
Expand Down
2 changes: 1 addition & 1 deletion nsqd/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func NewOptions() *Options {

NSQLookupdTCPAddresses: make([]string, 0),
AuthHTTPAddresses: make([]string, 0),
AuthHTTPRequestMethod: "GET",
AuthHTTPRequestMethod: "get",

HTTPClientConnectTimeout: 2 * time.Second,
HTTPClientRequestTimeout: 5 * time.Second,
Expand Down
28 changes: 20 additions & 8 deletions nsqd/protocol_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"os"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -1476,7 +1477,7 @@ func TestClientAuth(t *testing.T) {
authSuccess := ""
tlsEnabled := false
commonName := ""
httpAuthRequestMethod := "GET"
httpAuthRequestMethod := "get"
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName, httpAuthRequestMethod)

// now one that will succeed
Expand All @@ -1493,7 +1494,7 @@ func TestClientAuth(t *testing.T) {
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName, httpAuthRequestMethod)

// test POST based authentication
httpAuthRequestMethod = "POST"
httpAuthRequestMethod = "post"
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName, httpAuthRequestMethod)

}
Expand All @@ -1509,12 +1510,23 @@ func runAuthTest(t *testing.T, authResponse string, authSecret string, authError

authd := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("in test auth handler %s", r.RequestURI)
test.Equal(t, httpAuthRequestMethod, r.Method)
r.ParseForm()
test.Equal(t, expectedRemoteIP, r.Form.Get("remote_ip"))
test.Equal(t, expectedTLS, r.Form.Get("tls"))
test.Equal(t, commonName, r.Form.Get("common_name"))
test.Equal(t, authSecret, r.Form.Get("secret"))
test.Equal(t, httpAuthRequestMethod, strings.ToLower(r.Method))

var values url.Values

if r.Method == "POST" {
err = json.NewDecoder(r.Body).Decode(&values)
if err != nil {
t.Error(err)
}
} else {
r.ParseForm()
values = r.Form
}
test.Equal(t, expectedRemoteIP, values.Get("remote_ip"))
test.Equal(t, expectedTLS, values.Get("tls"))
test.Equal(t, commonName, values.Get("common_name"))
test.Equal(t, authSecret, values.Get("secret"))
fmt.Fprint(w, authResponse)
}))
defer authd.Close()
Expand Down

0 comments on commit c3556fa

Please sign in to comment.