Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added moving avg features #79

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pts/model/deepar/deepar_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
distr_output: DistributionOutput = StudentTOutput(),
scaling: bool = True,
lags_seq: Optional[List[int]] = None,
agg_lags: Optional[List[int]] = None,
time_features: Optional[List[TimeFeature]] = None,
num_parallel_samples: int = 100,
dtype: np.dtype = np.float32,
Expand Down Expand Up @@ -95,6 +96,7 @@ def __init__(
if lags_seq is not None
else get_lags_for_frequency(freq_str=freq, lag_ub=self.context_length)
)
self.agg_lags = agg_lags
self.time_features = (
time_features
if time_features is not None
Expand Down Expand Up @@ -221,6 +223,7 @@ def create_training_network(self, device: torch.device) -> DeepARTrainingNetwork
cardinality=self.cardinality,
embedding_dimension=self.embedding_dimension,
lags_seq=self.lags_seq,
agg_lags=self.agg_lags,
scaling=self.scaling,
dtype=self.dtype,
).to(device)
Expand All @@ -245,6 +248,7 @@ def create_predictor(
cardinality=self.cardinality,
embedding_dimension=self.embedding_dimension,
lags_seq=self.lags_seq,
agg_lags=self.agg_lags,
scaling=self.scaling,
dtype=self.dtype,
).to(device)
Expand Down
109 changes: 95 additions & 14 deletions pts/model/deepar/deepar_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
cardinality: List[int],
embedding_dimension: List[int],
lags_seq: List[int],
agg_lags: Optional[List[int]] = None,
scaling: bool = True,
dtype: np.dtype = np.float32,
) -> None:
Expand All @@ -52,6 +53,7 @@ def __init__(
self.dtype = dtype

self.lags_seq = lags_seq
self.agg_lags = agg_lags

self.distr_output = distr_output
rnn = {"LSTM": nn.LSTM, "GRU": nn.GRU}[self.cell_type]
Expand Down Expand Up @@ -116,6 +118,31 @@ def get_lagged_subsequences(
lagged_values.append(sequence[:, begin_index:end_index, ...])
return torch.stack(lagged_values, dim=-1)

@staticmethod
def get_mean_agg_lags(
sequence: torch.Tensor, # [B, T, 1]
agg_lags: List[int],
past_observed_values: Optional[torch.Tensor] = None,
) -> torch.Tensor:
accumlated_sum = torch.cumsum(sequence, dim=-1)
averaged_seqs = []

for w in agg_lags:
moving_avg = torch.zeros_like(accumlated_sum)
moving_avg[...] = 0.0
moving_avg[:, w:, ...] = (
accumlated_sum[:, w:, ...] - accumlated_sum[:, :-w, ...]
)
moving_avg /= w

averaged_seqs.append(moving_avg)

if past_observed_values is not None:
# set all starting values to unobserved
past_observed_values[..., : max(agg_lags)] = 0.0

return torch.stack(averaged_seqs, dim=-1)

def unroll_encoder(
self,
feat_static_cat: torch.Tensor, # (batch_size, num_features)
Expand Down Expand Up @@ -150,8 +177,24 @@ def unroll_encoder(
sequence_length = self.history_length + self.prediction_length
subsequences_length = self.context_length + self.prediction_length

if self.agg_lags is not None:
moving_avg = self.get_mean_agg_lags(
sequence=sequence,
agg_lags=self.agg_lags,
past_observed_values=past_observed_values,
)
merged_sequence = torch.cat(
(
sequence.unsqueeze(-1) if len(self.target_shape) == 0 else sequence,
moving_avg,
),
dim=-1,
)
else:
merged_sequence = sequence

lags = self.get_lagged_subsequences(
sequence=sequence,
sequence=merged_sequence,
sequence_length=sequence_length,
indices=self.lags_seq,
subsequences_length=subsequences_length,
Expand Down Expand Up @@ -184,17 +227,28 @@ def unroll_encoder(
-1, subsequences_length, -1
)

# (batch_size, sub_seq_len, *target_shape, num_lags)
lags_scaled = lags / scale.unsqueeze(-1)

# from (batch_size, sub_seq_len, *target_shape, num_lags)
# to (batch_size, sub_seq_len, prod(target_shape) * num_lags)
input_lags = lags_scaled.reshape(
(-1, subsequences_length, len(self.lags_seq) * prod(self.target_shape))
input_lags = lags.reshape(
(
-1,
subsequences_length,
len(self.lags_seq) * prod(self.target_shape) * (1 + len(self.agg_lags)),
)
)

# (batch_size, sub_seq_len, *target_shape, num_lags)
lags_scaled = input_lags / scale.unsqueeze(-1)

# (batch_size, sub_seq_len, input_dim)
inputs = torch.cat((input_lags, time_feat, repeated_static_feat), dim=-1)
inputs = torch.cat(
(
lags_scaled,
time_feat,
repeated_static_feat,
),
dim=-1,
)

# unroll encoder
outputs, state = self.rnn(inputs)
Expand Down Expand Up @@ -359,26 +413,53 @@ def sampling_decoder(

# for each future time-units we draw new samples for this time-unit and update the state
for k in range(self.prediction_length):
if self.agg_lags is not None:
repeated_moving_avg = self.get_mean_agg_lags(
sequence=repeated_past_target,
agg_lags=self.agg_lags,
)
repeated_merged_seq = torch.cat(
(
repeated_past_target.unsqueeze(-1)
if len(self.target_shape) == 0
else repeated_past_target,
repeated_moving_avg,
),
dim=-1,
)
else:
repeated_merged_seq = repeated_past_target

# (batch_size * num_samples, 1, *target_shape, num_lags)
lags = self.get_lagged_subsequences(
sequence=repeated_past_target,
sequence=repeated_merged_seq,
sequence_length=self.history_length + k,
indices=self.shifted_lags,
subsequences_length=1,
)

# (batch_size * num_samples, 1, *target_shape, num_lags)
lags_scaled = lags / repeated_scale.unsqueeze(-1)

# from (batch_size * num_samples, 1, *target_shape, num_lags)
# to (batch_size * num_samples, 1, prod(target_shape) * num_lags)
input_lags = lags_scaled.reshape(
(-1, 1, prod(self.target_shape) * len(self.lags_seq))
input_lags = lags.reshape(
(
-1,
1,
prod(self.target_shape)
* len(self.lags_seq)
* (1 + len(self.agg_lags)),
)
)

# (batch_size * num_samples, 1, *target_shape, num_lags)
lags_scaled = input_lags / repeated_scale.unsqueeze(-1)

# (batch_size * num_samples, 1, prod(target_shape) * num_lags + num_time_features + num_static_features)
decoder_input = torch.cat(
(input_lags, repeated_time_feat[:, k : k + 1, :], repeated_static_feat),
(
lags_scaled,
repeated_time_feat[:, k : k + 1, :],
repeated_static_feat,
),
dim=-1,
)

Expand Down