Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce filter in Fixed/Variable sized Array types #2678

Merged
merged 7 commits into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 99 additions & 4 deletions runtime/interpreter/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -1607,7 +1607,7 @@
interpreter *Interpreter,
arrayType ArrayStaticType,
address common.Address,
count uint64,
countOverestimate uint64,
values func() Value,
) *ArrayValue {
interpreter.ReportComputation(common.ComputationKindCreateArrayValue, 1)
Expand Down Expand Up @@ -1652,7 +1652,7 @@
return array
}
// must assign to v here for tracing to work properly
v = newArrayValueFromConstructor(interpreter, arrayType, count, constructor)
v = newArrayValueFromConstructor(interpreter, arrayType, countOverestimate, constructor)
return v
}

Expand All @@ -1669,14 +1669,14 @@
func newArrayValueFromConstructor(
gauge common.MemoryGauge,
staticType ArrayStaticType,
count uint64,
countOverestimate uint64,
constructor func() *atree.Array,
) (array *ArrayValue) {
var elementSize uint
if staticType != nil {
elementSize = staticType.ElementType().elementSize()
}
baseUsage, elementUsage, dataSlabs, metaDataSlabs := common.NewArrayMemoryUsages(count, elementSize)
baseUsage, elementUsage, dataSlabs, metaDataSlabs := common.NewArrayMemoryUsages(countOverestimate, elementSize)
common.UseMemory(gauge, baseUsage)
common.UseMemory(gauge, elementUsage)
common.UseMemory(gauge, dataSlabs)
Expand Down Expand Up @@ -2443,6 +2443,29 @@
)
},
)

case sema.ArrayTypeFilterFunctionName:
return NewHostFunctionValue(
interpreter,
sema.ArrayFilterFunctionType(
interpreter,
v.SemaType(interpreter).ElementType(false),
),
func(invocation Invocation) Value {
interpreter := invocation.Interpreter

funcArgument, ok := invocation.Arguments[0].(FunctionValue)
if !ok {
panic(errors.NewUnreachableError())
}

return v.Filter(
interpreter,
invocation.LocationRange,
funcArgument,
)
},
)
}

return nil
Expand Down Expand Up @@ -2946,6 +2969,78 @@
)
}

func (v *ArrayValue) Filter(
interpreter *Interpreter,
locationRange LocationRange,
procedure FunctionValue,
) Value {

iterationInvocation := func(arrayElement Value) Invocation {
darkdrag00nv2 marked this conversation as resolved.
Show resolved Hide resolved
invocation := NewInvocation(
interpreter,
nil,
nil,
[]Value{arrayElement},
[]sema.Type{v.semaType.ElementType(false)},
darkdrag00nv2 marked this conversation as resolved.
Show resolved Hide resolved
nil,
locationRange,
)
return invocation
}

iterator, err := v.array.Iterator()
if err != nil {
panic(errors.NewExternalError(err))
}

return NewArrayValueWithIterator(
interpreter,
NewVariableSizedStaticType(interpreter, v.Type.ElementType()),
common.ZeroAddress,
uint64(v.Count()), // worst case estimation.
func() Value {

var value Value

for {
atreeValue, err := iterator.Next()
if err != nil {
panic(errors.NewExternalError(err))
}

Check warning on line 3009 in runtime/interpreter/value.go

View check run for this annotation

Codecov / codecov/patch

runtime/interpreter/value.go#L3009

Added line #L3009 was not covered by tests

// Also handles the end of array case since iterator.Next() returns nil for that.
if atreeValue == nil {
return nil
}

value = MustConvertStoredValue(interpreter, atreeValue)
if value == nil {
return nil
}

shouldInclude, ok := procedure.invoke(iterationInvocation(value)).(BoolValue)
if !ok {
panic(errors.NewUnreachableError())
}

Check warning on line 3024 in runtime/interpreter/value.go

View check run for this annotation

Codecov / codecov/patch

runtime/interpreter/value.go#L3024

Added line #L3024 was not covered by tests

// We found the next entry of the filtered array.
if shouldInclude {
break
}
}

return value.Transfer(
interpreter,
locationRange,
atree.Address{},

Check warning on line 3035 in runtime/interpreter/value.go

View check run for this annotation

Codecov / codecov/patch

runtime/interpreter/value.go#L3034-L3035

Added lines #L3034 - L3035 were not covered by tests
false,
nil,
nil,
)

Check warning on line 3039 in runtime/interpreter/value.go

View check run for this annotation

Codecov / codecov/patch

runtime/interpreter/value.go#L3039

Added line #L3039 was not covered by tests
},
)
}

// NumberValue
type NumberValue interface {
ComparableValue
Expand Down
57 changes: 57 additions & 0 deletions runtime/sema/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -1795,6 +1795,13 @@ Returns a new array with contents in the reversed order.
Available if the array element type is not resource-kinded.
`

const ArrayTypeFilterFunctionName = "filter"

const arrayTypeFilterFunctionDocString = `
Returns a new array whose elements are filtered by applying the filter function on each element of the original array.
Available if the array element type is not resource-kinded.
`

func getArrayMembers(arrayType ArrayType) map[string]MemberResolver {

members := map[string]MemberResolver{
Expand Down Expand Up @@ -1913,6 +1920,31 @@ func getArrayMembers(arrayType ArrayType) map[string]MemberResolver {
)
},
},
ArrayTypeFilterFunctionName: {
Kind: common.DeclarationKindFunction,
Resolve: func(memoryGauge common.MemoryGauge, identifier string, targetRange ast.Range, report func(error)) *Member {

elementType := arrayType.ElementType(false)

if elementType.IsResourceType() {
report(
&InvalidResourceArrayMemberError{
Name: identifier,
DeclarationKind: common.DeclarationKindFunction,
Range: targetRange,
},
)
}

return NewPublicFunctionMember(
memoryGauge,
arrayType,
identifier,
ArrayFilterFunctionType(memoryGauge, elementType),
arrayTypeFilterFunctionDocString,
)
},
},
}

// TODO: maybe still return members but report a helpful error?
Expand Down Expand Up @@ -2232,6 +2264,31 @@ func ArrayReverseFunctionType(arrayType ArrayType) *FunctionType {
}
}

func ArrayFilterFunctionType(memoryGauge common.MemoryGauge, elementType Type) *FunctionType {
// fun filter(_ function: ((T): Bool)): [T]
// funcType: elementType -> Bool
funcType := &FunctionType{
SupunS marked this conversation as resolved.
Show resolved Hide resolved
Parameters: []Parameter{
{
Identifier: "element",
TypeAnnotation: NewTypeAnnotation(elementType),
},
},
ReturnTypeAnnotation: NewTypeAnnotation(BoolType),
}

return &FunctionType{
Parameters: []Parameter{
{
Label: ArgumentLabelNotRequired,
Identifier: "f",
TypeAnnotation: NewTypeAnnotation(funcType),
},
},
ReturnTypeAnnotation: NewTypeAnnotation(NewVariableSizedType(memoryGauge, elementType)),
}
}

// VariableSizedType is a variable sized array type
type VariableSizedType struct {
Type Type
Expand Down
97 changes: 97 additions & 0 deletions runtime/tests/checker/arrays_dictionaries_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1128,6 +1128,103 @@ func TestCheckResourceArrayReverseInvalid(t *testing.T) {
assert.IsType(t, &sema.InvalidResourceArrayMemberError{}, errs[0])
}

func TestCheckArrayFilter(t *testing.T) {

t.Parallel()

_, err := ParseAndCheck(t, `
fun test() {
let x = [1, 2, 3]
let onlyEven =
fun (_ x: Int): Bool {
return x % 2 == 0
}

let y = x.filter(onlyEven)
}

fun testFixedSize() {
let x : [Int; 5] = [1, 2, 3, 21, 30]
let onlyEvenInt =
fun (_ x: Int): Bool {
return x % 2 == 0
}

let y = x.filter(onlyEvenInt)
}
`)

require.NoError(t, err)
}

func TestCheckArrayFilterInvalidArgs(t *testing.T) {

t.Parallel()

testInvalidArgs := func(code string, expectedErrors []sema.SemanticError) {
_, err := ParseAndCheck(t, code)

errs := RequireCheckerErrors(t, err, len(expectedErrors))

for i, e := range expectedErrors {
assert.IsType(t, e, errs[i])
}
}

testInvalidArgs(`
fun test() {
let x = [1, 2, 3]
let y = x.filter(100)
}
`,
[]sema.SemanticError{
&sema.TypeMismatchError{},
},
)

testInvalidArgs(`
fun test() {
let x = [1, 2, 3]
let onlyEvenInt16 =
fun (_ x: Int16): Bool {
return x % 2 == 0
}

let y = x.filter(onlyEvenInt16)
}
`,
[]sema.SemanticError{
&sema.TypeMismatchError{},
},
)
}

func TestCheckResourceArrayFilterInvalid(t *testing.T) {

t.Parallel()

_, err := ParseAndCheck(t, `
resource X {}

fun test(): @[X] {
let xs <- [<-create X()]
let allResources =
fun (_ x: @X): Bool {
destroy x
return true
}

let filteredXs <-xs.filter(allResources)
destroy xs
return <- filteredXs
}
`)

errs := RequireCheckerErrors(t, err, 1)

assert.IsType(t, &sema.InvalidResourceArrayMemberError{}, errs[0])
}

func TestCheckArrayContains(t *testing.T) {

t.Parallel()
Expand Down
Loading
Loading