系统认证风险预测-异常检测

本项目依据用户登录某一网站的一些基本信息,利用Paddle搭建一个神经网络模型,来预测出此次登录是否存在风险。

一、赛题背景

随着国家、企业对安全和效率越来越重视,作为安全基础设施之一的统一身份管理(IAM,Identity and Access Management)也得到越来越多的关注。 在IAM领域中,其主要安全防护手段是身份鉴别,既我们常见的口令验证、扫码验证、指纹验证等。它们一般分为三类,既用户所知(如口令)、所有(如身份证)、特征(如人脸)。但这些身份鉴别方式都有其各自的优缺点,比如口令,强度高了记不住,强度低了容易丢,又比如人脸,摇头晃脑做活体检测体验不好,静默检测如果算法不够,又很容易被照片、视频、人脸模型绕过。也因此,在等保2.0中对于三级以上系统要求必须使用两种以上的身份鉴别方式进行身份验证,来提高身份鉴别的可信度,这也被成为双因素认证。

但这对用户来说虽然在一定程度提高了安全性,但极大的降低了用户体验。也因此,IAM厂商开始参考UEBA、用户画像等行为分析技术,来探索一种既能确保用户体验,又能提高身份鉴别强度的方法。而在当前IAM的探索进程当中,目前最为可具落地的方法,是基于规则的行为分析技术,因为它的可理解性很高,且很容易与身份鉴别技术进行联动。

但基于规则但局限性也很明显,它是基于经验的,有宁错杀一千,不放过一个的特点,缺少从数据层面来证明是否有人正在尝试窃取或验证非法获取的身份信息,又或者正在使用窃取的身份信息,以此来提前进行风险预警和处置。

更多内容请前往比赛官网查看:系统认证风险预测-异常检测
在这里插入图片描述

比赛任务

本赛题中,参赛团队需要基于用户认证行为数据及风险异常标记结构,构建用户认证行为特征模型和风险异常评估模型,利用风险评估模型去判断当前用户认证行为是否存在风险。

  • 利用用户认证数据构建行为基线
  • 采用监督学习模型,基于用户认证行为特征,构建风险异常评估模型,判断当前用户认证行为是否存在风险

二、数据分析及处理

比赛数据是从竹云的风险分析产品日志库中摘录而来,主要涉及认证日志与风险日志数据。比赛数据经过数据脱敏和数据筛选等安全处理操作,供大家使用。其中认证日志是用户在访问应用系统产生的行为数据,包括登录、单点登录、退出等行为。

该比赛使用的数据已上传至AI Studio:https://aistudio.baidu.com/aistudio/datasetdetail/112151

数据概况

import csv
import numpy as np
import pandas as pd

train_dataset = r'data/data112151/train_dataset.csv'
test_dataset = r'data/data112151/test_dataset.csv'

with open(train_dataset, encoding = 'utf-8') as trainset:
    traindatas = np.loadtxt(trainset, str, delimiter = "\t", skiprows = 1)

with open(test_dataset, encoding = 'utf-8') as testset:
    testdatas = np.loadtxt(testset, str, delimiter = "\t", skiprows = 1)

print("在训练集中,共有{}条数据,其中每条数据有{}个特征".format(traindatas.shape[0], traindatas.shape[1]))
print("在测试集中,共有{}条数据,其中每条数据有{}个特征".format(testdatas.shape[0], testdatas.shape[1]))
traindata2 = pd.DataFrame(traindatas)
traindata2.columns = ['认证ID','认证时间','用户名','操作类型','首次认证方式','IP地址','IP类型','IP威胁级别','地点','客户端类型','浏览器来源',
'浏览器类型','浏览器版本号','操作系统类型','操作系统版本号','设备型号','应用系统编码','应用系统类目','风险标识']
testdata2 = pd.DataFrame(testdatas)
testdata2.columns = ['认证ID','认证时间','用户名','操作类型','首次认证方式','IP地址','IP类型','IP威胁级别','地点','客户端类型','浏览器来源',
'浏览器类型','浏览器版本号','操作系统类型','操作系统版本号','设备型号','应用系统编码','应用系统类目']
traindata2.head()
在训练集中,共有15016条数据,其中每条数据有19个特征
在测试集中,共有10000条数据,其中每条数据有18个特征
认证ID 认证时间 用户名 操作类型 首次认证方式 IP地址 IP类型 IP威胁级别 地点 客户端类型 浏览器来源 浏览器类型 浏览器版本号 操作系统类型 操作系统版本号 设备型号 应用系统编码 应用系统类目 风险标识
0 access:test_d:20180101111639:bBp1 2018/1/1 11:16 test_d login otp 192.168.100.101 内网 1级 {"first_lvl":"成都分公司","sec_lvl":"9楼","third_lvl... web desktop think_pad_e460 windows windows 10 chrome chrome 90 coremail management 0
1 access:test_d:20180101121524:OBSg 2018/1/1 12:15 test_d login qr 192.168.100.101 内网 1级 {"first_lvl":"成都分公司","sec_lvl":"9楼","third_lvl... web desktop think_pad_e460 windows windows 10 edge edge 93 order-mgnt sales 0
2 access:test_d:20180101151333:BpQN 2018/1/1 15:13 test_d login qr 192.168.100.101 内网 1级 {"first_lvl":"成都分公司","sec_lvl":"9楼","third_lvl... web desktop think_pad_e460 windows windows 10 chrome chrome 90 order-mgnt sales 0
3 access:test_d:20180101124502:hYQm 2018/1/1 12:45 test_d sso 192.168.100.101 内网 1级 {"first_lvl":"成都分公司","sec_lvl":"9楼","third_lvl... web desktop think_pad_e460 windows windows 10 edge edge 93 oa management 0
4 access:test_d:20180101202749:FkDK 2018/1/1 20:27 test_d sso 192.168.100.101 内网 1级 {"first_lvl":"成都分公司","sec_lvl":"9楼","third_lvl... web desktop think_pad_e460 windows windows 10 edge edge 93 order-mgnt sales 0

数据预处理

比赛提供的部分数据结构如下所示:

序号 字段名_中文 字段名_英文 含义 示例
3 操作类型 action login还是sso 1 : login 2 : sso
4 首次认证方式 auth_type 账密、短信、otp、二维码 1 : pwd 2 : sms 3 : otp 4 : qr
6 IP类型 ip_location_type_keyword 内网、家庭宽带、公共宽带、代理ip 1 : 家庭宽带 2 : 代理ip 3 : 内网 4 : 公共宽带
7 IP威胁级别 ip_risk_level 1级、2级、3级 1 : 1级 2 : 2级 3 : 3级
9 客户端类型 client_type app、web 1 : app 2 : web
10 浏览器来源 browser_source 桌面端、移动端 1 : desktop 2 : mobile
17 应用系统类目 op_target 系统所属的系统类别 1 : sales 2 : finance 3 : management 4 : hr
18 风险标识 risk_label 有、无 1 : 有;0 : 无
1 认证时间 op_date 认证时间 2018-1-1 11:16:00
2 用户名 user_name 用户名 test_a
5 IP地址 ip 用户认证ip地址 192.168.0.100
8 地点 location
11 浏览器类型 browser_type 浏览器类型 chrome
12 浏览器版本号 browser_version 浏览器版本号 chrome 90
13 操作系统类型 os_type 操作系统类型 windows
14 操作系统版本号 os_version 操作系统版本号 windows 10
15 设备型号 device_model 访问应用时所使用设备 think_pad_e460
16 应用系统编码 bus_system_code 所访问的应用系统编码 coremail
#合并数据
data = pd.concat([traindata2,testdata2])
print(data.shape)
(25016, 19)
# 删除无用特征
data.drop(['客户端类型','浏览器来源'], axis=1, inplace=True)
#特征集成
data['首次认证方式'].fillna('_', inplace=True)
data['登录类型'] = data['操作类型'] + data['首次认证方式']
data['登录类型'].value_counts()
sso         12600
loginpwd     3150
loginsms     3119
loginqr      3083
loginotp     3064
Name: 登录类型, dtype: int64
data['设备'] = data['设备型号'] + data['操作系统类型'] + data['操作系统版本号'] + data['浏览器类型'] + data['浏览器版本号']
data['设备'].value_counts()
chrome 90windows 10chromethink_pad_e460windows      9894
edge 93windows 10edgethink_pad_e460windows          9878
safari 13macOS Big Sur 11safarimacbookmacOS         3126
firefox 78windows 11firefoxthink_pad_t480windows     511
chrome 93windows 11chromethink_pad_t480windows       498
ie 11windows 11iethink_pad_t480windows               473
ie 9windows 7iethink_pad_l470windows                 333
chrome 77windows 7chromethink_pad_l470windows        303
Name: 设备, dtype: int64
data['应用系统相关'] = data['应用系统编码'] + data['应用系统类目']
data['应用系统相关'].value_counts()
order-mgntsales         8015
crmsales                4957
coremailmanagement      4486
oamanagement            2988
reimbursementfinance    2300
attendancehr            1811
salaryhr                 459
Name: 应用系统相关, dtype: int64
#处理地点列
import json
data['第一地点'] = data['地点'].astype(str).apply(lambda x: json.loads(x)['first_lvl'])
data['第二地点'] = data['地点'].astype(str).apply(lambda x: json.loads(x)['sec_lvl'])
data['第三地点'] = data['地点'].astype(str).apply(lambda x: json.loads(x)['third_lvl'])
pip uninstall tqdm -y
Found existing installation: tqdm 4.27.0
Uninstalling tqdm-4.27.0:
  Successfully uninstalled tqdm-4.27.0
Note: you may need to restart the kernel to use updated packages.
pip install tqdm
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting tqdm
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/8a/c4/d15f1e627fff25443ded77ea70a7b5532d6371498f9285d44d62587e209c/tqdm-4.64.0-py2.py3-none-any.whl (78 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.4/78.4 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tqdm
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
paddlefsl 1.0.0 requires tqdm~=4.27.0, but you have tqdm 4.64.0 which is incompatible.[0m[31m
[0mSuccessfully installed tqdm-4.64.0
Note: you may need to restart the kernel to use updated packages.
#特征编码
from tqdm import tqdm
from sklearn.preprocessing import LabelEncoder
features_encoder = [i for i in data.columns if i not in['认证ID','认证时间','风险标识']]
for col in tqdm(features_encoder):
    lbl = LabelEncoder()
    data[col] = lbl.fit_transform(data[col])
100%|██████████| 20/20 [00:00<00:00, 1011.76it/s]
# 处理时间
data['认证时间'] = pd.to_datetime(data['认证时间'])
data['认证时间转换'] = data['认证时间'].values.astype(np.int64) // 10 ** 9
#滞后历史特征
data = data.sort_values(by=['用户名', '认证时间转换']).reset_index(drop=True)
data['滞后'] = data.groupby(['用户名'])['认证时间转换'].shift(1)#该列下移一行,即上一次行为发生时间,用于计算行为发生间隔时间
data['时间间隔'] = data['认证时间转换'] - data['滞后']  #同一用户登录行为发生时间间隔
#统计不同用户的以下各列中不同值的个数(出现了多少种类别) 类别型统计特征(离散型)
for item in ['IP地址', '地点', '设备型号', '操作系统版本号', '浏览器版本号']:
    data[f'user_{item}_nunique'] = data.groupby(['用户名'])[item].transform('nunique')

data[f'user_browser_version_mode'] = data.groupby(['用户名'])['浏览器版本号'].transform(lambda x :x.mode()[0])
#统计不同分组下,时间间隔的统计特征  数值型统计特征'min', 
#可以尝试删除sum、prod,sum与mean冗余,prod计算所有元素的乘积 ,'prod', 'sum'
for method in ['mean', 'max', 'std', 'median']:
    for col in ['用户名', 'IP地址', '地点', '设备型号', '操作系统版本号', '浏览器版本号']:
        data[f'time_interval_{method}_' + str(col)] = data.groupby(col)['时间间隔'].transform(method)
# 各个特征的种类数 发现 客户端类型 和 浏览器来源 所有行均相同,不进行使用
data.nunique()
认证ID                            25016
认证时间                            25016
用户名                                 7
操作类型                                2
首次认证方式                              5
IP地址                                5
IP类型                                3
IP威胁级别                              3
地点                                  5
浏览器类型                               4
浏览器版本号                              2
操作系统类型                              4
操作系统版本号                             5
设备型号                                8
应用系统编码                              7
应用系统类目                              4
风险标识                                2
登录类型                                5
设备                                  8
应用系统相关                              7
第一地点                                3
第二地点                                4
第三地点                                4
认证时间转换                          25016
滞后                              25009
时间间隔                              450
user_IP地址_nunique                   2
user_地点_nunique                     2
user_设备型号_nunique                   2
user_操作系统版本号_nunique                1
user_浏览器版本号_nunique                 1
user_browser_version_mode           1
time_interval_mean_用户名              7
time_interval_mean_IP地址             5
time_interval_mean_地点               5
time_interval_mean_设备型号             8
time_interval_mean_操作系统版本号          5
time_interval_mean_浏览器版本号           2
time_interval_max_用户名               7
time_interval_max_IP地址              5
time_interval_max_地点                5
time_interval_max_设备型号              8
time_interval_max_操作系统版本号           5
time_interval_max_浏览器版本号            2
time_interval_std_用户名               7
time_interval_std_IP地址              5
time_interval_std_地点                5
time_interval_std_设备型号              8
time_interval_std_操作系统版本号           5
time_interval_std_浏览器版本号            2
time_interval_median_用户名            2
time_interval_median_IP地址           2
time_interval_median_地点             2
time_interval_median_设备型号           2
time_interval_median_操作系统版本号        2
time_interval_median_浏览器版本号         1
dtype: int64
data.columns
Index(['认证ID', '认证时间', '用户名', '操作类型', '首次认证方式', 'IP地址', 'IP类型', 'IP威胁级别', '地点',
       '浏览器类型', '浏览器版本号', '操作系统类型', '操作系统版本号', '设备型号', '应用系统编码', '应用系统类目',
       '风险标识', '登录类型', '设备', '应用系统相关', '第一地点', '第二地点', '第三地点', '认证时间转换', '滞后',
       '时间间隔', 'user_IP地址_nunique', 'user_地点_nunique', 'user_设备型号_nunique',
       'user_操作系统版本号_nunique', 'user_浏览器版本号_nunique',
       'user_browser_version_mode', 'time_interval_mean_用户名',
       'time_interval_mean_IP地址', 'time_interval_mean_地点',
       'time_interval_mean_设备型号', 'time_interval_mean_操作系统版本号',
       'time_interval_mean_浏览器版本号', 'time_interval_max_用户名',
       'time_interval_max_IP地址', 'time_interval_max_地点',
       'time_interval_max_设备型号', 'time_interval_max_操作系统版本号',
       'time_interval_max_浏览器版本号', 'time_interval_std_用户名',
       'time_interval_std_IP地址', 'time_interval_std_地点',
       'time_interval_std_设备型号', 'time_interval_std_操作系统版本号',
       'time_interval_std_浏览器版本号', 'time_interval_median_用户名',
       'time_interval_median_IP地址', 'time_interval_median_地点',
       'time_interval_median_设备型号', 'time_interval_median_操作系统版本号',
       'time_interval_median_浏览器版本号'],
      dtype='object')
data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 25016 entries, 0 to 25015
Data columns (total 56 columns):
 #   Column                        Non-Null Count  Dtype         
---  ------                        --------------  -----         
 0   认证ID                          25016 non-null  object        
 1   认证时间                          25016 non-null  datetime64[ns]
 2   用户名                           25016 non-null  int64         
 3   操作类型                          25016 non-null  int64         
 4   首次认证方式                        25016 non-null  int64         
 5   IP地址                          25016 non-null  int64         
 6   IP类型                          25016 non-null  int64         
 7   IP威胁级别                        25016 non-null  int64         
 8   地点                            25016 non-null  int64         
 9   浏览器类型                         25016 non-null  int64         
 10  浏览器版本号                        25016 non-null  int64         
 11  操作系统类型                        25016 non-null  int64         
 12  操作系统版本号                       25016 non-null  int64         
 13  设备型号                          25016 non-null  int64         
 14  应用系统编码                        25016 non-null  int64         
 15  应用系统类目                        25016 non-null  int64         
 16  风险标识                          15016 non-null  object        
 17  登录类型                          25016 non-null  int64         
 18  设备                            25016 non-null  int64         
 19  应用系统相关                        25016 non-null  int64         
 20  第一地点                          25016 non-null  int64         
 21  第二地点                          25016 non-null  int64         
 22  第三地点                          25016 non-null  int64         
 23  认证时间转换                        25016 non-null  int64         
 24  滞后                            25009 non-null  float64       
 25  时间间隔                          25009 non-null  float64       
 26  user_IP地址_nunique             25016 non-null  int64         
 27  user_地点_nunique               25016 non-null  int64         
 28  user_设备型号_nunique             25016 non-null  int64         
 29  user_操作系统版本号_nunique          25016 non-null  int64         
 30  user_浏览器版本号_nunique           25016 non-null  int64         
 31  user_browser_version_mode     25016 non-null  int64         
 32  time_interval_mean_用户名        25016 non-null  float64       
 33  time_interval_mean_IP地址       25016 non-null  float64       
 34  time_interval_mean_地点         25016 non-null  float64       
 35  time_interval_mean_设备型号       25016 non-null  float64       
 36  time_interval_mean_操作系统版本号    25016 non-null  float64       
 37  time_interval_mean_浏览器版本号     25016 non-null  float64       
 38  time_interval_max_用户名         25016 non-null  float64       
 39  time_interval_max_IP地址        25016 non-null  float64       
 40  time_interval_max_地点          25016 non-null  float64       
 41  time_interval_max_设备型号        25016 non-null  float64       
 42  time_interval_max_操作系统版本号     25016 non-null  float64       
 43  time_interval_max_浏览器版本号      25016 non-null  float64       
 44  time_interval_std_用户名         25016 non-null  float64       
 45  time_interval_std_IP地址        25016 non-null  float64       
 46  time_interval_std_地点          25016 non-null  float64       
 47  time_interval_std_设备型号        25016 non-null  float64       
 48  time_interval_std_操作系统版本号     25016 non-null  float64       
 49  time_interval_std_浏览器版本号      25016 non-null  float64       
 50  time_interval_median_用户名      25016 non-null  float64       
 51  time_interval_median_IP地址     25016 non-null  float64       
 52  time_interval_median_地点       25016 non-null  float64       
 53  time_interval_median_设备型号     25016 non-null  float64       
 54  time_interval_median_操作系统版本号  25016 non-null  float64       
 55  time_interval_median_浏览器版本号   25016 non-null  float64       
dtypes: datetime64[ns](1), float64(26), int64(27), object(2)
memory usage: 10.7+ MB
data.shape
(25016, 56)
cols = [i for i in data.columns if i not in ['风险标识','认证ID','认证时间','时间间隔','滞后']]
train = data[data['风险标识'].notna()]
test = data[data['风险标识'].isna()]
traindata = np.array(train[cols])
trainlabel = np.array(train['风险标识'])
testdata = np.array(test[cols])
print(traindata.shape, trainlabel.shape)
(15016, 51) (15016,)
# 转换为nparray
data_X = np.array(traindata[:12000,], dtype='float32')
data_Y = np.array(trainlabel[:12000,], dtype='float32')
test_data_X = np.array(traindata[12000:,], dtype='float32')
test_data_Y = np.array(trainlabel[12000:,], dtype='float32')
# 检查大小
print('data shape', data_X.shape, data_Y.shape)
print('testdata shape', test_data_X.shape,test_data_Y)
# aa = pd.DataFrame(test_data_X)
# aa.info()
from sklearn.preprocessing import MinMaxScaler
mm = MinMaxScaler()
data_X = mm.fit_transform(data_X)
test_data_X = mm.fit_transform(test_data_X)
testdata = mm.fit_transform(testdata)
data shape (12000, 51) (12000,)
testdata shape (3016, 51) [0. 0. 0. ... 0. 0. 0.]

三、模型组网

使用飞桨PaddlePaddle进行组网,在本基线系统中,只使用两层全连接层完成分类任务。

import paddle
import paddle.nn as nn

# 定义动态图
class Classification(paddle.nn.Layer):
    def __init__(self):
        super(Classification, self).__init__()
        self.drop = paddle.nn.Dropout(p=0.5)
        self.fc1 = paddle.nn.Linear(51, 32)
        self.fc2 = paddle.nn.Linear(32, 16)
        self.fc3 = paddle.nn.Linear(16, 2)
    
    # 网络的前向计算函数
    def forward(self, inputs):
        x = self.fc1(inputs)
        x = self.fc2(x)
        pred = self.fc3(x)
        return pred

四、配置参数及训练

记录日志

# 定义绘制训练过程的损失值变化趋势的方法draw_train_process
train_nums = []
train_costs = []
def draw_train_process(iters,train_costs):
    title="training cost"
    plt.title(title, fontsize=24)
    plt.xlabel("iter", fontsize=14)
    plt.ylabel("cost", fontsize=14)
    plt.plot(iters, train_costs,color='red',label='training cost') 
    plt.grid()
    plt.show()

定义损失函数

损失函数使用【R-Drop:摘下SOTA的Dropout正则化策略】里的kl_loss:

import paddle
import paddle.nn.functional as F

class kl_loss(paddle.nn.Layer):
    def __init__(self):
       super(kl_loss, self).__init__()

    def forward(self, p, q, label):
        ce_loss = 0.5 * (F.mse_loss(p, label=label)) + F.mse_loss(q, label=label)
        kl_loss = self.compute_kl_loss(p, q)

        # carefully choose hyper-parameters
        loss = ce_loss + 0.3 * kl_loss 

        return loss

    def compute_kl_loss(self, p, q):
        
        p_loss = F.kl_div(F.log_softmax(p, axis=-1), F.softmax(q, axis=-1), reduction='none')
        q_loss = F.kl_div(F.log_softmax(q, axis=-1), F.softmax(p, axis=-1), reduction='none')

        # You can choose whether to use function "sum" and "mean" depending on your task
        p_loss = p_loss.sum()
        q_loss = q_loss.sum()

        loss = (p_loss + q_loss) / 2

        return loss

模型训练

import paddle.nn.functional as F
y_preds = []
labels_list = []
BATCH_SIZE = 32
train_data = np.column_stack((data_X,data_Y))
test_data = np.column_stack((test_data_X,test_data_Y)) 
compute_kl_loss = kl_loss()

def train(model):
    print('start training ... ')
    # 开启模型训练模式
    model.train()
    EPOCH_NUM = 10
    train_num = 0
    scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=0.00025, T_max=int(traindatas.shape[0]/BATCH_SIZE*EPOCH_NUM), verbose=False)
    optimizer = paddle.optimizer.Adam(learning_rate=scheduler, parameters=model.parameters())
    for epoch_id in range(EPOCH_NUM):
        # 在每轮迭代开始之前,将训练数据的顺序随机的打乱
        np.random.shuffle(train_data)
        # 将训练数据进行拆分,每个batch包含50条数据
        mini_batches = [train_data[k: k+BATCH_SIZE] for k in range(0, len(train_data), BATCH_SIZE)]
        # print(mini_batches[0].shape)
        for batch_id, data in enumerate(mini_batches):
            # print(batch_id)
            # print('data',data.shape)
            # print(data[:, -1:].shape)
            features_np = np.array(data[:, :51], np.float32)
            # print(data[:, -1:])
            labels_np = np.array(data[:,-1:], np.float32)

            features = paddle.to_tensor(features_np)
            labels = paddle.to_tensor(labels_np)

            #前向计算
            y_pred1 = model(features)
            # print(y_pred1[0])
            y_pred2 = model(features)
            cost = compute_kl_loss(y_pred1, y_pred2, label=labels)
            # cost = F.mse_loss(y_pred, label=labels)
            train_cost = cost.numpy()[0]
            #反向传播
            cost.backward()
            #最小化loss,更新参数
            optimizer.step()
            # 清除梯度
            optimizer.clear_grad()
            if batch_id % 500 == 0 and epoch_id % 1 == 0:
                print("Pass:%d,Cost:%0.5f"%(epoch_id, train_cost))

            train_num = train_num + BATCH_SIZE
            train_nums.append(train_num)
            train_costs.append(train_cost)

model = Classification()
train(model)
W0705 10:08:57.338639  1304 gpu_context.cc:278] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 10.1
W0705 10:08:57.342248  1304 gpu_context.cc:306] device: 0, cuDNN Version: 7.6.


start training ... 
Pass:0,Cost:7.29135
Pass:1,Cost:0.28412
Pass:2,Cost:0.21600
Pass:3,Cost:0.21525
Pass:4,Cost:0.19563
Pass:5,Cost:0.28115
Pass:6,Cost:0.48149
Pass:7,Cost:0.30064
Pass:8,Cost:0.28895
Pass:9,Cost:0.27548
def predict(model):
    print('predict....')
    model.eval()
    outputs = []
    mini_batches = [test_data[k: k+BATCH_SIZE] for k in range(0, len(test_data), BATCH_SIZE)]
    # print(len(mini_batches[0]))
    for data in mini_batches:
        features_np = np.array(data[:, :51], np.float32)
        features = paddle.to_tensor(features_np)
        pred = model(features)
        # print('pred',pred.shape)
        out = paddle.argmax(pred, axis=1)
        # print('out:',out.numpy())
        outputs.extend(out.numpy())
    return outputs
outputs = predict(model)
test_data_y = test_data[:,-1:].reshape(-1,)
outputs = np.array(outputs,np.float32)
np.sum(test_data_y == outputs) / test_data_y.shape[0]

predict....





0.7231432360742706

五、保存预测结果

模型预测:

predict_result = []
for infer_feature in testdata:
    # print(infer_feature)
    infer_feature = paddle.to_tensor(np.array(infer_feature, dtype='float32'))
    result = model(infer_feature)
    # print(result)
    predict_result.append(result)

将结果写入.CSV文件中:

import os
import pandas as pd

id_list = [item for item in range(1, 10001)]
label_list = []
csv_file = 'submission.csv'

for item in range(len(id_list)):
    label = np.argmax(predict_result[item])
    label_list.append(label)

data = {'id':id_list, 'ret':label_list}
df = pd.DataFrame(data)
list = []
csv_file = 'submission.csv'

for item in range(len(id_list)):
    label = np.argmax(predict_result[item])
    label_list.append(label)

data = {'id':id_list, 'ret':label_list}
df = pd.DataFrame(data)
df.to_csv(csv_file, index=False, encoding='utf8')

六、改进思路

  1. 可以考虑设计更复杂的模型来提高模型准确度
  2. 目前只使用了数据集中的7个特征,可以尝试对还未利用的特征进行编码
  3. 目前进行训练的数据较少,可以尝试进行数据增强

1.在原有特征的基础上添加特征由8维特征添加到51维

在这里插入图片描述

2.对模型进行更改,增加了一个隐含层

在这里插入图片描述

3.对数据进行归一化处理。

在这里插入图片描述

4.添加了预测函数。

在这里插入图片描述

开源链接https://aistudio.baidu.com/aistudio/projectdetail/4217617

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐