Note: 我们的 TensorFlow 社区翻译了这些文档。因为社区翻译是尽力而为, 所以无法保证它们是最准确的,并且反映了最新的 官方英文文档。如果您有改进此翻译的建议, 请提交 pull request 到 tensorflow/docs GitHub 仓库。要志愿地撰写或者审核译文,请加入 [email protected] Google Group。
本教程将向您展示如何使用 Estimators 解决 Tensorflow 中的鸢尾花(Iris)分类问题。Estimator 是 Tensorflow 完整模型的高级表示,它被设计用于轻松扩展和异步训练。更多细节请参阅 Estimators。
请注意,在 Tensorflow 2.0 中,Keras API 可以完成许多相同的任务,而且被认为是一个更易学习的 API。如果您刚刚开始入门,我们建议您从 Keras 开始。有关 Tensorflow 2.0 中可用高级 API 的更多信息,请参阅 Keras 标准化。
为了开始,您将首先导入 Tensorflow 和一系列您需要的库。
import tensorflow as tf
import pandas as pd
本文档中的示例程序构建并测试了一个模型,该模型根据花萼和花瓣的大小将鸢尾花分成三种物种。
您将使用鸢尾花数据集训练模型。该数据集包括四个特征和一个标签。这四个特征确定了单个鸢尾花的以下植物学特征:
- 花萼长度
- 花萼宽度
- 花瓣长度
- 花瓣宽度
根据这些信息,您可以定义一些有用的常量来解析数据:
CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth', 'Species']
SPECIES = ['Setosa', 'Versicolor', 'Virginica']
接下来,使用 Keras 与 Pandas 下载并解析鸢尾花数据集。注意为训练和测试保留不同的数据集。
train_path = tf.keras.utils.get_file(
"iris_training.csv", "https://storage.googleapis.com/download.tensorflow.org/data/iris_training.csv")
test_path = tf.keras.utils.get_file(
"iris_test.csv", "https://storage.googleapis.com/download.tensorflow.org/data/iris_test.csv")
train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0)
test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0)
通过检查数据您可以发现有四列浮点型特征和一列 int32 型标签。
train.head()
<iframe src="/tutorials/estimator/premade_4ef55cf026eec3ed4d0c8562a0aea6d97ef7158cca81e5bee02dcca4014bb030.frame" class="framebox inherit-locale " allowfullscreen="" is-upgraded=""></iframe>
对于每个数据集都分割出标签,模型将被训练来预测这些标签。
train_y = train.pop('Species')
test_y = test.pop('Species')
# 标签列现已从数据中删除
train.head()
<iframe src="/tutorials/estimator/premade_cb6311d9578c260b2f77793d8cda49d8df64c7169d77c5cd07ed7ef07477a397.frame" class="framebox inherit-locale " allowfullscreen="" is-upgraded=""></iframe>
现在您已经设定好了数据,您可以使用 Tensorflow Estimator 定义模型。Estimator 是从 tf.estimator.Estimator
中派生的任何类。Tensorflow 提供了一组tf.estimator
(例如,LinearRegressor
)来实现常见的机器学习算法。此外,您可以编写您自己的自定义 Estimator。入门阶段我们建议使用预创建的 Estimator。
为了编写基于预创建的 Estimator 的 Tensorflow 项目,您必须完成以下工作:
- 创建一个或多个输入函数
- 定义模型的特征列
- 实例化一个 Estimator,指定特征列和各种超参数。
- 在 Estimator 对象上调用一个或多个方法,传递合适的输入函数以作为数据源。
我们来看看这些任务是如何在鸢尾花分类中实现的。
您必须创建输入函数来提供用于训练、评估和预测的数据。
输入函数是一个返回 tf.data.Dataset
对象的函数,此对象会输出下列含两个元素的元组:
为了向您展示输入函数的格式,请查看下面这个简单的实现:
def input_evaluation_set():
features = {'SepalLength': np.array([6.4, 5.0]),
'SepalWidth': np.array([2.8, 2.3]),
'PetalLength': np.array([5.6, 3.3]),
'PetalWidth': np.array([2.2, 1.0])}
labels = np.array([2, 1])
return features, labels
您的输入函数可以以您喜欢的方式生成 features
字典与 label
列表。但是,我们建议使用 Tensorflow 的 Dataset API,该 API 可以用来解析各种类型的数据。
Dataset API 可以为您处理很多常见情况。例如,使用 Dataset API,您可以轻松地从大量文件中并行读取记录,并将它们合并为单个数据流。
为了简化此示例,我们将使用 pandas 加载数据,并利用此内存数据构建输入管道。
def input_fn(features, labels, training=True, batch_size=256):
"""An input function for training or evaluating"""
# 将输入转换为数据集。
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
# 如果在训练模式下混淆并重复数据。
if training:
dataset = dataset.shuffle(1000).repeat()
return dataset.batch(batch_size)
特征列(feature columns)是一个对象,用于描述模型应该如何使用特征字典中的原始输入数据。当您构建一个 Estimator 模型的时候,您会向其传递一个特征列的列表,其中包含您希望模型使用的每个特征。tf.feature_column
模块提供了许多为模型表示数据的选项。
对于鸢尾花问题,4 个原始特征是数值,因此我们将构建一个特征列的列表,以告知 Estimator 模型将 4 个特征都表示为 32 位浮点值。故创建特征列的代码如下所示:
# 特征列描述了如何使用输入。
my_feature_columns = []
for key in train.keys():
my_feature_columns.append(tf.feature_column.numeric_column(key=key))
特征列可能比上述示例复杂得多。您可以从指南获取更多关于特征列的信息。
我们已经介绍了如何使模型表示原始特征,现在您可以构建 Estimator 了。
鸢尾花为题是一个经典的分类问题。幸运的是,Tensorflow 提供了几个预创建的 Estimator 分类器,其中包括:
tf.estimator.DNNClassifier
用于多类别分类的深度模型tf.estimator.DNNLinearCombinedClassifier
用于广度与深度模型tf.estimator.LinearClassifier
用于基于线性模型的分类器
对于鸢尾花问题,tf.estimator.DNNClassifier
似乎是最好的选择。您可以这样实例化该 Estimator:
# 构建一个拥有两个隐层,隐藏节点分别为 30 和 10 的深度神经网络。
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
# 隐层所含结点数量分别为 30 和 10.
hidden_units=[30, 10],
# 模型必须从三个类别中做出选择。
n_classes=3)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpkhwws8ja
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpkhwws8ja', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
rewrite_options {
meta_optimizer_iterations: ONE
}
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
我们已经有一个 Estimator 对象,现在可以调用方法来执行下列操作:
- 训练模型。
- 评估经过训练的模型。
- 使用经过训练的模型进行预测。
通过调用 Estimator 的 Train
方法来训练模型,如下所示:
# 训练模型。
classifier.train(
input_fn=lambda: input_fn(train, train_y, training=True),
steps=5000)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:Layer dnn is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2\. The layer has dtype float32 because its dtype defaults to floatx.
If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.
To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.
Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/adagrad.py:83: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpkhwws8ja/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 1.6968713, step = 0
INFO:tensorflow:global_step/sec: 308.34
INFO:tensorflow:loss = 1.1691835, step = 100 (0.325 sec)
INFO:tensorflow:global_step/sec: 365.112
INFO:tensorflow:loss = 1.0332501, step = 200 (0.274 sec)
INFO:tensorflow:global_step/sec: 365.44
INFO:tensorflow:loss = 0.9807229, step = 300 (0.274 sec)
INFO:tensorflow:global_step/sec: 364.789
INFO:tensorflow:loss = 0.9437329, step = 400 (0.274 sec)
INFO:tensorflow:global_step/sec: 368.124
INFO:tensorflow:loss = 0.94162637, step = 500 (0.272 sec)
INFO:tensorflow:global_step/sec: 366.689
INFO:tensorflow:loss = 0.9129944, step = 600 (0.273 sec)
INFO:tensorflow:global_step/sec: 368.813
INFO:tensorflow:loss = 0.91519016, step = 700 (0.271 sec)
INFO:tensorflow:global_step/sec: 369.377
INFO:tensorflow:loss = 0.8866866, step = 800 (0.271 sec)
INFO:tensorflow:global_step/sec: 371.999
INFO:tensorflow:loss = 0.88594323, step = 900 (0.269 sec)
INFO:tensorflow:global_step/sec: 372.481
INFO:tensorflow:loss = 0.8859284, step = 1000 (0.269 sec)
INFO:tensorflow:global_step/sec: 369.793
INFO:tensorflow:loss = 0.87800217, step = 1100 (0.270 sec)
INFO:tensorflow:global_step/sec: 364.966
INFO:tensorflow:loss = 0.8652306, step = 1200 (0.274 sec)
INFO:tensorflow:global_step/sec: 368.742
INFO:tensorflow:loss = 0.8569569, step = 1300 (0.271 sec)
INFO:tensorflow:global_step/sec: 368.955
INFO:tensorflow:loss = 0.8538004, step = 1400 (0.271 sec)
INFO:tensorflow:global_step/sec: 371.44
INFO:tensorflow:loss = 0.8501439, step = 1500 (0.269 sec)
INFO:tensorflow:global_step/sec: 369.55
INFO:tensorflow:loss = 0.8453819, step = 1600 (0.271 sec)
INFO:tensorflow:global_step/sec: 366
INFO:tensorflow:loss = 0.83854586, step = 1700 (0.273 sec)
INFO:tensorflow:global_step/sec: 370.695
INFO:tensorflow:loss = 0.81984085, step = 1800 (0.270 sec)
INFO:tensorflow:global_step/sec: 371.791
INFO:tensorflow:loss = 0.8254725, step = 1900 (0.271 sec)
INFO:tensorflow:global_step/sec: 363.724
INFO:tensorflow:loss = 0.839285, step = 2000 (0.273 sec)
INFO:tensorflow:global_step/sec: 366.998
INFO:tensorflow:loss = 0.81192434, step = 2100 (0.273 sec)
INFO:tensorflow:global_step/sec: 362.578
INFO:tensorflow:loss = 0.80626756, step = 2200 (0.276 sec)
INFO:tensorflow:global_step/sec: 370.678
INFO:tensorflow:loss = 0.8144733, step = 2300 (0.270 sec)
INFO:tensorflow:global_step/sec: 367.415
INFO:tensorflow:loss = 0.80486006, step = 2400 (0.272 sec)
INFO:tensorflow:global_step/sec: 363.869
INFO:tensorflow:loss = 0.7996403, step = 2500 (0.275 sec)
INFO:tensorflow:global_step/sec: 366.247
INFO:tensorflow:loss = 0.78972137, step = 2600 (0.273 sec)
INFO:tensorflow:global_step/sec: 366.514
INFO:tensorflow:loss = 0.7898851, step = 2700 (0.273 sec)
INFO:tensorflow:global_step/sec: 363.635
INFO:tensorflow:loss = 0.7798088, step = 2800 (0.275 sec)
INFO:tensorflow:global_step/sec: 371.201
INFO:tensorflow:loss = 0.7830296, step = 2900 (0.269 sec)
INFO:tensorflow:global_step/sec: 372.843
INFO:tensorflow:loss = 0.78415155, step = 3000 (0.268 sec)
INFO:tensorflow:global_step/sec: 370.754
INFO:tensorflow:loss = 0.7710204, step = 3100 (0.270 sec)
INFO:tensorflow:global_step/sec: 373.092
INFO:tensorflow:loss = 0.7817295, step = 3200 (0.268 sec)
INFO:tensorflow:global_step/sec: 369.337
INFO:tensorflow:loss = 0.78129435, step = 3300 (0.271 sec)
INFO:tensorflow:global_step/sec: 368.646
INFO:tensorflow:loss = 0.78726315, step = 3400 (0.271 sec)
INFO:tensorflow:global_step/sec: 367.989
INFO:tensorflow:loss = 0.76692796, step = 3500 (0.273 sec)
INFO:tensorflow:global_step/sec: 365.108
INFO:tensorflow:loss = 0.7719732, step = 3600 (0.272 sec)
INFO:tensorflow:global_step/sec: 370.532
INFO:tensorflow:loss = 0.76764953, step = 3700 (0.270 sec)
INFO:tensorflow:global_step/sec: 362.993
INFO:tensorflow:loss = 0.75807786, step = 3800 (0.277 sec)
INFO:tensorflow:global_step/sec: 365.707
INFO:tensorflow:loss = 0.7590251, step = 3900 (0.272 sec)
INFO:tensorflow:global_step/sec: 368.977
INFO:tensorflow:loss = 0.7478892, step = 4000 (0.271 sec)
INFO:tensorflow:global_step/sec: 370.263
INFO:tensorflow:loss = 0.74537545, step = 4100 (0.270 sec)
INFO:tensorflow:global_step/sec: 370.648
INFO:tensorflow:loss = 0.7506561, step = 4200 (0.270 sec)
INFO:tensorflow:global_step/sec: 372.419
INFO:tensorflow:loss = 0.74983096, step = 4300 (0.268 sec)
INFO:tensorflow:global_step/sec: 370.771
INFO:tensorflow:loss = 0.74485517, step = 4400 (0.270 sec)
INFO:tensorflow:global_step/sec: 371.489
INFO:tensorflow:loss = 0.74746263, step = 4500 (0.269 sec)
INFO:tensorflow:global_step/sec: 370.063
INFO:tensorflow:loss = 0.7356381, step = 4600 (0.270 sec)
INFO:tensorflow:global_step/sec: 370.305
INFO:tensorflow:loss = 0.74623525, step = 4700 (0.270 sec)
INFO:tensorflow:global_step/sec: 365.488
INFO:tensorflow:loss = 0.7425093, step = 4800 (0.274 sec)
INFO:tensorflow:global_step/sec: 370.235
INFO:tensorflow:loss = 0.7342787, step = 4900 (0.270 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5000...
INFO:tensorflow:Saving checkpoints for 5000 into /tmp/tmpkhwws8ja/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5000...
INFO:tensorflow:Loss for final step: 0.7211363.
<tensorflow_estimator.python.estimator.canned.dnn.DNNClassifierV2 at 0x7f16ef6d0cf8>
注意将 input_fn
调用封装在 lambda
中以获取参数,同时提供不带参数的输入函数,如 Estimator 所预期的那样。step
参数告知该方法在训练多少步后停止训练。
现在模型已经经过训练,您可以获取一些关于模型性能的统计信息。代码块将在测试数据上对经过训练的模型的准确率(accuracy)进行评估:
eval_result = classifier.evaluate(
input_fn=lambda: input_fn(test, test_y, training=False))
print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:Layer dnn is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2\. The layer has dtype float32 because its dtype defaults to floatx.
If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.
To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2020-09-22T19:58:23Z
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpkhwws8ja/model.ckpt-5000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 0.20579s
INFO:tensorflow:Finished evaluation at 2020-09-22-19:58:23
INFO:tensorflow:Saving dict for global step 5000: accuracy = 0.53333336, average_loss = 0.760622, global_step = 5000, loss = 0.760622
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 5000: /tmp/tmpkhwws8ja/model.ckpt-5000
Test set accuracy: 0.533
与对 train
方法的调用不同,我们没有传递 steps
参数来进行评估。用于评估的 input_fn
只生成一个 epoch 的数据。
eval_result
字典亦包含 average_loss
(每个样本的平均误差),loss
(每个 mini-batch 的平均误差)与 Estimator 的 global_step
(经历的训练迭代次数)值。
我们已经有一个经过训练的模型,可以生成准确的评估结果。我们现在可以使用经过训练的模型,根据一些无标签测量结果预测鸢尾花的品种。与训练和评估一样,我们使用单个函数调用进行预测:
# 由模型生成预测
expected = ['Setosa', 'Versicolor', 'Virginica']
predict_x = {
'SepalLength': [5.1, 5.9, 6.9],
'SepalWidth': [3.3, 3.0, 3.1],
'PetalLength': [1.7, 4.2, 5.4],
'PetalWidth': [0.5, 1.5, 2.1],
}
def input_fn(features, batch_size=256):
"""An input function for prediction."""
# 将输入转换为无标签数据集。
return tf.data.Dataset.from_tensor_slices(dict(features)).batch(batch_size)
predictions = classifier.predict(
input_fn=lambda: input_fn(predict_x))
predict
方法返回一个 Python 可迭代对象,为每个样本生成一个预测结果字典。以下代码输出了一些预测及其概率:
for pred_dict, expec in zip(predictions, expected):
class_id = pred_dict['class_ids'][0]
probability = pred_dict['probabilities'][class_id]
print('Prediction is "{}" ({:.1f}%), expected "{}"'.format(
SPECIES[class_id], 100 * probability, expec))
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpkhwws8ja/model.ckpt-5000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
Prediction is "Versicolor" (36.6%), expected "Setosa"
Prediction is "Virginica" (50.9%), expected "Versicolor"
Prediction is "Virginica" (62.6%), expected "Virginica"