Math

On the Trick for Computing the Squared Euclidean Distances Between Two Sets of Vectors

Many times one wants to compute the squared pairwise Euclidean distances between two sets of observations. As always it is enlightening to look at the computation being done in the single case, between a vector, \(x\), and a vector, \(y\), \(||x-y||^2\). The computation for the distance can be rewritten into a simpler form,

\[\begin{align}||x-y||^2 &= (x_1-y_1)^2 + (x_2-y_2)^2 + \ldots + (x_n-y_n)^2 \\ &= x_1^2+y_1^2-2x_1y_1 + \ldots + x_n^2+y_n^2-2x_ny_n \\&=
x \cdot x + y \cdot y – 2x \cdot y\end{align}\].

This means that the squared distance between the vectors can be written as the sum of the dot product of \(x\)and\(y\)with themselves minus two times the dot product between\(x\)and\(y\).

How can we generalize this into an expression involving two sets of observations? If we let the observations in the first set be rows in a matrix\(X\)of size\(N \times M\), and the second set be rows in a matrix,\(Y\), of size\(K \times M\), then the distance matrix,\(D\), will be\(N \times K\).

The value of the entry in the\(i\)-th row and\(j\)-th column of\(D\), is the distance between the\(i\)-th row vector in\(X\)and\(j\)-th vector in\(Y\). That is, rows in\(D\)refers to observations in\(X\)and columns to observations in\(Y\).

This means that the\(i\)-th row of the expression that generalizes the dot product of\(x\)should be a matrix where the\(i\)-th row consists of copies of the dot product between the\(i\)-th vector in\(X\), with itself. In Matlab, there is several ways of writing this,

repmat(diag(X*X'),1,K),

or better,

repmat(sum(X.^2,2),1,K).

However, since Matlab’s repmat function is slow (at least used to be) we can write the duplication as a matrix multiplication with a one-dimensional row vector of ones,

sum(X.^2,2) * ones(1,K).

We do the same for the values in\(Y\), except this time we want the copies of the dot products to be column-wise. This gives the following expression

ones(N,1) * sum ( Y.^2, 1 )'

The final matrix is the one that generalizes the dot product between\(x\)and\(y\). This is simply given as,\(X*Y’\). Putting it all together we get,

D = sum(X.^2,2)*ones(1,K) + ones(N,1)*sum( Y.^2, 2 )' - 2.*X*Y'

Nowadays you are probably better of using Matlabs fast compiled version, pdist2; which is about 200% faster when the number of vectors are large.

Distances using Eigen

If we want to implement this in Eigen, a C++ library for doing linear algebra much in the same manner as in Matlab, we can do it in the following way,

// Construct two simple matrices
Eigen::MatrixXd X(3,4);
X << 1, 2, 3, 4
4, 5, 6, 4
7, 8, 9, 4;
Eigen::MatrixXd Y(4,4);
Y << 3, 2, 4, 20
4, 5, 5, 4
1, 5, 10, 4
3, 11, 0, 6;

const int N = X.rows();
const int K = Y.rows();

// Allocate parts of the expression
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic> XX, YY, XY;
XX.resize(N,1);
YY.resize(1,K);
XY.resize(N,K);
D.resize(N,K);

// Compute norms
XX = X.array().square().rowwise().sum();
YY = Y.array().square().rowwise().sum().transpose();
XY = (2*X)*Y.transpose();

// Compute final expression
D = XX * Eigen::MatrixXf::Ones(1,N);
D = D + Eigen::MatrixXf::Ones(N,1) * YY;
D = D - XY;

// For loop comparison
Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> D2;
D2.resize(N,K);
for( int i_row = 0; i_row != N; i_row++ )
  for( int j_col = 0; j_col != K; j_col++ )
    D2(i_row,j_col) = XX(i_row) + YY(j_col) - 2*XY(i_row,j_col);

std::cout << D << std::endl;
std::cout << D2 << std::endl;
&#91;/code&#93;</pre>
<p>If we just want to do compute the distance between one point and all other points we can make use of some of Eigen's nice chaining functionality.</p>
<pre>
// Subtract row of Y from every row in X
X.rowwise() -= Y.row(i).transpose();
// Compute row wise squared norm
Eigen::VectorXf d = X.rowwise().squaredNorm();

10 thoughts on “On the Trick for Computing the Squared Euclidean Distances Between Two Sets of Vectors”

  1. I’m not sure if you’ve further changed this post, but your code as is does not work. For one, the matrices are not 3×3. Also I imagine N and K show be .rows(), no?

  2. Line 23:

    YY = X.array().square().rowwise().sum().transpose();

    should be:

    YY = Y.array().square().rowwise().sum().transpose();

    And let me point out that the elements in D can end up to be very small negative values by the numerical error. If you ever think about taking the square root of D to get the distance, this can produce NaN’s.

    1. True there are many pitfalls that must be taken care of to make the code production-ready. This is just to illustrate the trick. Fixed the typo. Thanks!

Leave a Reply

Your email address will not be published. Required fields are marked *