tf.keras.utils.to_categorical(one hot编码) 作者:马育民 • 2019-10-30 21:20 • 阅读:10308 # 语法 ``` tf.keras.utils.to_categorical(y,num_classes,dtype) ``` ##### 参数 - y:要转换的值 - num_classes:类别数量 - dtype:数据类型,默认float32 ##### 返回值 - numpy类型,one hot编码 # 例子 ### 一维 ``` import tensorflow as tf import numpy as np arr=np.array([0,1,2,3]) train_label_onehot=tf.keras.utils.to_categorical(arr) print(train_label_onehot.shape) print(train_label_onehot) ``` 执行结果: ``` (4, 4) [[1. 0. 0. 0.] [0. 1. 0. 0.] [0. 0. 1. 0.] [0. 0. 0. 1.]] ``` ### 二维 ``` arr=np.array([[0,1,2,3],[3,2,1,0]]) train_label_onehot=tf.keras.utils.to_categorical(arr) train_label_onehot ``` 执行结果: ``` array([[[1., 0., 0., 0.], [0., 1., 0., 0.], [0., 0., 1., 0.], [0., 0., 0., 1.]], [[0., 0., 0., 1.], [0., 0., 1., 0.], [0., 1., 0., 0.], [1., 0., 0., 0.]]], dtype=float32) ``` # 对独热编码 解码 - 使用```np.argmax()```函数 - 使用```tf.argmax()```函数 参见 https://www.malaoshi.top/show_1EF4LCNXBuQr.html ``` print(np.argmax(train_label_onehot,axis=1)) ``` 原文出处:http://malaoshi.top/show_1EF4LIYD8wg3.html