Skip to content

Commit

Permalink
Parse state aggregates
Browse files Browse the repository at this point in the history
  • Loading branch information
carlocamurri committed Jun 23, 2023
1 parent b4d1291 commit 2adfa5e
Show file tree
Hide file tree
Showing 10 changed files with 724 additions and 200 deletions.
30 changes: 14 additions & 16 deletions internal/common/database/db_testutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ package database

import (
"context"
"fmt"

"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/pgxpool"
_ "github.com/jackc/pgx/v4/stdlib"
Expand Down Expand Up @@ -39,20 +37,20 @@ func WithTestDb(migrations []Migration, action func(db *pgxpool.Pool) error) err
return errors.WithStack(err)
}

defer func() {
// disconnect all db user before cleanup
_, err = db.Exec(ctx,
`SELECT pg_terminate_backend(pg_stat_activity.pid)
FROM pg_stat_activity WHERE pg_stat_activity.datname = '`+dbName+`';`)
if err != nil {
fmt.Println("Failed to disconnect users")
}

_, err = db.Exec(ctx, "DROP DATABASE "+dbName)
if err != nil {
fmt.Println("Failed to drop database")
}
}()
//defer func() {
// // disconnect all db user before cleanup
// _, err = db.Exec(ctx,
// `SELECT pg_terminate_backend(pg_stat_activity.pid)
// FROM pg_stat_activity WHERE pg_stat_activity.datname = '`+dbName+`';`)
// if err != nil {
// fmt.Println("Failed to disconnect users")
// }
//
// _, err = db.Exec(ctx, "DROP DATABASE "+dbName)
// if err != nil {
// fmt.Println("Failed to drop database")
// }
//}()

err = UpdateDatabase(ctx, testDbPool, migrations)
if err != nil {
Expand Down
12 changes: 12 additions & 0 deletions internal/common/database/lookout/jobstates.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,18 @@ const (
)

var (
// JobStates is an ordered list of states
JobStates = []JobState{
JobQueued,
JobLeased,
JobPending,
JobRunning,
JobSucceeded,
JobFailed,
JobCancelled,
JobPreempted,
}

JobStateMap = map[int]JobState{
JobLeasedOrdinal: JobLeased,
JobQueuedOrdinal: JobQueued,
Expand Down
131 changes: 131 additions & 0 deletions internal/lookoutv2/repository/aggregates.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package repository

import (
"fmt"
"github.com/armadaproject/armada/internal/common/database/lookout"
"github.com/armadaproject/armada/internal/common/util"
"github.com/armadaproject/armada/internal/lookoutv2/model"
"github.com/pkg/errors"
)

type QueryAggregator interface {
AggregateSql() (string, error)
}

type SqlFunctionAggregator struct {
queryCol *queryColumn
sqlFunction string
}

func NewSqlFunctionAggregator(queryCol *queryColumn, fn string) *SqlFunctionAggregator {
return &SqlFunctionAggregator{
queryCol: queryCol,
sqlFunction: fn,
}
}

func (qa *SqlFunctionAggregator) aggregateColName() string {
return qa.queryCol.name
}

func (qa *SqlFunctionAggregator) AggregateSql() (string, error) {
return fmt.Sprintf("%s(%s.%s) AS %s", qa.sqlFunction, qa.queryCol.abbrev, qa.queryCol.name, qa.aggregateColName()), nil
}

type StateCountAggregator struct {
queryCol *queryColumn
stateString string
}

func NewStateCountAggregator(queryCol *queryColumn, stateString string) *StateCountAggregator {
return &StateCountAggregator{
queryCol: queryCol,
stateString: stateString,
}
}

func (qa *StateCountAggregator) aggregateColName() string {
return fmt.Sprintf("%s_%s", qa.queryCol.name, qa.stateString)
}

func (qa *StateCountAggregator) AggregateSql() (string, error) {
stateInt, ok := lookout.JobStateOrdinalMap[lookout.JobState(qa.stateString)]
if !ok {
return "", errors.Errorf("state %s does not exist", qa.stateString)
}
return fmt.Sprintf(
"SUM(CASE WHEN %s.%s = %d THEN 1 ELSE 0 END) AS %s",
qa.queryCol.abbrev, qa.queryCol.name, stateInt, qa.aggregateColName(),
), nil
}

func GetAggregatorsForColumn(queryCol *queryColumn, aggregateType AggregateType, filters []*model.Filter) ([]QueryAggregator, error) {
switch aggregateType {
case Max:
return []QueryAggregator{NewSqlFunctionAggregator(queryCol, "MAX")}, nil
case Average:
return []QueryAggregator{NewSqlFunctionAggregator(queryCol, "AVG")}, nil
case StateCounts:
states := GetStatesForFilter(filters)
aggregators := make([]QueryAggregator, len(states))
for i, state := range states {
aggregators[i] = NewStateCountAggregator(queryCol, state)
}
return aggregators, nil
default:
return nil, errors.Errorf("cannot determine aggregate type: %v", aggregateType)
}
}

// GetStatesForFilter returns a list of states as string if filter for state exists
// Will always return the states in the same order, irrespective of the ordering of the states in the filter
func GetStatesForFilter(filters []*model.Filter) []string {
var stateFilter *model.Filter
for _, f := range filters {
if f.Field == stateField {
stateFilter = f
}
}
allStates := util.Map(lookout.JobStates, func(jobState lookout.JobState) string { return string(jobState) })
if stateFilter == nil {
// If no state filter is specified, use all states
return allStates
}

switch stateFilter.Match {
case model.MatchExact:
return []string{fmt.Sprintf("%s", stateFilter.Value)}
case model.MatchAnyOf:
strSlice, err := toStringSlice(stateFilter.Value)
if err != nil {
return allStates
}
stateStringSet := util.StringListToSet(strSlice)
// Ensuring they are in the same order
var finalStates []string
for _, state := range allStates {
if _, ok := stateStringSet[state]; ok {
finalStates = append(finalStates, state)
}
}
return finalStates
default:
return allStates
}
}

func toStringSlice(val interface{}) ([]string, error) {
switch v := val.(type) {
case []string:
return v, nil
case []interface{}:
result := make([]string, len(v))
for i := 0; i < len(v); i++ {
str := fmt.Sprintf("%v", v[i])
result[i] = str
}
return result, nil
default:
return nil, errors.Errorf("failed to convert interface to string slice: %v of type %T", val, val)
}
}
120 changes: 120 additions & 0 deletions internal/lookoutv2/repository/fieldparser.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package repository

import (
"fmt"
"github.com/armadaproject/armada/internal/common/database/lookout"
"github.com/armadaproject/armada/internal/lookoutv2/model"
"github.com/jackc/pgtype"
"github.com/pkg/errors"
"math"
"time"
)

type FieldParser interface {
GetField() string
GetVariableRef() interface{}
ParseValue() (interface{}, error)
}

type LastTransitionTimeParser struct {
variable pgtype.Numeric
}

func (fp *LastTransitionTimeParser) GetField() string {
return lastTransitionTimeField
}

func (fp *LastTransitionTimeParser) GetVariableRef() interface{} {
return &fp.variable
}

func (fp *LastTransitionTimeParser) ParseValue() (interface{}, error) {
var dst float64
err := fp.variable.AssignTo(&dst)
if err != nil {
return "", err
}
t := time.Unix(int64(math.Round(dst)), 0)
return t.Format(time.RFC3339), nil
}

type TimeParser struct {
field string
variable time.Time
}

func (fp *TimeParser) GetField() string {
return fp.field
}

func (fp *TimeParser) GetVariableRef() interface{} {
return &fp.variable
}

func (fp *TimeParser) ParseValue() (interface{}, error) {
return fp.variable.Format(time.RFC3339), nil
}

type StateParser struct {
variable int16
}

func (fp *StateParser) GetField() string {
return stateField
}

func (fp *StateParser) GetVariableRef() interface{} {
return &fp.variable
}

func (fp *StateParser) ParseValue() (interface{}, error) {
state, ok := lookout.JobStateMap[int(fp.variable)]
if !ok {
return "", errors.Errorf("state not found: %d", fp.variable)
}
return string(state), nil
}

type BasicParser[T any] struct {
field string
variable T
}

func (fp *BasicParser[T]) GetField() string {
return fp.field
}

func (fp *BasicParser[T]) GetVariableRef() interface{} {
return &fp.variable
}

func (fp *BasicParser[T]) ParseValue() (interface{}, error) {
return fp.variable, nil
}

func ParserForGroup(field string) FieldParser {
switch field {
case stateField:
return &StateParser{}
default:
return &BasicParser[string]{field: field}
}
}

func ParsersForAggregate(field string, filters []*model.Filter) ([]FieldParser, error) {
var parsers []FieldParser
switch field {
case lastTransitionTimeField:
parsers = append(parsers, &LastTransitionTimeParser{})
case submittedField:
parsers = append(parsers, &TimeParser{field: submittedField})
case stateField:
states := GetStatesForFilter(filters)
for _, state := range states {
parsers = append(parsers, &BasicParser[int]{field: fmt.Sprintf("%s%s", stateAggregatePrefix, state)})
}
default:
return nil, errors.Errorf("no aggregate found for field %s", field)
}
return parsers, nil
}
Loading

0 comments on commit 2adfa5e

Please sign in to comment.