Skip to content

Commit

Permalink
逐步淘汰高精度模式
Browse files Browse the repository at this point in the history
  • Loading branch information
TimLai666 committed Sep 14, 2024
1 parent 901bdba commit a9be767
Show file tree
Hide file tree
Showing 8 changed files with 39 additions and 61 deletions.
72 changes: 25 additions & 47 deletions datalist.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,11 @@ type IDataList interface {
Upper() *DataList
Lower() *DataList
Capitalize() *DataList
// Statistics
Sum() interface{}
Max() interface{}
Min() interface{}
Mean(highPrecision ...bool) interface{}
Mean() float64
WeightedMean(weights interface{}) interface{}
GMean() interface{}
Median(highPrecision ...bool) interface{}
Expand Down Expand Up @@ -407,7 +408,7 @@ func (dl *DataList) ReplaceAll(oldValue, newValue interface{}) {

// ReplaceOutliers replaces outliers in the DataList with the specified replacement value (e.g., mean, median).
func (dl *DataList) ReplaceOutliers(stdDevs float64, replacement float64) *DataList {
mean := dl.Mean(false).(float64)
mean := dl.Mean()
stddev := dl.Stdev(false).(float64)
threshold := stdDevs * stddev

Expand Down Expand Up @@ -693,7 +694,7 @@ func (dl *DataList) ClearOutliers(stdDevs float64) *DataList {
go dl.updateTimestamp()
}()

mean := dl.Mean(false).(float64)
mean := dl.Mean()
stddev := dl.Stdev(false).(float64)
threshold := stdDevs * stddev

Expand Down Expand Up @@ -750,7 +751,7 @@ func (dl *DataList) Standardize() *DataList {
go reorganizeMemory(dl)
go dl.updateTimestamp()
}()
mean := dl.Mean(false).(float64)
mean := dl.Mean()
stddev := dl.Stdev(false).(float64)
dl.mu.Lock()
for i, v := range dl.Data() {
Expand All @@ -774,7 +775,7 @@ func (dl *DataList) FillNaNWithMean() *DataList {
}()
dlclone := dl.Clone()
dlNoNaN := dlclone.ClearNaNs()
mean := dlNoNaN.Mean(false).(float64)
mean := dlNoNaN.Mean()
dl.mu.Lock()
for i, v := range dl.Data() {
vfloat := conv.ParseF64(v)
Expand Down Expand Up @@ -1137,56 +1138,33 @@ func (dl *DataList) Min() interface{} {
}

// Mean calculates the arithmetic mean of the DataList.
// If highPrecision is true, it will calculate using big.Rat for high precision.
// Otherwise, it calculates using float64.
// Returns nil if the DataList is empty or if an invalid number of parameters is provided.
func (dl *DataList) Mean(highPrecision ...bool) interface{} {
// Returns math.NaN() if the DataList is empty or if no elements can be converted to float64.
func (dl *DataList) Mean() float64 {
mean := math.NaN()
if len(dl.data) == 0 {
LogWarning("DataList.Mean(): DataList is empty, returning nil.")
return nil
}

// 檢查參數數量
if len(highPrecision) > 1 {
LogWarning("DataList.Mean(): Too many arguments, returning nil.")
return nil
}

// 默認使用普通模式(float64),若有參數則使用參數設定
highPrecisionMode := false
if len(highPrecision) == 1 {
highPrecisionMode = highPrecision[0]
}

if highPrecisionMode {
sum := new(big.Rat)
count := big.NewRat(int64(len(dl.data)), 1)

for _, v := range dl.data {
if val, ok := ToFloat64Safe(v); ok {
ratValue := new(big.Rat).SetFloat64(val)
sum.Add(sum, ratValue)
} else {
LogWarning("DataList.Mean(): Data types cannot be compared, returning nil.")
return nil
}
}

mean := new(big.Rat).Quo(sum, count)
LogWarning("DataList.Mean(): DataList is empty.")
return mean
}

// 普通模式(float64)
var sum float64
var count int
for _, v := range dl.data {
if val, ok := ToFloat64Safe(v); ok {
sum += val
count++
} else {
LogWarning("DataList.Mean(): Data types cannot be compared, returning nil.")
return nil
LogWarning("DataList.Mean(): Element %v cannot be converted to float64, skipping.", val)
// 跳过无法转换的元素
continue
}
}
mean := sum / float64(len(dl.data))

if count == 0 {
LogWarning("DataList.Mean(): No elements could be converted to float64.")
return mean
}

mean = sum / float64(count)
return mean
}

Expand Down Expand Up @@ -1434,7 +1412,7 @@ func (dl *DataList) Var(highPrecision ...bool) interface{} {

if useHighPrecision {
// 使用 big.Rat 進行高精度計算
mean := dl.Mean(true).(*big.Rat)
mean := new(big.Rat).SetFloat64(dl.Mean())
denominator := new(big.Rat).SetFloat64(n - 1)
if denominator.Cmp(big.NewRat(0, 1)) == 0 {
LogWarning("DataList.Var(): Denominator is 0, returning nil.")
Expand All @@ -1457,7 +1435,7 @@ func (dl *DataList) Var(highPrecision ...bool) interface{} {
}

// 普通模式使用 float64 計算
mean := dl.Mean(false).(float64)
mean := dl.Mean()
denominator := n - 1
if denominator == 0 {
LogWarning("DataList.Var(): Denominator is 0, returning nil.")
Expand Down Expand Up @@ -1494,7 +1472,7 @@ func (dl *DataList) VarP(highPrecision ...bool) interface{} {

if useHighPrecision {
// 使用高精度计算
mean := dl.Mean(true).(*big.Rat)
mean := new(big.Rat).SetFloat64(dl.Mean())
numerator := new(big.Rat)
for i := 0; i < len(dl.data); i++ {
xi := new(big.Rat).SetFloat64(ToFloat64(dl.data[i]))
Expand Down
2 changes: 1 addition & 1 deletion insyra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ func TestMean(t *testing.T) {
dl := NewDataList(1, 2, 3, 4)
mean := dl.Mean()

if v, ok := mean.(float64); !ok || !float64Equal(v, 2.5) {
if !float64Equal(mean, 2.5) {
t.Errorf("Expected mean 2.5, got %v", mean)
}
}
Expand Down
6 changes: 3 additions & 3 deletions stats/anova.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ func OneWayANOVA_WideFormat(dataTable insyra.IDataTable) *OneWayANOVAResult {
func() {
SSB = 0.0
for _, group := range groups {
groupMean := group.Mean().(float64)
groupMean := group.Mean()
SSB += float64(group.Len()) * math.Pow(groupMean-totalMean, 2)
}
},
func() {
SSW = 0.0
for _, group := range groups {
groupMean := group.Mean().(float64)
groupMean := group.Mean()
for i := 0; i < group.Len(); i++ {
value, _ := group.Get(i).(float64)
SSW += math.Pow(value-groupMean, 2)
Expand Down Expand Up @@ -305,7 +305,7 @@ func RepeatedMeasuresANOVA_WideFormat(dataTable insyra.IDataTable) *RepeatedMeas
ssBetweenFunc := func() {
ssBetween = 0.0
for i := 0; i < rowNum; i++ {
conditionMean := dataTable.GetRow(i).Mean().(float64)
conditionMean := dataTable.GetRow(i).Mean()
ssBetween += float64(colNum) * math.Pow(conditionMean-grandMean, 2)
}
}
Expand Down
4 changes: 2 additions & 2 deletions stats/correlation.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ const (
// Covariance calculates the covariance between two datasets.
// Always returns *big.Rat.
func Covariance(dlX, dlY insyra.IDataList) *big.Rat {
meanX := dlX.Mean(true).(*big.Rat)
meanY := dlY.Mean(true).(*big.Rat)
meanX := new(big.Rat).SetFloat64(dlX.Mean())
meanY := new(big.Rat).SetFloat64(dlY.Mean())

cov := new(big.Rat)
for i := 0; i < dlX.Len(); i++ {
Expand Down
4 changes: 2 additions & 2 deletions stats/linear_regression.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ func LinearRegression(dlX, dlY insyra.IDataList) *LinearRegressionResult {
}

// 計算 X 和 Y 的均值
meanX := dlX.Mean(true).(*big.Rat)
meanY := dlY.Mean(true).(*big.Rat)
meanX := new(big.Rat).SetFloat64(dlX.Mean())
meanY := new(big.Rat).SetFloat64(dlY.Mean())

// 初始化變量
numerator := new(big.Rat)
Expand Down
2 changes: 1 addition & 1 deletion stats/moments.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func CalculateMoment(dl insyra.IDataList, n int, central bool) *big.Rat {
// 初始化均值
mean := new(big.Rat)
if central {
mean = dl.Mean(true).(*big.Rat) // 計算均值
mean = new(big.Rat).SetFloat64(dl.Mean()) // 計算均值
}

// 初始化 n 階矩
Expand Down
2 changes: 1 addition & 1 deletion stats/skewness.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func Skewness(sample interface{}, method ...int) interface{} {
func calculateSkewType1(dl *insyra.DataList, highPrecision ...bool) interface{} {
n := new(big.Rat).SetFloat64(conv.ParseF64(dl.Len()))
nReciprocal := new(big.Rat).Inv(n)
m1 := dl.Mean(true).(*big.Rat)
m1 := new(big.Rat).SetFloat64(dl.Mean())
toM2Fn := func() *big.Rat {
var m2Cal = new(big.Rat)
for _, v := range dl.Data() {
Expand Down
8 changes: 4 additions & 4 deletions stats/ttest.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func SingleSampleTTest(data insyra.IDataList, mu float64) *TTestResult {
}

// 計算樣本均值
mean := data.Mean(false).(float64)
mean := data.Mean()

// 計算標準差和標準誤差
stddev := data.Stdev(false).(float64)
Expand Down Expand Up @@ -56,8 +56,8 @@ func TwoSampleTTest(data1, data2 insyra.IDataList, equalVariance bool) *TTestRes
}

// 計算兩個樣本的均值
mean1 := data1.Mean(false).(float64)
mean2 := data2.Mean(false).(float64)
mean1 := data1.Mean()
mean2 := data2.Mean()

// 計算兩個樣本的標準差
stddev1 := data1.Stdev(false).(float64)
Expand Down Expand Up @@ -109,7 +109,7 @@ func PairedTTest(data1, data2 insyra.IDataList) *TTestResult {
}

// 計算差值的均值和標準差
meanDiff := insyra.NewDataList(diffs).Mean(false).(float64)
meanDiff := insyra.NewDataList(diffs).Mean()
stddevDiff := insyra.NewDataList(diffs).Stdev(false).(float64)

// 計算 t 值
Expand Down

0 comments on commit a9be767

Please sign in to comment.