-
Notifications
You must be signed in to change notification settings - Fork 0
/
FFT三周期拟合(plotly展示)
392 lines (343 loc) · 17.3 KB
/
FFT三周期拟合(plotly展示)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
import numpy as np
from scipy import signal, interpolate
import pandas as pd
import statsmodels.api as sm
import matplotlib.pyplot as plt
import matplotlib as mpl
import calendar
import datetime
#字符串转数字
import dash
from dash import dcc
from dash import html
from dash.dependencies import Input, Output
import locale
import plotly.graph_objects as go
from plotly.offline import plot
from plotly.subplots import make_subplots
# #print(mpl.matplotlib_fname())
# 获取上证综指历史数据,这里我使用的是Tushare,如果有的话可以去申请一个,没有也可以用自己的数据,把下面这段改一下
# import tushare as ts
#
# #ts.set_token('9b5781c35d09ed01f8c24414a0196e778d8e52e727b6f4bece2e494b')
# ts.set_token('6815b3ea0ec3cc19474beff10e82d6b59bc651c9557c45d15dc5c0d2')
# pro = ts.pro_api()
# shanghai = pro.index_monthly(ts_code='399006.SZ', start_date='20100701', end_date='20210701', \
# fields='trade_date,close').sort_index(ascending=False)
# '''shanghai = pro.stock_monthly(ts_code='600196.SH', start_date='20050901', end_date='20200301', \
# fields='trade_date,close').sort_index(ascending=False)'''
# shanghai = shanghai.rename(columns={'close': '上证综指'})
# #shanghai = shanghai.rename(columns={'close': 'shanghai'})
# a_seq = np.array(shanghai['上证综指'])
#日线转月线
def daily2monthly(df: pd.DataFrame, column_date: str, column_close: str):
# 将日期数据列转换为日期时间格式
df[column_date] = pd.to_datetime(df[column_date])
# 获取每个日期是否是当月的最后一天的布尔值
is_last_day = df[column_date].dt.is_month_end
# 筛选出最后一天的数据
last_day_data = df[is_last_day]
# 创建新的 DataFrame,添加日期和收盘价的值
# 设置地域设置为当前系统设置,用于转换字符串为数字
locale.setlocale(locale.LC_ALL, '')
new_df = pd.DataFrame(
{column_date: last_day_data[column_date],
column_close: last_day_data[column_close].apply(lambda x :locale.atof(x))})#转换为数字
return new_df
def interpolation(data):
'''
:param data:
:return: 对data立面的NaN值插值处理
'''
index_data = list(range(len(data)))
pd_a = pd.Series(~np.isnan(data)) # 由True和False组成的序列,True表明是非NaN值
ind_a = pd_a[pd_a].index.tolist() # 找出非NaN的索引值
pd_nan = pd.Series(np.isnan(data))
ind_nan = pd_nan[pd_nan].index.tolist() # 找出NaN的索引值
y_data = data[ind_a] # 非NaN值的序列
y1_data = interpolate.interp1d(ind_a, y_data, kind='cubic') # 立方插值
y2_data = data
for i in index_data:
if i in ind_nan:
y2_data[i] = y1_data(i)
return y2_data
def period_mean_fft(data, nfft, peak_num, figure_flag):
'''
:param data:对数同比序列
:param nfft:补零后长度
:param peak_num:要找几个峰值
:param figure_flag:是否画频谱图
:return:傅里叶变换后的序列
'''
data_n = np.array(data)
if len(data_n.shape) > 1 and data_n.shape[-1] > data_n.shape[0]:
data_n = data_n.T
Y_fft = np.fft.fft(data_n, nfft)
nfft = len(Y_fft)
Y_fft = np.delete(Y_fft, 0)
power = np.abs(np.square(Y_fft[:int(np.floor(nfft / 2))])) # 求功率谱
amplitude = np.abs(Y_fft[:int(np.floor(nfft / 2))]) # 求振幅
nyquist = 1 / 2
freq = [nyquist * i / np.floor(nfft / 2) for i in range(1, int(np.floor(nfft / 2) + 1))]
freq = np.array(freq)
period = 1 / freq
loc_raw = list(signal.find_peaks(power)[0])
peak_raw = power[loc_raw]
data_dict = {key: value for key, value in zip(loc_raw, peak_raw)}
temp = sorted(data_dict.items(), key=lambda x: x[1], reverse=True)
loc_peaks = temp[:peak_num]
posi = [loc_peaks[i][0] for i in range(peak_num)]
T_fft = period[posi]
df_freq_domain = pd.DataFrame({'周期(月)':period, '振幅':power, 'amplitude':amplitude})
# if figure_flag == 1:
# mpl.rcParams['font.family'] = 'sans-serif'
#
# mpl.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
# #mpl.rcParams['font.sans-serif'] = ['/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc']
# mpl.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
# plt.figure(figsize=(20, 10), dpi=200)
# plt.plot(period, power)
# plt.grid(False)
# plt.xlabel('周期(月)')
# plt.ylabel('振幅')
# plt.title('周期-振幅图')
# for i in range(peak_num):
# plt.text(period[loc_peaks[i][0]], loc_peaks[i][1],
# '[' + str(round(period[loc_peaks[i][0]], 2)) + ',' + str(round(loc_peaks[i][1])) + ']')
# plt.xlim(12, 300)
# plt.savefig('周期振幅图')
return T_fft, df_freq_domain
def gauss_wave_predict(wave, period, n_fft, n_predict, gauss_alpha):
'''
高斯滤波
'''
# 1、填充0
wave_pad = np.concatenate([np.zeros(n_fft - len(wave), ), wave])
# 2、进行FFT变换
wave_fft = np.fft.fft(wave_pad, n_fft)
# 3、生成高斯滤波频率响应,注意这里只刻画了低频部分,后续做共轭对称处理
gauss_index = [i for i in range(int(n_fft))]
center_frequency = n_fft / period + 1
gauss_win = np.exp(-np.power((np.array(gauss_index) - center_frequency), 2) / (gauss_alpha ** 2))
# 4、频域滤波,因为时域为实数,所以频域序列有共轭对称的属性
wave_filter = np.multiply(wave_fft, gauss_win)
if np.mod(n_fft, 2) == 0:
wave_filter[int(n_fft / 2 + 1):] = np.conj(wave_filter[int(n_fft / 2 - 1):0:-1])
else:
wave_filter[int((n_fft - 1) / 2 + 1):] = np.conj(wave_filter[int((n_fft - 1) / 2):0:-1])
# 5、逆傅里叶变换得到时域还原序列,外延预测本质上是在延拓主值序列
ret = np.fft.ifft(wave_filter).real
output = np.concatenate([ret[int(len(ret) - len(wave)):], ret[:int(n_predict)]])
return output
def regress_predict_output_f(seq, predict_len, pad_to_len, gauss_alpha, mean_flag, period_flag, peak_num, figure_flag):
'''
回归、预测、输出
'''
df_freq_domain = pd.DataFrame()
if mean_flag == 0:
seq_len = len(seq) # 输入序列的长度
trend_a_seq = np.zeros((seq_len, 1)) # 输出去除方式,不处理的意思即为全0序列
d_a_seq = seq
predict_trend_seq = np.zeros((seq_len + predict_len, 1)) # 输出序列长度为输入序列长度+预测长度,因为没处理原数据,基准值为0
elif mean_flag == 1:
seq_len = len(seq)
trend_a_seq = np.nanmean(seq) # 均值,常量
d_a_seq = seq - trend_a_seq
predict_trend_seq = np.ones((seq_len + predict_len, 1)) * np.nanmean(seq) # 基准值为均值
else:
seq_len = len(seq)
d_a_seq = signal.detrend(seq) # 去趋势后的序列
trend_a_seq = seq - d_a_seq # 趋势值,长度为trend_a_seq的序列
predict_trend_seq = np.ones((seq_len + predict_len)) * np.nan
predict_trend_seq[:seq_len] = trend_a_seq # 输出序列原数据长度部分为 趋势值序列
predict_trend_seq = interpolation(predict_trend_seq) # 输出序列预测长度部分 为用趋势值序列插值的结果
if period_flag == '固定周期':
#period = [30, 40, 113]
period = [28, 42, 132] #上证50得出
print('固定周期')
#period = [42, 70, 200]
else:
period, df_freq_domain = period_mean_fft(d_a_seq, pad_to_len, peak_num, figure_flag)
if len(period) == 2:
period.append(200)
elif len(period) == 1:
if period[0] < 60:
period.extend([100, 200])
else:
period.extend([40, 200])
# 在对数同比序列前面补0,提升频域分辨率
d_seq_pad = np.zeros((pad_to_len,))
d_seq_pad[-seq_len:] = d_a_seq
# 高斯滤波获取三周期对应的序列以及预测结果
filter_result = np.zeros((pad_to_len + predict_len, len(period)))
for iPeriod in range(len(period)):
filter_result[:, iPeriod] = gauss_wave_predict(d_seq_pad, period[iPeriod], pad_to_len, predict_len, gauss_alpha)
filter_result = filter_result[-(seq_len + predict_len):, :]
Y = pd.DataFrame(d_a_seq)
regress_result = np.zeros((4, 6))
# 单变量回归
for iPeriod in range(len(period)):
X = pd.DataFrame(filter_result[:seq_len, iPeriod])
X = sm.add_constant(X)
est = sm.OLS(Y, X).fit()
regress_result[iPeriod, :2] = np.array(est.params)
regress_result[iPeriod, 4] = est.rsquared
regress_result[iPeriod, 5] = est.f_pvalue
# 多变量回归
X = pd.DataFrame(filter_result[:seq_len, :])
X = sm.add_constant(X)
est = sm.OLS(Y, X).fit()
regress_result[iPeriod, :4] = np.array(est.params)
regress_result[iPeriod, 4] = est.rsquared
regress_result[iPeriod, 5] = est.f_pvalue
# pdb.set_trace()
predict_result_temp = np.dot(np.concatenate([np.ones((seq_len + predict_len, 1)), filter_result], axis=1),
np.array(est.params))
predict_result = predict_result_temp + predict_trend_seq
# pdb.set_trace()
return d_a_seq, trend_a_seq, filter_result, predict_trend_seq, predict_result_temp, predict_result, period, regress_result, df_freq_domain
def plot_graph(df_seq:pd.DataFrame, df_predict:pd.DataFrame, freq_domain:pd.DataFrame, to_html:str):
#fig = make_subplots(rows=1, cols=2, specs=[[{'secondary_y': True}], {None}])
# row_height两行高度分配, vertical_spacing 两行间隔宽度
fig = make_subplots(rows=2, subplot_titles=('指数同比拟合', '周期-振幅'), row_heights=[0.6, 0.4], vertical_spacing=0.1)
# 绘制原始数据
# fig.add_trace(go.Scatter(x=df_seq['交易时间'], y=df_seq['收盘对数同比'], name='对数同比', mode='lines'))
fig.add_trace(go.Scatter(x=df_predict['交易时间'], y=df_seq['收盘对数同比'], name='对数同比', mode='lines'), row=1, col=1)
# 绘制拟合曲线
fig.add_trace(go.Scatter(x=df_predict['交易时间'], y=df_predict['预测值'], mode='lines', name='拟合曲线'), row=1, col=1)
# 添加额外的坐标轴和数据集
if len(freq_domain):
fig.add_trace(go.Scatter(x=freq_domain['周期(月)'], y=freq_domain['振幅'], name='周期-振幅', mode='lines'), row=2, col=1)
# 设置图表布局
fig.update_layout(title='指数同比拟合/周期-振幅',
xaxis_title='日期',
yaxis_title='指数收盘对数同比',
legend=dict(x=0.8, y=0.95))
# 第二行使用x对数坐标轴
fig.update_xaxes(type='log', row=2, col=1)
# 显示图表
if to_html:
plot(fig, filename='中国指数同比拟合.html', include_plotlyjs='include_plotlyjs')
fig.show()
#在同一图标通过按钮切换,尚有bug
def plot_graph1(df_seq:pd.DataFrame, df_predict:pd.DataFrame, freq_domain:pd.DataFrame, to_html:str):
# 创建图表布局
fig = make_subplots(rows=1, cols=1, specs=[[{'secondary_y': True}]])
# 添加原始数据和拟合曲线
fig.add_trace(go.Scatter(x=df_predict['交易时间'], y=df_seq['收盘对数同比'], name='收盘对数同比', mode='lines'))
fig.add_trace(go.Scatter(x=df_predict['交易时间'], y=df_predict['预测值'], mode='lines', name='拟合预测'), secondary_y=True)
# 添加额外的坐标轴和数据集
fig.add_trace(go.Scatter(x=freq_domain['周期(月)'], y=freq_domain['振幅'], name='周期-振幅', mode='lines', xaxis='x2', yaxis='y2'))
# 设置图表布局
fig.update_layout(
title='指数同比拟合',
xaxis_title='日期(月)',
yaxis_title='收盘对数同比/周期振幅',
legend=dict(x=0.7, y=0.95),
updatemenus=[
dict(
buttons=[
dict(
label="收盘对数同比",
method="update",
args=[{"visible": [True, True, False],
"xaxis": dict(title='日期(月)'),
"yaxis": dict(title='收盘对数同比/频域振幅'),
"xaxis2":dict(visible=False),
"yaxis2": dict(visible=True)}] # 显示第一个数据集,隐藏第二,三个数据集
),
dict(
label="周期-振幅",
method="update",
args=[{"visible": [False, False, True],
"xaxis2": dict(title='频域周期',domain=[0,1000]),
"yaxis2": dict(title='频域振幅'),
"xaxis": dict(visible=False),
"yaxis": dict(visible=False)}] # 显示第三个数据集,隐藏第一,二个数据集
)
],
direction="down",
showactive=False,
x=0,
xanchor="left",
y=1,
yanchor="top"
)
]
)
# 显示图表
if to_html:
#fig.to_html(to_html)
plot(fig, filename='中国指数同比拟合.html', include_plotlyjs='include_plotlyjs')
fig.show()
#文件路径
bars = pd.read_csv(r'D:\China_macro\K线导出_000001_日线数据.csv')
#a_seq = np.array(shanghai['收盘价'])
df_seq = daily2monthly(bars, '交易时间', '收盘价')
a_seq = np.array(df_seq['收盘价'])
# 求对数同序列
log_a_seq = np.log(a_seq[12:]) - np.log(a_seq[:-12]) # 对数同比序列
# 因为同比所有前12个月没有数据
df_seq = df_seq[12:]
df_seq['收盘对数同比'] = log_a_seq
df_seq = df_seq.reset_index()
predict_len = 12 * 5 # 预测长度,单位为月
pad_to_len = 4096 # 填0后长度,填0是为了提升频谱分辨率
gauss_alpha = 1 # 高斯滤波器带宽,推荐设置为1
mean_flag = 1 # 数据处理方式1:去均值项 0:不处理 2:去趋势项
#period_flag = '变化周期' # 中心频率选择方式:固定周期取42,100,200 ;变化周期由傅里叶变换则计算得出
period_flag = '固定周期'
# 我:固定周期采取3个:30个月为变换计算得出,42个月基钦周期,122个月朱格拉周期
peak_num = 3 # 提取三大周期。
figure_flag = 1 # 0:傅里叶变换不画图 1:画图
#输出预测结果
[d_a_seq, trend_a_seq, filter_result, predict_trend_seq, predict_result_temp, predict_result, period, regress_result, freq_domain]\
= regress_predict_output_f(log_a_seq[0:], predict_len, pad_to_len, gauss_alpha, mean_flag, period_flag, peak_num, figure_flag)
pre_result = predict_result_temp + predict_trend_seq.reshape(len(predict_trend_seq))
#df_seq['交易时间'] = df_seq['交易时间'].apply(lambda x: pd.to_datetime(x).date)
#创建预测序列时间和值
df_predict = pd.DataFrame({'预测值': pre_result})
#将对象户序列的时间赋值给预测序列
df_predict['交易时间'] = pd.Series(df_seq['交易时间'].values[:len(df_seq)])
# 获取"交易时间"列的最后一个有效日期
last_date = pd.to_datetime(df_predict['交易时间'].dropna().iloc[-1])
# 生成与缺失行数相同的日期序列,从最后一个日期的这个月开始
new_dates = pd.date_range(start=last_date + pd.offsets.MonthEnd(), periods=len(df_predict) - len(df_predict.dropna()), freq='M')
# 使用生成的日期序列填充"交易时间"列的NaT值
df_predict.loc[df_predict['交易时间'].isna(), '交易时间'] = new_dates
plot_graph(df_seq, df_predict, freq_domain, '上证指数预测拟合.html')
# date = [datetime.datetime.strptime(x, '%Y%m%d') for x in list(data['trade_date'])]
# max_months = max(len(log_a_seq), len(pre_result))
# while (True):
# if len(date) < max_months:
# if date[-1].month != 12:
# year = date[-1].year
# month = date[-1].month
# else:
# year = date[-1].year + 1
# month = 1
# date.append(date[-1] + datetime.timedelta(days=calendar.monthrange(year, month)[1]))
# else:
# break
# date = [str(x)[:4] + str(x)[5:7] for x in date]
# fig = plt.figure(figsize=(16, 8), dpi=200)
# ax = fig.add_axes([0, 0, 1, 1])
# l1, = ax.plot(range(0, len(log_a_seq)), log_a_seq, 'b', label='同比序列')
# l3, = ax.plot(range(0, len(pre_result)), pre_result, 'y', label='预测走势')
# xticks = [x for x in
# range(0, (datetime.date.today().year - datetime.datetime.strptime(date[0], '%Y%m').year) * 12 + 6, 12)]
# xticks.extend([x for x in
# range((datetime.date.today().year - datetime.datetime.strptime(date[0], '%Y%m').year) * 12 + 6,
# len(date), 6)])
# date_display = date[0:(datetime.date.today().year - datetime.datetime.strptime(date[0], '%Y%m').year) * 12 + 6:12]
# date_display.extend(date[(datetime.date.today().year - datetime.datetime.strptime(date[0], '%Y%m').year) * 12 + 6::6])
# ax.set_xticks(xticks)
# ax.set_xticklabels(date_display, rotation=30)
# ax2 = ax.twinx()
# l2, = ax2.plot(range(0, len(log_a_seq)), data['上证综指'], 'r', label='上证综指')
# ax.grid(False)
# handles1, labels1 = ax.get_legend_handles_labels()
# handles2, labels2 = ax2.get_legend_handles_labels()
# plt.legend(handles=[l1, l2, l3], loc='upper center', fontsize=20, frameon=False, ncol=2, columnspacing=10)
# plt.savefig('预测', bbox_inches='tight')
# plt.show()