科大讯飞-电商图像检索挑战赛:基础思路CNN相似度 线上0.56得分
https://aistudio.baidu.com/aistudio/projectdetail/2798206
·
转载自AI Studio
项目链接https://aistudio.baidu.com/aistudio/projectdetail/2798206
赛题介绍
赛事背景
在电商应用中每天商家都会上传数以百万的商品图像,商品图像可能是从不同角度拍摄的,也有可能是不同款式的商品图像。对于消费者而言,很难通过肉眼去找到相似的商品。如果有一种人工智能算法,能够找到相同商品的相同图像,则是非常有用的一项技术。
赛事任务
给定一批电商商品(主要是服务商品)的图像,找到属于同一个商品的图像。任务可以视为一个图像检索问题,或者一个图像聚类问题,需要将同一个商品的图像聚类到一起。
- 训练集:约7千张商品图像,并且给定了相同商品对应的商品图像集合。具体可以在train.csv标注文件查找到,每行为一个图片对应的商品图像集合。
- 测试集:约4.8千张商品图像,需要选手识别出相同商品的图片集合。
数据集介绍
赛题数据由训练集和测试集组成,train.csv为训练集标注数据,第一列为图片名称,第二列图片对应商品下所有的图片集合。
name,label
008233.jpg,008233.jpg 006688.jpg
006688.jpg,008233.jpg 006688.jpg
000232.jpg,000232.jpg 003552.jpg
003552.jpg,000232.jpg 003552.jpg
评估指标
本次竞赛的评价标准采用图像集合交叉比,最高分为1。评估代码参考:
def set_iou(label, predict):
interset = set(label.split()) & set(predict.split())
unionset = set(label.split()) | set(predict.split())
return len(interset) *1.0 / len(unionset) *1.0
# 数据集解压
!cp data/data117293/电商图像检索_数据集.zip .pip/
!echo y | unzip -O CP936 data/data117293/电商图像检索_数据集.zip > /dev/null
replace 电商图像检索_数据集/sample_submit.csv? [y]es, [n]o, [A]ll, [N]one, [r]ename: replace 电商图像检索_数据集/test/000005.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: NULL
(EOF or read error, treating as "[N]one" ...)
)
import pandas as pd
import cv2
import glob
import os, sys, codecs, glob
from PIL import Image, ImageDraw
from tqdm import tqdm
import numpy as np
import pandas as pd
import cv2
%pylab inline
Populating the interactive namespace from numpy and matplotlib
train_df = pd.read_csv('./电商图像检索_数据集/train.csv')
train_df.head()
name | label | |
---|---|---|
0 | 008233.jpg | 008233.jpg 006688.jpg |
1 | 006688.jpg | 008233.jpg 006688.jpg |
2 | 000232.jpg | 000232.jpg 003552.jpg |
3 | 003552.jpg | 000232.jpg 003552.jpg |
4 | 000814.jpg | 000814.jpg 013765.jpg |
样本可视化
%pylab inline
def show_image(paths):
plt.figure(figsize=(10, 8))
for idx, path in enumerate(paths):
plt.subplot(1, len(paths), idx+1)
img = cv2.imread(path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
plt.imshow(img)
plt.xticks([]); plt.yticks([])
Populating the interactive namespace from numpy and matplotlib
show_image(
['./电商图像检索_数据集/train/'+x
for x in train_df['label'].iloc[0].split()]
)
show_image(
['./电商图像检索_数据集/train/'+x
for x in train_df['label'].iloc[200].split()]
)
加载预训练模型
import paddle
from paddle.vision.models import resnet18
# 使用预训练模型
def make_model():
model = resnet18(pretrained=True)
model.fc = paddle.nn.Identity()
return model
model = make_model()
W1118 14:07:56.395839 137 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W1118 14:07:56.401476 137 device_context.cc:465] device: 0, cuDNN Version: 7.6.
INFO:paddle.utils.download:unique_endpoints {''}
INFO:paddle.utils.download:Downloading resnet18.pdparams from https://paddle-hapi.bj.bcebos.com/models/resnet18.pdparams
100%|██████████| 69183/69183 [00:02<00:00, 30911.69it/s]
INFO:paddle.utils.download:File /home/aistudio/.cache/paddle/hapi/weights/resnet18.pdparams md5 checking...
提取图片特征
from paddle.io import DataLoader, Dataset
from PIL import Image
# 读取数据集
class MyDataset(Dataset):
def __init__(self, paths):
super(MyDataset, self).__init__()
self.paths = paths
def __getitem__(self, index):
img = Image.open(self.paths[index]).resize((224, 224))
return np.asarray(img).reshape(3, 224, 224).astype(np.float32)/255, 1
def __len__(self):
return len(self.paths)
train_dataset = MyDataset(
['./电商图像检索_数据集/train/' + x for x in train_df['name']]
)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=False)
test_dataset = MyDataset(
glob.glob('./电商图像检索_数据集/test/*.jpg')
)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
train_feat = []
for data in train_loader:
feat = model(data[0]).numpy()
train_feat.append(feat)
test_feat = []
for data in test_loader:
feat = model(data[0]).numpy()
test_feat.append(feat)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/nn/layer/norm.py:653: UserWarning: When training, we now always track global mean and variance.
"When training, we now always track global mean and variance.")
计算相似度
import numpy as np
from sklearn.preprocessing import normalize
# 拼接图像特征
train_feat = np.vstack(train_feat)
train_feat = normalize(train_feat)
test_feat = np.vstack(test_feat)
test_feat = normalize(test_feat)
train_feat = paddle.to_tensor(train_feat)
test_feat = paddle.to_tensor(test_feat)
# 计算图像相似度
train_ids = paddle.matmul(train_feat, train_feat.T)
test_ids = paddle.matmul(test_feat, test_feat.T)
train_ids = train_ids.numpy()
test_ids = test_ids.numpy()
# 评价函数
def set_iou(label, predict):
interset = set(label.split()) & set(predict.split())
unionset = set(label.split()) | set(predict.split())
return len(interset) *1.0 / len(unionset) *1.0
idx = 10
dis = test_ids[idx]
ids = np.argsort(test_ids[idx])[::-1][:2]
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
/tmp/ipykernel_550/2426643090.py in <module>
1 idx = 10
----> 2 dis = test_ids[idx]
3 ids = np.argsort(test_ids[idx])[::-1][:2]
NameError: name 'test_ids' is not defined
show_image(
np.array(glob.glob('./电商图像检索_数据集/test/*.jpg'))[ids]
)
总结与改进
- CNN是很好的特征提取器,可以很好的完成相似度计算的过程。
- 相同商品的模态和衣服都存在相似,且主要的区别是衣服的款式,所以应该消除模特所带来的影响。
- 可以考虑使用度量学习对模型进行训练,然后再计算相似度。
更多推荐
已为社区贡献1438条内容
所有评论(0)