tf.keras.metrics.Recall 召回率 作者:马育民 • 2020-05-28 17:05 • 阅读:10551 # 介绍 计算召回率 ### 算法 详见: https://www.malaoshi.top/show_1EF5bae9qIQI.html ### 常用别名 - tf.metrics.Recall # 方法 ### 创建对象 ``` tf.keras.metrics.Recall( thresholds=None, top_k=None, class_id=None, name=None, dtype=None ) ``` ##### 参数: - thresholds:(可选)浮点值或[0,1]中的浮点阈值的python列表/元组。将阈值与预测值进行比较,以确定预测的真值(即,在阈值之上为true,在之下为false)。为每个阈值生成一个度量值。如果既未设置阈值也未设置top_k,则默认设置为使用计算召回率thresholds=0.5。 - top_k:(可选)默认情况下未设置。一个整数值,指定在计算召回率时要考虑的前k个预测。 - class_id:(可选)我们想要二进制指标的整数类ID。该值必须在半开放时间间隔[0, num_classes),即 num_classes预测的最后一个维度。 - dtype:(可选)指标结果的数据类型。 ### update_state 传入 y_true,y_pred ,且形状相同 ``` update_state( y_true, y_pred, sample_weight=None ) ``` ### reset_states() 重置所有变量。 在训练期间评估指标时,在各个时期/步骤之间调用此功能。 ``` reset_states() ``` ### result 计算并返回结果 ``` result() ``` # 例子 ``` m = tf.keras.metrics.Recall() m.update_state([0, 1, 1, 1], [1, 0, 1, 1]) print('Final result: ', m.result().numpy()) ``` 执行结果: ``` Final result: 0.66 ``` 后2个数字完全匹配,TP是2, 第2个数字label是1,预测成了0,FN是1 所以:`2/(2+1)=0.66` ### 与 tf.keras API 配合使用: ``` model = tf.keras.Model(inputs, outputs) model.compile('sgd', loss='mse', metrics=[tf.keras.metrics.Recall()]) ``` 原文出处:http://malaoshi.top/show_1EF5bb260yDt.html