Skip to content

Instantly share code, notes, and snippets.

Revisions

  1. @nilsleh nilsleh revised this gist Dec 14, 2022. 1 changed file with 7 additions and 1 deletion.
    8 changes: 7 additions & 1 deletion darts_load_checkpoint.py
    Original file line number Diff line number Diff line change
    @@ -108,7 +108,13 @@ def main():
    work_dir="./work_dir"
    )

    my_model.fit(train_transformed, future_covariates=covariates_transformed, verbose=True)
    my_model.fit(
    train_transformed,
    future_covariates=covariates_transformed,
    val_series=train_transformed,
    val_future_covariates=covariates_transformed,
    verbose=True
    )
    my_model.load_from_checkpoint(model_name="my_model", best=True)
    # my_model.load_from_checkpoint(model_name="my_model") throws same error

  2. @nilsleh nilsleh revised this gist Dec 14, 2022. 1 changed file with 1 addition and 0 deletions.
    1 change: 1 addition & 0 deletions darts_load_checkpoint.py
    Original file line number Diff line number Diff line change
    @@ -110,6 +110,7 @@ def main():

    my_model.fit(train_transformed, future_covariates=covariates_transformed, verbose=True)
    my_model.load_from_checkpoint(model_name="my_model", best=True)
    # my_model.load_from_checkpoint(model_name="my_model") throws same error


    if __name__ == "__main__":
  3. @nilsleh nilsleh revised this gist Dec 14, 2022. 1 changed file with 2 additions and 1 deletion.
    3 changes: 2 additions & 1 deletion darts_load_checkpoint.py
    Original file line number Diff line number Diff line change
    @@ -104,7 +104,8 @@ def main():
    # loss_fn=MSELoss(),
    random_state=42,
    save_checkpoints=True,
    model_name="my_model"
    model_name="my_model",
    work_dir="./work_dir"
    )

    my_model.fit(train_transformed, future_covariates=covariates_transformed, verbose=True)
  4. @nilsleh nilsleh created this gist Dec 14, 2022.
    115 changes: 115 additions & 0 deletions darts_load_checkpoint.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,115 @@
    import numpy as np
    import pandas as pd
    from tqdm import tqdm_notebook as tqdm

    import matplotlib.pyplot as plt

    from darts import TimeSeries, concatenate
    from darts.dataprocessing.transformers import Scaler
    from darts.models import TFTModel
    from darts.metrics import mape
    from darts.utils.statistics import check_seasonality, plot_acf
    from darts.datasets import AirPassengersDataset, IceCreamHeaterDataset
    from darts.utils.timeseries_generation import datetime_attribute_timeseries
    from darts.utils.likelihood_models import QuantileRegression

    import warnings

    warnings.filterwarnings("ignore")
    import logging

    logging.disable(logging.CRITICAL)

    def main():
    # before starting, we define some constants
    num_samples = 200

    figsize = (9, 6)
    lowest_q, low_q, high_q, highest_q = 0.01, 0.1, 0.9, 0.99
    label_q_outer = f"{int(lowest_q * 100)}-{int(highest_q * 100)}th percentiles"
    label_q_inner = f"{int(low_q * 100)}-{int(high_q * 100)}th percentiles"

    # Read data
    series = AirPassengersDataset().load()

    # we convert monthly number of passengers to average daily number of passengers per month
    series = series / TimeSeries.from_series(series.time_index.days_in_month)
    series = series.astype(np.float32)

    # Create training and validation sets:
    training_cutoff = pd.Timestamp("19571201")
    train, val = series.split_after(training_cutoff)

    # Normalize the time series (note: we avoid fitting the transformer on the validation set)
    transformer = Scaler()
    train_transformed = transformer.fit_transform(train)
    val_transformed = transformer.transform(val)
    series_transformed = transformer.transform(series)

    # create year, month and integer index covariate series
    covariates = datetime_attribute_timeseries(series, attribute="year", one_hot=False)
    covariates = covariates.stack(
    datetime_attribute_timeseries(series, attribute="month", one_hot=False)
    )
    covariates = covariates.stack(
    TimeSeries.from_times_and_values(
    times=series.time_index,
    values=np.arange(len(series)),
    columns=["linear_increase"],
    )
    )
    covariates = covariates.astype(np.float32)

    # transform covariates (note: we fit the transformer on train split and can then transform the entire covariates series)
    scaler_covs = Scaler()
    cov_train, cov_val = covariates.split_after(training_cutoff)
    scaler_covs.fit(cov_train)
    covariates_transformed = scaler_covs.transform(covariates)

    quantiles = [
    0.01,
    0.05,
    0.1,
    0.15,
    0.2,
    0.25,
    0.3,
    0.4,
    0.5,
    0.6,
    0.7,
    0.75,
    0.8,
    0.85,
    0.9,
    0.95,
    0.99,
    ]
    input_chunk_length = 24
    forecast_horizon = 12
    my_model = TFTModel(
    input_chunk_length=input_chunk_length,
    output_chunk_length=forecast_horizon,
    hidden_size=64,
    lstm_layers=1,
    num_attention_heads=4,
    dropout=0.1,
    batch_size=128,
    n_epochs=1,
    add_relative_index=False,
    add_encoders=None,
    likelihood=QuantileRegression(
    quantiles=quantiles
    ), # QuantileRegression is set per default
    # loss_fn=MSELoss(),
    random_state=42,
    save_checkpoints=True,
    model_name="my_model"
    )

    my_model.fit(train_transformed, future_covariates=covariates_transformed, verbose=True)
    my_model.load_from_checkpoint(model_name="my_model", best=True)


    if __name__ == "__main__":
    main()