You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I executed code in Tutorial of Demand forecasting with the Temporal Fusion Transformer and expected to get result of SMAPE. The issue appeared while i didn`t change any code from tutorial. I hit the breakpoint but the breakpoint can't break before the error. Solving this problem is beyond my power.
Code to reproduce the problem
# calcualte metric by which to display
predictions = best_tft.predict(val_dataloader,return_y=True)
mean_losses = SMAPE(reduction="none")
mean_losses =mean_losses(predictions.output, predictions.y)
mean_losses =mean_losses.mean(1)
indices = mean_losses.argsort(descending=True) # sort losses
for idx in range(10): # plot 10 examples
best_tft.plot_prediction(
raw_predictions.x,
raw_predictions.output,
idx=indices[idx],
add_loss_to_title=SMAPE(quantiles=best_tft.loss.quantiles),
)
Cell In[66], [line 4](vscode-notebook-cell:?execution_count=66&line=4)
[2](vscode-notebook-cell:?execution_count=66&line=2) predictions = best_tft.predict(val_dataloader,return_y=True)
[3](vscode-notebook-cell:?execution_count=66&line=3) mean_losses = SMAPE(reduction="none")
----> [4](vscode-notebook-cell:?execution_count=66&line=4) mean_losses =mean_losses(predictions.output, predictions.y)
[5](vscode-notebook-cell:?execution_count=66&line=5) mean_losses =mean_losses.mean(1)
[6](vscode-notebook-cell:?execution_count=66&line=6) indices = mean_losses.argsort(descending=True) # sort losses
File [~.conda\envs\cgm\lib\site-packages\torch\nn\modules\module.py:1501](~.conda/envs/cgm/lib/site-packages/torch/nn/modules/module.py:1501), in Module._call_impl(self, *args, **kwargs)
[1496](~.conda/envs/cgm/lib/site-packages/torch/nn/modules/module.py:1496) # If we don't have any hooks, we want to skip the rest of the logic in
[1497](~.conda/envs/cgm/lib/site-packages/torch/nn/modules/module.py:1497) # this function, and just call forward.
[1498](~.conda/envs/cgm/lib/site-packages/torch/nn/modules/module.py:1498) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
[1499](~.conda/envs/cgm/lib/site-packages/torch/nn/modules/module.py:1499) or _global_backward_pre_hooks or _global_backward_hooks
[1500](~.conda/envs/cgm/lib/site-packages/torch/nn/modules/module.py:1500) or _global_forward_hooks or _global_forward_pre_hooks):
-> [1501](~.conda/envs/cgm/lib/site-packages/torch/nn/modules/module.py:1501) return forward_call(*args, **kwargs)
[1502](~.conda/envs/cgm/lib/site-packages/torch/nn/modules/module.py:1502) # Do not call functions when jit is used
[1503](~.conda/envs/cgm/lib/site-packages/torch/nn/modules/module.py:1503) full_backward_hooks, non_full_backward_hooks = [], []
File [~.conda\envs\cgm\lib\site-packages\torchmetrics\metric.py:303](~.conda/envs/cgm/lib/site-packages/torchmetrics/metric.py:303), in Metric.forward(self, *args, **kwargs)
[301](~.conda/envs/cgm/lib/site-packages/torchmetrics/metric.py:301) self._forward_cache = self._forward_full_state_update(*args, **kwargs)
[302](~.conda/envs/cgm/lib/site-packages/torchmetrics/metric.py:302) else:
--> [303](~.conda/envs/cgm/lib/site-packages/torchmetrics/metric.py:303) self._forward_cache = self._forward_reduce_state_update(*args, **kwargs)
[305](~.conda/envs/cgm/lib/site-packages/torchmetrics/metric.py:305) return self._forward_cache
File [~.conda\envs\cgm\lib\site-packages\torchmetrics\metric.py:378](~.conda/envs/cgm/lib/site-packages/torchmetrics/metric.py:378), in Metric._forward_reduce_state_update(self, *args, **kwargs)
[376](~.conda/envs/cgm/lib/site-packages/torchmetrics/metric.py:376) self._update_count = _update_count + 1
[377](~.conda/envs/cgm/lib/site-packages/torchmetrics/metric.py:377) with torch.no_grad():
--> [378](~.conda/envs/cgm/lib/site-packages/torchmetrics/metric.py:378) self._reduce_states(global_state)
[380](~.conda/envs/cgm/lib/site-packages/torchmetrics/metric.py:380) # restore context
[381](~.conda/envs/cgm/lib/site-packages/torchmetrics/metric.py:381) self._is_synced = False
File [~.conda\envs\cgm\lib\site-packages\torchmetrics\metric.py:413](~.conda/envs/cgm/lib/site-packages/torchmetrics/metric.py:413), in Metric._reduce_states(self, incoming_state)
[411](~.conda/envs/cgm/lib/site-packages/torchmetrics/metric.py:411) elif reduce_fn == dim_zero_cat:
[412](~.conda/envs/cgm/lib/site-packages/torchmetrics/metric.py:412) if isinstance(global_state, Tensor):
--> [413](~.conda/envs/cgm/lib/site-packages/torchmetrics/metric.py:413) reduced = torch.cat([global_state, local_state])
[414](~.conda/envs/cgm/lib/site-packages/torchmetrics/metric.py:414) else:
[415](~.conda/envs/cgm/lib/site-packages/torchmetrics/metric.py:415) reduced = global_state + local_state
RuntimeError: zero-dimensional tensor (at position 0) cannot be concatenated```
The text was updated successfully, but these errors were encountered:
Expected behavior
I executed code in Tutorial of Demand forecasting with the Temporal Fusion Transformer and expected to get result of SMAPE. The issue appeared while i didn`t change any code from tutorial. I hit the breakpoint but the breakpoint can't break before the error. Solving this problem is beyond my power.
Code to reproduce the problem
The text was updated successfully, but these errors were encountered: