Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Monad instance for Vector and Matrix #607

Merged
merged 3 commits into from
Sep 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions src/FSharpPlus.TypeLevel/Data/Matrix.fs
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,17 @@ module Vector =

let inline apply (f: Vector<'a -> 'b, 'n>) (v: Vector<'a, 'n>) : Vector<'b, 'n> = map2 id f v

/// <description>
/// Converts the vector of vectors to a square matrix and returns its diagonal.
/// </description>
/// <seealso href="https://stackoverflow.com/questions/5802628/monad-instance-of-a-number-parameterised-vector" />
[<MethodImpl(MethodImplOptions.AggressiveInlining)>]
let join (vv: Vector<Vector<'a, 'n>, 'n>): Vector<'a, 'n> =
{ Items = Array.init (Array.length vv.Items) (fun i -> vv.Items.[i].Items.[i]) }

let inline bind (f: 'a -> Vector<'b, 'n>) (v: Vector<'a, 'n>) : Vector<'b, 'n> =
v |> map f |> join

let inline norm (v: Vector< ^a, ^n >) : ^a =
v |> toArray |> Array.sumBy (fun x -> x * x) |> sqrt
let inline maximumNorm (v: Vector< ^a, ^n >) : ^a =
Expand Down Expand Up @@ -327,6 +338,20 @@ module Matrix =
for j = 0 to Array2D.length2 m1.Items - 1 do
f i j m1.Items.[i, j] m2.Items.[i, j]

let inline apply (f: Matrix<'a -> 'b, 'm, 'n>) (m: Matrix<'a, 'm, 'n>) : Matrix<'b, 'm, 'n> = map2 id f m

/// <description>
/// Converts the matrix of matrices to a 3D cube matrix and returns its diagonal.
/// </description>
/// <seealso href="https://stackoverflow.com/questions/5802628/monad-instance-of-a-number-parameterised-vector" />
[<MethodImpl(MethodImplOptions.AggressiveInlining)>]
let join (m: Matrix<Matrix<'a, 'm, 'n>, 'm, 'n>) : Matrix<'a, 'm, 'n> =
{ Items =
Array2D.init (Array2D.length1 m.Items) (Array2D.length2 m.Items)
(fun i j -> m.Items.[i, j].Items.[i, j] ) }

let inline bind (f: 'a -> Matrix<'b, 'm, 'n>) (m: Matrix<'a, 'm, 'n>) : Matrix<'b, 'm, 'n> = m |> map f |> join

let inline rowLength (_: Matrix<'a, 'm, 'n>) : 'm = Singleton<'m>
let inline colLength (_: Matrix<'a, 'm, 'n>) : 'n = Singleton<'n>
let inline rowLength' (_: Matrix<'a, ^m, 'n>) : int = RuntimeValue (Singleton< ^m >)
Expand Down Expand Up @@ -571,8 +596,10 @@ type Matrix<'Item, 'Row, 'Column> with

static member inline Return (x: 'x) : Matrix<'x, 'm, 'n> = Matrix.replicate Singleton Singleton x
static member inline Pure (x: 'x) : Matrix<'x, 'm, 'n> = Matrix.replicate Singleton Singleton x
static member inline ( <*> ) (f: Matrix<'x -> 'y, 'm, 'n>, x: Matrix<'x, 'm, 'n>) = Matrix.map2 id f x
static member inline ( <.> ) (f: Matrix<'x -> 'y, 'm, 'n>, x: Matrix<'x, 'm, 'n>) = Matrix.map2 id f x
static member inline ( <*> ) (f: Matrix<'x -> 'y, 'm, 'n>, x: Matrix<'x, 'm, 'n>) = Matrix.apply f x
static member inline ( <.> ) (f: Matrix<'x -> 'y, 'm, 'n>, x: Matrix<'x, 'm, 'n>) = Matrix.apply f x
static member inline Join (x: Matrix<Matrix<'x, 'm, 'n>, 'm, 'n>) = Matrix.join x
static member inline ( >>= ) (x: Matrix<'x, 'm, 'n>, f: 'x -> Matrix<'y, 'm, 'n>) = Matrix.bind f x
static member inline get_Zero () : Matrix<'a, 'm, 'n> = Matrix.zero
static member inline ( + ) (m1, m2) = Matrix.map2 (+) m1 m2
static member inline ( - ) (m1, m2) = Matrix.map2 (-) m1 m2
Expand Down Expand Up @@ -607,6 +634,8 @@ type Vector<'Item, 'Length> with
static member inline Pure (x: 'x) : Vector<'x, 'n> = Vector.replicate Singleton x
static member inline ( <*> ) (f: Vector<'x -> 'y, 'n>, x: Vector<'x, 'n>) : Vector<'y, 'n> = Vector.apply f x
static member inline ( <.> ) (f: Vector<'x -> 'y, 'n>, x: Vector<'x, 'n>) : Vector<'y, 'n> = Vector.apply f x
static member inline Join (x: Vector<Vector<'x, 'n>, 'n>) : Vector<'x, 'n> = Vector.join x
static member inline ( >>= ) (x: Vector<'x, 'n>, f: 'x -> Vector<'y, 'n>) = Vector.bind f x

[<EditorBrowsable(EditorBrowsableState.Never)>]
static member inline Zip (x, y) = Vector.zip x y
Expand Down
1 change: 1 addition & 0 deletions tests/FSharpPlus.Tests/FSharpPlus.Tests.fsproj
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
<Compile Include="Lens.fs" />
<Compile Include="Extensions.fs" />
<Compile Include="BifoldableTests.fs" />
<Compile Include="Matrix.fs" />
<Compile Include="TypeLevel.fs" />
</ItemGroup>
<ItemGroup>
Expand Down
98 changes: 98 additions & 0 deletions tests/FSharpPlus.Tests/Matrix.fs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
namespace FSharpPlus.Tests

open System
open NUnit.Framework
open Helpers

open FSharpPlus
open FSharpPlus.Data
open FSharpPlus.TypeLevel

module VectorTests =
[<Test>]
let constructorAndDeconstructorWorks() =
let v1 = vector (1,2,3,4,5)
let v2 = vector (1,2,3,4,5,6,7,8,9,0,1,2,3,4,5)
let (Vector(_,_,_,_,_)) = v1
let (Vector(_,_,_,_,_,_,_,_,_,_,_,_,_,_,_)) = v2
()

[<Test>]
let applicativeWorks() =
let v = vector ((fun i -> i + 1), (fun i -> i * 2))
let u = vector (2, 3)
let vu = v <*> u
NUnit.Framework.Assert.IsInstanceOf<Option<Vector<int,S<S<Z>>>>> (Some vu)
CollectionAssert.AreEqual ([|3; 6|], Vector.toArray vu)

[<Test>]
let satisfiesApplicativeLaws() =
let u = vector ((fun i -> i - 1), (fun i -> i * 2))
let v = vector ((fun i -> i + 1), (fun i -> i * 3))
let w = vector (1, 1)

areEqual (result id <*> v) v
areEqual (result (<<) <*> u <*> v <*> w) (u <*> (v <*> w))
areEqual (result 2) ((result (fun i -> i + 1) : Vector<int -> int, S<S<Z>>>) <*> result 1)
areEqual (u <*> result 1) (result ((|>) 1) <*> u)

[<Test>]
let satisfiesMonadLaws() =
let k = fun (a: int) -> vector (a - 1, a * 2)
let h = fun (a: int) -> vector (a + 1, a * 3)
let m = vector (1, 2)

areEqual (result 2 >>= k) (k 2)
areEqual (m >>= result) m
areEqual (m >>= (fun x -> k x >>= h)) ((m >>= k) >>= h)

module MatrixTests =
[<Test>]
let constructorAndDeconstructorWorks() =
let m1 =
matrix (
(1,0,0,0),
(0,1,0,0),
(0,0,1,0)
)
let m2 =
matrix (
(1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0),
(0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0),
(0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0),
(0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0),
(0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0),
(0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0),
(0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0),
(0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0)
)
let (Matrix(_x1,_x2,_x3)) = m1
let (Matrix(_y1: int*int*int*int*int*int*int*int*int*int*int*int*int*int*int*int,_y2,_y3,_y4,_y5,_y6,_y7,_y8)) = m2
()

[<Test>]
let satisfiesApplicativeLaws() =
let u = matrix (
((fun i -> i - 1), (fun i -> i * 2)),
((fun i -> i + 1), (fun i -> i * 3))
)
let v = matrix (
((fun i -> i - 2), (fun i -> i * 5)),
((fun i -> i + 2), (fun i -> i * 7))
)
let w = matrix ((1, 1), (1, 2))

areEqual (result id <*> v) v
areEqual (result (<<) <*> u <*> v <*> w) (u <*> (v <*> w))
areEqual ((result (fun i -> i + 1) : Matrix<int -> int, S<S<Z>>, S<S<Z>>>) <*> result 1) (result 2)
areEqual (u <*> result 1) (result ((|>) 1) <*> u)

[<Test>]
let satisfiesMonadLaws() =
let k = fun (a: int) -> matrix ((a - 1, a * 2), (a + 1, a * 3))
let h = fun (a: int) -> matrix ((a - 2, a * 5), (a + 2, a * 7))
let m = matrix ((1, 1), (1, 2))

areEqual (result 2 >>= k) (k 2)
areEqual (m >>= result) m
areEqual (m >>= (fun x -> k x >>= h)) ((m >>= k) >>= h)
52 changes: 1 addition & 51 deletions tests/FSharpPlus.Tests/TypeLevel.fs
Original file line number Diff line number Diff line change
Expand Up @@ -150,38 +150,8 @@ module NatTests =
Assert (g2 =^ S(S(S(S(S(S Z))))))


open FSharpPlus.Data

module MatrixTests =
[<Test>]
let matrixTests =
let v1 = vector (1,2,3,4,5)
let v2 = vector (1,2,3,4,5,6,7,8,9,0,1,2,3,4,5)
let (Vector(_,_,_,_,_)) = v1
let (Vector(_,_,_,_,_,_,_,_,_,_,_,_,_,_,_)) = v2

let m1 =
matrix (
(1,0,0,0),
(0,1,0,0),
(0,0,1,0)
)
let m2 =
matrix (
(1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0),
(0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0),
(0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0),
(0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0),
(0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0),
(0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0),
(0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0),
(0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0)
)
let (Matrix(_x1,_x2,_x3)) = m1
let (Matrix(_y1: int*int*int*int*int*int*int*int*int*int*int*int*int*int*int*int,_y2,_y3,_y4,_y5,_y6,_y7,_y8)) = m2
()

open Helpers
open FSharpPlus.Data

module TypeProviderTests =
type ``0`` = TypeNat<0>
Expand All @@ -206,23 +176,3 @@ module TypeProviderTests =
Assert (Matrix.colLength row1 =^ (Z |> S |> S |> S))
areEqual 5 (Matrix.get Z (S Z) row1)
areEqual [3; 6; 9] (Vector.toList col2)

module TestFunctors1 =
[<Test>]
let applicativeOperatorWorks() =
let v = vector ((fun i -> i + 1), (fun i -> i * 2))
let u = vector (2, 3)
let vu = v <*> u
NUnit.Framework.Assert.IsInstanceOf<Option<Vector<int,S<S<Z>>>>> (Some vu)
CollectionAssert.AreEqual ([|3; 6|], Vector.toArray vu)

module TestFunctors2 =
open FSharpPlus

[<Test>]
let applicativeWorksWithoutSubsumption() =
let v = vector ((fun i -> i + 1), (fun i -> i * 2))
let u = vector (2, 3)
let vu = v <*> u
NUnit.Framework.Assert.IsInstanceOf<Option<Vector<int,S<S<Z>>>>> (Some vu)
CollectionAssert.AreEqual ([|3; 6|], Vector.toArray vu)
Loading