【749】Kirby - Temporal Fusion Transformer related materials
参考:Darts - Temporal Fusion Transformer(Examples)
参考:什么是协变量以及协变量的定义是什么?(Covariate,研究某种自变量对因变量的影响,则实验过程中除研究的自变量因变量之外,还有其他很多变量对实验造成影响,而这些其他变量中,可以被控制的叫控制变量,不可被控制的叫协变量)
Temporal Fusion Transformer: Time Series Forecasting | Towards Data Science
Inter-Series Attention Model for COVID-19 Forecasting (siam.org)
Interpretable Temporal Attention Network for COVID-19 forecasting - PMC (nih.gov)
A Transformer-based Framework for Multivariate Time Series Representation Learning (acm.org)
Time Series Made Easy in Python — darts documentation (unit8co.github.io)
Temporal Fusion Transformer (TFT) — darts documentation (unit8co.github.io)
TemporalFusionTransformer — pytorch-forecasting documentation
官网🌰解读!
1. 读取数据
# Read data
series = AirPassengersDataset().load()
2. 根据每月的天数做平均
# we convert monthly number of passengers to average daily number of passengers per month
series = series / TimeSeries.from_series(series.time_index.days_in_month)
series = series.astype(np.float32)
3. 获取train和val数据集
# Create training and validation sets:
training_cutoff = pd.Timestamp("19571201")
train, val = series.split_after(training_cutoff)
4. 数据Normalization
# Normalize the time series, different functions of transformer
transformer = Scaler()
train_transformed = transformer.fit_transform(train)
val_transformed = transformer.transform(val)
series_transformed = transformer.transform(series)
5. 构建协变量
- year: 年份的信息
- month: 季节性的月份信息
- integer index: 连续性的时间递进
# create year, month and integer index covariate series
covariates = datetime_attribute_timeseries(series, attribute="year", one_hot=False)
covariates = covariates.stack(
datetime_attribute_timeseries(series, attribute="month", one_hot=False)
)
covariates = covariates.stack(
TimeSeries.from_times_and_values(
times=series.time_index,
values=np.arange(len(series)),
columns=["linear_increase"],
)
)
covariates = covariates.astype(np.float32)
6. 数据Normalization
# transform covariates
scaler_conv = Scaler()
cov_train, cov_val = covariates.split_after(training_cutoff)
scaler_covs.fit(cov_train)
covariates_transformed = scaler_covs.transform(covariates)
7. 构建模型
# default quantiles for QuantileRegression
quantiles = [
0.01,
0.05,
0.1,
0.15,
0.2,
0.25,
0.3,
0.4,
0.5,
0.6,
0.7,
0.75,
0.8,
0.85,
0.9,
0.95,
0.99,
]
input_chunk_length = 24
forecast_horizon = 12
my_model = TFTModel(
input_chunk_length=input_chunk_length,
output_chunk_length=forecast_horizon,
hidden_size=64,
lstm_layers=1,
num_attention_heads=4,
dropout=0.1,
batch_size=16,
n_epochs=10,
add_relative_index=False,
add_encoders=None,
likelihood=QuantileRegression(
quantiles=quantiles
), # QuantileRegression is set per default
# loss_fn=MSELoss(),
random_state=42,
)
8. 模型训练
my_model.fit(train_transformed, future_covariates=covariates_transformed, verbose=True)
9. 结果显示
def eval_model(model, n, actual_series, val_series):
pred_series = model.predict(n=n, num_samples=num_samples)
# plot actual series
plt.figure(figsize=figsize)
actual_series[: pred_series.end_time()].plot(label="actual")
# plot prediction with quantile ranges
pred_series.plot(
low_quantile=lowest_q, high_quantile=highest_q, label=label_q_outer
)
pred_series.plot(low_quantile=low_q, high_quantile=high_q, label=label_q_inner)
plt.title("MAPE: {:.2f}%".format(mape(val_series, pred_series)))
plt.legend()
eval_model(my_model, 24, series_transformed, val_transformed)