tensorflow2实现梯度下降 作者:马育民 • 2019-12-27 15:44 • 阅读:10323 梯度下降(手写代码实现)参见: https://www.malaoshi.top/show_1EF4G05NJ24L.html # 概述 tensorflow2中,通过 [tf.GradientTape()](https://www.malaoshi.top/show_1EF4gqBjRRIF.html "tf.GradientTape()") 实现自动微分求导,通过```tf.keras.optimizers```包下的优化器实现梯度下降 所以,使用tensorflow2实现梯度下降,非常简单 # 代码 ### 导入模块 ``` import tensorflow as tf import matplotlib.pyplot as plt import numpy as np ``` ### 实现损失函数 由于损失函数较为复杂,为了便于理解,这里定义一个简单的函数,模拟损失函数 ``` # 损失函数 def loss_fun(x): return (x-50)**2 ``` ### 绘制损失函数图像 ``` x_arr=np.linspace(0,100,101) y_arr=loss_fun(x_arr) x_arr,y_arr plt.plot(x_arr,y_arr) ``` 执行如下图: [![](https://www.malaoshi.top/upload/0/0/1EF4Fztbmhg9.png)](https://www.malaoshi.top/upload/0/0/1EF4Fztbmhg9.png) ### 定义一个初始点 该点在函数图像上任意取一点即可 ``` # 定义一个初始点 init_x = tf.constant(3.0) init_y=loss_fun(init_x) init_x,init_y ``` ### 绘制函数图像和初始点 ``` plt.plot(x_arr,y_arr) plt.scatter(init_x.numpy(),init_y.numpy()) ``` 执行如下: [![](https://www.malaoshi.top/upload/0/0/1EF4Fzu4YWsC.png)](https://www.malaoshi.top/upload/0/0/1EF4Fzu4YWsC.png) ### 实现求导函数 tensorflow2通过 [tf.GradientTape()](https://www.malaoshi.top/show_1EF4gqBjRRIF.html "tf.GradientTape()") 实现自动微分求导 ``` def grads(x): with tf.GradientTape() as tape: # tape.watch(x) result=loss_fun(x) #由于sources是list类型,所以gradients也是list类型 gradients=tape.gradient(target=result,sources=[x]) # print("cg:%s,x:%s"%(cg,x)) # print("gradients:",gradients) return gradients ``` **注意:** 由于```tape.gradient()```传入```[x]```,所以返回值gradients是list类型 详见:tape.gradient()函数 ### 测试求导函数 ``` x = tf.Variable(3.0) grads(x) ``` 执行结果: ``` [] ``` **运算过程:** 函数```(x-50)**2```的导数是```2(x-50)```,代入```x=3```,结果为:```-94``` ### 梯度下降 通过 [tf.keras.optimizers.Adam](https://www.malaoshi.top/show_1EF4grhYWhnF.html "tf.keras.optimizers.Adam") 实现梯度下降 ``` points=[] adam=tf.keras.optimizers.Adam(1) temp_x=tf.Variable(init_x) last_x=tf.Variable(init_x) for i in range(200): last_x.assign(temp_x) last_y=loss_fun(last_x) points.append((last_x.numpy(),last_y.numpy())) g=grads(temp_x) # 梯度下降,此时temp_x已经修改。 # 由于g是list类型,所以要做zip()处理,转成[(g,v)],详见adam文档 adam.apply_gradients(zip(g,[temp_x])) print("第%s次:::x=%s,last_x=%s,导数=%s"%(i,temp_x.numpy(),last_x.numpy(),g)) """ r=loss_fun(temp_x)-last_y if abs(r)<0.001: print('r:',r) break """ print("最小值x:",temp_x) print("最小值y:",loss_fun(temp_x)) ``` ### 绘制历史点图像 ``` points_l=list(zip(*points)) plt.plot(x_arr,y_arr) # 绘制将历史点 plt.scatter(points_l[0],points_l[1],color='blue') # 绘制初始点 plt.scatter(init_x.numpy(),init_y.numpy(),color='red') # 极值点 plt.scatter(temp_x.numpy(),loss_fun(temp_x.numpy()),color='red') ``` [![](https://www.malaoshi.top/upload/0/0/1EF4grP74R7r.png)](https://www.malaoshi.top/upload/0/0/1EF4grP74R7r.png) # 感谢 https://blog.csdn.net/xierhacker/article/details/53174558 原文出处:http://malaoshi.top/show_1EF4grPGjcvs.html