梯度下降-代码实现 作者:马育民 • 2019-10-16 13:44 • 阅读:10144 # 概述 本文通过代码感受梯度下降的魅力 # 代码 ### 实现损失函数 ``` # 损失函数 def loss_fun(x): return (x-50)**2 ``` ### 实现求导函数 根据上面的损失函数,实现求导函数 ``` # 求导 def d_fun(x): return 2*(x-50) ``` ### 绘制损失函数图像 ``` import numpy as np import matplotlib.pyplot as plt x_arr=np.linspace(0,100,101) y_arr=loss_fun(x_arr) x_arr,y_arr plt.plot(x_arr,y_arr) ``` 执行如下图: [![](https://www.malaoshi.top/upload/0/0/1EF4Fztbmhg9.png)](https://www.malaoshi.top/upload/0/0/1EF4Fztbmhg9.png) ### 定义一个初始点 该点在函数图像上任意取一点即可 ``` # 定义一个初始点 init_x=10 init_y=loss_fun(init_x) init_x,init_y ``` ### 绘制函数图像和初始点 ``` plt.plot(x_arr,y_arr) plt.scatter(init_x,init_y) ``` 执行如下: [![](https://www.malaoshi.top/upload/0/0/1EF4Fzu4YWsC.png)](https://www.malaoshi.top/upload/0/0/1EF4Fzu4YWsC.png) ### 求出该初始点的导数 ``` d=d_fun(init_x) d ``` ### 绘制初始点的切线 ``` qiexian_x=np.linspace(0,30,10) qiexian_x qiexian_y=qiexian_x*d+2400 qiexian_y plt.plot(qiexian_x,qiexian_y) plt.plot(x_arr,y_arr) plt.scatter(init_x,init_y) ``` [![](https://www.malaoshi.top/upload/0/0/1EF4FzvjfgrF.png)](https://www.malaoshi.top/upload/0/0/1EF4FzvjfgrF.png) ### 梯度下降(后面会优化) ``` # 将每次梯度下降的点的x坐标记录下来 x_history=[init_x] # 定义步长 step=0.1 temp_x=init_x while True: #求导 d=d_fun(temp_x) last_temp_x=temp_x # 找到下一个点的x坐标 temp_x=temp_x-step*d x_history.append(temp_x) """ 人眼可看出x为50时,y最小,即0,但在运算时,很难达到正好是0 所以让两个点的y值相减,小于某个精度,即认为找到最小值 """ if abs(loss_fun(temp_x)-loss_fun(last_temp_x))<0.00001: break print('最小值:',temp_x,loss_fun(temp_x)) print("导数:",d_fun(temp_x)) print('点数量:',len(x_history)) ``` 执行结果如下: ``` 最小值: 49.99659717633079 1.1579208923751433e-05 导数: -0.006805647338424592 点数量: 43 ``` **总结:** 从结果上看,极值点的y坐标很接近0,该点的导数接近0 ### 绘制历史点图像 ``` plt.plot(qiexian_x,qiexian_y) plt.plot(x_arr,y_arr) plt.scatter(init_x,init_y) y_history=loss_fun(np.array(x_history)) # 绘制将历史点 plt.scatter(x_history,y_history,color='blue') # 绘制将历史点的直线图 plt.plot(x_history,y_history,color='blue') # 极值点 plt.scatter(temp_x,loss_fun(temp_x),color='red') ``` [![](https://www.malaoshi.top/upload/0/0/1EF4G061JKEO.png)](https://www.malaoshi.top/upload/0/0/1EF4G061JKEO.png) ### 为什么开始下降很快,后来下降很慢 因为此损失函数图像一开始很陡峭,x减小,对应的y变化很大 函数图像越到中间,越平缓,所以 x减小,对应的y变化很小 # 封装 将 **梯度下** 降封装成函数, 该函数两个参数: 1. 初始点x坐标 2. 步长 返回值: 1. 极值点x坐标 2. 历史点x坐标 ``` def gd(init_x,step): x_history=[init_x] temp_x=init_x while True: d=d_fun(temp_x) last_temp_x=temp_x temp_x=temp_x-step*d x_history.append(temp_x) if abs(loss_fun(temp_x)-loss_fun(last_temp_x))<0.00001: break return temp_x,x_history ``` # 测试 ### 减小步长到0.01 将步长设置为 **0.01**,执行该函数 ``` min_x,x_history=gd(init_x,0.01) print('最小值:',min_x,loss_fun(min_x)) print("导数:",d_fun(min_x)) print('点数量:',len(x_history)) ``` 执行结果如下: ``` 最小值: 49.98454733335889 0.0002387849063212773 导数: -0.030905333282220226 点数量: 390 ``` **总结:** 由于步长设置0.01,找到极值点需要更多的步数,历史点也更多 [![](https://www.malaoshi.top/upload/0/0/1EF4G0fenZIi.png)](https://www.malaoshi.top/upload/0/0/1EF4G0fenZIi.png) ### 增大步长到0.9 将步长设置为 **0.9**,执行该函数 ``` min_x,x_history=gd(init_x,0.9) print('最小值:',min_x,loss_fun(min_x)) print("导数:",d_fun(min_x)) print('点数量:',len(x_history)) ``` 执行结果如下: ``` 最小值: 49.99659717633079 1.1579208923751433e-05 导数: -0.006805647338424592 点数量: 43 ``` 绘制图像如下: [![](https://www.malaoshi.top/upload/0/0/1EF4G0eido5f.png)](https://www.malaoshi.top/upload/0/0/1EF4G0eido5f.png) 由于步长过大,会在左右两侧跳跃 # 增大步长到1.1 将步长设置为 **1.1**,执行该函数 ``` min_x,x_history=gd(init_x,1.1) print('最小值:',min_x,loss_fun(min_x)) print("导数:",d_fun(min_x)) print('点数量:',len(x_history)) ``` 执行报错 [![](https://www.malaoshi.top/upload/0/0/1EF4G0l4ppCn.png)](https://www.malaoshi.top/upload/0/0/1EF4G0l4ppCn.png) **原因:** 由于步长设置1.1,那么在梯度下降函数中,x值可能越来越大,直到大到一定程度,导致损失函数求其y值时,抛出 ```result too large``` 异常 ### 解决上面的bug 在损失函数加上try语句 ``` #损失函数 def loss_fun(x): try: return(x-50)**2 except: #表示无穷大 return float('inf') ``` 当抛出异常时返回无穷大 ### 进入死循环 再次执行,会发现程序进入死循环,是因为: ``` if abs(loss_fun(temp_x)-loss_fun(last_temp_x))<0.00001: break ``` 一个无穷大 减去 另一个无穷大,得到的是 **nan**,即:“ not a number ”,所以永远也不会出发这个条件 ### 解决死循环bug 将封装的梯度下降改造如下: ``` def gd(init_x,step): x_history=[init_x] temp_x=init_x for item in range(10000): d=d_fun(temp_x) last_temp_x=temp_x temp_x=temp_x-step*d x_history.append(temp_x) if abs(loss_fun(temp_x)-loss_fun(last_temp_x))<0.00001: break return temp_x,x_history ``` 再次执行 ``` min_x,x_history=gd(init_x,1.1) print('最小值:',min_x,loss_fun(min_x)) print("导数:",d_fun(min_x)) print('点数量:',len(x_history)) ``` 结果如下: ``` 最小值: nan nan 导数: nan 点数量: 10001 ``` 说明步长取值过大 ### 绘制前5个点 ``` min_x,x_history=gd(init_x,1.1) print('最小值:',min_x,loss_fun(min_x)) print("导数:",d_fun(min_x)) print('点数量:',len(x_history)) x_history=x_history[0:5] ``` [![](https://www.malaoshi.top/upload/0/0/1EF4G1GnycJ7.png)](https://www.malaoshi.top/upload/0/0/1EF4G1GnycJ7.png) 原文出处:http://malaoshi.top/show_1EF4G05NJ24L.html