alex_bn_lee

导航

【749】Kirby - Temporal Fusion Transformer related materials

参考:Darts - Temporal Fusion Transformer(Examples)

参考:什么是协变量以及协变量的定义是什么?(Covariate,研究某种自变量对因变量的影响,则实验过程中除研究的自变量因变量之外,还有其他很多变量对实验造成影响,而这些其他变量中,可以被控制的叫控制变量,不可被控制的叫协变量)


Temporal Fusion Transformer Unleashed: Deep Forecasting of Multivariate Time Series in Python | by Heiko Onnen | Medium

Temporal Fusion Transformer: Time Series Forecasting | Towards Data Science

Inter-Series Attention Model for COVID-19 Forecasting (siam.org)

Interpretable Temporal Attention Network for COVID-19 forecasting - PMC (nih.gov)

A Transformer-based Framework for Multivariate Time Series Representation Learning (acm.org) 

Time Series Made Easy in Python — darts documentation (unit8co.github.io)

Temporal Fusion Transformer (TFT) — darts documentation (unit8co.github.io)

TemporalFusionTransformer — pytorch-forecasting documentation


官网🌰解读!

1. 读取数据

# Read data

series = AirPassengersDataset().load() 

2. 根据每月的天数做平均

# we convert monthly number of passengers to average daily number of passengers per month
series = series / TimeSeries.from_series(series.time_index.days_in_month)
series = series.astype(np.float32)

3. 获取train和val数据集

# Create training and validation sets:
training_cutoff = pd.Timestamp("19571201")
train, val = series.split_after(training_cutoff) 

4. 数据Normalization

# Normalize the time series, different functions of transformer
transformer = Scaler()
train_transformed = transformer.fit_transform(train)
val_transformed = transformer.transform(val)
series_transformed = transformer.transform(series)

5. 构建协变量

  • year: 年份的信息
  • month: 季节性的月份信息
  • integer index: 连续性的时间递进
# create year, month and integer index covariate series

covariates = datetime_attribute_timeseries(series, attribute="year", one_hot=False)
covariates = covariates.stack(
    datetime_attribute_timeseries(series, attribute="month", one_hot=False)
)
covariates = covariates.stack(
    TimeSeries.from_times_and_values(
        times=series.time_index,
        values=np.arange(len(series)),
        columns=["linear_increase"],
    )
)
covariates = covariates.astype(np.float32)

6. 数据Normalization

# transform covariates
scaler_conv = Scaler() 
cov_train, cov_val = covariates.split_after(training_cutoff)
scaler_covs.fit(cov_train)
covariates_transformed = scaler_covs.transform(covariates) 

7. 构建模型

# default quantiles for QuantileRegression
quantiles = [
    0.01,
    0.05,
    0.1,
    0.15,
    0.2,
    0.25,
    0.3,
    0.4,
    0.5,
    0.6,
    0.7,
    0.75,
    0.8,
    0.85,
    0.9,
    0.95,
    0.99,
]
input_chunk_length = 24
forecast_horizon = 12
my_model = TFTModel(
    input_chunk_length=input_chunk_length,
    output_chunk_length=forecast_horizon,
    hidden_size=64,
    lstm_layers=1,
    num_attention_heads=4,
    dropout=0.1,
    batch_size=16,
    n_epochs=10,
    add_relative_index=False,
    add_encoders=None,
    likelihood=QuantileRegression(
        quantiles=quantiles
    ),  # QuantileRegression is set per default
    # loss_fn=MSELoss(),
    random_state=42,
)

8. 模型训练

my_model.fit(train_transformed, future_covariates=covariates_transformed, verbose=True)

9. 结果显示

def eval_model(model, n, actual_series, val_series):
    pred_series = model.predict(n=n, num_samples=num_samples)

    # plot actual series
    plt.figure(figsize=figsize)
    actual_series[: pred_series.end_time()].plot(label="actual")

    # plot prediction with quantile ranges
    pred_series.plot(
        low_quantile=lowest_q, high_quantile=highest_q, label=label_q_outer
    )
    pred_series.plot(low_quantile=low_q, high_quantile=high_q, label=label_q_inner)

    plt.title("MAPE: {:.2f}%".format(mape(val_series, pred_series)))
    plt.legend()


eval_model(my_model, 24, series_transformed, val_transformed)

posted on 2022-10-12 14:13  McDelfino  阅读(139)  评论(0)    收藏  举报