understanding numpy’s dstack function

It’s easier to understand what np.vstack, np.hstack and np.dstack* do by looking at the .shape attribute of the output array.

Using your two example arrays:

print(a.shape, b.shape)
# (3, 2) (3, 2)
  • np.vstack concatenates along the first dimension…

    print(np.vstack((a, b)).shape)
    # (6, 2)
    
  • np.hstack concatenates along the second dimension…

    print(np.hstack((a, b)).shape)
    # (3, 4)
    
  • and np.dstack concatenates along the third dimension.

    print(np.dstack((a, b)).shape)
    # (3, 2, 2)
    

Since a and b are both two dimensional, np.dstack expands them by inserting a third dimension of size 1. This is equivalent to indexing them in the third dimension with np.newaxis (or alternatively, None) like this:

print(a[:, :, np.newaxis].shape)
# (3, 2, 1)

If c = np.dstack((a, b)), then c[:, :, 0] == a and c[:, :, 1] == b.

You could do the same operation more explicitly using np.concatenate like this:

print(np.concatenate((a[..., None], b[..., None]), axis=2).shape)
# (3, 2, 2)

* Importing the entire contents of a module into your global namespace using import * is considered bad practice for several reasons. The idiomatic way is to import numpy as np.

Leave a Comment

Hata!: SQLSTATE[HY000] [1045] Access denied for user 'divattrend_liink'@'localhost' (using password: YES)