tensorflow2实现梯度下降

梯度下降(手写代码实现)参见:
https://www.malaoshi.top/show_1EF4G05NJ24L.html

概述

tensorflow2中,通过 tf.GradientTape() 实现自动微分求导,通过tf.keras.optimizers包下的优化器实现梯度下降

所以,使用tensorflow2实现梯度下降,非常简单

代码

导入模块

  1. import tensorflow as tf
  2. import matplotlib.pyplot as plt
  3. import numpy as np

实现损失函数

由于损失函数较为复杂,为了便于理解,这里定义一个简单的函数,模拟损失函数

  1. # 损失函数
  2. def loss_fun(x):
  3. return (x-50)**2

绘制损失函数图像

  1. x_arr=np.linspace(0,100,101)
  2. y_arr=loss_fun(x_arr)
  3. x_arr,y_arr
  4. plt.plot(x_arr,y_arr)

执行如下图:

定义一个初始点

该点在函数图像上任意取一点即可

  1. # 定义一个初始点
  2. init_x = tf.constant(3.0)
  3. init_y=loss_fun(init_x)
  4. init_x,init_y

绘制函数图像和初始点

  1. plt.plot(x_arr,y_arr)
  2. plt.scatter(init_x.numpy(),init_y.numpy())

执行如下:

实现求导函数

tensorflow2通过 tf.GradientTape() 实现自动微分求导

  1. def grads(x):
  2. with tf.GradientTape() as tape:
  3. # tape.watch(x)
  4. result=loss_fun(x)
  5. #由于sources是list类型,所以gradients也是list类型
  6. gradients=tape.gradient(target=result,sources=[x])
  7. # print("cg:%s,x:%s"%(cg,x))
  8. # print("gradients:",gradients)
  9. return gradients

注意: 由于tape.gradient()传入[x],所以返回值gradients是list类型

详见:tape.gradient()函数

测试求导函数

  1. x = tf.Variable(3.0)
  2. grads(x)

执行结果:

  1. [<tf.Tensor: id=72, shape=(), dtype=float32, numpy=-94.0>]

运算过程:
函数(x-50)**2的导数是2(x-50),代入x=3,结果为:-94

梯度下降

通过 tf.keras.optimizers.Adam 实现梯度下降

  1. points=[]
  2. adam=tf.keras.optimizers.Adam(1)
  3. temp_x=tf.Variable(init_x)
  4. last_x=tf.Variable(init_x)
  5. for i in range(200):
  6. last_x.assign(temp_x)
  7. last_y=loss_fun(last_x)
  8. points.append((last_x.numpy(),last_y.numpy()))
  9. g=grads(temp_x)
  10. # 梯度下降,此时temp_x已经修改。
  11. # 由于g是list类型,所以要做zip()处理,转成[(g,v)],详见adam文档
  12. adam.apply_gradients(zip(g,[temp_x]))
  13. print("第%s次:::x=%s,last_x=%s,导数=%s"%(i,temp_x.numpy(),last_x.numpy(),g))
  14. """
  15. r=loss_fun(temp_x)-last_y
  16. if abs(r)<0.001:
  17. print('r:',r)
  18. break
  19. """
  20. print("最小值x:",temp_x)
  21. print("最小值y:",loss_fun(temp_x))

绘制历史点图像

  1. points_l=list(zip(*points))
  2. plt.plot(x_arr,y_arr)
  3. # 绘制将历史点
  4. plt.scatter(points_l[0],points_l[1],color='blue')
  5. # 绘制初始点
  6. plt.scatter(init_x.numpy(),init_y.numpy(),color='red')
  7. # 极值点
  8. plt.scatter(temp_x.numpy(),loss_fun(temp_x.numpy()),color='red')

感谢

https://blog.csdn.net/xierhacker/article/details/53174558


原文出处:http://malaoshi.top/show_1EF4grPGjcvs.html