Skip to content

Commit

Permalink
Feat/FSRS-5 (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock authored Aug 27, 2024
1 parent 3a13e86 commit 736741a
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 32 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jobs:
strategy:
matrix:
os: [ ubuntu-latest, macos-latest, windows-latest ]
go: [ 1.18.x, 1.19.x, 1.20.x ]
go: [ 1.22.x ]
env:
OS: ${{ matrix.os }}
GO: ${{ matrix.go }}
Expand Down
34 changes: 25 additions & 9 deletions fsrs.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,22 @@ func (p *Parameters) Repeat(card Card, now time.Time) map[Rating]SchedulingInfo
s.Easy.ScheduledDays = uint64(easyInterval)
s.Easy.Due = now.Add(time.Duration(easyInterval) * 24 * time.Hour)
case Learning, Relearning:
interval := card.ElapsedDays
lastD := card.Difficulty
lastS := card.Stability
retrievability := p.forgettingCurve(float64(interval), lastS)
p.nextDS(s, lastD, lastS, retrievability, card.State)

hardInterval := 0.0
goodInterval := p.nextInterval(s.Good.Stability)
easyInterval := math.Max(p.nextInterval(s.Easy.Stability), goodInterval+1)

s.schedule(now, hardInterval, goodInterval, easyInterval)
case Review:
elapsedDays := float64(card.ElapsedDays)
lastD := card.Difficulty
lastS := card.Stability
retrievability := p.forgettingCurve(elapsedDays, lastS)
p.nextDS(s, lastD, lastS, retrievability)
p.nextDS(s, lastD, lastS, retrievability, card.State)

hardInterval := p.nextInterval(s.Hard.Stability)
goodInterval := p.nextInterval(s.Good.Stability)
Expand Down Expand Up @@ -135,22 +140,29 @@ func (p *Parameters) initDS(s *schedulingCards) {
s.Easy.Stability = p.initStability(Easy)
}

func (p *Parameters) nextDS(s *schedulingCards, lastD float64, lastS float64, retrievability float64) {
func (p *Parameters) nextDS(s *schedulingCards, lastD float64, lastS float64, retrievability float64, state State) {
s.Again.Difficulty = p.nextDifficulty(lastD, Again)
s.Again.Stability = p.nextForgetStability(lastD, lastS, retrievability)
s.Hard.Difficulty = p.nextDifficulty(lastD, Hard)
s.Hard.Stability = p.nextRecallStability(lastD, lastS, retrievability, Hard)
s.Good.Difficulty = p.nextDifficulty(lastD, Good)
s.Good.Stability = p.nextRecallStability(lastD, lastS, retrievability, Good)
s.Easy.Difficulty = p.nextDifficulty(lastD, Easy)
s.Easy.Stability = p.nextRecallStability(lastD, lastS, retrievability, Easy)
if state == Learning || state == Relearning {
s.Again.Stability = p.shortTermStability(lastS, Again)
s.Hard.Stability = p.shortTermStability(lastS, Hard)
s.Good.Stability = p.shortTermStability(lastS, Good)
s.Easy.Stability = p.shortTermStability(lastS, Easy)
} else if state == Review {
s.Again.Stability = p.nextForgetStability(lastD, lastS, retrievability)
s.Hard.Stability = p.nextRecallStability(lastD, lastS, retrievability, Hard)
s.Good.Stability = p.nextRecallStability(lastD, lastS, retrievability, Good)
s.Easy.Stability = p.nextRecallStability(lastD, lastS, retrievability, Easy)
}
}

func (p *Parameters) initStability(r Rating) float64 {
return math.Max(p.W[r-1], 0.1)
}
func (p *Parameters) initDifficulty(r Rating) float64 {
return constrainDifficulty(p.W[4] - p.W[5]*float64(r-3))
return constrainDifficulty(p.W[4] - math.Exp(p.W[5]*float64(r-1)) + 1)
}

func constrainDifficulty(d float64) float64 {
Expand All @@ -164,7 +176,11 @@ func (p *Parameters) nextInterval(s float64) float64 {

func (p *Parameters) nextDifficulty(d float64, r Rating) float64 {
nextD := d - p.W[6]*float64(r-3)
return constrainDifficulty(p.meanReversion(p.W[4], nextD))
return constrainDifficulty(p.meanReversion(p.initDifficulty(Easy), nextD))
}

func (p *Parameters) shortTermStability(s float64, r Rating) float64 {
return s * math.Exp(p.W[17]*(float64(r-3)+p.W[18]))
}

func (p *Parameters) meanReversion(init float64, current float64) float64 {
Expand Down
45 changes: 27 additions & 18 deletions fsrs_test.go
Original file line number Diff line number Diff line change
@@ -1,30 +1,47 @@
package fsrs

import (
"encoding/json"
"fmt"
"math"
"reflect"
"testing"
"time"
)

var testWeights = Weights{
0.4197,
1.1869,
3.0412,
15.2441,
7.1434,
0.6477,
1.0007,
0.0674,
1.6597,
0.1712,
1.1178,
2.0225,
0.0904,
0.3025,
2.1214,
0.2498,
2.9466,
0.4891,
0.6468,
}

func roundFloat(val float64, precision uint) float64 {
ratio := math.Pow(10, float64(precision))
return math.Round(val*ratio) / ratio
}

func TestExample(t *testing.T) {
p := DefaultParam()
p.W = Weights{1.0171, 1.8296, 4.4145, 10.9355, 5.0965, 1.3322, 1.017, 0.0, 1.6243, 0.1369, 1.0321,
2.1866, 0.0661, 0.336, 1.7766, 0.1693, 2.9244}
p.W = testWeights
card := NewCard()
now := time.Date(2022, 11, 29, 12, 30, 0, 0, time.UTC)
var ivlList []uint64
var stateList []State
schedulingCards := p.Repeat(card, now)
schedule, _ := json.MarshalIndent(schedulingCards, "", " ")
fmt.Println(string(schedule))

var ratings = []Rating{Good, Good, Good, Good, Good, Good, Again, Again, Good, Good, Good, Good, Good}
var rating Rating
Expand All @@ -38,14 +55,9 @@ func TestExample(t *testing.T) {
stateList = append(stateList, revlog.State)
now = card.Due
schedulingCards = p.Repeat(card, now)
schedule, _ = json.MarshalIndent(schedulingCards, "", " ")
fmt.Println(string(schedule))
}

fmt.Println(ivlList)
fmt.Println(stateList)

wantIvlList := []uint64{0, 4, 15, 49, 143, 379, 0, 0, 15, 37, 85, 184, 376}
wantIvlList := []uint64{0, 4, 17, 62, 198, 563, 0, 0, 9, 27, 74, 190, 457}
if !reflect.DeepEqual(ivlList, wantIvlList) {
t.Errorf("excepted:%v, got:%v", wantIvlList, ivlList)
}
Expand All @@ -57,8 +69,7 @@ func TestExample(t *testing.T) {

func TestMemoState(t *testing.T) {
p := DefaultParam()
p.W = Weights{1.0171, 1.8296, 4.4145, 10.9355, 5.0965, 1.3322, 1.017, 0.0, 1.6243, 0.1369, 1.0321,
2.1866, 0.0661, 0.336, 1.7766, 0.1693, 2.9244}
p.W = testWeights
card := NewCard()
now := time.Date(2022, 11, 29, 12, 30, 0, 0, time.UTC)

Expand All @@ -72,9 +83,9 @@ func TestMemoState(t *testing.T) {
now = now.Add(time.Duration(ivlList[i]) * 24 * time.Hour)
schedulingCards = p.Repeat(card, now)
}
wantStability := 43.0554
wantStability := 71.4554
cardStability := roundFloat(schedulingCards[Good].Card.Stability, 4)
wantDifficulty := 7.7609
wantDifficulty := 5.0976
cardDifficulty := roundFloat(schedulingCards[Good].Card.Difficulty, 4)

if !reflect.DeepEqual(wantStability, cardStability) {
Expand All @@ -88,8 +99,6 @@ func TestMemoState(t *testing.T) {

func TestNextInterval(t *testing.T) {
p := DefaultParam()
p.W = Weights{1.0171, 1.8296, 4.4145, 10.9355, 5.0965, 1.3322, 1.017, 0.0, 1.6243, 0.1369, 1.0321,
2.1866, 0.0661, 0.336, 1.7766, 0.1693, 2.9244}
var ivlList []float64
for i := 1; i <= 10; i++ {
p.RequestRetention = float64(i) / 10
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
module github.com/open-spaced-repetition/go-fsrs

go 1.18
go 1.22
23 changes: 20 additions & 3 deletions params.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package fsrs

import "math"

type Weights [17]float64
type Weights [19]float64

type Parameters struct {
RequestRetention float64 `json:"RequestRetention"`
Expand All @@ -25,6 +25,23 @@ func DefaultParam() Parameters {
}

func DefaultWeights() Weights {
return Weights{0.5701, 1.4436, 4.1386, 10.9355, 5.1443, 1.2006, 0.8627, 0.0362, 1.629, 0.1342, 1.0166, 2.1174,
0.0839, 0.3204, 1.4676, 0.219, 2.8237}
return Weights{0.4197,
1.1869,
3.0412,
15.2441,
7.1434,
0.6477,
1.0007,
0.0674,
1.6597,
0.1712,
1.1178,
2.0225,
0.0904,
0.3025,
2.1214,
0.2498,
2.9466,
0.4891,
0.6468}
}

0 comments on commit 736741a

Please sign in to comment.