Created
February 2, 2019 02:24
-
-
Save XUJiahua/10ab26a859fc5d61501ba933bc646699 to your computer and use it in GitHub Desktop.
梯度下降法求函数最低点
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# 计算函数的最低点,也就是最优化问题 | |
# f(x) = (x-1)^2 + 2,理论解为(1,2) | |
def f(x): | |
return (x-1)*(x-1) + 2 | |
# 计算导数(梯度),数值计算的方式 | |
def derivative(x): | |
deltax = 1e-10 | |
return (f(x+deltax)-f(x-deltax))/2/deltax | |
# 不能太低,迭代会很慢 | |
learning_rate = 1e-5 | |
# 随机选一个点,初始位置 | |
import numpy as np | |
x0 = np.random.randint(-100, 100) | |
print("random initial point ", x0, f(x0)) | |
# 误差值 | |
epsilon = 1e-5 | |
# 迭代计算 | |
while True: | |
d = derivative(x0) | |
# 最低点的导数为0,导数接近为0代表成功 | |
if abs(d) < epsilon: | |
break | |
x0 = x0 - learning_rate * d | |
print("move to point ", x0, f(x0)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment