From 3dcf69e47649eadf360ebef8d32fdcb7d98d1dd2 Mon Sep 17 00:00:00 2001 From: TimLai666 <43640816+TimLai666@users.noreply.github.com> Date: Sun, 1 Sep 2024 18:15:00 +0800 Subject: [PATCH] Update datalist.go --- datalist.go | 77 ++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 58 insertions(+), 19 deletions(-) diff --git a/datalist.go b/datalist.go index ff06ed5..4e369d5 100644 --- a/datalist.go +++ b/datalist.go @@ -64,9 +64,9 @@ type IDataList interface { Median(highPrecision ...bool) interface{} Mode() interface{} Stdev(highPrecision ...bool) interface{} - StdevP() interface{} + StdevP(highPrecision ...bool) interface{} Var(highPrecision ...bool) interface{} - VarP() interface{} + VarP(highPrecision ...bool) interface{} Range() interface{} Quartile(int) interface{} IQR() interface{} @@ -1045,7 +1045,7 @@ func (dl *DataList) Stdev(highPrecision ...bool) interface{} { } if useHighPrecision { - // 高精度模式下使用 big.Rat 進行開方運算 + // 高精度模式下使用 SqrtRat 進行開方運算 varianceRat := variance.(*big.Rat) sqrtVariance := SqrtRat(varianceRat) return sqrtVariance @@ -1058,17 +1058,33 @@ func (dl *DataList) Stdev(highPrecision ...bool) interface{} { // StdevP calculates the standard deviation(population) of the DataList. // Returns the standard deviation. // Returns nil if the DataList is empty or the standard deviation cannot be calculated. -func (dl *DataList) StdevP() interface{} { +func (dl *DataList) StdevP(highPrecision ...bool) interface{} { if len(dl.data) == 0 { LogWarning("DataList.StdevP(): DataList is empty, returning nil.") return nil } - varianceP := dl.VarP() + if len(highPrecision) > 1 { + LogWarning("DataList.StdevP(): Too many arguments, returning nil.") + return nil + } + var varianceP interface{} + if len(highPrecision) == 1 && highPrecision[0] { + // 使用 big.Rat 進行高精度計算 + varianceP = dl.VarP(true) + } else { + varianceP = dl.VarP() + } + if varianceP == nil { LogWarning("DataList.StdevP(): Variance calculation failed, returning nil.") return nil } - return math.Sqrt(ToFloat64(varianceP)) + + if !highPrecision[0] { + return math.Sqrt(ToFloat64(varianceP)) + } else { + return SqrtRat(varianceP.(*big.Rat)) + } } // Var calculates the variance(sample) of the DataList. @@ -1134,28 +1150,51 @@ func (dl *DataList) Var(highPrecision ...bool) interface{} { // VarP calculates the variance(population) of the DataList. // Returns the variance. // Returns nil if the DataList is empty or the variance cannot be calculated. -func (dl *DataList) VarP() interface{} { +func (dl *DataList) VarP(highPrecision ...bool) interface{} { + if len(highPrecision) > 1 { + LogWarning("VarP(): More than one highPrecision argument, returning nil.") + return nil + } + + useHighPrecision := len(highPrecision) == 1 && highPrecision[0] + n := float64(dl.Len()) if n == 0.0 { LogWarning("DataList.VarP(): DataList is empty, returning nil.") return nil } - m := dl.Mean() - mean, ok := ToFloat64Safe(m) - if !ok { - LogWarning("DataList.VarP(): Mean is not a float64, returning nil.") - return nil - } - numerator := 0.0 - for i := 0; i < len(dl.data); i++ { - xi, ok := ToFloat64Safe(dl.data[i]) + + if useHighPrecision { + // 使用高精度计算 + mean := dl.Mean(true).(*big.Rat) + numerator := new(big.Rat) + for i := 0; i < len(dl.data); i++ { + xi := new(big.Rat).SetFloat64(ToFloat64(dl.data[i])) + diff := new(big.Rat).Sub(xi, mean) + diffSquared := new(big.Rat).Mul(diff, diff) + numerator.Add(numerator, diffSquared) + } + denominator := new(big.Rat).SetFloat64(n) + variance := new(big.Rat).Quo(numerator, denominator) + return variance + } else { + // 使用普通精度计算 + mean, ok := ToFloat64Safe(dl.Mean()) if !ok { - LogWarning("DataList.VarP(): Element is not a float64, returning nil.") + LogWarning("DataList.VarP(): Mean is not a float64, returning nil.") return nil } - numerator += math.Pow(xi-mean, 2) + numerator := 0.0 + for i := 0; i < len(dl.data); i++ { + xi, ok := ToFloat64Safe(dl.data[i]) + if !ok { + LogWarning("DataList.VarP(): Element is not a float64, returning nil.") + return nil + } + numerator += math.Pow(xi-mean, 2) + } + return numerator / n } - return numerator / n } // Range calculates the range of the DataList.