基于PaddleTS的LSTNet时序预测模型实现中国人口预测
基于PaddleTS的LSTNet时序预测模型实现中国人口预测
1. 引言 ✨
1.1 项目简介 🎄
本项目属于机器学习范畴,根据指定数据集(中国人口数据集等)使用PaddleTS进行LSTNet网络模型搭建全流程,包括数据预处理、模型构建、模型训练、模型预测、预测结果可视化等。
-
我们将根据中国人口数据集中的多个特征(features),例如:出生人口(万)、中国人均GPA(美元计)、中国性别比例(按照女生=100)、自然增长率(%)等8个特征字段,预测中国未来总人口(万人)这1个标签字段。属于多输入,单输出LSTM神经网路预测范畴。
-
对于本项目使用的工具PaddleTS,PaddleTS是一个易用的深度时序建模的Python库,它基于飞桨深度学习框架PaddlePaddle,专注业界领先的深度模型,旨在为领域专家和行业用户提供可扩展的时序建模能力和便捷易用的用户体验。
-
尤其是其内置业界领先的深度学习模型,包括NBEATS、NHiTS、LSTNet、TCN、Transformer, DeepAR(概率预测)、Informer等时序预测模型,以及TS2Vec等时序表征模型。本项目将使用其中的LSTNet深度学习模型完成项目开发。
🌈 LSTM(Long Short Term Memory networks)出现以来,在捕获时间序列依赖关系方面表现出了强大的潜力,LSTM由三个门来控制细胞状态,这三个门分别称为忘记门、输入门和输出门,细胞状态像传送带一样。它贯穿整个细胞却只有很少的分支,这样能保证信息不变的流过整个RNN。LSTM是一种特殊的RNN,主要是为了解决长序列训练过程中的梯度消失和梯度爆炸问题。
🌈 2018年,正式提出了LSTNet。该网络是一种专门设计用于时间序列预测的深度学习网络,LSTNet的出现可以认为是研究人员通过注意力机制提升LSTM模型时序预测能力的一次尝试。
- 卷积层组件(Convolutional Component)
- LSTNet的第一层是一个没有池化的卷积网络,用于抽取时间维度的短期模式和变量间的局部依赖;卷积层由若干个宽为ω,高为n的多个过滤器组成(高与变量的数目相同)。
- 循环(递归)层组件(Recurrent Component)
- 卷积层的输出同时输入到递归分量(the Recurrent component)和递归跳过分量(Recurrent-skipcomponent)(在第3小节中进行描述)。递归分量是门控递归单元(GRU)的递归层,使用RELU函数作为隐式更新激活函数。
- Recurrent-skip组件
- 循环层由GRU和LSTM单元精心设计,以记住历史信息,从而了解相对长期的依赖关系。然而,由于梯度消失,GRU和LSTM在实际应用中往往不能捕捉到非常长期的相关性。
- 时间注意层(Temporal AttentionLayer)
- 注意机制,它学习输入矩阵的每个窗口位置上隐藏表示的加权组合。
- 自回归组件(AutoregressiveComponent)
- LSTNet的最终预测分解为线性部分(主要关注局部尺度问题)和包含重复模式的非线性部分,在LSTNet体系结构中,采用经典的自回归(AR)模型作为线性分量。
1.2 数据集介绍 🌲
本项目使用的数据集为中国人口预测数据集,包含10个字段,其中8个特征字段,1个标签字段,1个行索引字段,数据集各字段对应的数据类型如下表所示:
- 有些字段为int64类型,需要经过相关的数据处理,才可传入模型进行训练。
- 数据包含50条样本,因此应该合理确定训练数据、测试数据和验证数据
年份 | 出生人口(万) | 总人口(万人) | 中国人均GPA(美元计) | 中国性别比例(按照女生=100) | 自然增长率(%) | 城镇人口(城镇+乡村=100) | 乡村人口 | 美元兑换人民币汇率 | 中国就业人口(万人) |
---|---|---|---|---|---|---|---|---|---|
int64 | int64 | int64 | int64 | float64 | float64 | float64 | float64 | float64 | int64 |
本项目使用环境为:V100;RAM32GB;显存32GB;磁盘100GB;4核CPU;
2. 环境准备
安装模块并导入模块
2.1 安装依赖库
# 安装paddlets依赖库
!pip install paddlets -q
! pip install seaborn -q
2.2 导入所需库
%matplotlib inline
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import paddlets
from paddlets import TSDataset
from paddlets import TimeSeries
from paddlets.models.forecasting import MLPRegressor, LSTNetRegressor
from paddlets.transform import Fill, StandardScaler
from paddlets.metrics import MSE, MAE
3. 数据处理
3.1 导入数据
population = pd.read_csv("data/data140190/人口.csv")
population.head()
年份 | 出生人口(万) | 总人口(万人) | 中国人均GPA(美元计) | 中国性别比例(按照女生=100) | 自然增长率(%) | 城镇人口(城镇+乡村=100) | 乡村人口 | 美元兑换人民币汇率 | 中国就业人口(万人) | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 1970 | 2710 | 82542 | 113 | 105.90 | 25.95 | 17.38 | 82.62 | 2.4618 | 34432 |
1 | 1971 | 2551 | 84779 | 118 | 105.82 | 23.40 | 17.26 | 82.74 | 2.2673 | 35520 |
2 | 1972 | 2550 | 86727 | 131 | 105.78 | 22.27 | 17.13 | 82.87 | 2.2401 | 35854 |
3 | 1973 | 2447 | 88761 | 157 | 105.86 | 20.99 | 17.20 | 82.80 | 2.0202 | 36652 |
4 | 1974 | 2226 | 90409 | 160 | 105.88 | 17.57 | 17.16 | 82.84 | 1.8397 | 37369 |
population['date'] = pd.to_datetime(population['年份'])
3.2 查看各字段类型
因为数据集中有些字段为int64类型,需要转化成float类型,才能在后续导入模型中进行训练,否则会报错。
population.dtypes
年份 int64
出生人口(万) int64
总人口(万人) int64
中国人均GPA(美元计) int64
中国性别比例(按照女生=100) float64
自然增长率(%) float64
城镇人口(城镇+乡村=100) float64
乡村人口 float64
美元兑换人民币汇率 float64
中国就业人口(万人) int64
date datetime64[ns]
dtype: object
3.3 数据可视化
3.3.1 特征(features)折线图
绘制出各个特征与年份索引之间的折线图,进行初步观察
因为数据集中包含中文字段,想要能够在绘图中正常显示中文,需要进行如下设定:
from pylab import mpl
from matplotlib.font_manager import FontProperties
myfont=FontProperties(fname=r'/usr/share/fonts/fangzheng/FZSYJW.TTF',size=12)
sns.set(font=myfont.get_name())
from pylab import mpl
from matplotlib.font_manager import FontProperties
myfont=FontProperties(fname=r'/usr/share/fonts/fangzheng/FZSYJW.TTF',size=12)
sns.set(font=myfont.get_name())
titles = [
"出生人口(万)",
"总人口(万人)",
"中国人均GPA(美元计)",
"中国性别比例(按照女生=100)",
"自然增长率(%)",
"城镇人口(城镇+乡村=100)",
"乡村人口",
"美元兑换人民币汇率",
"中国就业人口(万人)",
]
feature_keys = [
"出生人口(万)",
"总人口(万人)",
"中国人均GPA(美元计)",
"中国性别比例(按照女生=100)",
"自然增长率(%)",
"城镇人口(城镇+乡村=100)",
"乡村人口",
"美元兑换人民币汇率",
"中国就业人口(万人)",
]
colors = [
"blue",
"chocolate",
"green",
"red",
"purple",
"brown",
"darkblue",
"black",
"magenta",
]
date_time_key = "年份"
def show_raw_visualization(data):
time_data = data[date_time_key]
fig, axes = plt.subplots(
nrows=3, ncols=3, figsize=(15, 15), dpi=100, facecolor="w", edgecolor="k"
)
for i in range(len(feature_keys)):
key = feature_keys[i]
c = colors[i % (len(colors))]
t_data = data[key]
t_data.index = time_data
t_data.head()
ax = t_data.plot(
ax=axes[i // 3, i % 3],
color=c,
title="{}".format(titles[i], key),
rot=25,
)
ax.legend([titles[i]])
plt.tight_layout()
show_raw_visualization(population)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/pandas/plotting/_matplotlib/tools.py:331: MatplotlibDeprecationWarning:
The is_first_col function was deprecated in Matplotlib 3.4 and will be removed two minor releases later. Use ax.get_subplotspec().is_first_col() instead.
if ax.is_first_col():
3.3.2 箱型图
查看部分数据的分布情况,下面抽取了出生人口(万)、总人口(万人)、中国人均GPA(美元计)、中国就业人口(万人)这四个字段进行箱型图展示。
from pylab import mpl
from matplotlib.font_manager import FontProperties
myfont=FontProperties(fname=r'/usr/share/fonts/fangzheng/FZSYJW.TTF',size=12)
sns.set(font=myfont.get_name())
plt.figure(figsize=(15,8),dpi=100)
plt.subplot(1,4,1)
sns.boxplot(y="出生人口(万)", data=population, saturation=0.9)
plt.subplot(1,4,2)
sns.boxplot(y="总人口(万人)", data=population, saturation=0.9)
plt.subplot(1,4,3)
sns.boxplot(y="中国人均GPA(美元计)", data=population, saturation=0.9)
plt.subplot(1,4,4)
sns.boxplot(y="中国就业人口(万人)", data=population, saturation=0.9)
plt.tight_layout()
3.3.3 相关性分析
查看变量两两之间的相关性,可以考虑将两个相关性较强的特征选择一个进行保留,因为本数据集特征字段不是很多,就不考虑剔除了。
corr = population.corr()
# 调用热力图绘制相关性关系
plt.figure(figsize=(10,10),dpi=100)
sns.heatmap(corr, square=True, linewidths=0.1, annot=True)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/seaborn/utils.py:80: UserWarning: Glyph 8722 (\N{MINUS SIGN}) missing from current font.
fig.canvas.draw()
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/IPython/core/events.py:89: UserWarning: Glyph 8722 (\N{MINUS SIGN}) missing from current font.
func(*args, **kwargs)
<AxesSubplot:>
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/IPython/core/pylabtools.py:151: UserWarning: Glyph 8722 (\N{MINUS SIGN}) missing from current font.
fig.canvas.print_figure(bytes_io, **kw)
3.4 数据预处理
3.4.1 打印特征字段
因为数据集特征数量有限,我们保留了数据集中所有的特征字段,下面进行对选取的特征进行打印和展示。
print(
"选取的参数指标是:",
", ".join([titles[i] for i in [0, 1, 2, 3, 4, 5, 6, 7, 8]]),
)
selected_features = [feature_keys[i] for i in [0, 1, 2, 3, 4, 5, 6, 7, 8]]
features = population[selected_features]
features.index = population[date_time_key]
features.head()
选取的参数指标是: 出生人口(万), 总人口(万人), 中国人均GPA(美元计), 中国性别比例(按照女生=100), 自然增长率(%), 城镇人口(城镇+乡村=100), 乡村人口, 美元兑换人民币汇率, 中国就业人口(万人)
出生人口(万) | 总人口(万人) | 中国人均GPA(美元计) | 中国性别比例(按照女生=100) | 自然增长率(%) | 城镇人口(城镇+乡村=100) | 乡村人口 | 美元兑换人民币汇率 | 中国就业人口(万人) | |
---|---|---|---|---|---|---|---|---|---|
年份 | |||||||||
1970 | 2710 | 82542 | 113 | 105.90 | 25.95 | 17.38 | 82.62 | 2.4618 | 34432 |
1971 | 2551 | 84779 | 118 | 105.82 | 23.40 | 17.26 | 82.74 | 2.2673 | 35520 |
1972 | 2550 | 86727 | 131 | 105.78 | 22.27 | 17.13 | 82.87 | 2.2401 | 35854 |
1973 | 2447 | 88761 | 157 | 105.86 | 20.99 | 17.20 | 82.80 | 2.0202 | 36652 |
1974 | 2226 | 90409 | 160 | 105.88 | 17.57 | 17.16 | 82.84 | 1.8397 | 37369 |
3.4.2 重复值检测
查看特征中是否包含重复值,返回false说明没有重复值。无需剔除。
features.duplicated().any()
False
3.4.3 缺失值检测
查看是否有缺失的样本。返回True说明无缺失值,无需进行额外处理。
pd.notnull(features).all()
出生人口(万) True
总人口(万人) True
中国人均GPA(美元计) True
中国性别比例(按照女生=100) True
自然增长率(%) True
城镇人口(城镇+乡村=100) True
乡村人口 True
美元兑换人民币汇率 True
中国就业人口(万人) True
dtype: bool
3.4.4 转换字段类型
我们需要将int类型字段转化为float类型的字段,方可用于模型训练,否则会报错。
- 首先使用如下语句查看数据集各个字段类型。
- 接下来我们将int64转为float64并替换原数据字段
population.dtypes
年份 int64
出生人口(万) int64
总人口(万人) int64
中国人均GPA(美元计) int64
中国性别比例(按照女生=100) float64
自然增长率(%) float64
城镇人口(城镇+乡村=100) float64
乡村人口 float64
美元兑换人民币汇率 float64
中国就业人口(万人) int64
dtype: object
population.head()
年份 | 出生人口(万) | 总人口(万人) | 中国人均GPA(美元计) | 中国性别比例(按照女生=100) | 自然增长率(%) | 城镇人口(城镇+乡村=100) | 乡村人口 | 美元兑换人民币汇率 | 中国就业人口(万人) | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 1970 | 2710 | 82542 | 113 | 105.90 | 25.95 | 17.38 | 82.62 | 2.4618 | 34432 |
1 | 1971 | 2551 | 84779 | 118 | 105.82 | 23.40 | 17.26 | 82.74 | 2.2673 | 35520 |
2 | 1972 | 2550 | 86727 | 131 | 105.78 | 22.27 | 17.13 | 82.87 | 2.2401 | 35854 |
3 | 1973 | 2447 | 88761 | 157 | 105.86 | 20.99 | 17.20 | 82.80 | 2.0202 | 36652 |
4 | 1974 | 2226 | 90409 | 160 | 105.88 | 17.57 | 17.16 | 82.84 | 1.8397 | 37369 |
population['出生人口(万)'] = population['出生人口(万)'].astype('float64')
population['总人口(万人)'] = population['总人口(万人)'].astype('float64')
population['中国人均GPA(美元计)'] = population['中国人均GPA(美元计)'].astype('float64')
population['中国就业人口(万人)'] = population['中国就业人口(万人)'].astype('float64')
# population['date'] = population['年份'].astype('str')
# population['date'] = pd.to_datetime(population['date'])
以下面命令展示出处理后的数据,若“出生人口(万)”字段出现NaN,则需要重新读取一次数据进行以上的数据处理即可。
population.head()
年份 | 出生人口(万) | 总人口(万人) | 中国人均GPA(美元计) | 中国性别比例(按照女生=100) | 自然增长率(%) | 城镇人口(城镇+乡村=100) | 乡村人口 | 美元兑换人民币汇率 | 中国就业人口(万人) | 出生人口(万)1 | |
---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1970 | 2710.0 | 82542.0 | 113.0 | 105.90 | 25.95 | 17.38 | 82.62 | 2.4618 | 34432.0 | 2710.0 |
1 | 1971 | 2551.0 | 84779.0 | 118.0 | 105.82 | 23.40 | 17.26 | 82.74 | 2.2673 | 35520.0 | 2551.0 |
2 | 1972 | 2550.0 | 86727.0 | 131.0 | 105.78 | 22.27 | 17.13 | 82.87 | 2.2401 | 35854.0 | 2550.0 |
3 | 1973 | 2447.0 | 88761.0 | 157.0 | 105.86 | 20.99 | 17.20 | 82.80 | 2.0202 | 36652.0 | 2447.0 |
4 | 1974 | 2226.0 | 90409.0 | 160.0 | 105.88 | 17.57 | 17.16 | 82.84 | 1.8397 | 37369.0 | 2226.0 |
3.5 构造TSDataset
TSDataset 是 PaddleTS 中最主要的类之一,其被设计用来表示绝大多数时序样本数据。通常,时序数据可以分为以下几种:
- 单变量数据,只包含单列的预测目标,同时可以包涵单列或者多列协变量
- 多变量数据,包涵多列预测目标,同时可以包涵单列或者多列协变量
TSDataset 需要包含time_index属性,time_index支持 pandas.DatetimeIndex 和 pandas.RangeIndex 两种类型。
target_cov_dataset = TSDataset.load_from_dataframe(
population,
time_col='年份',
target_cols='总人口(万人)',
observed_cov_cols=['出生人口(万)', '中国人均GPA(美元计)', '中国性别比例(按照女生=100)', '自然增长率(%)',
'城镇人口(城镇+乡村=100)', '乡村人口', '美元兑换人民币汇率', '中国就业人口(万人)'],
fill_missing_dates=True,
fillna_method='pre'
)
target_cov_dataset.plot(['总人口(万人)', '出生人口(万)', '中国人均GPA(美元计)', '中国性别比例(按照女生=100)', '自然增长率(%)',
'城镇人口(城镇+乡村=100)', '乡村人口', '美元兑换人民币汇率', '中国就业人口(万人)'])
<AxesSubplot:>
target_cov_dataset
总人口(万人) 出生人口(万) 中国人均GPA(美元计) 中国性别比例(按照女生=100) 自然增长率(%) \
1970 82542.0 2710.0 113.0 105.90 25.95
1971 84779.0 2551.0 118.0 105.82 23.40
1972 86727.0 2550.0 131.0 105.78 22.27
1973 88761.0 2447.0 157.0 105.86 20.99
1974 90409.0 2226.0 160.0 105.88 17.57
1975 91970.0 2102.0 178.0 106.04 15.77
1976 93267.0 1849.0 165.0 106.15 12.72
1977 94774.0 1783.0 185.0 106.17 12.12
1978 96159.0 1733.0 156.0 106.16 12.00
1979 97542.0 1715.0 183.0 106.00 11.61
1980 98705.0 1776.0 194.0 105.98 11.87
1981 100072.0 2064.0 197.0 106.11 14.55
1982 101654.0 2230.0 203.0 106.19 15.68
1983 103008.0 2052.0 225.0 106.61 13.29
1984 104357.0 2050.0 250.0 106.61 13.08
1985 105851.0 2196.0 294.0 107.04 14.26
1986 107507.0 2374.0 281.0 107.04 15.57
1987 109300.0 2508.0 251.0 106.19 16.61
1988 111026.0 2445.0 283.0 106.27 15.73
1989 112704.0 2396.0 310.0 106.40 15.04
1990 114333.0 2374.0 317.0 106.27 14.39
1991 115823.0 2250.0 333.0 105.52 12.98
1992 117171.0 2113.0 366.0 104.27 11.50
1993 118517.0 2120.0 377.0 104.18 11.45
1994 119850.0 2098.0 473.0 104.51 11.21
1995 121121.0 2052.0 609.0 104.21 10.55
1996 122389.0 2057.0 709.0 103.34 10.42
1997 123626.0 2028.0 781.0 104.36 10.06
1998 124761.0 1934.0 828.0 105.13 9.14
1999 125786.0 1827.0 873.0 105.89 8.18
2000 126743.0 1765.0 959.0 106.74 7.58
2001 127627.0 1696.0 1053.0 106.00 6.95
2002 128453.0 1641.0 1148.0 106.06 6.45
2003 129227.0 1594.0 1288.0 106.20 6.01
2004 129988.0 1588.0 1508.0 106.29 5.87
2005 130756.0 1612.0 1753.0 106.30 5.89
2006 131448.0 1581.0 2099.0 106.29 5.28
2007 132129.0 1591.0 2693.0 106.19 5.17
2008 132802.0 1604.0 3468.0 106.07 5.08
2009 133450.0 1587.0 3832.0 105.93 4.87
2010 134091.0 1588.0 4550.0 105.21 4.79
2011 134735.0 1600.0 5618.0 105.18 4.79
2012 135404.0 1635.0 6316.0 105.13 4.95
2013 136072.0 1640.0 7050.0 105.10 4.92
2014 136782.0 1687.0 7678.0 105.06 5.21
2015 137462.0 1655.0 8066.0 105.02 4.96
2016 138271.0 1786.0 8147.0 104.98 5.86
2017 139008.0 1723.0 8879.0 104.81 5.32
2018 139538.0 1523.0 9976.0 104.64 3.81
2019 140005.0 1465.0 10261.0 104.60 3.34
城镇人口(城镇+乡村=100) 乡村人口 美元兑换人民币汇率 中国就业人口(万人)
1970 17.38 82.62 2.4618 34432.0
1971 17.26 82.74 2.2673 35520.0
1972 17.13 82.87 2.2401 35854.0
1973 17.20 82.80 2.0202 36652.0
1974 17.16 82.84 1.8397 37369.0
1975 17.34 82.66 1.9663 38168.0
1976 17.44 82.55 1.8803 38834.0
1977 17.55 82.45 1.7300 39377.0
1978 17.92 82.08 1.5771 40152.0
1979 18.96 81.04 1.4962 41024.0
1980 19.39 80.61 1.5303 42361.0
1981 20.16 79.84 1.7051 43725.0
1982 21.13 78.87 1.8926 45295.0
1983 21.52 78.38 1.9757 46436.0
1984 23.01 76.99 2.3270 48197.0
1985 23.71 76.29 2.9367 49873.0
1986 24.52 75.48 3.4528 51282.0
1987 25.32 74.68 3.7221 52783.0
1988 25.81 74.19 3.7221 54334.0
1989 25.21 73.79 3.7659 55329.0
1990 25.41 73.59 4.7838 64749.0
1991 26.94 73.06 5.3227 65491.0
1992 27.46 72.54 5.5149 66152.0
1993 27.99 72.01 5.7619 66808.0
1994 28.51 71.49 8.6187 67455.0
1995 29.04 70.95 8.3507 68065.0
1996 30.48 69.52 8.3142 68950.0
1997 31.91 68.09 8.2898 69820.0
1998 33.35 66.55 8.2791 70537.0
1999 34.78 65.22 8.2796 71394.0
2000 36.22 63.78 8.2784 72085.0
2001 37.66 62.34 8.2770 72797.0
2002 39.09 60.91 8.2770 73280.0
2003 40.53 59.47 8.2774 73736.0
2004 41.76 58.24 8.2780 74264.0
2005 42.99 57.01 8.1013 74547.0
2006 44.34 55.56 7.8087 74978.0
2007 45.89 54.11 7.3872 75321.0
2008 46.99 53.01 6.8500 75564.0
2009 48.34 51.66 6.8100 75828.0
2010 49.95 50.05 6.6220 76105.0
2011 51.27 48.73 6.6100 76420.0
2012 52.57 47.43 6.2500 76704.0
2013 53.73 46.27 6.0700 76977.0
2014 54.77 45.23 6.0500 77253.0
2015 56.10 43.90 6.2000 77451.0
2016 57.35 42.65 6.5600 77603.0
2017 58.52 41.48 6.5000 77640.0
2018 59.58 40.42 6.8600 77586.0
2019 60.60 39.40 6.8967 77471.0
target_cov_dataset.summary()
总人口(万人) | 出生人口(万) | 中国人均GPA(美元计) | 中国性别比例(按照女生=100) | 自然增长率(%) | 城镇人口(城镇+乡村=100) | 乡村人口 | 美元兑换人民币汇率 | 中国就业人口(万人) | |
---|---|---|---|---|---|---|---|---|---|
missing | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
count | 50.000000 | 50.000000 | 50.000000 | 50.000000 | 50.000000 | 50.000000 | 50.000000 | 50.000000 | 50.000000 |
mean | 116769.640000 | 1943.420000 | 2120.840000 | 105.673600 | 10.741200 | 33.584800 | 66.368800 | 5.219768 | 61200.560000 |
std | 17393.764444 | 334.886157 | 3010.987419 | 0.814879 | 5.537814 | 14.148318 | 14.125293 | 2.580723 | 15633.538166 |
min | 82542.000000 | 1465.000000 | 113.000000 | 103.340000 | 3.340000 | 17.130000 | 39.400000 | 1.496200 | 34432.000000 |
25% | 101992.500000 | 1640.250000 | 208.500000 | 105.107500 | 5.455000 | 21.227500 | 54.472500 | 2.282225 | 45580.250000 |
50% | 120485.500000 | 1838.000000 | 541.000000 | 105.955000 | 10.880000 | 28.775000 | 71.220000 | 6.060000 | 67760.000000 |
75% | 131958.750000 | 2177.000000 | 2544.500000 | 106.190000 | 14.357500 | 45.502500 | 78.747500 | 7.703325 | 75235.250000 |
max | 140005.000000 | 2710.000000 | 10261.000000 | 107.040000 | 25.950000 | 60.600000 | 82.870000 | 8.618700 | 77640.000000 |
3.6 划分数据集
训练集:验证集:测试集 = 0.6 :0.2 :0.2
train_dataset, val_test_dataset = target_cov_dataset.split(0.6)
val_dataset, test_dataset = val_test_dataset.split(0.5)
train_dataset.plot(add_data=[val_dataset,test_dataset], labels=['Val', 'Test'])
<AxesSubplot:>
3.7 归一化
scaler = StandardScaler()
scaler.fit(train_dataset)
train_dataset_scaled = scaler.transform(train_dataset)
val_test_dataset_scaled = scaler.transform(val_test_dataset)
val_dataset_scaled = scaler.transform(val_dataset)
test_dataset_scaled = scaler.transform(test_dataset)
train_dataset_scaled
总人口(万人) 出生人口(万) 中国人均GPA(美元计) 中国性别比例(按照女生=100) 自然增长率(%) \
1970 -1.785855 2.118316 -0.978375 0.191621 2.796566
1971 -1.611715 1.512902 -0.955117 0.101799 2.182756
1972 -1.460072 1.509094 -0.894647 0.056887 1.910754
1973 -1.301734 1.116907 -0.773707 0.146710 1.602645
1974 -1.173445 0.275419 -0.759752 0.169165 0.779418
1975 -1.051928 -0.196728 -0.676024 0.348810 0.346141
1976 -0.950963 -1.160061 -0.736494 0.472316 -0.388024
1977 -0.833650 -1.411365 -0.643464 0.494771 -0.532450
1978 -0.725834 -1.601747 -0.778358 0.483543 -0.561335
1979 -0.618173 -1.670284 -0.652767 0.303899 -0.655212
1980 -0.527639 -1.438018 -0.601600 0.281443 -0.592627
1981 -0.421224 -0.341418 -0.587645 0.427404 0.052475
1982 -0.298073 0.290650 -0.559736 0.517227 0.324477
1983 -0.192670 -0.387110 -0.457402 0.988794 -0.250820
1984 -0.087657 -0.394725 -0.341113 0.988794 -0.301369
1985 0.028645 0.161190 -0.136445 1.471589 -0.017331
1986 0.157557 0.838950 -0.196915 1.471589 0.297999
1987 0.297134 1.349173 -0.336462 0.517227 0.548337
1988 0.431495 1.109292 -0.187612 0.607049 0.336512
1989 0.562120 0.922718 -0.062021 0.753010 0.170422
1990 0.688930 0.838950 -0.029460 0.607049 0.013961
1991 0.804920 0.366803 0.044965 -0.235035 -0.325440
1992 0.909855 -0.154844 0.198466 -1.638508 -0.681690
1993 1.014635 -0.128191 0.249633 -1.739558 -0.693725
1994 1.118403 -0.211959 0.696181 -1.369042 -0.751496
1995 1.217345 -0.387110 1.328791 -1.705875 -0.910364
1996 1.316053 -0.368072 1.793945 -2.682693 -0.941656
1997 1.412348 -0.478493 2.128857 -1.537458 -1.028312
1998 1.500702 -0.836411 2.347479 -0.672919 -1.249765
1999 1.580494 -1.243829 2.556799 0.180393 -1.480846
城镇人口(城镇+乡村=100) 乡村人口 美元兑换人民币汇率 中国就业人口(万人)
1970 -1.118440 1.124925 -0.586832 -1.335278
1971 -1.140859 1.147199 -0.664354 -1.250396
1972 -1.165148 1.171330 -0.675195 -1.224339
1973 -1.152069 1.158336 -0.762839 -1.162081
1974 -1.159543 1.165761 -0.834781 -1.106143
1975 -1.125913 1.132349 -0.784322 -1.043808
1976 -1.107230 1.111931 -0.818599 -0.991848
1977 -1.086678 1.093369 -0.878503 -0.949485
1978 -1.017550 1.024689 -0.939444 -0.889022
1979 -0.823245 0.831644 -0.971688 -0.820992
1980 -0.742907 0.751827 -0.958097 -0.716683
1981 -0.599046 0.608898 -0.888428 -0.610268
1982 -0.417819 0.428846 -0.813696 -0.487782
1983 -0.344954 0.337892 -0.780576 -0.398765
1984 -0.066575 0.079879 -0.640559 -0.261377
1985 0.064208 -0.050056 -0.397553 -0.130621
1986 0.215542 -0.200409 -0.191853 -0.020695
1987 0.365008 -0.348905 -0.084519 0.096408
1988 0.456555 -0.439860 -0.084519 0.217412
1989 0.344456 -0.514108 -0.067061 0.295039
1990 0.381823 -0.551232 0.338640 1.029957
1991 0.667676 -0.649611 0.553427 1.087846
1992 0.764829 -0.746134 0.630032 1.139415
1993 0.863850 -0.844513 0.728478 1.190594
1994 0.961002 -0.941036 1.867103 1.241071
1995 1.060023 -1.041271 1.760287 1.288661
1996 1.329062 -1.306709 1.745739 1.357706
1997 1.596232 -1.572147 1.736014 1.425580
1998 1.865270 -1.858003 1.731750 1.481518
1999 2.132440 -2.104879 1.731949 1.548379
4. 构建网络模型
LSTNet是2018年提出的时序预测模型, 它同时利用卷积层和循环层的优势, 提取时间序列多变量之间的局部依赖模式和捕获复杂的长期依赖。
我们将训练最大轮数设置为2000,模型输入时间序列长度为5,输出序列长度为5。其余参数说明如下:
-
in_chunk_len (int) – 模型输入的时间序列长度.
-
out_chunk_len (int) – 模型输出的时间序列长度.
-
skip_chunk_len (int) – 可选变量, 输入序列与输出序列之间跳过的序列长度, 既不作为特征也不作为序测目标使用, 默认值为0
-
sampling_stride (int) – 相邻样本间的采样间隔.
-
loss_fn (Callable[…, paddle.Tensor]|None) – 损失函数.
-
optimizer_fn (Callable[…, Optimizer]) – 优化算法.
-
optimizer_params (Dict[str, Any]) – 优化器参数.
-
eval_metrics (List[str]) – 模型训练过程中的需要观测的评估指标.
-
callbacks (List[Callback]) – 自定义callback函数.
-
batch_size (int) – 训练数据或评估数据的批大小.
-
max_epochs (int) – 训练的最大轮数.
-
verbose (int) – 模型训练过程中打印日志信息的间隔.
-
patience (int) – 模型训练过程中, 当评估指标超过一定轮数不再变优,模型提前停止训练.
-
seed (int|None) – 全局随机数种子, 注: 保证每次模型参数初始化一致.
-
skip_size (int) – 递归跳跃组件(Skip RNN)用来捕获时间序列中的周期性所需的周期长度.
-
channels (int) – 第一层Conv1D的通道数量.
-
kernel_size (int) – 第一层Conv1D的卷积核大小.
-
rnn_cell_type (str) – RNN cell的类型, 支持GRU或LSTM.
-
rnn_num_cells (int) – RNN层中神经元的数量.
-
skip_rnn_cell_type (str) – Skip RNN cell的类型, 支持GRU或LSTM.
-
skip_rnn_num_cells (int) – Skip RNN层中神经元的数量.
-
dropout_rate (float) – 神经元丢弃概率.
-
output_activation (str|None) – 输出层的激活函数类型, 可以是None(无激活函数), sigmoid, tanh.
lstm = LSTNetRegressor(
in_chunk_len = 5,
out_chunk_len = 5,
max_epochs=2000
)
5. 模型训练
使用归一化后的训练数据与验证数据传入模型,进行模型训练。
lstm.fit(train_dataset_scaled, val_dataset_scaled)
[2022-11-30 16:10:43,402] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 000| loss: 1.335296| val_0_mae: 2.026456| 0:00:00s
[2022-11-30 16:10:43,411] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 001| loss: 1.431604| val_0_mae: 2.010468| 0:00:00s
[2022-11-30 16:10:43,419] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 002| loss: 1.398407| val_0_mae: 1.994482| 0:00:00s
[2022-11-30 16:10:43,428] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 003| loss: 1.504742| val_0_mae: 1.978191| 0:00:00s
[2022-11-30 16:10:43,436] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 004| loss: 1.397552| val_0_mae: 1.961923| 0:00:00s
[2022-11-30 16:10:43,445] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 005| loss: 1.433584| val_0_mae: 1.945455| 0:00:00s
[2022-11-30 16:10:43,453] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 006| loss: 1.365905| val_0_mae: 1.929153| 0:00:00s
[2022-11-30 16:10:43,461] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 007| loss: 1.234721| val_0_mae: 1.913647| 0:00:00s
[2022-11-30 16:10:43,469] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 008| loss: 1.225891| val_0_mae: 1.898627| 0:00:00s
[2022-11-30 16:10:43,478] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 009| loss: 1.268852| val_0_mae: 1.883862| 0:00:00s
[2022-11-30 16:10:43,486] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 010| loss: 1.266971| val_0_mae: 1.869215| 0:00:00s
[2022-11-30 16:10:43,494] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 011| loss: 1.260334| val_0_mae: 1.854911| 0:00:00s
[2022-11-30 16:10:43,502] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 012| loss: 1.268996| val_0_mae: 1.840751| 0:00:00s
[2022-11-30 16:10:43,511] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 013| loss: 1.205134| val_0_mae: 1.826880| 0:00:00s
[2022-11-30 16:10:43,519] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 014| loss: 1.238074| val_0_mae: 1.812879| 0:00:00s
[2022-11-30 16:10:43,527] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 015| loss: 1.247744| val_0_mae: 1.798924| 0:00:00s
[2022-11-30 16:10:43,537] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 016| loss: 1.228621| val_0_mae: 1.785040| 0:00:00s
[2022-11-30 16:10:43,546] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 017| loss: 1.187945| val_0_mae: 1.771226| 0:00:00s
[2022-11-30 16:10:43,554] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 018| loss: 1.175421| val_0_mae: 1.757583| 0:00:00s
[2022-11-30 16:10:43,562] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 019| loss: 1.145478| val_0_mae: 1.744175| 0:00:00s
[2022-11-30 16:10:43,571] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 020| loss: 1.162446| val_0_mae: 1.730875| 0:00:00s
[2022-11-30 16:10:43,579] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 021| loss: 1.209757| val_0_mae: 1.717458| 0:00:00s
[2022-11-30 16:10:43,587] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 022| loss: 1.158366| val_0_mae: 1.704297| 0:00:00s
[2022-11-30 16:10:43,596] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 023| loss: 1.198785| val_0_mae: 1.691088| 0:00:00s
[2022-11-30 16:10:43,604] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 024| loss: 1.077143| val_0_mae: 1.678203| 0:00:00s
[2022-11-30 16:10:43,612] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 025| loss: 1.109716| val_0_mae: 1.665533| 0:00:00s
[2022-11-30 16:10:43,620] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 026| loss: 1.094385| val_0_mae: 1.653193| 0:00:00s
[2022-11-30 16:10:43,628] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 027| loss: 1.080788| val_0_mae: 1.641071| 0:00:00s
[2022-11-30 16:10:43,639] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 028| loss: 1.083594| val_0_mae: 1.629065| 0:00:00s
[2022-11-30 16:10:43,647] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 029| loss: 1.116938| val_0_mae: 1.616999| 0:00:00s
[2022-11-30 16:10:43,655] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 030| loss: 1.128399| val_0_mae: 1.604847| 0:00:00s
[2022-11-30 16:10:43,665] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 031| loss: 1.057482| val_0_mae: 1.592970| 0:00:00s
[2022-11-30 16:10:43,673] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 032| loss: 0.988231| val_0_mae: 1.581546| 0:00:00s
[2022-11-30 16:10:43,681] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 033| loss: 1.047795| val_0_mae: 1.573081| 0:00:00s
[2022-11-30 16:10:43,689] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 034| loss: 1.000459| val_0_mae: 1.564725| 0:00:00s
[2022-11-30 16:10:43,697] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 035| loss: 0.981792| val_0_mae: 1.556305| 0:00:00s
[2022-11-30 16:10:43,705] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 036| loss: 1.006375| val_0_mae: 1.547980| 0:00:00s
[2022-11-30 16:10:43,714] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 037| loss: 1.042193| val_0_mae: 1.539739| 0:00:00s
[2022-11-30 16:10:43,722] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 038| loss: 1.035661| val_0_mae: 1.531963| 0:00:00s
[2022-11-30 16:10:43,730] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 039| loss: 0.991913| val_0_mae: 1.524082| 0:00:00s
[2022-11-30 16:10:43,738] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 040| loss: 1.005429| val_0_mae: 1.516243| 0:00:00s
[2022-11-30 16:10:43,746] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 041| loss: 0.994254| val_0_mae: 1.508542| 0:00:00s
[2022-11-30 16:10:43,754] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 042| loss: 0.898892| val_0_mae: 1.500784| 0:00:00s
[2022-11-30 16:10:43,762] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 043| loss: 0.959420| val_0_mae: 1.493016| 0:00:00s
[2022-11-30 16:10:43,770] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 044| loss: 0.922088| val_0_mae: 1.485151| 0:00:00s
[2022-11-30 16:10:43,778] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 045| loss: 0.922496| val_0_mae: 1.477284| 0:00:00s
[2022-11-30 16:10:43,786] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 046| loss: 0.899761| val_0_mae: 1.469546| 0:00:00s
[2022-11-30 16:10:43,794] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 047| loss: 0.887426| val_0_mae: 1.461833| 0:00:00s
[2022-11-30 16:10:43,802] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 048| loss: 0.933051| val_0_mae: 1.454159| 0:00:00s
[2022-11-30 16:10:43,811] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 049| loss: 0.909708| val_0_mae: 1.446734| 0:00:00s
[2022-11-30 16:10:43,819] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 050| loss: 0.858118| val_0_mae: 1.439384| 0:00:00s
[2022-11-30 16:10:43,827] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 051| loss: 0.896318| val_0_mae: 1.432108| 0:00:00s
[2022-11-30 16:10:43,835] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 052| loss: 0.873903| val_0_mae: 1.425059| 0:00:00s
[2022-11-30 16:10:43,843] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 053| loss: 0.904945| val_0_mae: 1.417843| 0:00:00s
[2022-11-30 16:10:43,852] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 054| loss: 0.877615| val_0_mae: 1.410705| 0:00:00s
[2022-11-30 16:10:43,860] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 055| loss: 0.833624| val_0_mae: 1.403709| 0:00:00s
[2022-11-30 16:10:43,868] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 056| loss: 0.854565| val_0_mae: 1.396727| 0:00:00s
[2022-11-30 16:10:43,876] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 057| loss: 0.840873| val_0_mae: 1.389806| 0:00:00s
[2022-11-30 16:10:43,888] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 058| loss: 0.794746| val_0_mae: 1.382859| 0:00:00s
[2022-11-30 16:10:43,896] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 059| loss: 0.812057| val_0_mae: 1.375864| 0:00:00s
[2022-11-30 16:10:43,905] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 060| loss: 0.771674| val_0_mae: 1.368605| 0:00:00s
[2022-11-30 16:10:43,913] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 061| loss: 0.775358| val_0_mae: 1.361753| 0:00:00s
[2022-11-30 16:10:43,921] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 062| loss: 0.785709| val_0_mae: 1.354705| 0:00:00s
[2022-11-30 16:10:43,929] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 063| loss: 0.761827| val_0_mae: 1.347852| 0:00:00s
[2022-11-30 16:10:43,937] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 064| loss: 0.809033| val_0_mae: 1.341352| 0:00:00s
[2022-11-30 16:10:43,945] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 065| loss: 0.780791| val_0_mae: 1.334946| 0:00:00s
[2022-11-30 16:10:43,954] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 066| loss: 0.731239| val_0_mae: 1.328409| 0:00:00s
[2022-11-30 16:10:43,962] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 067| loss: 0.730782| val_0_mae: 1.321399| 0:00:00s
[2022-11-30 16:10:43,970] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 068| loss: 0.809380| val_0_mae: 1.314795| 0:00:00s
[2022-11-30 16:10:43,978] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 069| loss: 0.742965| val_0_mae: 1.308291| 0:00:00s
[2022-11-30 16:10:43,986] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 070| loss: 0.727096| val_0_mae: 1.301758| 0:00:00s
[2022-11-30 16:10:43,994] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 071| loss: 0.697883| val_0_mae: 1.295217| 0:00:00s
[2022-11-30 16:10:44,002] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 072| loss: 0.713083| val_0_mae: 1.288782| 0:00:00s
[2022-11-30 16:10:44,010] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 073| loss: 0.709170| val_0_mae: 1.282597| 0:00:00s
[2022-11-30 16:10:44,019] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 074| loss: 0.696470| val_0_mae: 1.276140| 0:00:00s
[2022-11-30 16:10:44,027] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 075| loss: 0.694704| val_0_mae: 1.269327| 0:00:00s
[2022-11-30 16:10:44,035] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 076| loss: 0.667143| val_0_mae: 1.262488| 0:00:00s
[2022-11-30 16:10:44,043] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 077| loss: 0.712231| val_0_mae: 1.256098| 0:00:00s
[2022-11-30 16:10:44,051] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 078| loss: 0.631457| val_0_mae: 1.249475| 0:00:00s
[2022-11-30 16:10:44,060] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 079| loss: 0.668609| val_0_mae: 1.242805| 0:00:00s
[2022-11-30 16:10:44,069] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 080| loss: 0.668821| val_0_mae: 1.236076| 0:00:00s
[2022-11-30 16:10:44,077] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 081| loss: 0.711819| val_0_mae: 1.229457| 0:00:00s
[2022-11-30 16:10:44,086] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 082| loss: 0.694410| val_0_mae: 1.222744| 0:00:00s
[2022-11-30 16:10:44,094] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 083| loss: 0.665902| val_0_mae: 1.216351| 0:00:00s
[2022-11-30 16:10:44,103] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 084| loss: 0.635740| val_0_mae: 1.209996| 0:00:00s
[2022-11-30 16:10:44,111] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 085| loss: 0.614305| val_0_mae: 1.204036| 0:00:00s
[2022-11-30 16:10:44,120] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 086| loss: 0.620907| val_0_mae: 1.198223| 0:00:00s
[2022-11-30 16:10:44,133] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 087| loss: 0.561215| val_0_mae: 1.192517| 0:00:00s
[2022-11-30 16:10:44,142] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 088| loss: 0.563731| val_0_mae: 1.186862| 0:00:00s
[2022-11-30 16:10:44,150] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 089| loss: 0.661129| val_0_mae: 1.182760| 0:00:00s
[2022-11-30 16:10:44,158] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 090| loss: 0.616183| val_0_mae: 1.178985| 0:00:00s
[2022-11-30 16:10:44,167] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 091| loss: 0.574570| val_0_mae: 1.175775| 0:00:00s
[2022-11-30 16:10:44,175] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 092| loss: 0.602581| val_0_mae: 1.173194| 0:00:00s
[2022-11-30 16:10:44,183] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 093| loss: 0.585698| val_0_mae: 1.170198| 0:00:00s
[2022-11-30 16:10:44,191] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 094| loss: 0.629226| val_0_mae: 1.167235| 0:00:00s
[2022-11-30 16:10:44,200] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 095| loss: 0.555163| val_0_mae: 1.164508| 0:00:00s
[2022-11-30 16:10:44,208] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 096| loss: 0.573185| val_0_mae: 1.161562| 0:00:00s
[2022-11-30 16:10:44,216] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 097| loss: 0.562863| val_0_mae: 1.158835| 0:00:00s
[2022-11-30 16:10:44,224] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 098| loss: 0.574777| val_0_mae: 1.156082| 0:00:00s
[2022-11-30 16:10:44,232] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 099| loss: 0.555662| val_0_mae: 1.153205| 0:00:00s
[2022-11-30 16:10:44,241] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 100| loss: 0.601033| val_0_mae: 1.150310| 0:00:00s
[2022-11-30 16:10:44,249] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 101| loss: 0.580240| val_0_mae: 1.147768| 0:00:00s
[2022-11-30 16:10:44,261] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 102| loss: 0.591648| val_0_mae: 1.145656| 0:00:00s
[2022-11-30 16:10:44,270] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 103| loss: 0.542775| val_0_mae: 1.144088| 0:00:00s
[2022-11-30 16:10:44,278] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 104| loss: 0.539643| val_0_mae: 1.142433| 0:00:00s
[2022-11-30 16:10:44,286] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 105| loss: 0.515184| val_0_mae: 1.141113| 0:00:00s
[2022-11-30 16:10:44,294] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 106| loss: 0.570678| val_0_mae: 1.140380| 0:00:00s
[2022-11-30 16:10:44,302] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 107| loss: 0.559144| val_0_mae: 1.139631| 0:00:00s
[2022-11-30 16:10:44,311] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 108| loss: 0.527249| val_0_mae: 1.139258| 0:00:00s
[2022-11-30 16:10:44,319] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 109| loss: 0.599206| val_0_mae: 1.138611| 0:00:00s
[2022-11-30 16:10:44,327] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 110| loss: 0.535672| val_0_mae: 1.137478| 0:00:00s
[2022-11-30 16:10:44,336] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 111| loss: 0.541393| val_0_mae: 1.136143| 0:00:00s
[2022-11-30 16:10:44,344] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 112| loss: 0.488889| val_0_mae: 1.134465| 0:00:00s
[2022-11-30 16:10:44,353] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 113| loss: 0.517146| val_0_mae: 1.132642| 0:00:00s
[2022-11-30 16:10:44,361] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 114| loss: 0.524021| val_0_mae: 1.130486| 0:00:00s
[2022-11-30 16:10:44,369] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 115| loss: 0.499090| val_0_mae: 1.127806| 0:00:00s
[2022-11-30 16:10:44,378] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 116| loss: 0.449979| val_0_mae: 1.125503| 0:00:00s
[2022-11-30 16:10:44,389] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 117| loss: 0.488240| val_0_mae: 1.123149| 0:00:01s
[2022-11-30 16:10:44,398] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 118| loss: 0.479148| val_0_mae: 1.120922| 0:00:01s
[2022-11-30 16:10:44,406] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 119| loss: 0.444483| val_0_mae: 1.118181| 0:00:01s
[2022-11-30 16:10:44,414] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 120| loss: 0.505287| val_0_mae: 1.115203| 0:00:01s
[2022-11-30 16:10:44,423] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 121| loss: 0.444154| val_0_mae: 1.112558| 0:00:01s
[2022-11-30 16:10:44,431] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 122| loss: 0.453889| val_0_mae: 1.109772| 0:00:01s
[2022-11-30 16:10:44,439] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 123| loss: 0.427750| val_0_mae: 1.106739| 0:00:01s
[2022-11-30 16:10:44,447] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 124| loss: 0.451959| val_0_mae: 1.103423| 0:00:01s
[2022-11-30 16:10:44,455] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 125| loss: 0.560345| val_0_mae: 1.101498| 0:00:01s
[2022-11-30 16:10:44,463] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 126| loss: 0.447812| val_0_mae: 1.099758| 0:00:01s
[2022-11-30 16:10:44,471] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 127| loss: 0.502624| val_0_mae: 1.098869| 0:00:01s
[2022-11-30 16:10:44,479] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 128| loss: 0.419782| val_0_mae: 1.098326| 0:00:01s
[2022-11-30 16:10:44,488] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 129| loss: 0.433613| val_0_mae: 1.098895| 0:00:01s
[2022-11-30 16:10:44,495] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 130| loss: 0.425311| val_0_mae: 1.099732| 0:00:01s
[2022-11-30 16:10:44,502] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 131| loss: 0.447280| val_0_mae: 1.100813| 0:00:01s
[2022-11-30 16:10:44,510] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 132| loss: 0.424777| val_0_mae: 1.102612| 0:00:01s
[2022-11-30 16:10:44,519] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 133| loss: 0.437228| val_0_mae: 1.105103| 0:00:01s
[2022-11-30 16:10:44,527] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 134| loss: 0.430760| val_0_mae: 1.107998| 0:00:01s
[2022-11-30 16:10:44,534] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 135| loss: 0.515581| val_0_mae: 1.111710| 0:00:01s
[2022-11-30 16:10:44,542] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 136| loss: 0.430533| val_0_mae: 1.115131| 0:00:01s
[2022-11-30 16:10:44,549] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 137| loss: 0.463041| val_0_mae: 1.117901| 0:00:01s
[2022-11-30 16:10:44,557] [paddlets.models.common.callbacks.callbacks] [INFO] epoch 138| loss: 0.399685| val_0_mae: 1.120615| 0:00:01s
[2022-11-30 16:10:44,558] [paddlets.models.common.callbacks.callbacks] [INFO]
Early stopping occurred at epoch 138 with best_epoch = 128 and best_val_0_mae = 1.098326
[2022-11-30 16:10:44,559] [paddlets.models.common.callbacks.callbacks] [INFO] Best weights from best epoch are automatically used!
6. 模型的预测
对归一化后的验证集数据进行预测,数据集切分后的形状先是验证集,再是测试集
- 给训练后的LSTM网络模型传入验证集数据,然后会预测出后面一个序列的数据(根据现有序列推算新序列)
- 而对于验证集来讲,验证集后面的序列在测试集中,预测出的新序列就可以和测试集的真实数据做一个对比。
- 想获得测试集与预测出的数据重合部分,需要用到数组切分
val_dataset_scaled
总人口(万人) 中国人均GPA(美元计) 中国就业人口(万人) 中国性别比例(按照女生=100) 乡村人口 \
2000 1.654992 2.956831 1.602288 1.134755 -2.372173
2001 1.723807 3.394077 1.657836 0.303899 -2.639467
2002 1.788108 3.835973 1.695518 0.371266 -2.904905
2003 1.848360 4.487189 1.731094 0.528455 -3.172199
2004 1.907601 5.510529 1.772287 0.629505 -3.400513
2005 1.967386 6.650157 1.794366 0.640732 -3.628827
2006 2.021255 8.259592 1.827991 0.629505 -3.897977
2007 2.074268 11.022609 1.854751 0.517227 -4.167128
2008 2.126658 14.627556 1.873709 0.382493 -4.371310
2009 2.177102 16.320717 1.894305 0.225304 -4.621899
出生人口(万) 城镇人口(城镇+乡村=100) 美元兑换人民币汇率 自然增长率(%)
2000 -1.479902 2.401478 1.731471 -1.625272
2001 -1.742629 2.670516 1.730913 -1.776919
2002 -1.952049 2.937686 1.730913 -1.897274
2003 -2.131009 3.206725 1.731072 -2.003186
2004 -2.153854 3.436528 1.731311 -2.036886
2005 -2.062471 3.666332 1.660884 -2.032072
2006 -2.180508 3.918555 1.544264 -2.178905
2007 -2.142431 4.208145 1.376268 -2.205383
2008 -2.092932 4.413661 1.162158 -2.227046
2009 -2.157662 4.665884 1.146215 -2.277596
subset_test_pred_dataset = lstm.predict(val_dataset_scaled)
subset_test_dataset, _ = test_dataset_scaled.split(len(subset_test_pred_dataset.target))
subset_test_dataset, _ = test_dataset_scaled.split(len(subset_test_pred_dataset.target))
subset_test_dataset.plot(add_data=subset_test_pred_dataset, labels=['Pred'])
<AxesSubplot:>
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/IPython/core/events.py:89: UserWarning: Glyph 8722 (\N{MINUS SIGN}) missing from current font.
func(*args, **kwargs)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/IPython/core/pylabtools.py:151: UserWarning: Glyph 8722 (\N{MINUS SIGN}) missing from current font.
fig.canvas.print_figure(bytes_io, **kw)
7. 模型评估
使用平均绝对误差MAE(Mean Absolute Error)指标对预测结果进行评估,计算公式如下:
mae = MAE()
mae(subset_test_dataset, subset_test_pred_dataset)
{'总人口(万人)': 1.3227954641250552}
8. 预测结果反归一化
因为传入模型进行训练的数据都是经过标准化/归一化处理后的数据,因此对于经过模型预测出的结果,也是标准化后的数据,需要我们进行反归一化处理,从而查看原始大小的预测数据。
- 对预测出的数据进行反归一化
- 对测试集中真实数据标签进行反归一化
- 对二者反归一化的结果进行可视化
8.1 反归一化
subset_test_pred_dataset_new = scaler.inverse_transform(subset_test_pred_dataset)
subset_test_dataset_new = scaler.inverse_transform(subset_test_dataset)
8.2 可视化结果对比
subset_test_dataset, _ = test_dataset_scaled.split(len(subset_test_pred_dataset.target))
caled.split(len(subset_test_pred_dataset.target))
subset_test_dataset_new.plot(add_data=subset_test_pred_dataset_new, labels=['Pred'])
<AxesSubplot:>
9. 总结 🌟
- 本项目使用PaddleTS中LSTNet模型进行了完整预测流程,训练后的模型达到了较为不错的效果。
- 由于本数据集样本量有限,今后可以选择在样本更充分的数据集进行实验。
- 今后可以尝试手搭网络实现LSTM神经网络预测
再一次感谢项目导师
顾茜
的指导
本项目由
北京科技大学飞桨领航团
的3名成员共同完成
请点击此处查看本环境基本用法.
Please click here for more detailed instructions.
此文章为搬运
原项目链接
更多推荐
所有评论(0)