Python

Benchmarking Python fancy indexing vs. taken

I recently discovered that fancy indexing is not that fast in Python. If you do a lot of it, it can eat up some valuable computational time. So I benchmarked the different approaches I know:

  • Just extracting the vectors using a True-False vector
  • Using the np.where command on the True-False vector
  • Using the np.take command on the  np.where on the True-False vector
  • (I also tried np.take on just the True-False vector. It is really slow and is to be avoided)

From the plots below we can see that it is a close call between the where and np.take using np.where. Still perecent wise np.take has a slight advantage over just using np.where. As the number of dimensions gets higher the difference seems to be getting smaller. This might be due to the internal workings of np.take. I have included the source code below as well.

whereprofile_1 whereprofile_2 whereprofile_3

import time
import numpy as np
import matplotlib.pyplot as plt
 
Nmax = 1E6
Niters = 100
pts = 20
dims = 100
runtime = np.zeros((4,pts))
 
for idx,Npts in enumerate(np.linspace(10,Nmax, num=pts,dtype=int)):
    for ii in xrange(0,Niters):
        X = np.random.rand(Npts,dims)
        y = np.random.randint(2, size=Npts)
 
        t0 = time.time()
        idxs = y == 1
        X1 = X[idxs,:]
        t1 = time.time()
        runtime[0,idx] += t1-t0
        del t0
        del t1
        del X1
        del idxs
 
        t0 = time.time()
        idxs = np.where(y == 1)
        X1 = X[idxs,:]
        t1 = time.time()
        runtime[1,idx] += t1-t0
        del t0
        del t1
        del X1
        del idxs
 
        # Sloooow
        # t0 = time.time()
        # idxs = y == 1
        # X1 = np.take(X,idxs,axis=0)
        # t1 = time.time()
        # runtime[2,idx] += t1-t0
        # del t0
        # del t1
        # del X1
        # del idxs
 
        t0 = time.time()
        idxs = np.where(y == 1)
        X1 = np.take(X,idxs,axis=0)
        t1 = time.time()
        runtime[3,idx] += t1-t0
        del t0
        del t1
        del X1
        del idxs
 
print runtime
s = np.linspace(10,Nmax, num=pts,dtype=int)
runtime = 100 * (runtime/runtime[0,:])
 
plt.plot(s,vals[0,:],'ro-',label='True/False')
plt.plot(s,vals[1,:],'bo-',label='Where')
# plt.plot(s,vals[2,:],'gx-',label='Taken True/False')
plt.plot(s,vals[2,:],'yx-',label='Taken Where')
plt.xlabel('Sample Size')
plt.ylabel('Average Time % of True/False')
plt.legend(loc=2)
plt.title('Sample dim 100')
plt.savefig('whereprofile.png', format='png', bbox_inches="tight")
plt.show()

Leave a Reply

Your email address will not be published. Required fields are marked *