From 5dca27d29b6d77ed09f2744998768899ec57561b Mon Sep 17 00:00:00 2001 From: AzureAhai Date: Thu, 1 Feb 2024 14:02:40 -0800 Subject: [PATCH] Addressing the comments --- cni/network/network.go | 4 ++-- cni/network/network_windows.go | 2 +- cnm/network/network.go | 2 +- cns/api.go | 1 - cns/client/client.go | 16 +++------------ cns/client/client_test.go | 15 ++------------ cns/cnireconciler/podinfoprovider.go | 17 +++++++++------- cns/cnireconciler/podinfoprovider_test.go | 2 +- cns/restserver/ipam.go | 24 +++++++---------------- cns/service/main.go | 4 ++-- network/endpoint_linux.go | 2 +- network/manager.go | 13 ++++++------ network/manager_mock.go | 2 +- 13 files changed, 38 insertions(+), 66 deletions(-) diff --git a/cni/network/network.go b/cni/network/network.go index 65d6e0c17b..2381b8ffc4 100644 --- a/cni/network/network.go +++ b/cni/network/network.go @@ -889,7 +889,7 @@ func (plugin *NetPlugin) Get(args *cniSkel.CmdArgs) error { } // Query the endpoint. - if epInfo, err = plugin.nm.GetEndpointInfo(networkID, endpointID, args.IfName); err != nil { + if epInfo, err = plugin.nm.GetEndpointInfo(networkID, endpointID); err != nil { logger.Error("Failed to query endpoint", zap.Error(err)) return err } @@ -1051,7 +1051,7 @@ func (plugin *NetPlugin) Delete(args *cniSkel.CmdArgs) error { endpointID := plugin.nm.GetEndpointID(args.ContainerID, args.IfName) // Query the endpoint. - if epInfo, err = plugin.nm.GetEndpointInfo(networkID, endpointID, args.IfName); err != nil { + if epInfo, err = plugin.nm.GetEndpointInfo(networkID, endpointID); err != nil { logger.Info("GetEndpoint", zap.String("endpoint", endpointID), zap.Error(err)) diff --git a/cni/network/network_windows.go b/cni/network/network_windows.go index 3fe382bf18..ba0fdfaf4a 100644 --- a/cni/network/network_windows.go +++ b/cni/network/network_windows.go @@ -40,7 +40,7 @@ var ( func (plugin *NetPlugin) handleConsecutiveAdd(args *cniSkel.CmdArgs, endpointId string, networkId string, nwInfo *network.NetworkInfo, nwCfg *cni.NetworkConfig, ) (*cniTypesCurr.Result, error) { - epInfo, _ := plugin.nm.GetEndpointInfo(networkId, endpointId, "") + epInfo, _ := plugin.nm.GetEndpointInfo(networkId, endpointId) if epInfo == nil { return nil, nil } diff --git a/cnm/network/network.go b/cnm/network/network.go index ec0957d6b2..4358d1a485 100644 --- a/cnm/network/network.go +++ b/cnm/network/network.go @@ -357,7 +357,7 @@ func (plugin *netPlugin) endpointOperInfo(w http.ResponseWriter, r *http.Request } // Process request. - epInfo, err := plugin.nm.GetEndpointInfo(req.NetworkID, req.EndpointID, "") + epInfo, err := plugin.nm.GetEndpointInfo(req.NetworkID, req.EndpointID) if err != nil { plugin.SendErrorResponse(w, err) return diff --git a/cns/api.go b/cns/api.go index 969e437753..3894e9aff6 100644 --- a/cns/api.go +++ b/cns/api.go @@ -363,5 +363,4 @@ type GetHomeAzResponse struct { type EndpointRequest struct { HnsEndpointID string `json:"hnsEndpointID"` HostVethName string `json:"hostVethName"` - IFName string `json:"IFName"` } diff --git a/cns/client/client.go b/cns/client/client.go index 125a135079..2ee524cfb9 100644 --- a/cns/client/client.go +++ b/cns/client/client.go @@ -1024,20 +1024,11 @@ func (c *Client) GetHomeAz(ctx context.Context) (*cns.GetHomeAzResponse, error) } // GetEndpoint calls the EndpointHandlerAPI in CNS to retrieve the state of a given EndpointID -func (c *Client) GetEndpoint(ctx context.Context, endpointID, ifName string) (*restserver.GetEndpointResponse, error) { +func (c *Client) GetEndpoint(ctx context.Context, endpointID string) (*restserver.GetEndpointResponse, error) { // build the request - getEndpoint := cns.EndpointRequest{ - IFName: ifName, - } - var body bytes.Buffer - - if err := json.NewEncoder(&body).Encode(getEndpoint); err != nil { - return nil, errors.Wrap(err, "failed to encode getEndpoint") - } - u := c.routes[cns.EndpointAPI] uString := u.String() + endpointID - req, err := http.NewRequestWithContext(ctx, http.MethodGet, uString, &body) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, uString, http.NoBody) if err != nil { return nil, errors.Wrap(err, "failed to build request") } @@ -1068,12 +1059,11 @@ func (c *Client) GetEndpoint(ctx context.Context, endpointID, ifName string) (*r // UpdateEndpoint calls the EndpointHandlerAPI in CNS // to update the state of a given EndpointID with either HNSEndpointID or HostVethName -func (c *Client) UpdateEndpoint(ctx context.Context, endpointID, hnsID, vethName, ifName string) (*cns.Response, error) { +func (c *Client) UpdateEndpoint(ctx context.Context, endpointID, hnsID, vethName string) (*cns.Response, error) { // build the request updateEndpoint := cns.EndpointRequest{ HnsEndpointID: hnsID, HostVethName: vethName, - IFName: ifName, } var body bytes.Buffer diff --git a/cns/client/client_test.go b/cns/client/client_test.go index 1f53829796..871773eded 100644 --- a/cns/client/client_test.go +++ b/cns/client/client_test.go @@ -2777,7 +2777,6 @@ func TestUpdateEndpoint(t *testing.T) { containerID string hnsID string vethName string - ifName string response *RequestCapture expReq *cns.EndpointRequest shouldErr bool @@ -2787,7 +2786,6 @@ func TestUpdateEndpoint(t *testing.T) { "", "", "", - "", &RequestCapture{ Next: &mockdo{}, }, @@ -2799,7 +2797,6 @@ func TestUpdateEndpoint(t *testing.T) { "foo", "bar", "", - "too", &RequestCapture{ Next: &mockdo{ httpStatusCodeToReturn: http.StatusOK, @@ -2807,7 +2804,6 @@ func TestUpdateEndpoint(t *testing.T) { }, &cns.EndpointRequest{ HnsEndpointID: "bar", - IFName: "too", }, false, }, @@ -2816,7 +2812,6 @@ func TestUpdateEndpoint(t *testing.T) { "foo", "", "bar", - "too", &RequestCapture{ Next: &mockdo{ httpStatusCodeToReturn: http.StatusOK, @@ -2824,7 +2819,6 @@ func TestUpdateEndpoint(t *testing.T) { }, &cns.EndpointRequest{ HostVethName: "bar", - IFName: "too", }, false, }, @@ -2833,7 +2827,6 @@ func TestUpdateEndpoint(t *testing.T) { "foo", "", "bar", - "", &RequestCapture{ Next: &mockdo{ httpStatusCodeToReturn: http.StatusBadRequest, @@ -2858,7 +2851,7 @@ func TestUpdateEndpoint(t *testing.T) { } // execute the method under test - res, err := client.UpdateEndpoint(context.TODO(), test.containerID, test.hnsID, test.vethName, test.ifName) + res, err := client.UpdateEndpoint(context.TODO(), test.containerID, test.hnsID, test.vethName) if err != nil && !test.shouldErr { t.Fatal("unexpected error: err: ", err, res.Message) } @@ -2904,14 +2897,12 @@ func TestGetEndpoint(t *testing.T) { getEndpointTests := []struct { name string containerID string - ifName string response *RequestCapture shouldErr bool }{ { "empty", "", - "", &RequestCapture{ Next: &mockdo{}, }, @@ -2920,7 +2911,6 @@ func TestGetEndpoint(t *testing.T) { { "with EndpointID", "foo", - "foo", &RequestCapture{ Next: &mockdo{ httpStatusCodeToReturn: http.StatusOK, @@ -2931,7 +2921,6 @@ func TestGetEndpoint(t *testing.T) { { "Bad Request", "foo", - "foo", &RequestCapture{ Next: &mockdo{ httpStatusCodeToReturn: http.StatusBadRequest, @@ -2953,7 +2942,7 @@ func TestGetEndpoint(t *testing.T) { } // execute the method under test - res, err := client.GetEndpoint(context.TODO(), test.containerID, test.ifName) + res, err := client.GetEndpoint(context.TODO(), test.containerID) if err != nil && !test.shouldErr { t.Fatal("unexpected error: err: ", err, res.Response.Message) } diff --git a/cns/cnireconciler/podinfoprovider.go b/cns/cnireconciler/podinfoprovider.go index 3330b8e0d3..47fbeb269f 100644 --- a/cns/cnireconciler/podinfoprovider.go +++ b/cns/cnireconciler/podinfoprovider.go @@ -15,12 +15,11 @@ import ( "k8s.io/utils/exec" ) -const InterfaceName = "eth0" - // NewCNIPodInfoProvider returns an implementation of cns.PodInfoByIPProvider // that execs out to the CNI and uses the response to build the PodInfo map. -func NewCNIPodInfoProvider() (cns.PodInfoByIPProvider, map[string]*restserver.EndpointInfo, error) { - return newCNIPodInfoProvider(exec.New()) +// if stateMigration flag is set to true it will also returns a map of containerID->EndpointInfo +func NewCNIPodInfoProvider(stateMigration bool) (cns.PodInfoByIPProvider, map[string]*restserver.EndpointInfo, error) { + return newCNIPodInfoProvider(exec.New(), stateMigration) } func NewCNSPodInfoProvider(endpointStore store.KeyValueStore) (cns.PodInfoByIPProvider, error) { @@ -44,7 +43,7 @@ func newCNSPodInfoProvider(endpointStore store.KeyValueStore) (cns.PodInfoByIPPr }), nil } -func newCNIPodInfoProvider(exc exec.Interface) (cns.PodInfoByIPProvider, map[string]*restserver.EndpointInfo, error) { +func newCNIPodInfoProvider(exc exec.Interface, stateMigration bool) (cns.PodInfoByIPProvider, map[string]*restserver.EndpointInfo, error) { cli := client.New(exc) state, err := cli.GetEndpointState() if err != nil { @@ -54,7 +53,11 @@ func newCNIPodInfoProvider(exc exec.Interface) (cns.PodInfoByIPProvider, map[str logger.Printf("state dump from CNI: [%+v], [%+v]", containerID, endpointInfo) } var endpointState map[string]*restserver.EndpointInfo - endpointState, err = cniStateToCnsEndpointState(state) + if stateMigration { + endpointState, err = cniStateToCnsEndpointState(state) + } else { + endpointState = nil + } return cns.PodInfoByIPProviderFunc(func() (map[string]cns.PodInfo, error) { return cniStateToPodInfoByIP(state) }), endpointState, err @@ -152,7 +155,7 @@ func cniStateToCnsEndpointState(state *api.AzureCNIState) (map[string]*restserve // extractEndpointInfo extract Interface Name and endpointID for each endpoint based the CNI state func extractEndpointInfo(epID, containerID string) (endpointID, interfaceName string) { - ifName := InterfaceName + ifName := restserver.InterfaceName if strings.Contains(epID, "-eth") { ifName = epID[len(epID)-4:] } diff --git a/cns/cnireconciler/podinfoprovider_test.go b/cns/cnireconciler/podinfoprovider_test.go index da51b840e8..022c805251 100644 --- a/cns/cnireconciler/podinfoprovider_test.go +++ b/cns/cnireconciler/podinfoprovider_test.go @@ -76,7 +76,7 @@ func TestNewCNIPodInfoProvider(t *testing.T) { for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { - got, endpointState, err := newCNIPodInfoProvider(tt.exec) + got, endpointState, err := newCNIPodInfoProvider(tt.exec, true) if tt.wantErr { assert.Error(t, err) return diff --git a/cns/restserver/ipam.go b/cns/restserver/ipam.go index dbb5bbffca..66459c5a68 100644 --- a/cns/restserver/ipam.go +++ b/cns/restserver/ipam.go @@ -30,7 +30,10 @@ var ( ErrExistingIpconfigFound = errors.New("Found existing ipconfig for infra container") ) -const ContainerIDLength = 8 +const ( + ContainerIDLength = 8 + InterfaceName = "eth0" +) // requestIPConfigHandlerHelper validates the request, assign IPs and return the IPConfigs func (service *HTTPRestService) requestIPConfigHandlerHelper(ctx context.Context, ipconfigsRequest cns.IPConfigsRequest) (*cns.IPConfigsResponse, error) { @@ -1011,22 +1014,9 @@ func (service *HTTPRestService) EndpointHandlerAPI(w http.ResponseWriter, r *htt // GetEndpointHandler handles the incoming GetEndpoint requests with http Get method func (service *HTTPRestService) GetEndpointHandler(w http.ResponseWriter, r *http.Request) { logger.Printf("[GetEndpointState] GetEndpoint for %s", r.URL.Path) - var req cns.EndpointRequest - err := service.Listener.Decode(w, r, &req) endpointID := strings.TrimPrefix(r.URL.Path, cns.EndpointPath) - logger.Request(service.Name, &req, err) + endpointInfo, err := service.GetEndpointHelper(endpointID) // Check if the request is valid - if err != nil || req.IFName == "" { - response := cns.Response{ - ReturnCode: types.InvalidRequest, - Message: fmt.Sprintf("[getEndpoint] getEndpoint failed with error: %s", err.Error()), - } - w.Header().Set(cnsReturnCode, response.ReturnCode.String()) - err = service.Listener.Encode(w, &response) - logger.Response(service.Name, response, response.ReturnCode, err) - return - } - endpointInfo, err := service.GetEndpointHelper(endpointID, req) if err != nil { response := GetEndpointResponse{ Response: Response{ @@ -1060,7 +1050,7 @@ func (service *HTTPRestService) GetEndpointHandler(w http.ResponseWriter, r *htt } // GetEndpointHelper returns the state of the given endpointId -func (service *HTTPRestService) GetEndpointHelper(endpointID string, req cns.EndpointRequest) (*EndpointInfo, error) { +func (service *HTTPRestService) GetEndpointHelper(endpointID string) (*EndpointInfo, error) { logger.Printf("[GetEndpointState] Get endpoint state for infra container %s", endpointID) // Skip if a store is not provided. @@ -1084,7 +1074,7 @@ func (service *HTTPRestService) GetEndpointHelper(endpointID string, req cns.End logger.Warnf("[GetEndpointState] Found existing endpoint state for container %s", endpointID) return endpointInfo, nil } - legacyEndpointID := endpointID[:ContainerIDLength] + "-" + req.IFName + legacyEndpointID := endpointID[:ContainerIDLength] + "-" + InterfaceName if endpointInfo, ok := service.EndpointState[legacyEndpointID]; ok { logger.Warnf("[GetEndpointState] Found existing endpoint state for container %s", legacyEndpointID) return endpointInfo, nil diff --git a/cns/service/main.go b/cns/service/main.go index 2c155a1ec2..e274c847f5 100644 --- a/cns/service/main.go +++ b/cns/service/main.go @@ -1241,7 +1241,7 @@ func InitializeCRDState(ctx context.Context, httpRestService cns.HTTPService, cn case cnsconfig.InitializeFromCNI: logger.Printf("Initializing from CNI") - podInfoByIPProvider, _, err = cnireconciler.NewCNIPodInfoProvider() + podInfoByIPProvider, _, err = cnireconciler.NewCNIPodInfoProvider(false) if err != nil { return errors.Wrap(err, "failed to create CNI PodInfoProvider") } @@ -1525,7 +1525,7 @@ func InitializeStateFromCNS(cnsconfig *configuration.CNSConfig, endpointStateSto logger.Printf("StatelessCNI Migration is enabled") logger.Printf("initializing from Statefull CNI") var endpointState map[string]*restserver.EndpointInfo - podInfoByIPProvider, endpointState, err = cnireconciler.NewCNIPodInfoProvider() + podInfoByIPProvider, endpointState, err = cnireconciler.NewCNIPodInfoProvider(cnsconfig.StatelessCNIMigration) if err != nil { return nil, errors.Wrap(err, "failed to create CNI PodInfoProvider") } diff --git a/network/endpoint_linux.go b/network/endpoint_linux.go index 766d432603..179d28cc8c 100644 --- a/network/endpoint_linux.go +++ b/network/endpoint_linux.go @@ -533,7 +533,7 @@ func getDefaultGateway(routes []RouteInfo) net.IP { } // GetEndpointInfoByIPImpl returns an endpointInfo that contains corresponding HostVethName. -// TODO: It needs to be tested to see if HostVethName is required for SingleTenancy +// TODO: It needs to be tested to see if HostVethName is required for SingleTenancy, WorkItem: 26606939 func (epInfo *EndpointInfo) GetEndpointInfoByIPImpl(_ []net.IPNet, _ string) (*EndpointInfo, error) { return epInfo, nil } diff --git a/network/manager.go b/network/manager.go index bf95467811..19e46498f0 100644 --- a/network/manager.go +++ b/network/manager.go @@ -101,7 +101,7 @@ type NetworkManager interface { CreateEndpoint(client apipaClient, networkID string, epInfo []*EndpointInfo) error DeleteEndpoint(networkID string, endpointID string, epInfo *EndpointInfo) error - GetEndpointInfo(networkID string, endpointID string, ifName string) (*EndpointInfo, error) + GetEndpointInfo(networkID string, endpointID string) (*EndpointInfo, error) GetAllEndpoints(networkID string) (map[string]*EndpointInfo, error) GetEndpointInfoBasedOnPODDetails(networkID string, podName string, podNameSpace string, doExactMatchForPodName bool) (*EndpointInfo, error) AttachEndpoint(networkID string, endpointID string, sandboxKey string) (*endpoint, error) @@ -412,7 +412,7 @@ func (nm *networkManager) CreateEndpoint(cli apipaClient, networkID string, epIn // It will add HNSEndpointID or HostVeth name to the endpoint state func (nm *networkManager) UpdateEndpointState(ep *endpoint) error { logger.Info("Calling cns updateEndpoint API with ", zap.String("containerID: ", ep.ContainerID), zap.String("HnsId: ", ep.HnsId), zap.String("HostIfName: ", ep.HostIfName)) - response, err := nm.CnsClient.UpdateEndpoint(context.TODO(), ep.ContainerID, ep.HnsId, ep.HostIfName, ep.IfName) + response, err := nm.CnsClient.UpdateEndpoint(context.TODO(), ep.ContainerID, ep.HnsId, ep.HostIfName) if err != nil { return errors.Wrapf(err, "Update endpoint API returend with error") } @@ -421,8 +421,9 @@ func (nm *networkManager) UpdateEndpointState(ep *endpoint) error { } // GetEndpointState will make a call to CNS GetEndpointState API in the stateless CNI mode to fetch the endpointInfo -func (nm *networkManager) GetEndpointState(networkID, endpointID, ifName string) (*EndpointInfo, error) { - endpointResponse, err := nm.CnsClient.GetEndpoint(context.TODO(), endpointID, ifName) +// TODO unit tests need to be added, WorkItem: 26606939 +func (nm *networkManager) GetEndpointState(networkID, endpointID string) (*EndpointInfo, error) { + endpointResponse, err := nm.CnsClient.GetEndpoint(context.TODO(), endpointID) if err != nil { return nil, errors.Wrapf(err, "Get endpoint API returend with error") } @@ -507,13 +508,13 @@ func (nm *networkManager) DeleteEndpointState(networkID string, epInfo *Endpoint } // GetEndpointInfo returns information about the given endpoint. -func (nm *networkManager) GetEndpointInfo(networkID, endpointID, ifName string) (*EndpointInfo, error) { +func (nm *networkManager) GetEndpointInfo(networkID, endpointID string) (*EndpointInfo, error) { nm.Lock() defer nm.Unlock() if nm.IsStatelessCNIMode() { logger.Info("calling cns getEndpoint API") - epInfo, err := nm.GetEndpointState(networkID, endpointID, ifName) + epInfo, err := nm.GetEndpointState(networkID, endpointID) return epInfo, err } diff --git a/network/manager_mock.go b/network/manager_mock.go index 2f397c3aca..188d2bb2ad 100644 --- a/network/manager_mock.go +++ b/network/manager_mock.go @@ -99,7 +99,7 @@ func (nm *MockNetworkManager) GetAllEndpoints(networkID string) (map[string]*End } // GetEndpointInfo mock -func (nm *MockNetworkManager) GetEndpointInfo(_, endpointID, _ string) (*EndpointInfo, error) { +func (nm *MockNetworkManager) GetEndpointInfo(_, endpointID string) (*EndpointInfo, error) { if info, exists := nm.TestEndpointInfoMap[endpointID]; exists { return info, nil }