Created
June 22, 2022 17:02
-
-
Save jogardi/5cde48d1e2fb20154069843f76e0a13d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def f(x): | |
return x * np.sin(x) | |
xs = np.linspace(-1, 11, 2000) | |
ys = f(xs) | |
def train_dict_and_plot(num_train, quantize_scale): | |
train_x = xs[::len(xs)//num_train] | |
train_y = ys[::len(xs)//num_train] | |
db = {} | |
def quantize(x): | |
return int(x * quantize_scale) | |
for x, y in zip(train_x, train_y): | |
quantized = quantize(x) | |
if quantized not in db: | |
db[quantized] = [1, y] | |
else: | |
stat = db[quantized] | |
stat[0] += 1 | |
stat[1] += y | |
def pred(x): | |
quantized = quantize(x) | |
if quantized not in db: | |
return -10 | |
stat = db[quantize(x)] | |
return stat[1]/stat[0] | |
preds = np.array(list(map(pred, xs))) | |
plt.plot(xs, ys) | |
plt.scatter(xs, preds, s=10,c='tab:orange') | |
plt.figure(figsize=(15, 8)) | |
plt.subplot(1, 3, 1) | |
train_dict_and_plot(10, 1/1) | |
plt.subplot(1, 3, 2) | |
train_dict_and_plot(10, 2) | |
plt.subplot(1, 3, 3) | |
train_dict_and_plot(10, 50) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment