Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(upgrades_registry): small refactoring of upgrades registry #23

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions internal/pkg/provider/local/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ func (lp *Provider) GetVersionsByHeight(ctx context.Context, height uint64) ([]*
return provider.PostProcessVersions(filtered, urproto.ProviderType_LOCAL, lp.priority), nil
}

func (lp *Provider) StoreState(_ context.Context, state *sm.State) error {
func (lp *Provider) StoreState(state *sm.State) error {
lp.lock.Lock()
defer lp.lock.Unlock()

Expand Down Expand Up @@ -254,7 +254,7 @@ func (lp *Provider) checkUniqueKey(data *localProviderData) error {
return nil
}

func (lp *Provider) RestoreState(_ context.Context) (*sm.State, error) {
func (lp *Provider) RestoreState() (*sm.State, error) {
data, err := lp.readData(true)
if err != nil {
return nil, err
Expand Down
3 changes: 1 addition & 2 deletions internal/pkg/provider/local/local_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package local

import (
"context"
"path"
"sync"
"testing"
Expand Down Expand Up @@ -133,7 +132,7 @@ func TestLoadFailing(t *testing.T) {
lock: &sync.RWMutex{},
}

_, err := lp.RestoreState(context.Background())
_, err := lp.RestoreState()
require.Error(t, err)
assert.Contains(t, err.Error(), tt.err)
})
Expand Down
15 changes: 7 additions & 8 deletions internal/pkg/state_machine/state_machine.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package state_machine

import (
"context"
"fmt"
"slices"
"sync"
Expand All @@ -10,7 +9,7 @@ import (
urproto "blazar/internal/pkg/proto/upgrades_registry"
)

// The rule are are as follows:
// The rule are as follows:
// 1. upgrades coming from the providers have one of the following statuses (`upgrade.Status`)
// - UNKNOWN
// - SCHEDULED
Expand Down Expand Up @@ -44,8 +43,8 @@ func init() {
}

type StateMachineStorage interface {
StoreState(context.Context, *State) error
RestoreState(context.Context) (*State, error)
StoreState(*State) error
RestoreState() (*State, error)
}

type State struct {
Expand All @@ -56,7 +55,7 @@ type State struct {
PostCheckStatus map[int64]map[checksproto.PostCheck]checksproto.CheckStatus `json:"post_check_status"`
}

// Simple, unsphisitcated state machine for managing upgrades
// StateMachine a simple unsophisticated state machine for managing upgrades
type StateMachine struct {
lock *sync.RWMutex
state *State
Expand Down Expand Up @@ -272,13 +271,13 @@ func (sm *StateMachine) GetPostCheckStatus(height int64, check checksproto.PostC
return checksproto.CheckStatus_PENDING
}

func (sm *StateMachine) Restore(ctx context.Context) error {
func (sm *StateMachine) Restore() error {
if sm.storage == nil {
// if it wasn't configured then we don't need to restore the state
return nil
}

state, err := sm.storage.RestoreState(ctx)
state, err := sm.storage.RestoreState()
if err != nil {
return err
}
Expand Down Expand Up @@ -358,6 +357,6 @@ func (sm *StateMachine) persist() {
// TODO: For now we ignore writing to the storage errors because this is not a critical operation
// NOTE: The caller must hold the lock
if sm.storage != nil {
_ = sm.storage.StoreState(context.TODO(), sm.state)
_ = sm.storage.StoreState(sm.state)
}
}
66 changes: 32 additions & 34 deletions internal/pkg/upgrades_registry/upgrades_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type UpgradeRegistry struct {
// lock for the registry
lock *sync.RWMutex

// a list of latst fetched upgrades
// a list of latest fetched upgrades
upgrades map[int64]*urproto.Upgrade

// a list of versions fetched from providers
Expand Down Expand Up @@ -76,11 +76,14 @@ func NewUpgradeRegistry(providers map[urproto.ProviderType]provider.UpgradeProvi
}

func NewUpgradesRegistryFromConfig(cfg *config.Config) (*UpgradeRegistry, error) {
providers := make(map[urproto.ProviderType]provider.UpgradeProvider, 0)
providers := make(map[urproto.ProviderType]provider.UpgradeProvider)

if cfg.UpgradeRegistry.Provider.Chain != nil && slices.Contains(
cfg.UpgradeRegistry.SelectedProviders, urproto.ProviderType_name[int32(urproto.ProviderType_CHAIN)],
) {
isEnabled := func(providerType urproto.ProviderType) bool {
name := urproto.ProviderType_name[int32(providerType)]
return slices.Contains(cfg.UpgradeRegistry.SelectedProviders, name)
}

if cfg.UpgradeRegistry.Provider.Chain != nil && isEnabled(urproto.ProviderType_CHAIN) {
cosmosClient, err := cosmos.NewClient(cfg.Clients.Host, cfg.Clients.GrpcPort, cfg.Clients.CometbftPort, cfg.Clients.Timeout)
if err != nil {
return nil, errors.Wrapf(err, "failed to create cosmos client")
Expand All @@ -90,35 +93,31 @@ func NewUpgradesRegistryFromConfig(cfg *config.Config) (*UpgradeRegistry, error)
return nil, errors.Wrapf(err, "failed to start cometbft client")
}

provider := chain.NewProvider(cosmosClient, cfg.UpgradeRegistry.Network, cfg.UpgradeRegistry.Provider.Chain.DefaultPriority)
providers[provider.Type()] = provider
p := chain.NewProvider(cosmosClient, cfg.UpgradeRegistry.Network, cfg.UpgradeRegistry.Provider.Chain.DefaultPriority)
providers[p.Type()] = p
}

if cfg.UpgradeRegistry.Provider.Database != nil && slices.Contains(
cfg.UpgradeRegistry.SelectedProviders, urproto.ProviderType_name[int32(urproto.ProviderType_DATABASE)],
) {
provider, err := database.NewDatabaseProvider(
if cfg.UpgradeRegistry.Provider.Database != nil && isEnabled(urproto.ProviderType_DATABASE) {
p, err := database.NewDatabaseProvider(
cfg.UpgradeRegistry.Provider.Database,
cfg.UpgradeRegistry.Network,
)
if err != nil {
return nil, errors.Wrapf(err, "failed to create database provider")
}
providers[provider.Type()] = provider
providers[p.Type()] = p
}

if cfg.UpgradeRegistry.Provider.Local != nil && slices.Contains(
cfg.UpgradeRegistry.SelectedProviders, urproto.ProviderType_name[int32(urproto.ProviderType_LOCAL)],
) {
provider, err := local.NewProvider(
if cfg.UpgradeRegistry.Provider.Local != nil && isEnabled(urproto.ProviderType_LOCAL) {
p, err := local.NewProvider(
cfg.UpgradeRegistry.Provider.Local.ConfigPath,
cfg.UpgradeRegistry.Network,
cfg.UpgradeRegistry.Provider.Local.DefaultPriority,
)
if err != nil {
return nil, errors.Wrapf(err, "failed to create local provider")
}
providers[provider.Type()] = provider
providers[p.Type()] = p
}

versionProviders := make([]urproto.ProviderType, 0)
Expand Down Expand Up @@ -149,8 +148,7 @@ func NewUpgradesRegistryFromConfig(cfg *config.Config) (*UpgradeRegistry, error)
stateMachine = state_machine.NewStateMachine(nil)
}

// TODO: context in constructor aint great
err := stateMachine.Restore(context.Background())
err := stateMachine.Restore()
if err != nil {
return nil, errors.Wrapf(err, "failed to restore state machine")
}
Expand Down Expand Up @@ -206,7 +204,7 @@ func (ur *UpgradeRegistry) GetUpcomingUpgradesWithCache(height int64, allowedSta
ur.lock.RLock()
defer ur.lock.RUnlock()

upcomingUpgrades := sortAndfilterUpgradesByStatus(ur.upgrades, ur.stateMachine, height, allowedStatus...)
upcomingUpgrades := sortAndFilterUpgradesByStatus(ur.upgrades, ur.stateMachine, height, allowedStatus...)

return copyList(upcomingUpgrades)
}
Expand All @@ -224,7 +222,7 @@ func (ur *UpgradeRegistry) GetUpcomingUpgrades(ctx context.Context, useCache boo
return nil, err
}

return sortAndfilterUpgradesByStatus(resolvedUpgrades, ur.stateMachine, height, allowedStatus...), nil
return sortAndFilterUpgradesByStatus(resolvedUpgrades, ur.stateMachine, height, allowedStatus...), nil
}

func (ur *UpgradeRegistry) GetUpgradeWithCache(height int64) *urproto.Upgrade {
Expand Down Expand Up @@ -352,8 +350,8 @@ func (ur *UpgradeRegistry) UpdateVersions(ctx context.Context, commit bool) (map
// https://tip.golang.org/doc/go1.22#language

g.Go(func() error {
if provider, ok := ur.providers[providerName].(provider.VersionResolver); ok {
versions, err := provider.GetVersions(ctx)
if p, ok := ur.providers[providerName].(provider.VersionResolver); ok {
versions, err := p.GetVersions(ctx)
if err != nil {
return errors.Wrapf(err, "%s provider failed to fetch versions", providerName)
}
Expand Down Expand Up @@ -395,7 +393,7 @@ func (ur *UpgradeRegistry) UpdateUpgrades(ctx context.Context, currentHeight int
results := make([][]*urproto.Upgrade, len(ur.providers))

i := 0
for _, provider := range ur.providers {
for _, p := range ur.providers {
// from go 1.22 the copy of the loop variable (provider) is not needed anymore
// https://tip.golang.org/doc/go1.22#language
//
Expand All @@ -404,13 +402,13 @@ func (ur *UpgradeRegistry) UpdateUpgrades(ctx context.Context, currentHeight int
ii := i

g.Go(func() error {
upgrades, err := provider.GetUpgrades(ctx)
upgrades, err := p.GetUpgrades(ctx)
if err != nil {
return errors.Wrapf(err, "%s provider failed to fetch upgrades", provider.Type())
return errors.Wrapf(err, "%s provider failed to fetch upgrades", p.Type())
}

if err := checkDuplicates(upgrades, provider.Type()); err != nil {
return errors.Wrapf(err, "%s provider returned duplicate upgrades", provider.Type())
if err := checkDuplicates(upgrades, p.Type()); err != nil {
return errors.Wrapf(err, "%s provider returned duplicate upgrades", p.Type())
}

results[ii] = upgrades
Expand Down Expand Up @@ -523,7 +521,7 @@ func (ur *UpgradeRegistry) AddUpgrade(ctx context.Context, upgrade *urproto.Upgr
ur.lock.RLock()
defer ur.lock.RUnlock()

// The use case for cancelled status is for user to create and upgrade with higher proiority to cancel the existing upgrade
// The use case for cancelled status is for user to create and upgrade with higher priority to cancel the existing upgrade
if upgrade.Status != urproto.UpgradeStatus_UNKNOWN && upgrade.Status != urproto.UpgradeStatus_CANCELLED {
return errors.New("status is not allowed to be set manually")
}
Expand All @@ -537,15 +535,15 @@ func (ur *UpgradeRegistry) AddUpgrade(ctx context.Context, upgrade *urproto.Upgr
return errors.New("add upgrade is not supported for chain provider")

case urproto.ProviderType_DATABASE:
if provider, ok := ur.providers[urproto.ProviderType_DATABASE]; ok {
return provider.AddUpgrade(ctx, upgrade, overwrite)
if p, ok := ur.providers[urproto.ProviderType_DATABASE]; ok {
return p.AddUpgrade(ctx, upgrade, overwrite)
} else {
return errors.New("database provider is not configured")
}

case urproto.ProviderType_LOCAL:
if provider, ok := ur.providers[urproto.ProviderType_LOCAL]; ok {
return provider.AddUpgrade(ctx, upgrade, overwrite)
if p, ok := ur.providers[urproto.ProviderType_LOCAL]; ok {
return p.AddUpgrade(ctx, upgrade, overwrite)
} else {
return errors.New("local provider is not configured")
}
Expand Down Expand Up @@ -614,7 +612,7 @@ func checkDuplicates[T interface {
return nil
}

func sortAndfilterUpgradesByStatus(upgrades map[int64]*urproto.Upgrade, sm *state_machine.StateMachine, height int64, allowedStatus ...urproto.UpgradeStatus) []*urproto.Upgrade {
func sortAndFilterUpgradesByStatus(upgrades map[int64]*urproto.Upgrade, sm *state_machine.StateMachine, height int64, allowedStatus ...urproto.UpgradeStatus) []*urproto.Upgrade {
upcomingUpgrades := make([]*urproto.Upgrade, 0)
for _, upgrade := range upgrades {
currentStatus := sm.GetStatus(upgrade.Height)
Expand Down
Loading