From a9be767761f7ff2e257c3748e6afd93b296a0b93 Mon Sep 17 00:00:00 2001 From: TimLai666 <43640816+TimLai666@users.noreply.github.com> Date: Sat, 14 Sep 2024 14:07:19 +0800 Subject: [PATCH] =?UTF-8?q?=E9=80=90=E6=AD=A5=E6=B7=98=E6=B1=B0=E9=AB=98?= =?UTF-8?q?=E7=B2=BE=E5=BA=A6=E6=A8=A1=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- datalist.go | 72 +++++++++++++------------------------- insyra_test.go | 2 +- stats/anova.go | 6 ++-- stats/correlation.go | 4 +-- stats/linear_regression.go | 4 +-- stats/moments.go | 2 +- stats/skewness.go | 2 +- stats/ttest.go | 8 ++--- 8 files changed, 39 insertions(+), 61 deletions(-) diff --git a/datalist.go b/datalist.go index c910074..142a100 100644 --- a/datalist.go +++ b/datalist.go @@ -71,10 +71,11 @@ type IDataList interface { Upper() *DataList Lower() *DataList Capitalize() *DataList + // Statistics Sum() interface{} Max() interface{} Min() interface{} - Mean(highPrecision ...bool) interface{} + Mean() float64 WeightedMean(weights interface{}) interface{} GMean() interface{} Median(highPrecision ...bool) interface{} @@ -407,7 +408,7 @@ func (dl *DataList) ReplaceAll(oldValue, newValue interface{}) { // ReplaceOutliers replaces outliers in the DataList with the specified replacement value (e.g., mean, median). func (dl *DataList) ReplaceOutliers(stdDevs float64, replacement float64) *DataList { - mean := dl.Mean(false).(float64) + mean := dl.Mean() stddev := dl.Stdev(false).(float64) threshold := stdDevs * stddev @@ -693,7 +694,7 @@ func (dl *DataList) ClearOutliers(stdDevs float64) *DataList { go dl.updateTimestamp() }() - mean := dl.Mean(false).(float64) + mean := dl.Mean() stddev := dl.Stdev(false).(float64) threshold := stdDevs * stddev @@ -750,7 +751,7 @@ func (dl *DataList) Standardize() *DataList { go reorganizeMemory(dl) go dl.updateTimestamp() }() - mean := dl.Mean(false).(float64) + mean := dl.Mean() stddev := dl.Stdev(false).(float64) dl.mu.Lock() for i, v := range dl.Data() { @@ -774,7 +775,7 @@ func (dl *DataList) FillNaNWithMean() *DataList { }() dlclone := dl.Clone() dlNoNaN := dlclone.ClearNaNs() - mean := dlNoNaN.Mean(false).(float64) + mean := dlNoNaN.Mean() dl.mu.Lock() for i, v := range dl.Data() { vfloat := conv.ParseF64(v) @@ -1137,56 +1138,33 @@ func (dl *DataList) Min() interface{} { } // Mean calculates the arithmetic mean of the DataList. -// If highPrecision is true, it will calculate using big.Rat for high precision. -// Otherwise, it calculates using float64. -// Returns nil if the DataList is empty or if an invalid number of parameters is provided. -func (dl *DataList) Mean(highPrecision ...bool) interface{} { +// Returns math.NaN() if the DataList is empty or if no elements can be converted to float64. +func (dl *DataList) Mean() float64 { + mean := math.NaN() if len(dl.data) == 0 { - LogWarning("DataList.Mean(): DataList is empty, returning nil.") - return nil - } - - // 檢查參數數量 - if len(highPrecision) > 1 { - LogWarning("DataList.Mean(): Too many arguments, returning nil.") - return nil - } - - // 默認使用普通模式(float64),若有參數則使用參數設定 - highPrecisionMode := false - if len(highPrecision) == 1 { - highPrecisionMode = highPrecision[0] - } - - if highPrecisionMode { - sum := new(big.Rat) - count := big.NewRat(int64(len(dl.data)), 1) - - for _, v := range dl.data { - if val, ok := ToFloat64Safe(v); ok { - ratValue := new(big.Rat).SetFloat64(val) - sum.Add(sum, ratValue) - } else { - LogWarning("DataList.Mean(): Data types cannot be compared, returning nil.") - return nil - } - } - - mean := new(big.Rat).Quo(sum, count) + LogWarning("DataList.Mean(): DataList is empty.") return mean } - // 普通模式(float64) var sum float64 + var count int for _, v := range dl.data { if val, ok := ToFloat64Safe(v); ok { sum += val + count++ } else { - LogWarning("DataList.Mean(): Data types cannot be compared, returning nil.") - return nil + LogWarning("DataList.Mean(): Element %v cannot be converted to float64, skipping.", val) + // 跳过无法转换的元素 + continue } } - mean := sum / float64(len(dl.data)) + + if count == 0 { + LogWarning("DataList.Mean(): No elements could be converted to float64.") + return mean + } + + mean = sum / float64(count) return mean } @@ -1434,7 +1412,7 @@ func (dl *DataList) Var(highPrecision ...bool) interface{} { if useHighPrecision { // 使用 big.Rat 進行高精度計算 - mean := dl.Mean(true).(*big.Rat) + mean := new(big.Rat).SetFloat64(dl.Mean()) denominator := new(big.Rat).SetFloat64(n - 1) if denominator.Cmp(big.NewRat(0, 1)) == 0 { LogWarning("DataList.Var(): Denominator is 0, returning nil.") @@ -1457,7 +1435,7 @@ func (dl *DataList) Var(highPrecision ...bool) interface{} { } // 普通模式使用 float64 計算 - mean := dl.Mean(false).(float64) + mean := dl.Mean() denominator := n - 1 if denominator == 0 { LogWarning("DataList.Var(): Denominator is 0, returning nil.") @@ -1494,7 +1472,7 @@ func (dl *DataList) VarP(highPrecision ...bool) interface{} { if useHighPrecision { // 使用高精度计算 - mean := dl.Mean(true).(*big.Rat) + mean := new(big.Rat).SetFloat64(dl.Mean()) numerator := new(big.Rat) for i := 0; i < len(dl.data); i++ { xi := new(big.Rat).SetFloat64(ToFloat64(dl.data[i])) diff --git a/insyra_test.go b/insyra_test.go index 3009189..3aafc23 100644 --- a/insyra_test.go +++ b/insyra_test.go @@ -162,7 +162,7 @@ func TestMean(t *testing.T) { dl := NewDataList(1, 2, 3, 4) mean := dl.Mean() - if v, ok := mean.(float64); !ok || !float64Equal(v, 2.5) { + if !float64Equal(mean, 2.5) { t.Errorf("Expected mean 2.5, got %v", mean) } } diff --git a/stats/anova.go b/stats/anova.go index f705997..f47f76d 100644 --- a/stats/anova.go +++ b/stats/anova.go @@ -45,14 +45,14 @@ func OneWayANOVA_WideFormat(dataTable insyra.IDataTable) *OneWayANOVAResult { func() { SSB = 0.0 for _, group := range groups { - groupMean := group.Mean().(float64) + groupMean := group.Mean() SSB += float64(group.Len()) * math.Pow(groupMean-totalMean, 2) } }, func() { SSW = 0.0 for _, group := range groups { - groupMean := group.Mean().(float64) + groupMean := group.Mean() for i := 0; i < group.Len(); i++ { value, _ := group.Get(i).(float64) SSW += math.Pow(value-groupMean, 2) @@ -305,7 +305,7 @@ func RepeatedMeasuresANOVA_WideFormat(dataTable insyra.IDataTable) *RepeatedMeas ssBetweenFunc := func() { ssBetween = 0.0 for i := 0; i < rowNum; i++ { - conditionMean := dataTable.GetRow(i).Mean().(float64) + conditionMean := dataTable.GetRow(i).Mean() ssBetween += float64(colNum) * math.Pow(conditionMean-grandMean, 2) } } diff --git a/stats/correlation.go b/stats/correlation.go index 6101430..8b57c31 100644 --- a/stats/correlation.go +++ b/stats/correlation.go @@ -25,8 +25,8 @@ const ( // Covariance calculates the covariance between two datasets. // Always returns *big.Rat. func Covariance(dlX, dlY insyra.IDataList) *big.Rat { - meanX := dlX.Mean(true).(*big.Rat) - meanY := dlY.Mean(true).(*big.Rat) + meanX := new(big.Rat).SetFloat64(dlX.Mean()) + meanY := new(big.Rat).SetFloat64(dlY.Mean()) cov := new(big.Rat) for i := 0; i < dlX.Len(); i++ { diff --git a/stats/linear_regression.go b/stats/linear_regression.go index 8ae3e86..5fb4c1e 100644 --- a/stats/linear_regression.go +++ b/stats/linear_regression.go @@ -30,8 +30,8 @@ func LinearRegression(dlX, dlY insyra.IDataList) *LinearRegressionResult { } // 計算 X 和 Y 的均值 - meanX := dlX.Mean(true).(*big.Rat) - meanY := dlY.Mean(true).(*big.Rat) + meanX := new(big.Rat).SetFloat64(dlX.Mean()) + meanY := new(big.Rat).SetFloat64(dlY.Mean()) // 初始化變量 numerator := new(big.Rat) diff --git a/stats/moments.go b/stats/moments.go index 162d827..c89605e 100644 --- a/stats/moments.go +++ b/stats/moments.go @@ -20,7 +20,7 @@ func CalculateMoment(dl insyra.IDataList, n int, central bool) *big.Rat { // 初始化均值 mean := new(big.Rat) if central { - mean = dl.Mean(true).(*big.Rat) // 計算均值 + mean = new(big.Rat).SetFloat64(dl.Mean()) // 計算均值 } // 初始化 n 階矩 diff --git a/stats/skewness.go b/stats/skewness.go index 09b92ca..060fce2 100644 --- a/stats/skewness.go +++ b/stats/skewness.go @@ -62,7 +62,7 @@ func Skewness(sample interface{}, method ...int) interface{} { func calculateSkewType1(dl *insyra.DataList, highPrecision ...bool) interface{} { n := new(big.Rat).SetFloat64(conv.ParseF64(dl.Len())) nReciprocal := new(big.Rat).Inv(n) - m1 := dl.Mean(true).(*big.Rat) + m1 := new(big.Rat).SetFloat64(dl.Mean()) toM2Fn := func() *big.Rat { var m2Cal = new(big.Rat) for _, v := range dl.Data() { diff --git a/stats/ttest.go b/stats/ttest.go index f2109fa..6584d9f 100644 --- a/stats/ttest.go +++ b/stats/ttest.go @@ -26,7 +26,7 @@ func SingleSampleTTest(data insyra.IDataList, mu float64) *TTestResult { } // 計算樣本均值 - mean := data.Mean(false).(float64) + mean := data.Mean() // 計算標準差和標準誤差 stddev := data.Stdev(false).(float64) @@ -56,8 +56,8 @@ func TwoSampleTTest(data1, data2 insyra.IDataList, equalVariance bool) *TTestRes } // 計算兩個樣本的均值 - mean1 := data1.Mean(false).(float64) - mean2 := data2.Mean(false).(float64) + mean1 := data1.Mean() + mean2 := data2.Mean() // 計算兩個樣本的標準差 stddev1 := data1.Stdev(false).(float64) @@ -109,7 +109,7 @@ func PairedTTest(data1, data2 insyra.IDataList) *TTestResult { } // 計算差值的均值和標準差 - meanDiff := insyra.NewDataList(diffs).Mean(false).(float64) + meanDiff := insyra.NewDataList(diffs).Mean() stddevDiff := insyra.NewDataList(diffs).Stdev(false).(float64) // 計算 t 值