2. 用TPU分类100种花

本文是 Flower Classification with TPUs比赛中一个使用VGG16预训练模型在TPU上的实现 Getting started with 100+ flowers on TPU 。也可以在 Colab 查看.

首先是一些参数设置,主要是设置TPU。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import math, re, os
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix

print(tf.__version__)
AUTO = tf.data.experimental.AUTOTUNE

# TPU 或者 GPU检测
try:
tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect()
strategy = tf.distribute.TPUStrategy(tpu)
except ValueError: #使用GPU
strategy = tf.distribute.MirroredStrategy() #使用GPU或多GPU机器
strategy = tf.distribute.get_strategy() #默认设置two works on CPU 或单GPU
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()#GPU集群

print("加速器数目:", strategy.num_replicas_in_sync) #输出设备数量

IMG_SIZE = [512, 512]

EPOCHS = 12
BATCH_SIZE = 16 * strategy.num_replicas_in_sync
DATA_DIR = r'flowers'
DATA_SIZE_SELECT = {
192: DATA_DIR + '/tfrecords-jpeg-192x192',
224: DATA_DIR + '/tfrecords-jpeg-224x224',
331: DATA_DIR + '/tfrecords-jpeg-331x331',
512: DATA_DIR + '/tfrecords-jpeg-512x512',
}

DATA_SELECT = DATA_SIZE_SELECT[IMG_SIZE[0]]
TRAINING_FILENAMES = tf.io.gfile.glob(DATA_SELECT + '/train/*.tfrec')
VALID_FILENAMES = tf.io.gfile.glob(DATA_SELECT + '/val/*.tfrec')
TEST_FILENAMES = tf.io.gfile.glob(DATA_SELECT + '/test/*.tfrec')

IMG_SIZE = [512, 512]

EPOCHS = 12
BATCH_SIZE = 16 * strategy.num_replicas_in_sync
DATA_DIR = r'/content/work/flowers'
DATA_SIZE_SELECT = {
192: DATA_DIR + '/tfrecords-jpeg-192x192',
224: DATA_DIR + '/tfrecords-jpeg-224x224',
331: DATA_DIR + '/tfrecords-jpeg-331x331',
512: DATA_DIR + '/tfrecords-jpeg-512x512',
}

DATA_SELECT = DATA_SIZE_SELECT[IMG_SIZE[0]]
TRAINING_FILENAMES = tf.io.gfile.glob(DATA_SELECT + '/train/*.tfrec')
VALID_FILENAMES = tf.io.gfile.glob(DATA_SELECT + '/val/*.tfrec')
TEST_FILENAMES = tf.io.gfile.glob(DATA_SELECT + '/test/*.tfrec')

#花名列表
CLASSES = ['pink primrose', 'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea', 'wild geranium', 'tiger lily', 'moon orchid', 'bird of paradise', 'monkshood', 'globe thistle', # 00 - 09
'snapdragon', "colt's foot", 'king protea', 'spear thistle', 'yellow iris', 'globe-flower', 'purple coneflower', 'peruvian lily', 'balloon flower', 'giant white arum lily', # 10 - 19
'fire lily', 'pincushion flower', 'fritillary', 'red ginger', 'grape hyacinth', 'corn poppy', 'prince of wales feathers', 'stemless gentian', 'artichoke', 'sweet william', # 20 - 29
'carnation', 'garden phlox', 'love in the mist', 'cosmos', 'alpine sea holly', 'ruby-lipped cattleya', 'cape flower', 'great masterwort', 'siam tulip', 'lenten rose', # 30 - 39
'barberton daisy', 'daffodil', 'sword lily', 'poinsettia', 'bolero deep blue', 'wallflower', 'marigold', 'buttercup', 'daisy', 'common dandelion', # 40 - 49
'petunia', 'wild pansy', 'primula', 'sunflower', 'lilac hibiscus', 'bishop of llandaff', 'gaura', 'geranium', 'orange dahlia', 'pink-yellow dahlia', # 50 - 59
'cautleya spicata', 'japanese anemone', 'black-eyed susan', 'silverbush', 'californian poppy', 'osteospermum', 'spring crocus', 'iris', 'windflower', 'tree poppy', # 60 - 69
'gazania', 'azalea', 'water lily', 'rose', 'thorn apple', 'morning glory', 'passion flower', 'lotus', 'toad lily', 'anthurium', # 70 - 79
'frangipani', 'clematis', 'hibiscus', 'columbine', 'desert-rose', 'tree mallow', 'magnolia', 'cyclamen ', 'watercress', 'canna lily', # 80 - 89
'hippeastrum ', 'bee balm', 'pink quill', 'foxglove', 'bougainvillea', 'camellia', 'mallow', 'mexican petunia', 'bromelia', 'blanket flower', # 90 - 99
'trumpet creeper', 'blackberry lily', 'common tulip', 'wild rose'] # 100 - 102

1. 可视化工具

  1. batch data变换为imageslabelsnumpy格式,并且对于后面使用的test data这种没有label的数据,将label置为None
1
2
3
4
5
6
7
8
9
10
np.set_printoptions(threshold=15, linewidth=80)#threshold 列数阈值
def batch_to_numpy_images_and_labels(data):
"""batch data转numpy格式的 images 和 labels"""
images, labels = data
numpy_images = images.numpy()
numpy_labels = labels.numpy()
if numpy_labels.dtype == object: #数据是binary string 如果是object就置为None
numpy_labels = [None for _ in enumerate(numpy_images)]
#如果没有labels,只有image IDs就将labels置为None
return numpy_images, numpy_labels
  1. 对于预测得到的labelcorrect label,如果一样就标识为OK,不一样就表示NO;再加符号,以及正确的花名。实际上,就是给定格式:

预测的花名 [OK/NO 符号/ 正确的花名/ ]。如,

1
2
3
4
5
6
7
8
def title_from_label_and_target(label, correct_label):
if correct_label is None:
#如果未标注label, 就返回对于label名和 True的二元组
return CLASSES[label], True
correct = (label == correct_label) #判断label是否正确
#u"u2192" 表明是->符号
return "{} [{} {} {}]".format(CLASSES[label], 'OK' if correct else 'NO', u"\u2192" if not correct else '',
CLASSES[correct_label] if not correct else ''), correct
  1. 按照标题展示每张花的图片,如果不正确就用红色,正确就用黑色。并将子图中最后一位代表绘制第几张图的变量加1.
1
2
3
4
5
6
7
8
9
def display_one_flower(image, title, subplot, red=False, titlesize=16):
"""展示每张图片"""
plt.subplot(*subplot)
plt.axis('off')#关闭轴
plt.imshow(image)
if len(title) > 0:
plt.title(title, fontsize=int(titlesize) if not red else int(titlesize/1.2),\
color='red' if red else 'black', fontdict={'verticalalignment': 'center'}, pad=int(titlesize/1.5))
return (subplot[0], subplot[1], subplot[2]+1)
  1. 这个函数使用多次,会用于training dataset/test dataset可视化,以及预测结果可视化。

    1. 将batch data转为numpy格式,如果没有label就将labels赋值为None

    2. 通过rows, cols来调整figure, rows是绘制图的行数, 通过一个batch中images的张数的开方来确定,本意是绘制出子图构成2x2这种方形的大图。而cols通过整个images数目//rows得到。

      这时,我们要看rows和cols哪个大,两者对应着整个大图的长和宽,哪个大就将长或宽分割成FIGSIZE / cols * rows

    3. 给每个子图添加标题信息

    4. 展示每张花

image-20210814001834703

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def display_batch_of_images(databatch, predictions=None):
"""
展示批次图片
:param databatch:
:param predictions:
:return:
"""
images, labels = batch_to_numpy_images_and_labels(databatch)
if labels is None:
labels = [None for _ in enumerate(images)]

#将数据变为方形框适合展示的, 如果一个batch是16刚好4x4展示不然就扔掉一些数据
rows = int(math.sqrt(len(images)))#得到images
cols = len(images) // rows

FIGSIZE = 13.0
SPACING = 0.1
subplot = (rows, cols, 1)#子图
if rows < cols:
#如果行数小于列数, 就让整体figure中高作为调整对象
plt.figure(figsize=(FIGSIZE, FIGSIZE / cols * rows))
else:
plt.figure(figsize=(FIGSIZE/rows*cols, FIGSIZE))

for i, (image, label) in enumerate(zip(images[:rows*cols], labels[:rows*cols])):
title = '' if label is None else CLASSES[label]
correct = True
if predictions is not None:
title, correct = title_from_label_and_target(predictions[i], label)
dynamic_titlesize = FIGSIZE * SPACING / max(rows, cols) * 40 + 3
subplot = display_one_flower(image, title, subplot, not correct, titlesize=dynamic_titlesize)
plt.tight_layout()
if label is None and predictions is None:
plt.subplots_adjust(wspace=0, hspace=0)
else:
plt.subplots_adjust(wspace=SPACING, hspace=SPACING)
plt.show()
  1. 展示混淆矩阵:
    • 绘制混淆矩阵,并设置x,y刻度值和形式
    • 加上有关score, precision, recall 文本
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def display_confusion_matrix(cmat, score, precision, recall):
"""绘制混淆矩阵"""
plt.figure(figsize=(15, 15), dpi=200)
ax = plt.gca()#get current axes
ax.matshow(cmat, cmap='rainbow')

ax.set_xticks(range(len(CLASSES)))
ax.set_xticklabels(CLASSES, fontdict={'fontsize': 7})
#设置ax.get_xticklabels()格式45度旋转
plt.setp(ax.get_xticklabels(), rotation=45, ha='left', rotation_mode='anchor')

ax.set_yticks(range(len(CLASSES)))
ax.set_yticklabels(CLASSES, fontdict={'fontsize': 7})
# 设置ax.get_yticklabels()格式45度旋转
plt.setp(ax.get_yticklabels(), rotation=45, ha='right', rotation_mode='anchor')

#给图加上文本
titlestring = ""
if score is not None:
titlestring += 'f1 = {:.3f}'.format(score)
if precision is not None:
titlestring += '\n precision = {:.3f}'.format(precision)
if recall is not None:
titlestring += '\n recall = {:.3f}'.format(recall)
if len(titlestring) > 0:#101 x位置, y=1
ax.text(101, 1, titlestring, fontdict={'fontsize': 18,\
'horizontalalignment': 'right', 'verticalalignment': 'top', 'color': 'black'})
plt.show()
  1. 绘制训练指标曲线
1
2
3
4
5
6
7
8
9
10
11
12
13
def display_training_curves(training, validation, title, subplot):
"""绘制各种曲线"""
if subplot%10==1: #设置第一个子图
plt.subplots(figsize=(10, 10), dpi=150, facecolor="#F0F0F0")
plt.tight_layout()
ax = plt.subplot(subplot)
ax.set_facecolor('#F8F8F8')
ax.plot(training)
ax.plot(validation)
ax.set_title('model ' + title)
ax.set_ylabel(title)
ax.set_xlabel("epoch")
ax.legend(['train', 'valid'])

2. 创建数据集

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def decode_image(image_data):
"""decode 图片数据并resize到指定size"""
#image data uint8[0, 255]
image = tf.image.decode_jpeg(image_data, channels=3)
image = tf.reshape(image, [*IMG_SIZE, 3])#为TPUresize到指定SIZE
return image

def read_labeled_tfrecord(example):
"""从tfrecords读取有标签的数据并返回image和对应label"""
labeled_tfrec_format = {
"image": tf.io.FixedLenFeature([], tf.string),
"class": tf.io.FixedLenFeature([], tf.int64),
}
#按照指定格式解析单个样本
example = tf.io.parse_single_example(example, labeled_tfrec_format)
image = decode_image(example['image'])
label = tf.cast(example['class'], tf.int32) #转为tf.int32
return image, label

def read_unlabeled_tfrecord(example):
"""从tfrecords中读取无标签数据,并返回image和对应id"""
unlabeled_tfrec_format = {
"image": tf.io.FixedLenFeature([], tf.string),
"id": tf.io.FixedLenFeature([], tf.string),
}
#这里用testdataset来预测花的种类
example = tf.io.parse_single_example(example, unlabeled_tfrec_format)
image = decode_image(example['image'])
idx = example['id']
return image, idx

def load_dataset(filenames, labeled=True, ordered=True):
"""
:param filenames:文件路径
:param labeled: 是否标注的数据
:param ordered: 是否有序,其实没关系,都会shuffle
:return: dataset
"""
ignore_order = tf.data.Options() #设置dataset操作是否能用静态的一些方法和pipeline操作
if not ordered:
# 决定output是否是确定次序,False关闭有序增加读取速度
ignore_order.experimental_deterministic = False
dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO)
dataset = dataset.with_options(ignore_order)#使用该方法能加速streams,比用原本顺序读取要快
dataset = dataset.map(read_labeled_tfrecord if labeled else read_unlabeled_tfrecord,\
num_parallel_calls=AUTO)
return dataset

def data_augment(image, label):
"""数据增强"""
image = tf.image.random_flip_left_right(image)
return image, label

def get_training_dataset():
"""获取训练数据集"""
dataset = load_dataset(TRAINING_FILENAMES, labeled=True)
dataset = dataset.map(data_augment, num_parallel_calls=AUTO)
dataset = dataset.repeat() #因为要训练多轮,必须repeat来保证数据量
dataset = dataset.shuffle(2048) #buffer size > BATCH_SIZE
dataset = dataset.batch(BATCH_SIZE) #成batch
dataset = dataset.prefetch(AUTO)
return dataset

def get_validation_dataset(ordered=False):
"""获取验证数据集"""
dataset = load_dataset(VALID_FILENAMES, labeled=True, ordered=ordered)
dataset = dataset.batch(BATCH_SIZE)
dataset = dataset.cache() #缓存
dataset = dataset.prefetch(AUTO)
return dataset

def get_test_dataset(ordered=False):
"""获取test数据集"""
dataset = load_dataset(TEST_FILENAMES, labeled=False, ordered=ordered)
dataset = dataset.batch(BATCH_SIZE)
dataset = dataset.prefetch(AUTO)
return dataset

def count_data_items(filenames):
"""返回数据的数目, 输入为data/flowers00-230.tfrec, 匹配到-230.取到230然后求和"""
#r"-([0-9]*)\." 匹配-0到9带.的字符串,但只要group就是()内匹配字符
# group(1) 0返回所有的匹配串 1返回第一组
n = [int(re.compile(r"-([0-9]*)\.").search(f).group(1)) for f in filenames]
return np.sum(n)

利用 count_data_items(filenames),看看每个数据集大小。

1
2
3
4
5
6
7
8
9
10
11
NUM_TRAINING_IMAGES = count_data_items(TRAINING_FILENAMES) #训练集图片数目
NUM_VALIDATION_IMAGES = count_data_items(VALID_FILENAMES) #验证集图片数目
NUM_TEST_IMAGES = count_data_items(TEST_FILENAMES) #测试集图片数目
STEP_PER_EPOCH = NUM_TRAINING_IMAGES // BATCH_SIZE #每个epoch训练步数
# 向上取整trick -9//2负数向下取整应该是-5
VALIDATION_STEPS = -(-NUM_VALIDATION_IMAGES // BATCH_SIZE)
test_steps = -(-NUM_TEST_IMAGES // BATCH_SIZE)
print('Dataset:{} training images, {} validation images, {} unlabeled test images'.\
format(NUM_TRAINING_IMAGES, NUM_VALIDATION_IMAGES, NUM_TEST_IMAGES))
===================================================================================
Dataset:12753 training images, 3712 validation images, 7382 unlabeled test images

3. 数据集可视化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
print("Training data shapes:")
for image, label in get_training_dataset().take(3):
print(image.numpy().shape, label.numpy().shape)
print("Training data label examples:", label.numpy())

print("Validation data shapes:")
for image, label in get_validation_dataset().take(3):
print(image.numpy().shape, label.numpy().shape)
print("Validation data label examples:", label.numpy())

print("Test data shapes:")
for image, idx in get_test_dataset().take(3):
print(image.numpy().shape, idx.numpy().shape)
print("Test data IDs:", idx.numpy().astype('U'))
================================================================
Training data shapes:
(16, 512, 512, 3) (16,)
(16, 512, 512, 3) (16,)
(16, 512, 512, 3) (16,)
Training data label examples: [87 48 50 ... 74 4 39]
Validation data shapes:
(16, 512, 512, 3) (16,)
(16, 512, 512, 3) (16,)
(16, 512, 512, 3) (16,)
Validation data label examples: [73 45 76 ... 69 93 53]
Test data shapes:
(16, 512, 512, 3) (16,)
(16, 512, 512, 3) (16,)
(16, 512, 512, 3) (16,)
Test data IDs: ['4fb5992b3' '6557acff6' 'abfe5bd86' ... 'd4ae8d14a' 'b5422eec0' '6485b01ac']

训练数据展示:

1
2
3
4
training_dataset = get_training_dataset()
training_dataset = training_dataset.unbatch().batch(20)
training_batch = iter(training_dataset)
display_batch_of_images(next(training_batch))

总共20张,只展示5张。

image-20210814003727467

1
2
3
4
test_dataset = get_test_dataset()
test_dataset = test_dataset.unbatch().batch(20)
test_batch = iter(test_dataset)
display_batch_of_images(next(test_batch))

4. 建立模型和训练

1
2
3
4
5
6
7
8
9
10
with strategy.scope():
img_adjust_layer = tf.keras.layers.Lambda(lambda data: tf.keras.applications.vgg16.preprocess_input(\
tf.cast(data, tf.float32)), input_shape=[*IMG_SIZE, 3])
pretrained_model = tf.keras.applications.VGG16(weights='imagenet', include_top=False)
pretrained_model.trainable = False
model = tf.keras.Sequential(
[img_adjust_layer,
pretrained_model,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(len(CLASSES), activation='softmax')]

看看模型参数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['sparse_categorical_accuracy'])
model.summary()
=========================================================
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
lambda (Lambda) (None, 512, 512, 3) 0
_________________________________________________________________
vgg16 (Functional) (None, None, None, 512) 14714688
_________________________________________________________________
global_average_pooling2d (Gl (None, 512) 0
_________________________________________________________________
dense (Dense) (None, 104) 53352
=================================================================
Total params: 14,768,040
Trainable params: 53,352
Non-trainable params: 14,714,688
_________________________________________________________________

训练:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
history = model.fit(get_training_dataset(), steps_per_epoch=STEP_PER_EPOCH, epochs=EPOCHS,\
validation_data=get_validation_dataset(), validation_steps=VALIDATION_STEPS)
=================================================================
Epoch 1/12
797/797 [==============================] - 169s 180ms/step - loss: 2.1719 - sparse_categorical_accuracy: 0.5329 - val_loss: 1.0612 - val_sparse_categorical_accuracy: 0.7416
Epoch 2/12
797/797 [==============================] - 141s 177ms/step - loss: 0.7915 - sparse_categorical_accuracy: 0.8009 - val_loss: 0.8042 - val_sparse_categorical_accuracy: 0.8025
Epoch 3/12
797/797 [==============================] - 141s 177ms/step - loss: 0.5611 - sparse_categorical_accuracy: 0.8560 - val_loss: 0.7537 - val_sparse_categorical_accuracy: 0.8182
Epoch 4/12
797/797 [==============================] - 141s 177ms/step - loss: 0.4409 - sparse_categorical_accuracy: 0.8843 - val_loss: 0.6929 - val_sparse_categorical_accuracy: 0.8308
Epoch 5/12
797/797 [==============================] - 141s 176ms/step - loss: 0.3606 - sparse_categorical_accuracy: 0.9047 - val_loss: 0.6742 - val_sparse_categorical_accuracy: 0.8367
Epoch 6/12
797/797 [==============================] - 141s 177ms/step - loss: 0.3046 - sparse_categorical_accuracy: 0.9168 - val_loss: 0.6446 - val_sparse_categorical_accuracy: 0.8483
Epoch 7/12
797/797 [==============================] - 141s 177ms/step - loss: 0.2619 - sparse_categorical_accuracy: 0.9290 - val_loss: 0.6639 - val_sparse_categorical_accuracy: 0.8443
Epoch 8/12
797/797 [==============================] - 141s 177ms/step - loss: 0.2316 - sparse_categorical_accuracy: 0.9366 - val_loss: 0.6505 - val_sparse_categorical_accuracy: 0.8505
Epoch 9/12
797/797 [==============================] - 141s 177ms/step - loss: 0.1987 - sparse_categorical_accuracy: 0.9460 - val_loss: 0.6629 - val_sparse_categorical_accuracy: 0.8473
Epoch 10/12
797/797 [==============================] - 141s 177ms/step - loss: 0.1850 - sparse_categorical_accuracy: 0.9495 - val_loss: 0.6968 - val_sparse_categorical_accuracy: 0.8448
Epoch 11/12
797/797 [==============================] - 141s 177ms/step - loss: 0.1624 - sparse_categorical_accuracy: 0.9568 - val_loss: 0.6953 - val_sparse_categorical_accuracy: 0.8389
Epoch 12/12
797/797 [==============================] - 141s 177ms/step - loss: 0.1412 - sparse_categorical_accuracy: 0.9613 - val_loss: 0.6772 - val_sparse_categorical_accuracy: 0.8543

绘制训练曲线:

1
2
3
display_training_curves(history.history['loss'], history.history['val_loss'], 'loss', 211)
display_training_curves(history.history['sparse_categorical_accuracy'], \
history.history['val_sparse_categorical_accuracy'], 'accuracy', 212)

image-20210814004305884

5.混淆矩阵

1
2
3
4
5
6
7
8
9
10
11
12
13
#因为对数据集进行了分割,分开迭代images和labels,保证顺序才能保证两者是一对
confusion_dataset = get_validation_dataset(ordered=True)
image_ds = confusion_dataset.map(lambda image, label: image)
label_ds = confusion_dataset.map(lambda image, label: label).unbatch()

confusion_correct_labels = next(iter(label_ds.batch(NUM_VALIDATION_IMAGES))).numpy()#让label也成批
confusion_probabilities = model.predict(image_ds, steps=VALIDATION_STEPS)
confusion_predictions = np.argmax(confusion_probabilities, axis=1)
print("正确的标签有:", confusion_correct_labels.shape, confusion_correct_labels)
print("预测标签为:", confusion_predictions.shape, confusion_predictions)
=====================================================================
正确的标签有: (3712,) [69 95 48 ... 0 57 88]
预测标签为: (3712,) [49 95 48 ... 0 57 88]

图示混淆矩阵:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
confusion_mat = confusion_matrix(confusion_correct_labels,\
confusion_predictions, labels=range(len(CLASSES)))
score = f1_score(confusion_correct_labels, confusion_predictions,\
labels=range(len(CLASSES)), average='macro')
precision = precision_score(confusion_correct_labels,\
confusion_predictions, labels=range(len(CLASSES)), average='macro')
recall = recall_score(confusion_correct_labels,\
confusion_predictions, labels=range(len(CLASSES)), average='macro')
cmat = (confusion_mat.T / confusion_mat.sum(axis=1)).T
display_confusion_matrix(cmat, score, precision, recall)
print('f1 score:{:.3f}, precision: {:.3f}, recall: {:.3f}'\
.format(score, precision, recall))
===================================================
f1 score:0.849, precision: 0.875, recall: 0.839

cmt

6. 预测

1
2
3
4
5
6
7
test_ds = get_test_dataset(ordered=True)
test_images_ds = test_ds.map(lambda image, idx: image)
prob = model.predict(test_images_ds, steps=test_steps)
pred = np.argmax(prob, axis=-1)
print(pred)
=======================================================
[ 68 48 45 ... 53 103 67]

保存csv结果。

1
2
3
4
5
#生成csv
test_ids_ds = test_ds.map(lambda image, idx: idx).unbatch()
test_ids = next(iter(test_ids_ds.batch(NUM_TEST_IMAGES))).numpy().astype('U')
np.savetxt('submission.csv', np.rec.fromarrays([test_ids, pred]),\
fmt=['%s', '%d'], delimiter=',', header='id,label', comments='')

可视化结果:

1
2
3
4
5
6
7
8
9
## 可视化结果
dataset = get_validation_dataset()
dataset = dataset.unbatch().batch(20)
batch = iter(dataset)

imgs, labels = next(batch)
prob = model.predict(tf.cast(imgs, tf.float32))
pred = np.argmax(prob, axis=-1)
display_batch_of_images((imgs, labels), pred)

image-20210814004852977