【AI 达人特训营】基于飞桨实现系统认证风险预测-异常检测
本项目依据用户登录某一网站的一些基本信息,利用Paddle搭建一个神经网络模型,来预测出此次登录是否存在风险。
系统认证风险预测-异常检测
本项目依据用户登录某一网站的一些基本信息,利用Paddle搭建一个神经网络模型,来预测出此次登录是否存在风险。
一、赛题背景
随着国家、企业对安全和效率越来越重视,作为安全基础设施之一的统一身份管理(IAM,Identity and Access Management)也得到越来越多的关注。 在IAM领域中,其主要安全防护手段是身份鉴别,既我们常见的口令验证、扫码验证、指纹验证等。它们一般分为三类,既用户所知(如口令)、所有(如身份证)、特征(如人脸)。但这些身份鉴别方式都有其各自的优缺点,比如口令,强度高了记不住,强度低了容易丢,又比如人脸,摇头晃脑做活体检测体验不好,静默检测如果算法不够,又很容易被照片、视频、人脸模型绕过。也因此,在等保2.0中对于三级以上系统要求必须使用两种以上的身份鉴别方式进行身份验证,来提高身份鉴别的可信度,这也被成为双因素认证。
但这对用户来说虽然在一定程度提高了安全性,但极大的降低了用户体验。也因此,IAM厂商开始参考UEBA、用户画像等行为分析技术,来探索一种既能确保用户体验,又能提高身份鉴别强度的方法。而在当前IAM的探索进程当中,目前最为可具落地的方法,是基于规则的行为分析技术,因为它的可理解性很高,且很容易与身份鉴别技术进行联动。
但基于规则但局限性也很明显,它是基于经验的,有宁错杀一千,不放过一个的特点,缺少从数据层面来证明是否有人正在尝试窃取或验证非法获取的身份信息,又或者正在使用窃取的身份信息,以此来提前进行风险预警和处置。
更多内容请前往比赛官网查看:系统认证风险预测-异常检测
比赛任务
本赛题中,参赛团队需要基于用户认证行为数据及风险异常标记结构,构建用户认证行为特征模型和风险异常评估模型,利用风险评估模型去判断当前用户认证行为是否存在风险。
- 利用用户认证数据构建行为基线
- 采用监督学习模型,基于用户认证行为特征,构建风险异常评估模型,判断当前用户认证行为是否存在风险
二、数据分析及处理
比赛数据是从竹云的风险分析产品日志库中摘录而来,主要涉及认证日志与风险日志数据。比赛数据经过数据脱敏和数据筛选等安全处理操作,供大家使用。其中认证日志是用户在访问应用系统产生的行为数据,包括登录、单点登录、退出等行为。
该比赛使用的数据已上传至AI Studio:https://aistudio.baidu.com/aistudio/datasetdetail/112151
数据概况
import csv
import numpy as np
import pandas as pd
train_dataset = r'data/data112151/train_dataset.csv'
test_dataset = r'data/data112151/test_dataset.csv'
with open(train_dataset, encoding = 'utf-8') as trainset:
traindatas = np.loadtxt(trainset, str, delimiter = "\t", skiprows = 1)
with open(test_dataset, encoding = 'utf-8') as testset:
testdatas = np.loadtxt(testset, str, delimiter = "\t", skiprows = 1)
print("在训练集中,共有{}条数据,其中每条数据有{}个特征".format(traindatas.shape[0], traindatas.shape[1]))
print("在测试集中,共有{}条数据,其中每条数据有{}个特征".format(testdatas.shape[0], testdatas.shape[1]))
traindata2 = pd.DataFrame(traindatas)
traindata2.columns = ['认证ID','认证时间','用户名','操作类型','首次认证方式','IP地址','IP类型','IP威胁级别','地点','客户端类型','浏览器来源',
'浏览器类型','浏览器版本号','操作系统类型','操作系统版本号','设备型号','应用系统编码','应用系统类目','风险标识']
testdata2 = pd.DataFrame(testdatas)
testdata2.columns = ['认证ID','认证时间','用户名','操作类型','首次认证方式','IP地址','IP类型','IP威胁级别','地点','客户端类型','浏览器来源',
'浏览器类型','浏览器版本号','操作系统类型','操作系统版本号','设备型号','应用系统编码','应用系统类目']
traindata2.head()
在训练集中,共有15016条数据,其中每条数据有19个特征
在测试集中,共有10000条数据,其中每条数据有18个特征
认证ID | 认证时间 | 用户名 | 操作类型 | 首次认证方式 | IP地址 | IP类型 | IP威胁级别 | 地点 | 客户端类型 | 浏览器来源 | 浏览器类型 | 浏览器版本号 | 操作系统类型 | 操作系统版本号 | 设备型号 | 应用系统编码 | 应用系统类目 | 风险标识 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | access:test_d:20180101111639:bBp1 | 2018/1/1 11:16 | test_d | login | otp | 192.168.100.101 | 内网 | 1级 | {"first_lvl":"成都分公司","sec_lvl":"9楼","third_lvl... | web | desktop | think_pad_e460 | windows | windows 10 | chrome | chrome 90 | coremail | management | 0 |
1 | access:test_d:20180101121524:OBSg | 2018/1/1 12:15 | test_d | login | qr | 192.168.100.101 | 内网 | 1级 | {"first_lvl":"成都分公司","sec_lvl":"9楼","third_lvl... | web | desktop | think_pad_e460 | windows | windows 10 | edge | edge 93 | order-mgnt | sales | 0 |
2 | access:test_d:20180101151333:BpQN | 2018/1/1 15:13 | test_d | login | qr | 192.168.100.101 | 内网 | 1级 | {"first_lvl":"成都分公司","sec_lvl":"9楼","third_lvl... | web | desktop | think_pad_e460 | windows | windows 10 | chrome | chrome 90 | order-mgnt | sales | 0 |
3 | access:test_d:20180101124502:hYQm | 2018/1/1 12:45 | test_d | sso | 192.168.100.101 | 内网 | 1级 | {"first_lvl":"成都分公司","sec_lvl":"9楼","third_lvl... | web | desktop | think_pad_e460 | windows | windows 10 | edge | edge 93 | oa | management | 0 | |
4 | access:test_d:20180101202749:FkDK | 2018/1/1 20:27 | test_d | sso | 192.168.100.101 | 内网 | 1级 | {"first_lvl":"成都分公司","sec_lvl":"9楼","third_lvl... | web | desktop | think_pad_e460 | windows | windows 10 | edge | edge 93 | order-mgnt | sales | 0 |
数据预处理
比赛提供的部分数据结构如下所示:
序号 | 字段名_中文 | 字段名_英文 | 含义 | 示例 |
---|---|---|---|---|
3 | 操作类型 | action | login还是sso | 1 : login 2 : sso |
4 | 首次认证方式 | auth_type | 账密、短信、otp、二维码 | 1 : pwd 2 : sms 3 : otp 4 : qr |
6 | IP类型 | ip_location_type_keyword | 内网、家庭宽带、公共宽带、代理ip | 1 : 家庭宽带 2 : 代理ip 3 : 内网 4 : 公共宽带 |
7 | IP威胁级别 | ip_risk_level | 1级、2级、3级 | 1 : 1级 2 : 2级 3 : 3级 |
9 | 客户端类型 | client_type | app、web | 1 : app 2 : web |
10 | 浏览器来源 | browser_source | 桌面端、移动端 | 1 : desktop 2 : mobile |
17 | 应用系统类目 | op_target | 系统所属的系统类别 | 1 : sales 2 : finance 3 : management 4 : hr |
18 | 风险标识 | risk_label | 有、无 | 1 : 有;0 : 无 |
1 | 认证时间 | op_date | 认证时间 | 2018-1-1 11:16:00 |
2 | 用户名 | user_name | 用户名 | test_a |
5 | IP地址 | ip | 用户认证ip地址 | 192.168.0.100 |
8 | 地点 | location | 无 | 无 |
11 | 浏览器类型 | browser_type | 浏览器类型 | chrome |
12 | 浏览器版本号 | browser_version | 浏览器版本号 | chrome 90 |
13 | 操作系统类型 | os_type | 操作系统类型 | windows |
14 | 操作系统版本号 | os_version | 操作系统版本号 | windows 10 |
15 | 设备型号 | device_model | 访问应用时所使用设备 | think_pad_e460 |
16 | 应用系统编码 | bus_system_code | 所访问的应用系统编码 | coremail |
#合并数据
data = pd.concat([traindata2,testdata2])
print(data.shape)
(25016, 19)
# 删除无用特征
data.drop(['客户端类型','浏览器来源'], axis=1, inplace=True)
#特征集成
data['首次认证方式'].fillna('_', inplace=True)
data['登录类型'] = data['操作类型'] + data['首次认证方式']
data['登录类型'].value_counts()
sso 12600
loginpwd 3150
loginsms 3119
loginqr 3083
loginotp 3064
Name: 登录类型, dtype: int64
data['设备'] = data['设备型号'] + data['操作系统类型'] + data['操作系统版本号'] + data['浏览器类型'] + data['浏览器版本号']
data['设备'].value_counts()
chrome 90windows 10chromethink_pad_e460windows 9894
edge 93windows 10edgethink_pad_e460windows 9878
safari 13macOS Big Sur 11safarimacbookmacOS 3126
firefox 78windows 11firefoxthink_pad_t480windows 511
chrome 93windows 11chromethink_pad_t480windows 498
ie 11windows 11iethink_pad_t480windows 473
ie 9windows 7iethink_pad_l470windows 333
chrome 77windows 7chromethink_pad_l470windows 303
Name: 设备, dtype: int64
data['应用系统相关'] = data['应用系统编码'] + data['应用系统类目']
data['应用系统相关'].value_counts()
order-mgntsales 8015
crmsales 4957
coremailmanagement 4486
oamanagement 2988
reimbursementfinance 2300
attendancehr 1811
salaryhr 459
Name: 应用系统相关, dtype: int64
#处理地点列
import json
data['第一地点'] = data['地点'].astype(str).apply(lambda x: json.loads(x)['first_lvl'])
data['第二地点'] = data['地点'].astype(str).apply(lambda x: json.loads(x)['sec_lvl'])
data['第三地点'] = data['地点'].astype(str).apply(lambda x: json.loads(x)['third_lvl'])
pip uninstall tqdm -y
Found existing installation: tqdm 4.27.0
Uninstalling tqdm-4.27.0:
Successfully uninstalled tqdm-4.27.0
Note: you may need to restart the kernel to use updated packages.
pip install tqdm
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting tqdm
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/8a/c4/d15f1e627fff25443ded77ea70a7b5532d6371498f9285d44d62587e209c/tqdm-4.64.0-py2.py3-none-any.whl (78 kB)
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.4/78.4 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tqdm
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
paddlefsl 1.0.0 requires tqdm~=4.27.0, but you have tqdm 4.64.0 which is incompatible.[0m[31m
[0mSuccessfully installed tqdm-4.64.0
Note: you may need to restart the kernel to use updated packages.
#特征编码
from tqdm import tqdm
from sklearn.preprocessing import LabelEncoder
features_encoder = [i for i in data.columns if i not in['认证ID','认证时间','风险标识']]
for col in tqdm(features_encoder):
lbl = LabelEncoder()
data[col] = lbl.fit_transform(data[col])
100%|██████████| 20/20 [00:00<00:00, 1011.76it/s]
# 处理时间
data['认证时间'] = pd.to_datetime(data['认证时间'])
data['认证时间转换'] = data['认证时间'].values.astype(np.int64) // 10 ** 9
#滞后历史特征
data = data.sort_values(by=['用户名', '认证时间转换']).reset_index(drop=True)
data['滞后'] = data.groupby(['用户名'])['认证时间转换'].shift(1)#该列下移一行,即上一次行为发生时间,用于计算行为发生间隔时间
data['时间间隔'] = data['认证时间转换'] - data['滞后'] #同一用户登录行为发生时间间隔
#统计不同用户的以下各列中不同值的个数(出现了多少种类别) 类别型统计特征(离散型)
for item in ['IP地址', '地点', '设备型号', '操作系统版本号', '浏览器版本号']:
data[f'user_{item}_nunique'] = data.groupby(['用户名'])[item].transform('nunique')
data[f'user_browser_version_mode'] = data.groupby(['用户名'])['浏览器版本号'].transform(lambda x :x.mode()[0])
#统计不同分组下,时间间隔的统计特征 数值型统计特征'min',
#可以尝试删除sum、prod,sum与mean冗余,prod计算所有元素的乘积 ,'prod', 'sum'
for method in ['mean', 'max', 'std', 'median']:
for col in ['用户名', 'IP地址', '地点', '设备型号', '操作系统版本号', '浏览器版本号']:
data[f'time_interval_{method}_' + str(col)] = data.groupby(col)['时间间隔'].transform(method)
# 各个特征的种类数 发现 客户端类型 和 浏览器来源 所有行均相同,不进行使用
data.nunique()
认证ID 25016
认证时间 25016
用户名 7
操作类型 2
首次认证方式 5
IP地址 5
IP类型 3
IP威胁级别 3
地点 5
浏览器类型 4
浏览器版本号 2
操作系统类型 4
操作系统版本号 5
设备型号 8
应用系统编码 7
应用系统类目 4
风险标识 2
登录类型 5
设备 8
应用系统相关 7
第一地点 3
第二地点 4
第三地点 4
认证时间转换 25016
滞后 25009
时间间隔 450
user_IP地址_nunique 2
user_地点_nunique 2
user_设备型号_nunique 2
user_操作系统版本号_nunique 1
user_浏览器版本号_nunique 1
user_browser_version_mode 1
time_interval_mean_用户名 7
time_interval_mean_IP地址 5
time_interval_mean_地点 5
time_interval_mean_设备型号 8
time_interval_mean_操作系统版本号 5
time_interval_mean_浏览器版本号 2
time_interval_max_用户名 7
time_interval_max_IP地址 5
time_interval_max_地点 5
time_interval_max_设备型号 8
time_interval_max_操作系统版本号 5
time_interval_max_浏览器版本号 2
time_interval_std_用户名 7
time_interval_std_IP地址 5
time_interval_std_地点 5
time_interval_std_设备型号 8
time_interval_std_操作系统版本号 5
time_interval_std_浏览器版本号 2
time_interval_median_用户名 2
time_interval_median_IP地址 2
time_interval_median_地点 2
time_interval_median_设备型号 2
time_interval_median_操作系统版本号 2
time_interval_median_浏览器版本号 1
dtype: int64
data.columns
Index(['认证ID', '认证时间', '用户名', '操作类型', '首次认证方式', 'IP地址', 'IP类型', 'IP威胁级别', '地点',
'浏览器类型', '浏览器版本号', '操作系统类型', '操作系统版本号', '设备型号', '应用系统编码', '应用系统类目',
'风险标识', '登录类型', '设备', '应用系统相关', '第一地点', '第二地点', '第三地点', '认证时间转换', '滞后',
'时间间隔', 'user_IP地址_nunique', 'user_地点_nunique', 'user_设备型号_nunique',
'user_操作系统版本号_nunique', 'user_浏览器版本号_nunique',
'user_browser_version_mode', 'time_interval_mean_用户名',
'time_interval_mean_IP地址', 'time_interval_mean_地点',
'time_interval_mean_设备型号', 'time_interval_mean_操作系统版本号',
'time_interval_mean_浏览器版本号', 'time_interval_max_用户名',
'time_interval_max_IP地址', 'time_interval_max_地点',
'time_interval_max_设备型号', 'time_interval_max_操作系统版本号',
'time_interval_max_浏览器版本号', 'time_interval_std_用户名',
'time_interval_std_IP地址', 'time_interval_std_地点',
'time_interval_std_设备型号', 'time_interval_std_操作系统版本号',
'time_interval_std_浏览器版本号', 'time_interval_median_用户名',
'time_interval_median_IP地址', 'time_interval_median_地点',
'time_interval_median_设备型号', 'time_interval_median_操作系统版本号',
'time_interval_median_浏览器版本号'],
dtype='object')
data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 25016 entries, 0 to 25015
Data columns (total 56 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 认证ID 25016 non-null object
1 认证时间 25016 non-null datetime64[ns]
2 用户名 25016 non-null int64
3 操作类型 25016 non-null int64
4 首次认证方式 25016 non-null int64
5 IP地址 25016 non-null int64
6 IP类型 25016 non-null int64
7 IP威胁级别 25016 non-null int64
8 地点 25016 non-null int64
9 浏览器类型 25016 non-null int64
10 浏览器版本号 25016 non-null int64
11 操作系统类型 25016 non-null int64
12 操作系统版本号 25016 non-null int64
13 设备型号 25016 non-null int64
14 应用系统编码 25016 non-null int64
15 应用系统类目 25016 non-null int64
16 风险标识 15016 non-null object
17 登录类型 25016 non-null int64
18 设备 25016 non-null int64
19 应用系统相关 25016 non-null int64
20 第一地点 25016 non-null int64
21 第二地点 25016 non-null int64
22 第三地点 25016 non-null int64
23 认证时间转换 25016 non-null int64
24 滞后 25009 non-null float64
25 时间间隔 25009 non-null float64
26 user_IP地址_nunique 25016 non-null int64
27 user_地点_nunique 25016 non-null int64
28 user_设备型号_nunique 25016 non-null int64
29 user_操作系统版本号_nunique 25016 non-null int64
30 user_浏览器版本号_nunique 25016 non-null int64
31 user_browser_version_mode 25016 non-null int64
32 time_interval_mean_用户名 25016 non-null float64
33 time_interval_mean_IP地址 25016 non-null float64
34 time_interval_mean_地点 25016 non-null float64
35 time_interval_mean_设备型号 25016 non-null float64
36 time_interval_mean_操作系统版本号 25016 non-null float64
37 time_interval_mean_浏览器版本号 25016 non-null float64
38 time_interval_max_用户名 25016 non-null float64
39 time_interval_max_IP地址 25016 non-null float64
40 time_interval_max_地点 25016 non-null float64
41 time_interval_max_设备型号 25016 non-null float64
42 time_interval_max_操作系统版本号 25016 non-null float64
43 time_interval_max_浏览器版本号 25016 non-null float64
44 time_interval_std_用户名 25016 non-null float64
45 time_interval_std_IP地址 25016 non-null float64
46 time_interval_std_地点 25016 non-null float64
47 time_interval_std_设备型号 25016 non-null float64
48 time_interval_std_操作系统版本号 25016 non-null float64
49 time_interval_std_浏览器版本号 25016 non-null float64
50 time_interval_median_用户名 25016 non-null float64
51 time_interval_median_IP地址 25016 non-null float64
52 time_interval_median_地点 25016 non-null float64
53 time_interval_median_设备型号 25016 non-null float64
54 time_interval_median_操作系统版本号 25016 non-null float64
55 time_interval_median_浏览器版本号 25016 non-null float64
dtypes: datetime64[ns](1), float64(26), int64(27), object(2)
memory usage: 10.7+ MB
data.shape
(25016, 56)
cols = [i for i in data.columns if i not in ['风险标识','认证ID','认证时间','时间间隔','滞后']]
train = data[data['风险标识'].notna()]
test = data[data['风险标识'].isna()]
traindata = np.array(train[cols])
trainlabel = np.array(train['风险标识'])
testdata = np.array(test[cols])
print(traindata.shape, trainlabel.shape)
(15016, 51) (15016,)
# 转换为nparray
data_X = np.array(traindata[:12000,], dtype='float32')
data_Y = np.array(trainlabel[:12000,], dtype='float32')
test_data_X = np.array(traindata[12000:,], dtype='float32')
test_data_Y = np.array(trainlabel[12000:,], dtype='float32')
# 检查大小
print('data shape', data_X.shape, data_Y.shape)
print('testdata shape', test_data_X.shape,test_data_Y)
# aa = pd.DataFrame(test_data_X)
# aa.info()
from sklearn.preprocessing import MinMaxScaler
mm = MinMaxScaler()
data_X = mm.fit_transform(data_X)
test_data_X = mm.fit_transform(test_data_X)
testdata = mm.fit_transform(testdata)
data shape (12000, 51) (12000,)
testdata shape (3016, 51) [0. 0. 0. ... 0. 0. 0.]
三、模型组网
使用飞桨PaddlePaddle进行组网,在本基线系统中,只使用两层全连接层完成分类任务。
import paddle
import paddle.nn as nn
# 定义动态图
class Classification(paddle.nn.Layer):
def __init__(self):
super(Classification, self).__init__()
self.drop = paddle.nn.Dropout(p=0.5)
self.fc1 = paddle.nn.Linear(51, 32)
self.fc2 = paddle.nn.Linear(32, 16)
self.fc3 = paddle.nn.Linear(16, 2)
# 网络的前向计算函数
def forward(self, inputs):
x = self.fc1(inputs)
x = self.fc2(x)
pred = self.fc3(x)
return pred
四、配置参数及训练
记录日志
# 定义绘制训练过程的损失值变化趋势的方法draw_train_process
train_nums = []
train_costs = []
def draw_train_process(iters,train_costs):
title="training cost"
plt.title(title, fontsize=24)
plt.xlabel("iter", fontsize=14)
plt.ylabel("cost", fontsize=14)
plt.plot(iters, train_costs,color='red',label='training cost')
plt.grid()
plt.show()
定义损失函数
损失函数使用【R-Drop:摘下SOTA的Dropout正则化策略】里的kl_loss:
import paddle
import paddle.nn.functional as F
class kl_loss(paddle.nn.Layer):
def __init__(self):
super(kl_loss, self).__init__()
def forward(self, p, q, label):
ce_loss = 0.5 * (F.mse_loss(p, label=label)) + F.mse_loss(q, label=label)
kl_loss = self.compute_kl_loss(p, q)
# carefully choose hyper-parameters
loss = ce_loss + 0.3 * kl_loss
return loss
def compute_kl_loss(self, p, q):
p_loss = F.kl_div(F.log_softmax(p, axis=-1), F.softmax(q, axis=-1), reduction='none')
q_loss = F.kl_div(F.log_softmax(q, axis=-1), F.softmax(p, axis=-1), reduction='none')
# You can choose whether to use function "sum" and "mean" depending on your task
p_loss = p_loss.sum()
q_loss = q_loss.sum()
loss = (p_loss + q_loss) / 2
return loss
模型训练
import paddle.nn.functional as F
y_preds = []
labels_list = []
BATCH_SIZE = 32
train_data = np.column_stack((data_X,data_Y))
test_data = np.column_stack((test_data_X,test_data_Y))
compute_kl_loss = kl_loss()
def train(model):
print('start training ... ')
# 开启模型训练模式
model.train()
EPOCH_NUM = 10
train_num = 0
scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=0.00025, T_max=int(traindatas.shape[0]/BATCH_SIZE*EPOCH_NUM), verbose=False)
optimizer = paddle.optimizer.Adam(learning_rate=scheduler, parameters=model.parameters())
for epoch_id in range(EPOCH_NUM):
# 在每轮迭代开始之前,将训练数据的顺序随机的打乱
np.random.shuffle(train_data)
# 将训练数据进行拆分,每个batch包含50条数据
mini_batches = [train_data[k: k+BATCH_SIZE] for k in range(0, len(train_data), BATCH_SIZE)]
# print(mini_batches[0].shape)
for batch_id, data in enumerate(mini_batches):
# print(batch_id)
# print('data',data.shape)
# print(data[:, -1:].shape)
features_np = np.array(data[:, :51], np.float32)
# print(data[:, -1:])
labels_np = np.array(data[:,-1:], np.float32)
features = paddle.to_tensor(features_np)
labels = paddle.to_tensor(labels_np)
#前向计算
y_pred1 = model(features)
# print(y_pred1[0])
y_pred2 = model(features)
cost = compute_kl_loss(y_pred1, y_pred2, label=labels)
# cost = F.mse_loss(y_pred, label=labels)
train_cost = cost.numpy()[0]
#反向传播
cost.backward()
#最小化loss,更新参数
optimizer.step()
# 清除梯度
optimizer.clear_grad()
if batch_id % 500 == 0 and epoch_id % 1 == 0:
print("Pass:%d,Cost:%0.5f"%(epoch_id, train_cost))
train_num = train_num + BATCH_SIZE
train_nums.append(train_num)
train_costs.append(train_cost)
model = Classification()
train(model)
W0705 10:08:57.338639 1304 gpu_context.cc:278] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 10.1
W0705 10:08:57.342248 1304 gpu_context.cc:306] device: 0, cuDNN Version: 7.6.
start training ...
Pass:0,Cost:7.29135
Pass:1,Cost:0.28412
Pass:2,Cost:0.21600
Pass:3,Cost:0.21525
Pass:4,Cost:0.19563
Pass:5,Cost:0.28115
Pass:6,Cost:0.48149
Pass:7,Cost:0.30064
Pass:8,Cost:0.28895
Pass:9,Cost:0.27548
def predict(model):
print('predict....')
model.eval()
outputs = []
mini_batches = [test_data[k: k+BATCH_SIZE] for k in range(0, len(test_data), BATCH_SIZE)]
# print(len(mini_batches[0]))
for data in mini_batches:
features_np = np.array(data[:, :51], np.float32)
features = paddle.to_tensor(features_np)
pred = model(features)
# print('pred',pred.shape)
out = paddle.argmax(pred, axis=1)
# print('out:',out.numpy())
outputs.extend(out.numpy())
return outputs
outputs = predict(model)
test_data_y = test_data[:,-1:].reshape(-1,)
outputs = np.array(outputs,np.float32)
np.sum(test_data_y == outputs) / test_data_y.shape[0]
predict....
0.7231432360742706
五、保存预测结果
模型预测:
predict_result = []
for infer_feature in testdata:
# print(infer_feature)
infer_feature = paddle.to_tensor(np.array(infer_feature, dtype='float32'))
result = model(infer_feature)
# print(result)
predict_result.append(result)
将结果写入.CSV文件中:
import os
import pandas as pd
id_list = [item for item in range(1, 10001)]
label_list = []
csv_file = 'submission.csv'
for item in range(len(id_list)):
label = np.argmax(predict_result[item])
label_list.append(label)
data = {'id':id_list, 'ret':label_list}
df = pd.DataFrame(data)
list = []
csv_file = 'submission.csv'
for item in range(len(id_list)):
label = np.argmax(predict_result[item])
label_list.append(label)
data = {'id':id_list, 'ret':label_list}
df = pd.DataFrame(data)
df.to_csv(csv_file, index=False, encoding='utf8')
六、改进思路
- 可以考虑设计更复杂的模型来提高模型准确度
- 目前只使用了数据集中的7个特征,可以尝试对还未利用的特征进行编码
- 目前进行训练的数据较少,可以尝试进行数据增强
1.在原有特征的基础上添加特征由8维特征添加到51维
2.对模型进行更改,增加了一个隐含层
3.对数据进行归一化处理。
4.添加了预测函数。
开源链接:https://aistudio.baidu.com/aistudio/projectdetail/4217617
更多推荐
所有评论(0)