From 27ed056431783ed0b0f8bf4cbb32660d7a7579b0 Mon Sep 17 00:00:00 2001 From: Toni Kangas Date: Fri, 15 Nov 2024 13:55:53 +0200 Subject: [PATCH] feat: allow using UUID prefix as argument (#341) Also refactors argument resolution to be more flexible. This allows later adding support for case-insensitive and wild-card matching. --- CHANGELOG.md | 1 + internal/commands/network/modify.go | 3 +- internal/commands/runcommand.go | 5 +- internal/commands/runcommand_test.go | 11 ++-- internal/namedargs/resolve.go | 3 +- internal/resolver/account.go | 16 ++---- internal/resolver/account_test.go | 11 ++-- internal/resolver/database.go | 17 ++----- internal/resolver/database_test.go | 22 ++++---- internal/resolver/gateway.go | 17 ++----- internal/resolver/gateway_test.go | 21 ++++---- internal/resolver/ipaddress.go | 17 ++----- internal/resolver/ipaddress_test.go | 41 +++++---------- internal/resolver/kubernetes.go | 17 ++----- internal/resolver/kubernetes_test.go | 21 ++++---- internal/resolver/loadbalancer.go | 17 ++----- internal/resolver/loadbalancer_test.go | 21 ++++---- internal/resolver/matchers.go | 19 ++++++- internal/resolver/network.go | 22 +++----- internal/resolver/network_test.go | 51 ++++++++++++++----- internal/resolver/networkpeering.go | 17 ++----- internal/resolver/objectstorage.go | 17 ++----- internal/resolver/resolver.go | 70 +++++++++++++++++++++++++- internal/resolver/router.go | 17 ++----- internal/resolver/router_test.go | 20 +++++--- internal/resolver/server.go | 18 +++---- internal/resolver/server_test.go | 30 ++++++----- internal/resolver/servergroup.go | 17 ++----- internal/resolver/servergroup_test.go | 20 +++++--- internal/resolver/storage.go | 24 +++------ internal/resolver/storage_test.go | 28 +++++------ 31 files changed, 332 insertions(+), 299 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ffaa14ac9..a6e9b1b5e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Take server state into account in server completions. For example, do not offer started servers as completions for `server start` command. +- Allow using UUID prefix as an argument. For example, if there is only one network available that has an UUID starting with `0316`, details of that network can be listed with `upctl network show 0316` command. ## [3.11.1] - 2024-08-12 diff --git a/internal/commands/network/modify.go b/internal/commands/network/modify.go index b9718cf73..160f68f7d 100644 --- a/internal/commands/network/modify.go +++ b/internal/commands/network/modify.go @@ -96,7 +96,8 @@ func (s *modifyCommand) ExecuteSingleArgument(exec commands.Executor, arg string if err != nil { return commands.HandleError(exec, msg, fmt.Errorf("cannot get router resolver: %w", err)) } - routerUUID, err := routerResolver(s.attachRouter) + resolved := routerResolver(s.attachRouter) + routerUUID, err := resolved.GetOnly() if err != nil { return commands.HandleError(exec, msg, fmt.Errorf("cannot resolve router '%s': %w", s.attachRouter, err)) } diff --git a/internal/commands/runcommand.go b/internal/commands/runcommand.go index 432f9c9f4..63108ac61 100644 --- a/internal/commands/runcommand.go +++ b/internal/commands/runcommand.go @@ -78,8 +78,9 @@ func resolveArguments(nc Command, exec Executor, args []string) (out []resolvedA return nil, fmt.Errorf("cannot get resolver: %w", err) } for _, arg := range args { - resolved, err := argumentResolver(arg) - out = append(out, resolvedArgument{Resolved: resolved, Error: err, Original: arg}) + resolved := argumentResolver(arg) + value, err := resolved.GetOnly() + out = append(out, resolvedArgument{Resolved: value, Error: err, Original: arg}) } } else { for _, arg := range args { diff --git a/internal/commands/runcommand_test.go b/internal/commands/runcommand_test.go index ddfb5f707..31691f58c 100644 --- a/internal/commands/runcommand_test.go +++ b/internal/commands/runcommand_test.go @@ -99,11 +99,12 @@ type mockMultiResolver struct { } func (m *mockMultiResolver) Get(_ context.Context, _ internal.AllServices) (resolver.Resolver, error) { - return func(arg string) (uuid string, err error) { - if len(arg) > 5 { - return "", fmt.Errorf("MOCKTOOLONG") + return func(arg string) resolver.Resolved { + rv := resolver.Resolved{Arg: arg} + if len(arg) <= 5 { + rv.AddMatch("uuid:"+arg, resolver.MatchTypeExact) } - return fmt.Sprintf("uuid:%s", arg), nil + return rv }, nil } @@ -258,7 +259,7 @@ func TestExecute_Resolution(t *testing.T) { values[typedO.Value.(string)] = struct{}{} case output.Error: assert.Empty(t, typedO.Resolved) - assert.EqualError(t, typedO.Value, "cannot resolve argument: MOCKTOOLONG") + assert.EqualError(t, typedO.Value, "cannot resolve argument: nothing found matching 'failtoresolve'") } } assert.Equal(t, values, map[string]struct{}{ diff --git a/internal/namedargs/resolve.go b/internal/namedargs/resolve.go index b51eecd75..b14d1245f 100644 --- a/internal/namedargs/resolve.go +++ b/internal/namedargs/resolve.go @@ -14,5 +14,6 @@ func Resolve(provider resolver.ResolutionProvider, exec commands.Executor, arg s return "", fmt.Errorf("could not initialize resolver: %w", err) } - return resolver(arg) + resolved := resolver(arg) + return resolved.GetOnly() } diff --git a/internal/resolver/account.go b/internal/resolver/account.go index a6009b5d4..2520205c7 100644 --- a/internal/resolver/account.go +++ b/internal/resolver/account.go @@ -18,20 +18,12 @@ func (s CachingAccount) Get(ctx context.Context, svc internal.AllServices) (Reso if err != nil { return nil, err } - return func(arg string) (uuid string, err error) { - rv := "" + return func(arg string) Resolved { + rv := Resolved{Arg: arg} for _, account := range accounts { - if MatchArgWithWhitespace(arg, account.Username) { - if rv != "" { - return "", AmbiguousResolutionError(arg) - } - rv = account.Username - } + rv.AddMatch(account.Username, MatchArgWithWhitespace(arg, account.Username)) } - if rv != "" { - return rv, nil - } - return "", NotFoundError(arg) + return rv }, nil } diff --git a/internal/resolver/account_test.go b/internal/resolver/account_test.go index 49ab41500..59961385c 100644 --- a/internal/resolver/account_test.go +++ b/internal/resolver/account_test.go @@ -43,9 +43,10 @@ func TestAccountResolution(t *testing.T) { argResolver, err := res.Get(context.TODO(), mService) assert.NoError(t, err) for _, account := range allAccounts { - resolved, err := argResolver(account.Username) + resolved := argResolver(account.Username) + value, err := resolved.GetOnly() assert.NoError(t, err) - assert.Equal(t, account.Username, resolved) + assert.Equal(t, account.Username, value) } // make sure caching works, eg. we didn't call GetAccountList more than once mService.AssertNumberOfCalls(t, "GetAccountList", 1) @@ -60,12 +61,14 @@ func TestAccountResolution(t *testing.T) { assert.NoError(t, err) // not found - resolved, err := argResolver("notfound") + resolved := argResolver("notfound") + value, err := resolved.GetOnly() + if !assert.Error(t, err) { t.FailNow() } assert.ErrorIs(t, err, resolver.NotFoundError("notfound")) - assert.Equal(t, "", resolved) + assert.Equal(t, "", value) // make sure caching works, eg. we didn't call GetAccountList more than once mService.AssertNumberOfCalls(t, "GetAccountList", 1) diff --git a/internal/resolver/database.go b/internal/resolver/database.go index 121b30dc9..ca52bfdc6 100644 --- a/internal/resolver/database.go +++ b/internal/resolver/database.go @@ -19,20 +19,13 @@ func (s CachingDatabase) Get(ctx context.Context, svc internal.AllServices) (Res if err != nil { return nil, err } - return func(arg string) (uuid string, err error) { - rv := "" + return func(arg string) Resolved { + rv := Resolved{Arg: arg} for _, db := range databases { - if MatchArgWithWhitespace(arg, db.Title) || db.UUID == arg { - if rv != "" { - return "", AmbiguousResolutionError(arg) - } - rv = db.UUID - } + rv.AddMatch(db.UUID, MatchArgWithWhitespace(arg, db.Title)) + rv.AddMatch(db.UUID, MatchUUID(arg, db.UUID)) } - if rv != "" { - return rv, nil - } - return "", NotFoundError(arg) + return rv }, nil } diff --git a/internal/resolver/database_test.go b/internal/resolver/database_test.go index 688bd71c3..41a0051f7 100644 --- a/internal/resolver/database_test.go +++ b/internal/resolver/database_test.go @@ -27,9 +27,10 @@ func TestDatabaseResolution(t *testing.T) { argResolver, err := res.Get(context.TODO(), mService) assert.NoError(t, err) for _, db := range mockDatabases { - resolved, err := argResolver(db.UUID) + resolved := argResolver(db.UUID) + value, err := resolved.GetOnly() assert.NoError(t, err) - assert.Equal(t, db.UUID, resolved) + assert.Equal(t, db.UUID, value) } // Make sure caching works, eg. we didn't call GetManagedDatabases more than once @@ -44,9 +45,10 @@ func TestDatabaseResolution(t *testing.T) { assert.NoError(t, err) db := mockDatabases[2] - resolved, err := argResolver(db.Title) + resolved := argResolver(db.Title) + value, err := resolved.GetOnly() assert.NoError(t, err) - assert.Equal(t, db.UUID, resolved) + assert.Equal(t, db.UUID, value) // Make sure caching works, eg. we didn't call GetManagedDatabases more than once mService.AssertNumberOfCalls(t, "GetManagedDatabases", 1) }) @@ -58,23 +60,25 @@ func TestDatabaseResolution(t *testing.T) { res := resolver.CachingDatabase{} argResolver, err := res.Get(context.TODO(), mService) assert.NoError(t, err) - var resolved string // Ambiguous title - resolved, err = argResolver("asd") + resolved := argResolver("asd") + value, err := resolved.GetOnly() if !assert.Error(t, err) { t.FailNow() } assert.ErrorIs(t, err, resolver.AmbiguousResolutionError("asd")) - assert.Equal(t, "", resolved) + assert.Equal(t, "", value) // Not found - resolved, err = argResolver("not-found") + resolved = argResolver("not-found") + value, err = resolved.GetOnly() + if !assert.Error(t, err) { t.FailNow() } assert.ErrorIs(t, err, resolver.NotFoundError("not-found")) - assert.Equal(t, "", resolved) + assert.Equal(t, "", value) // Make sure caching works, eg. we didn't call GetManagedDatabases more than once mService.AssertNumberOfCalls(t, "GetManagedDatabases", 1) diff --git a/internal/resolver/gateway.go b/internal/resolver/gateway.go index 0d07a5c25..d24d9627b 100644 --- a/internal/resolver/gateway.go +++ b/internal/resolver/gateway.go @@ -18,20 +18,13 @@ func (s CachingGateway) Get(ctx context.Context, svc internal.AllServices) (Reso if err != nil { return nil, err } - return func(arg string) (uuid string, err error) { - rv := "" + return func(arg string) Resolved { + rv := Resolved{Arg: arg} for _, gtw := range gateways { - if MatchArgWithWhitespace(arg, gtw.Name) || gtw.UUID == arg { - if rv != "" { - return "", AmbiguousResolutionError(arg) - } - rv = gtw.UUID - } + rv.AddMatch(gtw.UUID, MatchArgWithWhitespace(arg, gtw.Name)) + rv.AddMatch(gtw.UUID, MatchUUID(arg, gtw.UUID)) } - if rv != "" { - return rv, nil - } - return "", NotFoundError(arg) + return rv }, nil } diff --git a/internal/resolver/gateway_test.go b/internal/resolver/gateway_test.go index f09c7c29e..e4f878da3 100644 --- a/internal/resolver/gateway_test.go +++ b/internal/resolver/gateway_test.go @@ -27,9 +27,10 @@ func TestGatewayResolution(t *testing.T) { argResolver, err := res.Get(context.TODO(), mService) assert.NoError(t, err) for _, db := range mockGateways { - resolved, err := argResolver(db.UUID) + resolved := argResolver(db.UUID) + value, err := resolved.GetOnly() assert.NoError(t, err) - assert.Equal(t, db.UUID, resolved) + assert.Equal(t, db.UUID, value) } // Make sure caching works, eg. we didn't call GetGateways more than once @@ -44,9 +45,10 @@ func TestGatewayResolution(t *testing.T) { assert.NoError(t, err) db := mockGateways[2] - resolved, err := argResolver(db.Name) + resolved := argResolver(db.Name) + value, err := resolved.GetOnly() assert.NoError(t, err) - assert.Equal(t, db.UUID, resolved) + assert.Equal(t, db.UUID, value) // Make sure caching works, eg. we didn't call GetGateways more than once mService.AssertNumberOfCalls(t, "GetGateways", 1) }) @@ -58,23 +60,24 @@ func TestGatewayResolution(t *testing.T) { res := resolver.CachingGateway{} argResolver, err := res.Get(context.TODO(), mService) assert.NoError(t, err) - var resolved string // Ambiguous Name - resolved, err = argResolver("asd") + resolved := argResolver("asd") + value, err := resolved.GetOnly() if !assert.Error(t, err) { t.FailNow() } assert.ErrorIs(t, err, resolver.AmbiguousResolutionError("asd")) - assert.Equal(t, "", resolved) + assert.Equal(t, "", value) // Not found - resolved, err = argResolver("not-found") + resolved = argResolver("not-found") + value, err = resolved.GetOnly() if !assert.Error(t, err) { t.FailNow() } assert.ErrorIs(t, err, resolver.NotFoundError("not-found")) - assert.Equal(t, "", resolved) + assert.Equal(t, "", value) // Make sure caching works, eg. we didn't call GetGateways more than once mService.AssertNumberOfCalls(t, "GetGateways", 1) diff --git a/internal/resolver/ipaddress.go b/internal/resolver/ipaddress.go index f99ea971f..9f426c8c6 100644 --- a/internal/resolver/ipaddress.go +++ b/internal/resolver/ipaddress.go @@ -18,20 +18,13 @@ func (s CachingIPAddress) Get(ctx context.Context, svc internal.AllServices) (Re if err != nil { return nil, err } - return func(arg string) (uuid string, err error) { - rv := "" + return func(arg string) Resolved { + rv := Resolved{Arg: arg} for _, ipAddress := range ipaddresses.IPAddresses { - if ipAddress.PTRRecord == arg || ipAddress.Address == arg { - if rv != "" { - return "", AmbiguousResolutionError(arg) - } - rv = ipAddress.Address - } + rv.AddMatch(ipAddress.Address, MatchArgWithWhitespace(arg, ipAddress.PTRRecord)) + rv.AddMatch(ipAddress.Address, MatchArgWithWhitespace(arg, ipAddress.Address)) } - if rv != "" { - return rv, nil - } - return "", NotFoundError(arg) + return rv }, nil } diff --git a/internal/resolver/ipaddress_test.go b/internal/resolver/ipaddress_test.go index fdf1fff41..2194f2048 100644 --- a/internal/resolver/ipaddress_test.go +++ b/internal/resolver/ipaddress_test.go @@ -58,17 +58,6 @@ func TestIPAddressResolution(t *testing.T) { Zone: "fi-hel1", } ipAddress5 := upcloud.IPAddress{ - Address: "94.237.117.154", // same IP as 4 (not sure if this is actually possible?) - Access: "public", - Family: "IPv4", - PartOfPlan: upcloud.FromBool(true), - PTRRecord: "94-237-117-155.fi-hel1.upcloud.host", - ServerUUID: "005ab220-7ff6-42c9-8615-e4c02eb4104e", - MAC: "ee:1b:db:ca:6b:84", - Floating: upcloud.FromBool(false), - Zone: "fi-hel1", - } - ipAddress6 := upcloud.IPAddress{ Address: "94.237.117.156", Access: "public", Family: "IPv4", @@ -81,7 +70,7 @@ func TestIPAddressResolution(t *testing.T) { } addresses := &upcloud.IPAddresses{IPAddresses: []upcloud.IPAddress{ - ipAddress1, ipAddress2, ipAddress3, ipAddress4, ipAddress5, ipAddress6, + ipAddress1, ipAddress2, ipAddress3, ipAddress4, ipAddress5, }} unambiguousAddresses := []upcloud.IPAddress{ipAddress1, ipAddress2, ipAddress3} @@ -92,9 +81,10 @@ func TestIPAddressResolution(t *testing.T) { argResolver, err := res.Get(context.TODO(), mService) assert.NoError(t, err) for _, network := range unambiguousAddresses { - resolved, err := argResolver(network.Address) + resolved := argResolver(network.Address) + value, err := resolved.GetOnly() assert.NoError(t, err) - assert.Equal(t, network.Address, resolved) + assert.Equal(t, network.Address, value) } // make sure caching works, eg. we didn't call GetServers more than once mService.AssertNumberOfCalls(t, "GetIPAddresses", 1) @@ -107,9 +97,10 @@ func TestIPAddressResolution(t *testing.T) { argResolver, err := res.Get(context.TODO(), mService) assert.NoError(t, err) for _, network := range unambiguousAddresses { - resolved, err := argResolver(network.PTRRecord) + resolved := argResolver(network.PTRRecord) + value, err := resolved.GetOnly() assert.NoError(t, err) - assert.Equal(t, network.Address, resolved) + assert.Equal(t, network.Address, value) } // make sure caching works, eg. we didn't call GetServers more than once mService.AssertNumberOfCalls(t, "GetIPAddresses", 1) @@ -122,29 +113,23 @@ func TestIPAddressResolution(t *testing.T) { argResolver, err := res.Get(context.TODO(), mService) assert.NoError(t, err) - // ambiguous address - resolved, err := argResolver(ipAddress4.Address) - if !assert.Error(t, err) { - t.FailNow() - } - assert.ErrorIs(t, err, resolver.AmbiguousResolutionError(ipAddress4.Address)) - assert.Equal(t, "", resolved) - // ambiguous ptr record - resolved, err = argResolver(ipAddress4.PTRRecord) + resolved := argResolver(ipAddress4.PTRRecord) + value, err := resolved.GetOnly() if !assert.Error(t, err) { t.FailNow() } assert.ErrorIs(t, err, resolver.AmbiguousResolutionError(ipAddress4.PTRRecord)) - assert.Equal(t, "", resolved) + assert.Equal(t, "", value) // not found - resolved, err = argResolver("notfound") + resolved = argResolver("notfound") + value, err = resolved.GetOnly() if !assert.Error(t, err) { t.FailNow() } assert.ErrorIs(t, err, resolver.NotFoundError("notfound")) - assert.Equal(t, "", resolved) + assert.Equal(t, "", value) // make sure caching works, eg. we didn't call GetServers more than once mService.AssertNumberOfCalls(t, "GetIPAddresses", 1) diff --git a/internal/resolver/kubernetes.go b/internal/resolver/kubernetes.go index 0ef6c06e8..70a4d4988 100644 --- a/internal/resolver/kubernetes.go +++ b/internal/resolver/kubernetes.go @@ -20,20 +20,13 @@ func (s CachingKubernetes) Get(ctx context.Context, svc service.AllServices) (Re if err != nil { return nil, err } - return func(arg string) (uuid string, err error) { - rv := "" + return func(arg string) Resolved { + rv := Resolved{Arg: arg} for _, cluster := range clusters { - if cluster.Name == arg || cluster.UUID == arg { - if rv != "" { - return "", AmbiguousResolutionError(arg) - } - rv = cluster.UUID - } + rv.AddMatch(cluster.UUID, MatchArgWithWhitespace(arg, cluster.Name)) + rv.AddMatch(cluster.UUID, MatchUUID(arg, cluster.UUID)) } - if rv != "" { - return rv, nil - } - return "", NotFoundError(arg) + return rv }, nil } diff --git a/internal/resolver/kubernetes_test.go b/internal/resolver/kubernetes_test.go index 9966b1ad9..443c1583f 100644 --- a/internal/resolver/kubernetes_test.go +++ b/internal/resolver/kubernetes_test.go @@ -27,9 +27,10 @@ func TestKubernetesResolution(t *testing.T) { argResolver, err := res.Get(context.TODO(), mService) assert.NoError(t, err) for _, db := range mockClusters { - resolved, err := argResolver(db.UUID) + resolved := argResolver(db.UUID) + value, err := resolved.GetOnly() assert.NoError(t, err) - assert.Equal(t, db.UUID, resolved) + assert.Equal(t, db.UUID, value) } // Make sure caching works, eg. we didn't call GetKubernetesClusters more than once @@ -44,9 +45,10 @@ func TestKubernetesResolution(t *testing.T) { assert.NoError(t, err) db := mockClusters[2] - resolved, err := argResolver(db.Name) + resolved := argResolver(db.Name) + value, err := resolved.GetOnly() assert.NoError(t, err) - assert.Equal(t, db.UUID, resolved) + assert.Equal(t, db.UUID, value) // Make sure caching works, eg. we didn't call GetKubernetesClusters more than once mService.AssertNumberOfCalls(t, "GetKubernetesClusters", 1) }) @@ -58,23 +60,24 @@ func TestKubernetesResolution(t *testing.T) { res := resolver.CachingKubernetes{} argResolver, err := res.Get(context.TODO(), mService) assert.NoError(t, err) - var resolved string // Ambiguous Name - resolved, err = argResolver("asd") + resolved := argResolver("asd") + value, err := resolved.GetOnly() if !assert.Error(t, err) { t.FailNow() } assert.ErrorIs(t, err, resolver.AmbiguousResolutionError("asd")) - assert.Equal(t, "", resolved) + assert.Equal(t, "", value) // Not found - resolved, err = argResolver("not-found") + resolved = argResolver("not-found") + value, err = resolved.GetOnly() if !assert.Error(t, err) { t.FailNow() } assert.ErrorIs(t, err, resolver.NotFoundError("not-found")) - assert.Equal(t, "", resolved) + assert.Equal(t, "", value) // Make sure caching works, eg. we didn't call GetKubernetesClusters more than once mService.AssertNumberOfCalls(t, "GetKubernetesClusters", 1) diff --git a/internal/resolver/loadbalancer.go b/internal/resolver/loadbalancer.go index c6de49380..9ef46c3d5 100644 --- a/internal/resolver/loadbalancer.go +++ b/internal/resolver/loadbalancer.go @@ -19,20 +19,13 @@ func (s CachingLoadBalancer) Get(ctx context.Context, svc internal.AllServices) if err != nil { return nil, err } - return func(arg string) (uuid string, err error) { - rv := "" + return func(arg string) Resolved { + rv := Resolved{Arg: arg} for _, lb := range loadbalancers { - if lb.Name == arg || lb.UUID == arg { - if rv != "" { - return "", AmbiguousResolutionError(arg) - } - rv = lb.UUID - } + rv.AddMatch(lb.UUID, MatchArgWithWhitespace(arg, lb.Name)) + rv.AddMatch(lb.UUID, MatchUUID(arg, lb.UUID)) } - if rv != "" { - return rv, nil - } - return "", NotFoundError(arg) + return rv }, nil } diff --git a/internal/resolver/loadbalancer_test.go b/internal/resolver/loadbalancer_test.go index 7a2219644..31049a5da 100644 --- a/internal/resolver/loadbalancer_test.go +++ b/internal/resolver/loadbalancer_test.go @@ -27,9 +27,10 @@ func TestLoadBalancerResolution(t *testing.T) { argResolver, err := res.Get(context.TODO(), mService) assert.NoError(t, err) for _, db := range mockLoadBalancers { - resolved, err := argResolver(db.UUID) + resolved := argResolver(db.UUID) + value, err := resolved.GetOnly() assert.NoError(t, err) - assert.Equal(t, db.UUID, resolved) + assert.Equal(t, db.UUID, value) } // Make sure caching works, eg. we didn't call GetLoadBalancers more than once @@ -44,9 +45,10 @@ func TestLoadBalancerResolution(t *testing.T) { assert.NoError(t, err) db := mockLoadBalancers[2] - resolved, err := argResolver(db.Name) + resolved := argResolver(db.Name) + value, err := resolved.GetOnly() assert.NoError(t, err) - assert.Equal(t, db.UUID, resolved) + assert.Equal(t, db.UUID, value) // Make sure caching works, eg. we didn't call GetLoadBalancers more than once mService.AssertNumberOfCalls(t, "GetLoadBalancers", 1) }) @@ -58,23 +60,24 @@ func TestLoadBalancerResolution(t *testing.T) { res := resolver.CachingLoadBalancer{} argResolver, err := res.Get(context.TODO(), mService) assert.NoError(t, err) - var resolved string // Ambiguous Name - resolved, err = argResolver("asd") + resolved := argResolver("asd") + value, err := resolved.GetOnly() if !assert.Error(t, err) { t.FailNow() } assert.ErrorIs(t, err, resolver.AmbiguousResolutionError("asd")) - assert.Equal(t, "", resolved) + assert.Equal(t, "", value) // Not found - resolved, err = argResolver("not-found") + resolved = argResolver("not-found") + value, err = resolved.GetOnly() if !assert.Error(t, err) { t.FailNow() } assert.ErrorIs(t, err, resolver.NotFoundError("not-found")) - assert.Equal(t, "", resolved) + assert.Equal(t, "", value) // Make sure caching works, eg. we didn't call GetLoadBalancers more than once mService.AssertNumberOfCalls(t, "GetLoadBalancers", 1) diff --git a/internal/resolver/matchers.go b/internal/resolver/matchers.go index addca6d43..4fe4dd836 100644 --- a/internal/resolver/matchers.go +++ b/internal/resolver/matchers.go @@ -1,10 +1,25 @@ package resolver import ( + "strings" + "github.com/UpCloudLtd/upcloud-cli/v3/internal/completion" ) // MatchStringWithWhitespace checks if arg that may include whitespace matches given value. This checks both quoted args and auto-completed args handled with completion.RemoveWordBreaks. -func MatchArgWithWhitespace(arg string, value string) bool { - return completion.RemoveWordBreaks(value) == arg || value == arg +func MatchArgWithWhitespace(arg, value string) MatchType { + if completion.RemoveWordBreaks(value) == arg || value == arg { + return MatchTypeExact + } + return MatchTypeNone +} + +func MatchUUID(arg, value string) MatchType { + if value == arg { + return MatchTypeExact + } + if strings.HasPrefix(value, arg) { + return MatchTypePrefix + } + return MatchTypeNone } diff --git a/internal/resolver/network.go b/internal/resolver/network.go index fdbf797d9..fb0f702f3 100644 --- a/internal/resolver/network.go +++ b/internal/resolver/network.go @@ -17,21 +17,14 @@ type CachingNetwork struct { // make sure we implement the ResolutionProvider interface var _ ResolutionProvider = &CachingNetwork{} -func networkMatcher(cached []upcloud.Network) func(arg string) (uuid string, err error) { - return func(arg string) (uuid string, err error) { - rv := "" +func networkMatcher(cached []upcloud.Network) func(arg string) Resolved { + return func(arg string) Resolved { + rv := Resolved{Arg: arg} for _, network := range cached { - if MatchArgWithWhitespace(arg, network.Name) || network.UUID == arg { - if rv != "" { - return "", AmbiguousResolutionError(arg) - } - rv = network.UUID - } + rv.AddMatch(network.UUID, MatchArgWithWhitespace(arg, network.Name)) + rv.AddMatch(network.UUID, MatchUUID(arg, network.UUID)) } - if rv != "" { - return rv, nil - } - return "", NotFoundError(arg) + return rv } } @@ -64,7 +57,8 @@ func (s *CachingNetwork) Resolve(arg string) (resolved string, err error) { return "", errors.New("caching network does not have a cache initialized") } - return networkMatcher(s.cached)(arg) + r := networkMatcher(s.cached)(arg) + return r.GetOnly() } // PositionalArgumentHelp implements resolver.ResolutionProvider diff --git a/internal/resolver/network_test.go b/internal/resolver/network_test.go index c93b29c20..46711d529 100644 --- a/internal/resolver/network_test.go +++ b/internal/resolver/network_test.go @@ -14,28 +14,28 @@ import ( var Network1 = upcloud.Network{ Name: "network-1", - UUID: "28e15cf5-8817-42ab-b017-970666be96ec", + UUID: "03e15cf5-8817-42ab-b017-970666be96ec", Type: upcloud.NetworkTypeUtility, Zone: "fi-hel1", } var Network2 = upcloud.Network{ Name: "network-2", - UUID: "f9f5ad16-a63a-4670-8449-c01d1e97281e", + UUID: "03f5ad16-a63a-4670-8449-c01d1e97281e", Type: upcloud.NetworkTypePrivate, Zone: "fi-hel1", } var Network3 = upcloud.Network{ Name: "network-3", - UUID: "e157ce0a-eeb0-49fc-9f2c-a05c3ac57066", + UUID: "0357ce0a-eeb0-49fc-9f2c-a05c3ac57066", Type: upcloud.NetworkTypeUtility, Zone: "uk-lon1", } var Network4 = upcloud.Network{ Name: Network1.Name, - UUID: "b3e49768-f13a-42c3-bea7-4e2471657f2f", + UUID: "03e49768-f13a-42c3-bea7-4e2471657f2f", Type: upcloud.NetworkTypePublic, Zone: "uk-lon1", } @@ -53,14 +53,29 @@ func TestNetworkResolution(t *testing.T) { argResolver, err := res.Get(context.TODO(), mService) assert.NoError(t, err) for _, network := range networks.Networks { - resolved, err := argResolver(network.UUID) + resolved := argResolver(network.UUID) + value, err := resolved.GetOnly() assert.NoError(t, err) - assert.Equal(t, network.UUID, resolved) + assert.Equal(t, network.UUID, value) } // make sure caching works, eg. we didn't call GetServers more than once mService.AssertNumberOfCalls(t, "GetNetworks", 1) }) + t.Run("resolve uuid prefix", func(t *testing.T) { + mService := &smock.Service{} + mService.On("GetNetworks").Return(networks, nil) + res := resolver.CachingNetwork{} + argResolver, err := res.Get(context.TODO(), mService) + assert.NoError(t, err) + resolved := argResolver("035") + value, err := resolved.GetOnly() + assert.NoError(t, err) + assert.Equal(t, Network3.UUID, value) + // make sure caching works, eg. we didn't call GetServers more than once + mService.AssertNumberOfCalls(t, "GetNetworks", 1) + }) + t.Run("resolve name", func(t *testing.T) { mService := &smock.Service{} mService.On("GetNetworks").Return(networks, nil) @@ -68,9 +83,10 @@ func TestNetworkResolution(t *testing.T) { argResolver, err := res.Get(context.TODO(), mService) assert.NoError(t, err) for _, network := range unambiguousNetworks { - resolved, err := argResolver(network.Name) + resolved := argResolver(network.Name) + value, err := resolved.GetOnly() assert.NoError(t, err) - assert.Equal(t, network.UUID, resolved) + assert.Equal(t, network.UUID, value) } // make sure caching works, eg. we didn't call GetServers more than once mService.AssertNumberOfCalls(t, "GetNetworks", 1) @@ -84,20 +100,31 @@ func TestNetworkResolution(t *testing.T) { assert.NoError(t, err) // ambiguous name - resolved, err := argResolver(Network1.Name) + resolved := argResolver(Network1.Name) + value, err := resolved.GetOnly() if !assert.Error(t, err) { t.FailNow() } assert.ErrorIs(t, err, resolver.AmbiguousResolutionError(Network1.Name)) - assert.Equal(t, "", resolved) + assert.Equal(t, "", value) + + // ambiguous UUID prefix + resolved = argResolver("03") + value, err = resolved.GetOnly() + if !assert.Error(t, err) { + t.FailNow() + } + assert.ErrorIs(t, err, resolver.AmbiguousResolutionError("03")) + assert.Equal(t, "", value) // not found - resolved, err = argResolver("notfound") + resolved = argResolver("notfound") + value, err = resolved.GetOnly() if !assert.Error(t, err) { t.FailNow() } assert.ErrorIs(t, err, resolver.NotFoundError("notfound")) - assert.Equal(t, "", resolved) + assert.Equal(t, "", value) // make sure caching works, eg. we didn't call GetServers more than once mService.AssertNumberOfCalls(t, "GetNetworks", 1) diff --git a/internal/resolver/networkpeering.go b/internal/resolver/networkpeering.go index f9cae6c40..6e8bb08ef 100644 --- a/internal/resolver/networkpeering.go +++ b/internal/resolver/networkpeering.go @@ -18,20 +18,13 @@ func (s CachingNetworkPeering) Get(ctx context.Context, svc internal.AllServices if err != nil { return nil, err } - return func(arg string) (uuid string, err error) { - rv := "" + return func(arg string) Resolved { + rv := Resolved{Arg: arg} for _, peering := range gateways { - if MatchArgWithWhitespace(arg, peering.Name) || peering.UUID == arg { - if rv != "" { - return "", AmbiguousResolutionError(arg) - } - rv = peering.UUID - } + rv.AddMatch(peering.UUID, MatchArgWithWhitespace(arg, peering.Name)) + rv.AddMatch(peering.UUID, MatchUUID(arg, peering.UUID)) } - if rv != "" { - return rv, nil - } - return "", NotFoundError(arg) + return rv }, nil } diff --git a/internal/resolver/objectstorage.go b/internal/resolver/objectstorage.go index 817628c3e..2fad9d910 100644 --- a/internal/resolver/objectstorage.go +++ b/internal/resolver/objectstorage.go @@ -19,20 +19,13 @@ func (s CachingObjectStorage) Get(ctx context.Context, svc internal.AllServices) if err != nil { return nil, err } - return func(arg string) (uuid string, err error) { - rv := "" + return func(arg string) Resolved { + rv := Resolved{Arg: arg} for _, objsto := range objectstorages { - if MatchArgWithWhitespace(arg, objsto.Name) || objsto.UUID == arg { - if rv != "" { - return "", AmbiguousResolutionError(arg) - } - rv = objsto.UUID - } + rv.AddMatch(objsto.UUID, MatchArgWithWhitespace(arg, objsto.Name)) + rv.AddMatch(objsto.UUID, MatchUUID(arg, objsto.UUID)) } - if rv != "" { - return rv, nil - } - return "", NotFoundError(arg) + return rv }, nil } diff --git a/internal/resolver/resolver.go b/internal/resolver/resolver.go index 23cb3004c..a520d28a1 100644 --- a/internal/resolver/resolver.go +++ b/internal/resolver/resolver.go @@ -8,11 +8,77 @@ import ( const helpUUIDTitle = "" -// Resolver represents the most basic argument resolver, a function that accepts and argument and returns an uuid (or error) -type Resolver func(arg string) (uuid string, err error) +// Resolver represents the most basic argument resolver, a function that accepts and argument and returns the resolved value(s). +type Resolver func(arg string) (resolved Resolved) // ResolutionProvider is an interface for commands that provide resolution, either custom or the built-in ones type ResolutionProvider interface { Get(ctx context.Context, svc service.AllServices) (Resolver, error) PositionalArgumentHelp() string } + +type MatchType int + +const ( + MatchTypeExact MatchType = 3 + MatchTypeCaseInsensitive MatchType = 2 + MatchTypeWildCard MatchType = 2 + MatchTypePrefix MatchType = 1 + MatchTypeNone MatchType = 0 +) + +type Resolved struct { + Arg string + matches map[string]MatchType +} + +// AddMatch adds a match to the resolved value. If the match is already present, the highest match type is kept. I.e., exact match is kept over case insensitive match. +func (r *Resolved) AddMatch(uuid string, matchType MatchType) { + if r.matches == nil { + r.matches = make(map[string]MatchType) + } + + current := r.matches[uuid] + r.matches[uuid] = max(current, matchType) +} + +// GetAll returns all matches with match-type that equals the highest available match-type for the resolved value. I.e., if there is an exact match, only exact matches are returned even if there would be case-insensitive matches. +func (r *Resolved) GetAll() ([]string, error) { + var all []string + for _, matchType := range []MatchType{ + MatchTypeExact, + MatchTypeCaseInsensitive, + MatchTypeWildCard, + MatchTypePrefix, + } { + for uuid, match := range r.matches { + if match == matchType { + all = append(all, uuid) + } + } + + if len(all) > 0 { + return all, nil + } + } + + var err error + if len(all) == 0 { + err = NotFoundError(r.Arg) + } + return all, err +} + +// GetOnly returns the only match if there is only one match. If there are no or multiple matches, an empty value and an error is returned. +func (r *Resolved) GetOnly() (string, error) { + all, err := r.GetAll() + if err != nil { + return "", err + } + + if len(all) > 1 { + return "", AmbiguousResolutionError(r.Arg) + } + + return all[0], nil +} diff --git a/internal/resolver/router.go b/internal/resolver/router.go index 13b1052f0..0f47afd36 100644 --- a/internal/resolver/router.go +++ b/internal/resolver/router.go @@ -24,20 +24,13 @@ func (s *CachingRouter) Get(ctx context.Context, svc internal.AllServices) (Reso return nil, err } s.cached = routers.Routers - return func(arg string) (uuid string, err error) { - rv := "" + return func(arg string) Resolved { + rv := Resolved{Arg: arg} for _, router := range s.cached { - if MatchArgWithWhitespace(arg, router.Name) || router.UUID == arg { - if rv != "" { - return "", AmbiguousResolutionError(arg) - } - rv = router.UUID - } + rv.AddMatch(router.UUID, MatchArgWithWhitespace(arg, router.Name)) + rv.AddMatch(router.UUID, MatchUUID(arg, router.UUID)) } - if rv != "" { - return rv, nil - } - return "", NotFoundError(arg) + return rv }, nil } diff --git a/internal/resolver/router_test.go b/internal/resolver/router_test.go index 88de57b3f..4193f0929 100644 --- a/internal/resolver/router_test.go +++ b/internal/resolver/router_test.go @@ -58,9 +58,10 @@ func TestRouterResolution(t *testing.T) { argResolver, err := res.Get(context.TODO(), mService) assert.NoError(t, err) for _, router := range allRouters.Routers { - resolved, err := argResolver(router.UUID) + resolved := argResolver(router.UUID) + value, err := resolved.GetOnly() assert.NoError(t, err) - assert.Equal(t, router.UUID, resolved) + assert.Equal(t, router.UUID, value) } // make sure caching works, eg. we didn't call GetRouters more than once mService.AssertNumberOfCalls(t, "GetRouters", 1) @@ -73,9 +74,10 @@ func TestRouterResolution(t *testing.T) { argResolver, err := res.Get(context.TODO(), mService) assert.NoError(t, err) for _, srv := range unambiguousRouters { - resolved, err := argResolver(srv.Name) + resolved := argResolver(srv.Name) + value, err := resolved.GetOnly() assert.NoError(t, err) - assert.Equal(t, srv.UUID, resolved) + assert.Equal(t, srv.UUID, value) } // make sure caching works, eg. we didn't call GetRouters more than once mService.AssertNumberOfCalls(t, "GetRouters", 1) @@ -90,20 +92,22 @@ func TestRouterResolution(t *testing.T) { assert.NoError(t, err) // ambiguous name - resolved, err := argResolver(Router1.Name) + resolved := argResolver(Router1.Name) + value, err := resolved.GetOnly() if !assert.Error(t, err) { t.FailNow() } assert.ErrorIs(t, err, resolver.AmbiguousResolutionError(Router1.Name)) - assert.Equal(t, "", resolved) + assert.Equal(t, "", value) // not found - resolved, err = argResolver("notfound") + resolved = argResolver("notfound") + value, err = resolved.GetOnly() if !assert.Error(t, err) { t.FailNow() } assert.ErrorIs(t, err, resolver.NotFoundError("notfound")) - assert.Equal(t, "", resolved) + assert.Equal(t, "", value) // make sure caching works, eg. we didn't call GetServers more than once mService.AssertNumberOfCalls(t, "GetRouters", 1) diff --git a/internal/resolver/server.go b/internal/resolver/server.go index 8e407aa1d..574a0f691 100644 --- a/internal/resolver/server.go +++ b/internal/resolver/server.go @@ -18,20 +18,14 @@ func (s CachingServer) Get(ctx context.Context, svc internal.AllServices) (Resol if err != nil { return nil, err } - return func(arg string) (uuid string, err error) { - rv := "" + return func(arg string) Resolved { + rv := Resolved{Arg: arg} for _, server := range servers.Servers { - if MatchArgWithWhitespace(arg, server.Title) || server.Hostname == arg || server.UUID == arg { - if rv != "" { - return "", AmbiguousResolutionError(arg) - } - rv = server.UUID - } + rv.AddMatch(server.UUID, MatchArgWithWhitespace(arg, server.Title)) + rv.AddMatch(server.UUID, MatchArgWithWhitespace(arg, server.Hostname)) + rv.AddMatch(server.UUID, MatchUUID(arg, server.UUID)) } - if rv != "" { - return rv, nil - } - return "", NotFoundError(arg) + return rv }, nil } diff --git a/internal/resolver/server_test.go b/internal/resolver/server_test.go index ff04a9255..c8c234144 100644 --- a/internal/resolver/server_test.go +++ b/internal/resolver/server_test.go @@ -104,9 +104,10 @@ func TestServerResolution(t *testing.T) { argResolver, err := res.Get(context.TODO(), mService) assert.NoError(t, err) for _, srv := range allServers.Servers { - resolved, err := argResolver(srv.UUID) + resolved := argResolver(srv.UUID) + value, err := resolved.GetOnly() assert.NoError(t, err) - assert.Equal(t, srv.UUID, resolved) + assert.Equal(t, srv.UUID, value) } // make sure caching works, eg. we didn't call GetServers more than once mService.AssertNumberOfCalls(t, "GetServers", 1) @@ -119,9 +120,10 @@ func TestServerResolution(t *testing.T) { argResolver, err := res.Get(context.TODO(), mService) assert.NoError(t, err) for _, srv := range unambiguousServers { - resolved, err := argResolver(srv.Hostname) + resolved := argResolver(srv.Hostname) + value, err := resolved.GetOnly() assert.NoError(t, err) - assert.Equal(t, srv.UUID, resolved) + assert.Equal(t, srv.UUID, value) } // make sure caching works, eg. we didn't call GetServers more than once mService.AssertNumberOfCalls(t, "GetServers", 1) @@ -134,9 +136,10 @@ func TestServerResolution(t *testing.T) { argResolver, err := res.Get(context.TODO(), mService) assert.NoError(t, err) for _, srv := range unambiguousServers { - resolved, err := argResolver(srv.Title) + resolved := argResolver(srv.Title) + value, err := resolved.GetOnly() assert.NoError(t, err) - assert.Equal(t, srv.UUID, resolved) + assert.Equal(t, srv.UUID, value) } // make sure caching works, eg. we didn't call GetServers more than once mService.AssertNumberOfCalls(t, "GetServers", 1) @@ -151,28 +154,31 @@ func TestServerResolution(t *testing.T) { assert.NoError(t, err) // ambiguous hostname - resolved, err := argResolver(Server4.Hostname) + resolved := argResolver(Server4.Hostname) + value, err := resolved.GetOnly() if !assert.Error(t, err) { t.FailNow() } assert.ErrorIs(t, err, resolver.AmbiguousResolutionError(Server4.Hostname)) - assert.Equal(t, "", resolved) + assert.Equal(t, "", value) // ambiguous title - resolved, err = argResolver(Server1.Title) + resolved = argResolver(Server1.Title) + value, err = resolved.GetOnly() if !assert.Error(t, err) { t.FailNow() } assert.ErrorIs(t, err, resolver.AmbiguousResolutionError(Server1.Title)) - assert.Equal(t, "", resolved) + assert.Equal(t, "", value) // not found - resolved, err = argResolver("notfound") + resolved = argResolver("notfound") + value, err = resolved.GetOnly() if !assert.Error(t, err) { t.FailNow() } assert.ErrorIs(t, err, resolver.NotFoundError("notfound")) - assert.Equal(t, "", resolved) + assert.Equal(t, "", value) // make sure caching works, eg. we didn't call GetServers more than once mService.AssertNumberOfCalls(t, "GetServers", 1) diff --git a/internal/resolver/servergroup.go b/internal/resolver/servergroup.go index d13fe683a..b335645ca 100644 --- a/internal/resolver/servergroup.go +++ b/internal/resolver/servergroup.go @@ -20,20 +20,13 @@ func (s CachingServerGroup) Get(ctx context.Context, svc internal.AllServices) ( if err != nil { return nil, err } - return func(arg string) (uuid string, err error) { - rv := "" + return func(arg string) Resolved { + rv := Resolved{Arg: arg} for _, serverGroup := range serverGroups { - if MatchArgWithWhitespace(arg, serverGroup.Title) || serverGroup.UUID == arg { - if rv != "" { - return "", AmbiguousResolutionError(arg) - } - rv = serverGroup.UUID - } + rv.AddMatch(serverGroup.UUID, MatchArgWithWhitespace(arg, serverGroup.Title)) + rv.AddMatch(serverGroup.UUID, MatchUUID(arg, serverGroup.UUID)) } - if rv != "" { - return rv, nil - } - return "", NotFoundError(arg) + return rv }, nil } diff --git a/internal/resolver/servergroup_test.go b/internal/resolver/servergroup_test.go index b59592cd5..eb67c6499 100644 --- a/internal/resolver/servergroup_test.go +++ b/internal/resolver/servergroup_test.go @@ -59,9 +59,10 @@ func TestServerGroupResolution(t *testing.T) { argResolver, err := res.Get(context.TODO(), mService) assert.NoError(t, err) for _, srv := range allServerGroups { - resolved, err := argResolver(srv.UUID) + resolved := argResolver(srv.UUID) + value, err := resolved.GetOnly() assert.NoError(t, err) - assert.Equal(t, srv.UUID, resolved) + assert.Equal(t, srv.UUID, value) } // make sure caching works, eg. we didn't call GetServerGroups more than once mService.AssertNumberOfCalls(t, "GetServerGroups", 1) @@ -74,9 +75,10 @@ func TestServerGroupResolution(t *testing.T) { argResolver, err := res.Get(context.TODO(), mService) assert.NoError(t, err) for _, srv := range unambiguousServerGroups { - resolved, err := argResolver(srv.Title) + resolved := argResolver(srv.Title) + value, err := resolved.GetOnly() assert.NoError(t, err) - assert.Equal(t, srv.UUID, resolved) + assert.Equal(t, srv.UUID, value) } // make sure caching works, eg. we didn't call GetServerGroups more than once mService.AssertNumberOfCalls(t, "GetServerGroups", 1) @@ -91,20 +93,22 @@ func TestServerGroupResolution(t *testing.T) { assert.NoError(t, err) // ambiguous title - resolved, err := argResolver(ServerGroup1.Title) + resolved := argResolver(ServerGroup1.Title) + value, err := resolved.GetOnly() if !assert.Error(t, err) { t.FailNow() } assert.ErrorIs(t, err, resolver.AmbiguousResolutionError(ServerGroup1.Title)) - assert.Equal(t, "", resolved) + assert.Equal(t, "", value) // not found - resolved, err = argResolver("notfound") + resolved = argResolver("notfound") + value, err = resolved.GetOnly() if !assert.Error(t, err) { t.FailNow() } assert.ErrorIs(t, err, resolver.NotFoundError("notfound")) - assert.Equal(t, "", resolved) + assert.Equal(t, "", value) // make sure caching works, eg. we didn't call GetServerGroups more than once mService.AssertNumberOfCalls(t, "GetServerGroups", 1) diff --git a/internal/resolver/storage.go b/internal/resolver/storage.go index 1ee9182a6..58da9c3bc 100644 --- a/internal/resolver/storage.go +++ b/internal/resolver/storage.go @@ -18,23 +18,14 @@ type CachingStorage struct { // make sure we implement the ResolutionProvider interface var _ ResolutionProvider = &CachingStorage{} -func storageMatcher(cached []upcloud.Storage) func(arg string) (uuid string, err error) { - return func(arg string) (uuid string, err error) { - rv := "" +func storageMatcher(cached []upcloud.Storage) func(arg string) Resolved { + return func(arg string) Resolved { + rv := Resolved{Arg: arg} for _, storage := range cached { - if MatchArgWithWhitespace(arg, storage.Title) || storage.UUID == arg { - if rv != "" { - return "", AmbiguousResolutionError(arg) - } - rv = storage.UUID - } + rv.AddMatch(storage.UUID, MatchArgWithWhitespace(arg, storage.Title)) + rv.AddMatch(storage.UUID, MatchUUID(arg, storage.UUID)) } - - if rv != "" { - return rv, nil - } - - return "", NotFoundError(arg) + return rv } } @@ -59,7 +50,8 @@ func (s *CachingStorage) Resolve(arg string) (resolved string, err error) { return "", errors.New("caching storage does not have a cache initialized") } - return storageMatcher(s.cachedStorages.Storages)(arg) + r := storageMatcher(s.cachedStorages.Storages)(arg) + return r.GetOnly() } // GetCached is a helper method for commands to use when they need to get an item from the cached results diff --git a/internal/resolver/storage_test.go b/internal/resolver/storage_test.go index 4e50a757e..09a3022c7 100644 --- a/internal/resolver/storage_test.go +++ b/internal/resolver/storage_test.go @@ -35,9 +35,10 @@ func TestStorageResolution(t *testing.T) { argResolver, err := res.Get(context.TODO(), mService) assert.NoError(t, err) for _, storage := range unambiguousStorages { - resolved, err := argResolver(storage.UUID) + resolved := argResolver(storage.UUID) + value, err := resolved.GetOnly() assert.NoError(t, err) - assert.Equal(t, storage.UUID, resolved) + assert.Equal(t, storage.UUID, value) } // make sure caching works, eg. we didn't call GetStorages more than once mService.AssertNumberOfCalls(t, "GetStorages", 1) @@ -50,9 +51,10 @@ func TestStorageResolution(t *testing.T) { argResolver, err := res.Get(context.TODO(), mService) assert.NoError(t, err) for _, storage := range unambiguousStorages { - resolved, err := argResolver(storage.Title) + resolved := argResolver(storage.Title) + value, err := resolved.GetOnly() assert.NoError(t, err) - assert.Equal(t, storage.UUID, resolved) + assert.Equal(t, storage.UUID, value) } // make sure caching works, eg. we didn't call GetStorages more than once mService.AssertNumberOfCalls(t, "GetStorages", 1) @@ -66,29 +68,23 @@ func TestStorageResolution(t *testing.T) { argResolver, err := res.Get(context.TODO(), mService) assert.NoError(t, err) - // ambiguous uuid - resolved, err := argResolver(amb2.UUID) - if !assert.Error(t, err) { - t.FailNow() - } - assert.ErrorIs(t, err, resolver.AmbiguousResolutionError(amb2.UUID)) - assert.Equal(t, "", resolved) - // ambiguous title - resolved, err = argResolver(amb1.Title) + resolved := argResolver(amb1.Title) + value, err := resolved.GetOnly() if !assert.Error(t, err) { t.FailNow() } assert.ErrorIs(t, err, resolver.AmbiguousResolutionError(amb1.Title)) - assert.Equal(t, "", resolved) + assert.Equal(t, "", value) // not found - resolved, err = argResolver("notfound") + resolved = argResolver("notfound") + value, err = resolved.GetOnly() if !assert.Error(t, err) { t.FailNow() } assert.ErrorIs(t, err, resolver.NotFoundError("notfound")) - assert.Equal(t, "", resolved) + assert.Equal(t, "", value) // make sure caching works, eg. we didn't call GetServers more than once mService.AssertNumberOfCalls(t, "GetStorages", 1)