Skip to content

Commit

Permalink
Addressing the comments
Browse files Browse the repository at this point in the history
  • Loading branch information
behzad-mir committed Feb 8, 2024
1 parent d8afe37 commit 5dca27d
Show file tree
Hide file tree
Showing 13 changed files with 38 additions and 66 deletions.
4 changes: 2 additions & 2 deletions cni/network/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -889,7 +889,7 @@ func (plugin *NetPlugin) Get(args *cniSkel.CmdArgs) error {
}

// Query the endpoint.
if epInfo, err = plugin.nm.GetEndpointInfo(networkID, endpointID, args.IfName); err != nil {
if epInfo, err = plugin.nm.GetEndpointInfo(networkID, endpointID); err != nil {
logger.Error("Failed to query endpoint", zap.Error(err))
return err
}
Expand Down Expand Up @@ -1051,7 +1051,7 @@ func (plugin *NetPlugin) Delete(args *cniSkel.CmdArgs) error {

endpointID := plugin.nm.GetEndpointID(args.ContainerID, args.IfName)
// Query the endpoint.
if epInfo, err = plugin.nm.GetEndpointInfo(networkID, endpointID, args.IfName); err != nil {
if epInfo, err = plugin.nm.GetEndpointInfo(networkID, endpointID); err != nil {
logger.Info("GetEndpoint",
zap.String("endpoint", endpointID),
zap.Error(err))
Expand Down
2 changes: 1 addition & 1 deletion cni/network/network_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ var (
func (plugin *NetPlugin) handleConsecutiveAdd(args *cniSkel.CmdArgs, endpointId string, networkId string,
nwInfo *network.NetworkInfo, nwCfg *cni.NetworkConfig,
) (*cniTypesCurr.Result, error) {
epInfo, _ := plugin.nm.GetEndpointInfo(networkId, endpointId, "")
epInfo, _ := plugin.nm.GetEndpointInfo(networkId, endpointId)
if epInfo == nil {
return nil, nil
}
Expand Down
2 changes: 1 addition & 1 deletion cnm/network/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ func (plugin *netPlugin) endpointOperInfo(w http.ResponseWriter, r *http.Request
}

// Process request.
epInfo, err := plugin.nm.GetEndpointInfo(req.NetworkID, req.EndpointID, "")
epInfo, err := plugin.nm.GetEndpointInfo(req.NetworkID, req.EndpointID)
if err != nil {
plugin.SendErrorResponse(w, err)
return
Expand Down
1 change: 0 additions & 1 deletion cns/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,5 +363,4 @@ type GetHomeAzResponse struct {
type EndpointRequest struct {
HnsEndpointID string `json:"hnsEndpointID"`
HostVethName string `json:"hostVethName"`
IFName string `json:"IFName"`
}
16 changes: 3 additions & 13 deletions cns/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1024,20 +1024,11 @@ func (c *Client) GetHomeAz(ctx context.Context) (*cns.GetHomeAzResponse, error)
}

// GetEndpoint calls the EndpointHandlerAPI in CNS to retrieve the state of a given EndpointID
func (c *Client) GetEndpoint(ctx context.Context, endpointID, ifName string) (*restserver.GetEndpointResponse, error) {
func (c *Client) GetEndpoint(ctx context.Context, endpointID string) (*restserver.GetEndpointResponse, error) {
// build the request
getEndpoint := cns.EndpointRequest{
IFName: ifName,
}
var body bytes.Buffer

if err := json.NewEncoder(&body).Encode(getEndpoint); err != nil {
return nil, errors.Wrap(err, "failed to encode getEndpoint")
}

u := c.routes[cns.EndpointAPI]
uString := u.String() + endpointID
req, err := http.NewRequestWithContext(ctx, http.MethodGet, uString, &body)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, uString, http.NoBody)
if err != nil {
return nil, errors.Wrap(err, "failed to build request")
}
Expand Down Expand Up @@ -1068,12 +1059,11 @@ func (c *Client) GetEndpoint(ctx context.Context, endpointID, ifName string) (*r

// UpdateEndpoint calls the EndpointHandlerAPI in CNS
// to update the state of a given EndpointID with either HNSEndpointID or HostVethName
func (c *Client) UpdateEndpoint(ctx context.Context, endpointID, hnsID, vethName, ifName string) (*cns.Response, error) {
func (c *Client) UpdateEndpoint(ctx context.Context, endpointID, hnsID, vethName string) (*cns.Response, error) {
// build the request
updateEndpoint := cns.EndpointRequest{
HnsEndpointID: hnsID,
HostVethName: vethName,
IFName: ifName,
}
var body bytes.Buffer

Expand Down
15 changes: 2 additions & 13 deletions cns/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2777,7 +2777,6 @@ func TestUpdateEndpoint(t *testing.T) {
containerID string
hnsID string
vethName string
ifName string
response *RequestCapture
expReq *cns.EndpointRequest
shouldErr bool
Expand All @@ -2787,7 +2786,6 @@ func TestUpdateEndpoint(t *testing.T) {
"",
"",
"",
"",
&RequestCapture{
Next: &mockdo{},
},
Expand All @@ -2799,15 +2797,13 @@ func TestUpdateEndpoint(t *testing.T) {
"foo",
"bar",
"",
"too",
&RequestCapture{
Next: &mockdo{
httpStatusCodeToReturn: http.StatusOK,
},
},
&cns.EndpointRequest{
HnsEndpointID: "bar",
IFName: "too",
},
false,
},
Expand All @@ -2816,15 +2812,13 @@ func TestUpdateEndpoint(t *testing.T) {
"foo",
"",
"bar",
"too",
&RequestCapture{
Next: &mockdo{
httpStatusCodeToReturn: http.StatusOK,
},
},
&cns.EndpointRequest{
HostVethName: "bar",
IFName: "too",
},
false,
},
Expand All @@ -2833,7 +2827,6 @@ func TestUpdateEndpoint(t *testing.T) {
"foo",
"",
"bar",
"",
&RequestCapture{
Next: &mockdo{
httpStatusCodeToReturn: http.StatusBadRequest,
Expand All @@ -2858,7 +2851,7 @@ func TestUpdateEndpoint(t *testing.T) {
}

// execute the method under test
res, err := client.UpdateEndpoint(context.TODO(), test.containerID, test.hnsID, test.vethName, test.ifName)
res, err := client.UpdateEndpoint(context.TODO(), test.containerID, test.hnsID, test.vethName)
if err != nil && !test.shouldErr {
t.Fatal("unexpected error: err: ", err, res.Message)
}
Expand Down Expand Up @@ -2904,14 +2897,12 @@ func TestGetEndpoint(t *testing.T) {
getEndpointTests := []struct {
name string
containerID string
ifName string
response *RequestCapture
shouldErr bool
}{
{
"empty",
"",
"",
&RequestCapture{
Next: &mockdo{},
},
Expand All @@ -2920,7 +2911,6 @@ func TestGetEndpoint(t *testing.T) {
{
"with EndpointID",
"foo",
"foo",
&RequestCapture{
Next: &mockdo{
httpStatusCodeToReturn: http.StatusOK,
Expand All @@ -2931,7 +2921,6 @@ func TestGetEndpoint(t *testing.T) {
{
"Bad Request",
"foo",
"foo",
&RequestCapture{
Next: &mockdo{
httpStatusCodeToReturn: http.StatusBadRequest,
Expand All @@ -2953,7 +2942,7 @@ func TestGetEndpoint(t *testing.T) {
}

// execute the method under test
res, err := client.GetEndpoint(context.TODO(), test.containerID, test.ifName)
res, err := client.GetEndpoint(context.TODO(), test.containerID)
if err != nil && !test.shouldErr {
t.Fatal("unexpected error: err: ", err, res.Response.Message)
}
Expand Down
17 changes: 10 additions & 7 deletions cns/cnireconciler/podinfoprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@ import (
"k8s.io/utils/exec"
)

const InterfaceName = "eth0"

// NewCNIPodInfoProvider returns an implementation of cns.PodInfoByIPProvider
// that execs out to the CNI and uses the response to build the PodInfo map.
func NewCNIPodInfoProvider() (cns.PodInfoByIPProvider, map[string]*restserver.EndpointInfo, error) {
return newCNIPodInfoProvider(exec.New())
// if stateMigration flag is set to true it will also returns a map of containerID->EndpointInfo
func NewCNIPodInfoProvider(stateMigration bool) (cns.PodInfoByIPProvider, map[string]*restserver.EndpointInfo, error) {
return newCNIPodInfoProvider(exec.New(), stateMigration)
}

func NewCNSPodInfoProvider(endpointStore store.KeyValueStore) (cns.PodInfoByIPProvider, error) {
Expand All @@ -44,7 +43,7 @@ func newCNSPodInfoProvider(endpointStore store.KeyValueStore) (cns.PodInfoByIPPr
}), nil
}

func newCNIPodInfoProvider(exc exec.Interface) (cns.PodInfoByIPProvider, map[string]*restserver.EndpointInfo, error) {
func newCNIPodInfoProvider(exc exec.Interface, stateMigration bool) (cns.PodInfoByIPProvider, map[string]*restserver.EndpointInfo, error) {
cli := client.New(exc)
state, err := cli.GetEndpointState()
if err != nil {
Expand All @@ -54,7 +53,11 @@ func newCNIPodInfoProvider(exc exec.Interface) (cns.PodInfoByIPProvider, map[str
logger.Printf("state dump from CNI: [%+v], [%+v]", containerID, endpointInfo)
}
var endpointState map[string]*restserver.EndpointInfo
endpointState, err = cniStateToCnsEndpointState(state)
if stateMigration {
endpointState, err = cniStateToCnsEndpointState(state)
} else {
endpointState = nil
}
return cns.PodInfoByIPProviderFunc(func() (map[string]cns.PodInfo, error) {
return cniStateToPodInfoByIP(state)
}), endpointState, err
Expand Down Expand Up @@ -152,7 +155,7 @@ func cniStateToCnsEndpointState(state *api.AzureCNIState) (map[string]*restserve

// extractEndpointInfo extract Interface Name and endpointID for each endpoint based the CNI state
func extractEndpointInfo(epID, containerID string) (endpointID, interfaceName string) {
ifName := InterfaceName
ifName := restserver.InterfaceName
if strings.Contains(epID, "-eth") {
ifName = epID[len(epID)-4:]
}
Expand Down
2 changes: 1 addition & 1 deletion cns/cnireconciler/podinfoprovider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func TestNewCNIPodInfoProvider(t *testing.T) {
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
got, endpointState, err := newCNIPodInfoProvider(tt.exec)
got, endpointState, err := newCNIPodInfoProvider(tt.exec, true)
if tt.wantErr {
assert.Error(t, err)
return
Expand Down
24 changes: 7 additions & 17 deletions cns/restserver/ipam.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ var (
ErrExistingIpconfigFound = errors.New("Found existing ipconfig for infra container")
)

const ContainerIDLength = 8
const (
ContainerIDLength = 8
InterfaceName = "eth0"
)

// requestIPConfigHandlerHelper validates the request, assign IPs and return the IPConfigs
func (service *HTTPRestService) requestIPConfigHandlerHelper(ctx context.Context, ipconfigsRequest cns.IPConfigsRequest) (*cns.IPConfigsResponse, error) {
Expand Down Expand Up @@ -1011,22 +1014,9 @@ func (service *HTTPRestService) EndpointHandlerAPI(w http.ResponseWriter, r *htt
// GetEndpointHandler handles the incoming GetEndpoint requests with http Get method
func (service *HTTPRestService) GetEndpointHandler(w http.ResponseWriter, r *http.Request) {
logger.Printf("[GetEndpointState] GetEndpoint for %s", r.URL.Path)
var req cns.EndpointRequest
err := service.Listener.Decode(w, r, &req)
endpointID := strings.TrimPrefix(r.URL.Path, cns.EndpointPath)
logger.Request(service.Name, &req, err)
endpointInfo, err := service.GetEndpointHelper(endpointID)
// Check if the request is valid
if err != nil || req.IFName == "" {
response := cns.Response{
ReturnCode: types.InvalidRequest,
Message: fmt.Sprintf("[getEndpoint] getEndpoint failed with error: %s", err.Error()),
}
w.Header().Set(cnsReturnCode, response.ReturnCode.String())
err = service.Listener.Encode(w, &response)
logger.Response(service.Name, response, response.ReturnCode, err)
return
}
endpointInfo, err := service.GetEndpointHelper(endpointID, req)
if err != nil {
response := GetEndpointResponse{
Response: Response{
Expand Down Expand Up @@ -1060,7 +1050,7 @@ func (service *HTTPRestService) GetEndpointHandler(w http.ResponseWriter, r *htt
}

// GetEndpointHelper returns the state of the given endpointId
func (service *HTTPRestService) GetEndpointHelper(endpointID string, req cns.EndpointRequest) (*EndpointInfo, error) {
func (service *HTTPRestService) GetEndpointHelper(endpointID string) (*EndpointInfo, error) {
logger.Printf("[GetEndpointState] Get endpoint state for infra container %s", endpointID)

// Skip if a store is not provided.
Expand All @@ -1084,7 +1074,7 @@ func (service *HTTPRestService) GetEndpointHelper(endpointID string, req cns.End
logger.Warnf("[GetEndpointState] Found existing endpoint state for container %s", endpointID)
return endpointInfo, nil
}
legacyEndpointID := endpointID[:ContainerIDLength] + "-" + req.IFName
legacyEndpointID := endpointID[:ContainerIDLength] + "-" + InterfaceName
if endpointInfo, ok := service.EndpointState[legacyEndpointID]; ok {
logger.Warnf("[GetEndpointState] Found existing endpoint state for container %s", legacyEndpointID)
return endpointInfo, nil
Expand Down
4 changes: 2 additions & 2 deletions cns/service/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -1241,7 +1241,7 @@ func InitializeCRDState(ctx context.Context, httpRestService cns.HTTPService, cn

case cnsconfig.InitializeFromCNI:
logger.Printf("Initializing from CNI")
podInfoByIPProvider, _, err = cnireconciler.NewCNIPodInfoProvider()
podInfoByIPProvider, _, err = cnireconciler.NewCNIPodInfoProvider(false)
if err != nil {
return errors.Wrap(err, "failed to create CNI PodInfoProvider")
}
Expand Down Expand Up @@ -1525,7 +1525,7 @@ func InitializeStateFromCNS(cnsconfig *configuration.CNSConfig, endpointStateSto
logger.Printf("StatelessCNI Migration is enabled")
logger.Printf("initializing from Statefull CNI")
var endpointState map[string]*restserver.EndpointInfo
podInfoByIPProvider, endpointState, err = cnireconciler.NewCNIPodInfoProvider()
podInfoByIPProvider, endpointState, err = cnireconciler.NewCNIPodInfoProvider(cnsconfig.StatelessCNIMigration)
if err != nil {
return nil, errors.Wrap(err, "failed to create CNI PodInfoProvider")
}
Expand Down
2 changes: 1 addition & 1 deletion network/endpoint_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ func getDefaultGateway(routes []RouteInfo) net.IP {
}

// GetEndpointInfoByIPImpl returns an endpointInfo that contains corresponding HostVethName.
// TODO: It needs to be tested to see if HostVethName is required for SingleTenancy
// TODO: It needs to be tested to see if HostVethName is required for SingleTenancy, WorkItem: 26606939
func (epInfo *EndpointInfo) GetEndpointInfoByIPImpl(_ []net.IPNet, _ string) (*EndpointInfo, error) {
return epInfo, nil
}
13 changes: 7 additions & 6 deletions network/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ type NetworkManager interface {

CreateEndpoint(client apipaClient, networkID string, epInfo []*EndpointInfo) error
DeleteEndpoint(networkID string, endpointID string, epInfo *EndpointInfo) error
GetEndpointInfo(networkID string, endpointID string, ifName string) (*EndpointInfo, error)
GetEndpointInfo(networkID string, endpointID string) (*EndpointInfo, error)
GetAllEndpoints(networkID string) (map[string]*EndpointInfo, error)
GetEndpointInfoBasedOnPODDetails(networkID string, podName string, podNameSpace string, doExactMatchForPodName bool) (*EndpointInfo, error)
AttachEndpoint(networkID string, endpointID string, sandboxKey string) (*endpoint, error)
Expand Down Expand Up @@ -412,7 +412,7 @@ func (nm *networkManager) CreateEndpoint(cli apipaClient, networkID string, epIn
// It will add HNSEndpointID or HostVeth name to the endpoint state
func (nm *networkManager) UpdateEndpointState(ep *endpoint) error {
logger.Info("Calling cns updateEndpoint API with ", zap.String("containerID: ", ep.ContainerID), zap.String("HnsId: ", ep.HnsId), zap.String("HostIfName: ", ep.HostIfName))
response, err := nm.CnsClient.UpdateEndpoint(context.TODO(), ep.ContainerID, ep.HnsId, ep.HostIfName, ep.IfName)
response, err := nm.CnsClient.UpdateEndpoint(context.TODO(), ep.ContainerID, ep.HnsId, ep.HostIfName)
if err != nil {
return errors.Wrapf(err, "Update endpoint API returend with error")
}
Expand All @@ -421,8 +421,9 @@ func (nm *networkManager) UpdateEndpointState(ep *endpoint) error {
}

// GetEndpointState will make a call to CNS GetEndpointState API in the stateless CNI mode to fetch the endpointInfo
func (nm *networkManager) GetEndpointState(networkID, endpointID, ifName string) (*EndpointInfo, error) {
endpointResponse, err := nm.CnsClient.GetEndpoint(context.TODO(), endpointID, ifName)
// TODO unit tests need to be added, WorkItem: 26606939
func (nm *networkManager) GetEndpointState(networkID, endpointID string) (*EndpointInfo, error) {
endpointResponse, err := nm.CnsClient.GetEndpoint(context.TODO(), endpointID)
if err != nil {
return nil, errors.Wrapf(err, "Get endpoint API returend with error")
}
Expand Down Expand Up @@ -507,13 +508,13 @@ func (nm *networkManager) DeleteEndpointState(networkID string, epInfo *Endpoint
}

// GetEndpointInfo returns information about the given endpoint.
func (nm *networkManager) GetEndpointInfo(networkID, endpointID, ifName string) (*EndpointInfo, error) {
func (nm *networkManager) GetEndpointInfo(networkID, endpointID string) (*EndpointInfo, error) {
nm.Lock()
defer nm.Unlock()

if nm.IsStatelessCNIMode() {
logger.Info("calling cns getEndpoint API")
epInfo, err := nm.GetEndpointState(networkID, endpointID, ifName)
epInfo, err := nm.GetEndpointState(networkID, endpointID)

return epInfo, err
}
Expand Down
2 changes: 1 addition & 1 deletion network/manager_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func (nm *MockNetworkManager) GetAllEndpoints(networkID string) (map[string]*End
}

// GetEndpointInfo mock
func (nm *MockNetworkManager) GetEndpointInfo(_, endpointID, _ string) (*EndpointInfo, error) {
func (nm *MockNetworkManager) GetEndpointInfo(_, endpointID string) (*EndpointInfo, error) {
if info, exists := nm.TestEndpointInfoMap[endpointID]; exists {
return info, nil
}
Expand Down

0 comments on commit 5dca27d

Please sign in to comment.