diff --git a/cns/nodesubnet/helper_for_ip_fetcher_test.go b/cns/nodesubnet/helper_for_ip_fetcher_test.go new file mode 100644 index 0000000000..f8eda641f4 --- /dev/null +++ b/cns/nodesubnet/helper_for_ip_fetcher_test.go @@ -0,0 +1,9 @@ +package nodesubnet + +import "time" + +// This method is in this file (_test.go) because it is a test helper method. +// The following method is built during tests, and is not part of the main code. +func (c *IPFetcher) SetSecondaryIPQueryInterval(interval time.Duration) { + c.secondaryIPQueryInterval = interval +} diff --git a/cns/nodesubnet/ip_fetcher.go b/cns/nodesubnet/ip_fetcher.go new file mode 100644 index 0000000000..5c2233786d --- /dev/null +++ b/cns/nodesubnet/ip_fetcher.go @@ -0,0 +1,77 @@ +package nodesubnet + +import ( + "context" + "log" + "net/netip" + "time" + + "github.com/Azure/azure-container-networking/nmagent" + "github.com/pkg/errors" +) + +var ErrRefreshSkipped = errors.New("refresh skipped due to throttling") + +// InterfaceRetriever is an interface is implemented by the NMAgent Client, and also a mock client for testing. +type InterfaceRetriever interface { + GetInterfaceIPInfo(ctx context.Context) (nmagent.Interfaces, error) +} + +type IPFetcher struct { + // Node subnet state + secondaryIPQueryInterval time.Duration // Minimum time between secondary IP fetches + secondaryIPLastRefreshTime time.Time // Time of last secondary IP fetch + + ipFectcherClient InterfaceRetriever +} + +func NewIPFetcher(nmaClient InterfaceRetriever, queryInterval time.Duration) *IPFetcher { + return &IPFetcher{ + ipFectcherClient: nmaClient, + secondaryIPQueryInterval: queryInterval, + } +} + +func (c *IPFetcher) RefreshSecondaryIPsIfNeeded(ctx context.Context) (ips []netip.Addr, err error) { + // If secondaryIPQueryInterval has elapsed since the last fetch, fetch secondary IPs + if time.Since(c.secondaryIPLastRefreshTime) < c.secondaryIPQueryInterval { + return nil, ErrRefreshSkipped + } + + c.secondaryIPLastRefreshTime = time.Now() + response, err := c.ipFectcherClient.GetInterfaceIPInfo(ctx) + if err != nil { + return nil, errors.Wrap(err, "getting interface IPs") + } + + res := flattenIPListFromResponse(&response) + return res, nil +} + +// Get the list of secondary IPs from fetched Interfaces +func flattenIPListFromResponse(resp *nmagent.Interfaces) (res []netip.Addr) { + // For each interface... + for _, intf := range resp.Entries { + if !intf.IsPrimary { + continue + } + + // For each subnet on the interface... + for _, s := range intf.InterfaceSubnets { + addressCount := 0 + // For each address in the subnet... + for _, a := range s.IPAddress { + // Primary addresses are reserved for the host. + if a.IsPrimary { + continue + } + + res = append(res, netip.Addr(a.Address)) + addressCount++ + } + log.Printf("Got %d addresses from subnet %s", addressCount, s.Prefix) + } + } + + return res +} diff --git a/cns/nodesubnet/ip_fetcher_test.go b/cns/nodesubnet/ip_fetcher_test.go new file mode 100644 index 0000000000..6a2e425126 --- /dev/null +++ b/cns/nodesubnet/ip_fetcher_test.go @@ -0,0 +1,86 @@ +package nodesubnet_test + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/Azure/azure-container-networking/cns/nodesubnet" + "github.com/Azure/azure-container-networking/nmagent" +) + +// Mock client that simply tracks if refresh has been called +type TestClient struct { + fetchCalled bool +} + +// Mock refresh +func (c *TestClient) GetInterfaceIPInfo(_ context.Context) (nmagent.Interfaces, error) { + c.fetchCalled = true + return nmagent.Interfaces{}, nil +} + +func TestRefreshSecondaryIPsIfNeeded(t *testing.T) { + getTests := []struct { + name string + shouldCall bool + interval time.Duration + }{ + { + "fetch called", + true, + -1 * time.Second, // Negative timeout to force refresh + }, + { + "no refresh needed", + false, + 10 * time.Hour, // High timeout to avoid refresh + }, + } + + clientPtr := &TestClient{} + fetcher := nodesubnet.NewIPFetcher(clientPtr, 0) + + for _, test := range getTests { + test := test + t.Run(test.name, func(t *testing.T) { // Do not parallelize, as we are using a shared client + fetcher.SetSecondaryIPQueryInterval(test.interval) + ctx, cancel := testContext(t) + defer cancel() + clientPtr.fetchCalled = false + _, err := fetcher.RefreshSecondaryIPsIfNeeded(ctx) + + if test.shouldCall { + if err != nil && errors.Is(err, nodesubnet.ErrRefreshSkipped) { + t.Error("refresh expected, but didn't happen") + } + + checkErr(t, err, false) + } else if err == nil || !errors.Is(err, nodesubnet.ErrRefreshSkipped) { + t.Error("refresh not expected, but happened") + } + }) + } +} + +// testContext creates a context from the provided testing.T that will be +// canceled if the test suite is terminated. +func testContext(t *testing.T) (context.Context, context.CancelFunc) { + if deadline, ok := t.Deadline(); ok { + return context.WithDeadline(context.Background(), deadline) + } + return context.WithCancel(context.Background()) +} + +// checkErr is an assertion of the presence or absence of an error +func checkErr(t *testing.T, err error, shouldErr bool) { + t.Helper() + if err != nil && !shouldErr { + t.Fatal("unexpected error: err:", err) + } + + if err == nil && shouldErr { + t.Fatal("expected error but received none") + } +} diff --git a/nmagent/client.go b/nmagent/client.go index 8eea299bc6..71a0810978 100644 --- a/nmagent/client.go +++ b/nmagent/client.go @@ -44,9 +44,8 @@ type Client struct { httpClient *http.Client // config - host string - port uint16 - + host string + port uint16 enableTLS bool retrier interface { @@ -284,6 +283,37 @@ func (c *Client) GetHomeAz(ctx context.Context) (AzResponse, error) { return homeAzResponse, nil } +// GetInterfaceIPInfo fetches the node's interface IP information from nmagent +func (c *Client) GetInterfaceIPInfo(ctx context.Context) (Interfaces, error) { + req, err := c.buildRequest(ctx, &GetSecondaryIPsRequest{}) + var out Interfaces + + if err != nil { + return out, errors.Wrap(err, "building request") + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return out, errors.Wrap(err, "submitting request") + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return out, die(resp.StatusCode, resp.Header, resp.Body, req.URL.Path) + } + + if resp.StatusCode != http.StatusOK { + return out, die(resp.StatusCode, resp.Header, resp.Body, req.URL.Path) + } + + err = xml.NewDecoder(resp.Body).Decode(&out) + if err != nil { + return out, errors.Wrap(err, "decoding response") + } + + return out, nil +} + func die(code int, headers http.Header, body io.ReadCloser, path string) error { // nolint:errcheck // make a best effort to return whatever information we can // returning an error here without the code and source would diff --git a/nmagent/client_test.go b/nmagent/client_test.go index c45c46d8eb..e4e0b36ece 100644 --- a/nmagent/client_test.go +++ b/nmagent/client_test.go @@ -3,9 +3,11 @@ package nmagent_test import ( "context" "encoding/json" + "encoding/xml" "fmt" "net/http" "net/http/httptest" + "net/netip" "strings" "testing" @@ -809,3 +811,86 @@ func TestGetHomeAz(t *testing.T) { }) } } + +func TestGetInterfaceIPInfo(t *testing.T) { + tests := []struct { + name string + expURL string + response nmagent.Interfaces + respStr string + }{ + { + "happy path", + "/machine/plugins?comp=nmagent&type=getinterfaceinfov1", + nmagent.Interfaces{ + Entries: []nmagent.Interface{ + { + MacAddress: nmagent.MACAddress{0x00, 0x0D, 0x3A, 0xF9, 0xDC, 0xA6}, + IsPrimary: true, + InterfaceSubnets: []nmagent.InterfaceSubnet{ + { + Prefix: "10.240.0.0/16", + IPAddress: []nmagent.NodeIP{ + { + Address: nmagent.IPAddress(netip.AddrFrom4([4]byte{10, 240, 0, 5})), + IsPrimary: true, + }, + { + Address: nmagent.IPAddress(netip.AddrFrom4([4]byte{10, 240, 0, 6})), + IsPrimary: false, + }, + }, + }, + }, + }, + }, + }, + "" + + "" + + "", + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + var gotURL string + client := nmagent.NewTestClient(&TestTripper{ + RoundTripF: func(req *http.Request) (*http.Response, error) { + gotURL = req.URL.RequestURI() + rr := httptest.NewRecorder() + rr.WriteHeader(http.StatusOK) + err := xml.NewEncoder(rr).Encode(test.response) + if err != nil { + t.Fatal("unexpected error encoding response: err:", err) + } + return rr.Result(), nil + }, + }) + + ctx, cancel := testContext(t) + defer cancel() + + resp, err := client.GetInterfaceIPInfo(ctx) + checkErr(t, err, false) + + if gotURL != test.expURL { + t.Error("received URL differs from expected: got:", gotURL, "exp:", test.expURL) + } + + if got := resp; !cmp.Equal(got, test.response) { + t.Error("response differs from expectation: diff:", cmp.Diff(got, test.response)) + } + + var unmarshaled nmagent.Interfaces + err = xml.Unmarshal([]byte(test.respStr), &unmarshaled) + checkErr(t, err, false) + + if !cmp.Equal(resp, unmarshaled) { + t.Error("response differs from expected decoded string: diff:", cmp.Diff(resp, unmarshaled)) + } + }) + } +} diff --git a/nmagent/ipaddress.go b/nmagent/ipaddress.go new file mode 100644 index 0000000000..2090bdbe86 --- /dev/null +++ b/nmagent/ipaddress.go @@ -0,0 +1,52 @@ +package nmagent + +import ( + "encoding/xml" + "net/netip" + + "github.com/pkg/errors" +) + +type IPAddress netip.Addr + +func (h IPAddress) Equal(other IPAddress) bool { + return netip.Addr(h).Compare(netip.Addr(other)) == 0 +} + +func (h *IPAddress) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { + var ipStr string + if err := d.DecodeElement(&ipStr, &start); err != nil { + return errors.Wrap(err, "decoding IP address") + } + + ip, err := netip.ParseAddr(ipStr) + if err != nil { + return errors.Wrap(err, "parsing IP address") + } + + *h = IPAddress(ip) + return nil +} + +func (h *IPAddress) UnmarshalXMLAttr(attr xml.Attr) error { + ipStr := attr.Value + ip, err := netip.ParseAddr(ipStr) + if err != nil { + return errors.Wrap(err, "parsing IP address") + } + + *h = IPAddress(ip) + return nil +} + +func (h IPAddress) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + err := e.EncodeElement(netip.Addr(h).String(), start) + return errors.Wrap(err, "encoding IP address") +} + +func (h IPAddress) MarshalXMLAttr(name xml.Name) (xml.Attr, error) { + return xml.Attr{ + Name: name, + Value: netip.Addr(h).String(), + }, nil +} diff --git a/nmagent/macaddress.go b/nmagent/macaddress.go new file mode 100644 index 0000000000..97c5385162 --- /dev/null +++ b/nmagent/macaddress.go @@ -0,0 +1,66 @@ +package nmagent + +import ( + "encoding/hex" + "encoding/xml" + "net" + + "github.com/pkg/errors" +) + +const ( + MACAddressSize = 6 +) + +type MACAddress net.HardwareAddr + +func (h *MACAddress) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { + var macStr string + if err := d.DecodeElement(&macStr, &start); err != nil { + return errors.Wrap(err, "decoding MAC address") + } + + // Convert the string (without colons) into a valid MACAddress + mac, err := hex.DecodeString(macStr) + if err != nil { + return &net.ParseError{Type: "MAC address", Text: macStr} + } + + *h = MACAddress(mac) + return nil +} + +func (h *MACAddress) UnmarshalXMLAttr(attr xml.Attr) error { + macStr := attr.Value + mac, err := hex.DecodeString(macStr) + if err != nil { + return &net.ParseError{Type: "MAC address", Text: macStr} + } + + *h = MACAddress(mac) + return nil +} + +func (h MACAddress) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + if len(h) != MACAddressSize { + return &net.AddrError{Err: "invalid MAC address", Addr: hex.EncodeToString(h)} + } + + macStr := hex.EncodeToString(h) + err := e.EncodeElement(macStr, start) + return errors.Wrap(err, "encoding MAC address") +} + +func (h MACAddress) MarshalXMLAttr(name xml.Name) (xml.Attr, error) { + if len(h) != MACAddressSize { + return xml.Attr{}, &net.AddrError{Err: "invalid MAC address", Addr: hex.EncodeToString(h)} + } + + macStr := hex.EncodeToString(h) + attr := xml.Attr{ + Name: name, + Value: macStr, + } + + return attr, nil +} diff --git a/nmagent/requests.go b/nmagent/requests.go index 01182bfbb3..6a173080fa 100644 --- a/nmagent/requests.go +++ b/nmagent/requests.go @@ -536,3 +536,29 @@ func (g *GetHomeAzRequest) Path() string { func (g *GetHomeAzRequest) Validate() error { return nil } + +var _ Request = &GetSecondaryIPsRequest{} + +type GetSecondaryIPsRequest struct{} + +// Body is a no-op method to satisfy the Request interface while indicating +// that there is no body for a GetSecondaryIPsRequest Request. +func (g *GetSecondaryIPsRequest) Body() (io.Reader, error) { + return nil, nil +} + +// Method indicates that GetSecondaryIPsRequest requests are GET requests. +func (g *GetSecondaryIPsRequest) Method() string { + return http.MethodGet +} + +// Path returns the necessary URI path for invoking a GetSecondaryIPsRequest request. +func (g *GetSecondaryIPsRequest) Path() string { + return "getinterfaceinfov1" +} + +// Validate is a no-op method because parameters are hard coded in the path, +// no customization needed. +func (g *GetSecondaryIPsRequest) Validate() error { + return nil +} diff --git a/nmagent/requests_test.go b/nmagent/requests_test.go index f556efbefe..e1da51a5be 100644 --- a/nmagent/requests_test.go +++ b/nmagent/requests_test.go @@ -568,3 +568,16 @@ func TestNCVersionRequestValidate(t *testing.T) { }) } } + +func TestGetSecondaryIPsRequest(t *testing.T) { + const exp string = "getinterfaceinfov1" + req := nmagent.GetSecondaryIPsRequest{} + + if err := req.Validate(); err != nil { + t.Error("Validation failed on GetSecondaryIpsRequest ", req) + } + + if req.Path() != exp { + t.Error("unexpected path: exp:", exp, "got:", req.Path()) + } +} diff --git a/nmagent/responses.go b/nmagent/responses.go index e5324d59f9..4e917e8bc9 100644 --- a/nmagent/responses.go +++ b/nmagent/responses.go @@ -40,3 +40,26 @@ type NCVersionList struct { type AzResponse struct { HomeAz uint `json:"homeAz"` } + +type NodeIP struct { + Address IPAddress `xml:"Address,attr"` + IsPrimary bool `xml:"IsPrimary,attr"` +} + +type InterfaceSubnet struct { + IPAddress []NodeIP `xml:"IPAddress"` + Prefix string `xml:"Prefix,attr"` +} + +type Interface struct { + InterfaceSubnets []InterfaceSubnet `xml:"IPSubnet"` + MacAddress MACAddress `xml:"MacAddress,attr"` + IsPrimary bool `xml:"IsPrimary,attr"` +} + +// Response from NMAgent for getinterfaceinfov1 (interface IP information) +// If we change this name, we need to tell the XML encoder to look for +// "Interfaces" in the respose. +type Interfaces struct { + Entries []Interface `xml:"Interface"` +}