Skip to content

Commit

Permalink
Fix NC connection cancellation tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bartoszWojciechO committed Dec 30, 2024
1 parent 7c2dc2b commit 96c2502
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 64 deletions.
36 changes: 22 additions & 14 deletions nc/nc.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ func (MqttClientBuilder) Build(opts *mqtt.ClientOptions) mqtt.Client {
return mqtt.NewClient(opts)
}

type TimeFunc func(int) time.Duration

// Client is a client for Notification center
type Client struct {
clientBuilder ClientBuilder
Expand All @@ -126,6 +128,7 @@ type Client struct {
subjectErr events.Publisher[error]
subjectPeerUpdate events.Publisher[[]string]
credsFetcher CredentialsGetter
timeFunc TimeFunc

startMu sync.Mutex
started bool
Expand All @@ -147,6 +150,7 @@ func NewClient(
subjectErr: subjectErr,
subjectPeerUpdate: subjectPeerUpdate,
credsFetcher: credsFetcher,
timeFunc: network.ExponentialBackoff,
}
}

Expand Down Expand Up @@ -305,7 +309,7 @@ func (c *Client) connectWithBackoff(client mqtt.Client,
client.Disconnect(0)
}
return client
case <-time.After(network.ExponentialBackoff(tries)):
case <-time.After(c.timeFunc(tries)):
}
}

Expand All @@ -321,6 +325,18 @@ func (c *Client) connectWithBackoff(client mqtt.Client,
return client
}

func (c *Client) connect(client mqtt.Client,
credentialsInvalidated bool,
connectionContext context.Context,
managementChan chan<- interface{},
connectedChan chan<- mqtt.Client) {
client = c.connectWithBackoff(client, credentialsInvalidated, managementChan, connectionContext)
select {
case connectedChan <- client:
case <-connectionContext.Done():
}
}

func (c *Client) sendDeliveryConfirmation(client mqtt.Client, messageID string) error {
payload, err := json.Marshal(ConfirmationPayload{
MessageID: messageID,
Expand Down Expand Up @@ -430,17 +446,9 @@ func (c *Client) ncClientManagementLoop(ctx context.Context) (<-chan any, error)
}()

connectedChan := make(chan mqtt.Client)
connect := func(client mqtt.Client, credentialsInvalidated bool, connectionContext context.Context) {
client = c.connectWithBackoff(client, credentialsInvalidated, managementChan, connectionContext)
select {
case connectedChan <- client:
case <-connectionContext.Done():
}
}

opts := c.createClientOptions(credentials, managementChan, connectionContext)
client = c.clientBuilder.Build(opts)
go connect(client, credentialsInvalidated, connectionContext)
go c.connect(client, credentialsInvalidated, connectionContext, managementChan, connectedChan)

log.Println(logPrefix, "starting initial connection loop")
CONNECTION_LOOP:
Expand All @@ -462,7 +470,7 @@ func (c *Client) ncClientManagementLoop(ctx context.Context) (<-chan any, error)
client.Disconnect(0)
c.credsFetcher.RevokeCredentials(false)
connectionContext, cancelConnectionFunc = context.WithCancel(ctx)
go connect(client, true, connectionContext)
go c.connect(client, true, connectionContext, managementChan, connectedChan)
}
}
log.Println(logPrefix, "initial connection established")
Expand All @@ -475,9 +483,9 @@ func (c *Client) ncClientManagementLoop(ctx context.Context) (<-chan any, error)
case event := <-managementChan:
switch ev := event.(type) {
case authLost:
go connect(client, true, connectionContext)
go c.connect(client, true, connectionContext, managementChan, connectedChan)
case connectionLost:
go connect(client, false, connectionContext)
go c.connect(client, false, connectionContext, managementChan, connectedChan)
case mqttMessage:
c.handleMessage(client, ev.message)
case time.Time:
Expand All @@ -491,7 +499,7 @@ func (c *Client) ncClientManagementLoop(ctx context.Context) (<-chan any, error)
client.Disconnect(0)
c.credsFetcher.RevokeCredentials(false)
connectionContext, cancelConnectionFunc = context.WithCancel(ctx)
go connect(client, true, connectionContext)
go c.connect(client, true, connectionContext, managementChan, connectedChan)
}
}
}()
Expand Down
118 changes: 68 additions & 50 deletions nc/nc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package nc
import (
"context"
"fmt"
"runtime"
"testing"
"time"

Expand Down Expand Up @@ -192,62 +191,81 @@ func TestStartStopNotificationClient(t *testing.T) {
}
}

func TestNCManagementLoop(t *testing.T) {
func TestConnectionCancellation(t *testing.T) {
category.Set(t, category.Unit)

goroutinesInitial := runtime.NumGoroutine()

cfg := config.Config{}
cfg.TokensData = make(map[int64]config.TokenData)
cfgManager := cfgmock.NewMockConfigManager()
cfgManager.Cfg = &cfg

connectionToken := mockMqttToken{
timesOut: false,
err: nil,
}
subscribeToken := mockMqttToken{
timesOut: false,
err: nil,
}
mockMqttClient := mockMqttClient{
connectToken: connectionToken,
subscribeToken: subscribeToken,
}
clientBuilderMock := ncmock.MockClientBuilder{
Client: &mockMqttClient,
tests := []struct {
name string
connectionErr error
fetchCredentialsErr error
tokenTimeout time.Duration // how long client will wait for connection to be established
}{
{
name: "connection success",
},
{
name: "connection failure",
connectionErr: fmt.Errorf("failed to connect"),
},
{
name: "connection auth failure",
connectionErr: mqttp.ErrorRefusedNotAuthorised,
},
{
name: "fetch credentails failure",

Check failure on line 220 in nc/nc_test.go

View workflow job for this annotation

GitHub Actions / lint

`credentails` is a misspelling of `credentials` (misspell)
fetchCredentialsErr: fmt.Errorf("failed to fetch credentials"),
},
{
name: "cancel while waiting for connection",
tokenTimeout: 10 * time.Second,
},
}

credsFetcher := NewCredsFetcher(&core.CredentialsAPIMock{
NotificationCredentialsError: nil,
}, cfgManager)
notificationClient := NewClient(&clientBuilderMock,
&subs.Subject[string]{},
&subs.Subject[error]{},
&subs.Subject[[]string]{},
credsFetcher)
notificationClient.Start()
time.Sleep(1 * time.Second)

goroutinesOnStartup := runtime.NumGoroutine()
assert.True(t, goroutinesOnStartup == goroutinesInitial+1 || goroutinesOnStartup == goroutinesInitial+2,
`On startup,
there should be at most two goroutines running(management loop goroutine and connection goroutine)`)

// give management loop time to start up properly
time.Sleep(1 * time.Second)
clientBuilderMock.CallConnectionLost(mqttp.ErrorRefusedNotAuthorised)
goroutinesAfterLoosingAuth := runtime.NumGoroutine()
assert.True(t, goroutinesAfterLoosingAuth == goroutinesInitial+1 || goroutinesAfterLoosingAuth == goroutinesInitial+2,
`After loosing authorization,
there should be at most two goroutines running(management loop goroutine and connection goroutine`)

time.Sleep(1 * time.Second)
normalOperationGoroutines := runtime.NumGoroutine()
assert.Equal(t, normalOperationGoroutines, goroutinesInitial+1,
"In normal operation, there should be only a single goroutine.")

notificationClient.Stop()
shutdownGoroutines := runtime.NumGoroutine()
assert.Equal(t, goroutinesInitial, shutdownGoroutines, "goroutines remaining after nc management loop was stopped.")
for _, test := range tests {
connectionToken := mockMqttToken{
timesOut: false,
err: test.connectionErr,
}
mockMqttClient := mockMqttClient{
connectToken: connectionToken,
}
clientBuilderMock := ncmock.MockClientBuilder{
Client: &mockMqttClient,
}

credsFetcher := NewCredsFetcher(&core.CredentialsAPIMock{
NotificationCredentialsError: test.fetchCredentialsErr,
}, cfgManager)

notificationClient := Client{
clientBuilder: &clientBuilderMock,
subjectInfo: &subs.Subject[string]{},
subjectErr: &subs.Subject[error]{},
subjectPeerUpdate: &subs.Subject[[]string]{},
credsFetcher: credsFetcher,
timeFunc: func(i int) time.Duration { return test.tokenTimeout },
}

t.Run(test.name, func(t *testing.T) {
ctx, cancelFunc := context.WithCancel(context.Background())
connectedChan := make(chan interface{})
go func() {
notificationClient.connect(&mockMqttClient, false, ctx, make(chan<- interface{}), make(chan<- mqtt.Client))
connectedChan <- true
}()

cancelFunc()

select {
case <-time.After(1 * time.Second):
assert.FailNow(t, "Time out when waiting for connect to finish.")
case <-connectedChan:
}
})
}
}

0 comments on commit 96c2502

Please sign in to comment.