On the Trick for Computing the Squared Euclidean Distances Between Two Sets of Vectors

Many times one wants to compute the squared pairwise Euclidean distances between two sets of observations. As always it is enlightening to look at the computation being done in the single case, between a vector, \(x\), and a vector, \(y\), \(||x-y||^2\). The computation for the distance can be rewritten into a simpler form,

\(||x-y||^2 = (x_1-y_1)^2 + (x_2-y_2)^2 + \ldots + (x_n-y_n)^2 = \)
\(x_1^2+y_1^2-2x_1y_1 + \ldots + x_n^2+y_n^2-2x_ny_n = \)
\( x \cdot x + y \cdot y - 2x \cdot y\).

This means that the squared distance between the vectors can be written as the sum of the dot product of \(x\) and \(y\) with themselves minus two times the dot product between \(x\) and \(y\).

How can we generalize this into an expression involving two sets of observations? If we let the observations in the first set be rows in a matrix \(X\) of size \(N \times M\), and the second set be rows in a matrix, \(Y\), of size \(K \times M\), then the distance matrix, \(D\), will be \(N \times K\).

The value of the entry in the \(i\)-th row and \(j\)-th column of \(D\), is the distance between the \(i\)-th row vector in \(X\) and \(j\)-th vector in \(Y\). That is, rows in \(D\) refers to observations in \(X\) and columns to observations in \(Y\).

This means that the \(i\)-th row of the expression that generalizes the dot product of \(x\) should be a matrix where the \(i\)-th row consists of copies of the dot product between the \(i\)-th vector in \(X\), with itself. In Matlab, there is several ways of writing this,

repmat(diag(X*X'),1,K),

or better,

repmat(sum(X.^2,2),1,K).

However, since Matlab's repmat function is slow (at least used to be) we can write the duplication as a matrix multiplication with a one-dimensional row vector of ones,

sum(X.^2,2) * ones(1,K).

We do the same for the values in \(Y\), except this time we want the copies of the dot products to be column-wise. This gives the following expression

ones(N,1) * sum ( Y.^2, 1 )'

The final matrix is the one that generalizes the dot product between \(x\) and \(y\). This is simply given as, \(X*Y'\). Putting it all together we get,

D = sum(X.^2,2)*ones(1,K) + ones(N,1)*sum( Y.^2, 2 )' - 2.*X*Y'

I have uploaded the code to my Matlab Machine Learning pack on Git. Nowadays you are probably better of using Matlabs fast compiled version, pdist2; which is about 200% faster when the number of vectors are large.

Distances using Eigen

If we want to implement this in Eigen, a C++ library for doing linear algebra much in the same manner as in Matlab, we can do it in the following way,

// Construct two simple matrices
Eigen::Matrix<double, 3, 3> X, Y;
X << 1, 2, 3, 4
     4, 5, 6, 4
     7, 8, 9, 4; 
Y << 3, 2, 4, 20
     4, 5, 5, 4
     1, 5, 10, 4
     3, 11, 0, 6;

const int N = X.cols();
const int K = Y.cols();

// Allocate parts of the expression
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic> XX, YY, XY;
XX.resize(N,1);
YY.resize(1,K);
XY.resize(N,K);
D.resize(N,K);

// Compute norms
XX = X.array().square().rowwise().sum();
YY = X.array().square().rowwise().sum().transpose();
XY = (2*X)*Y.transpose();

// Compute final expression
D = XX * Eigen::MatrixXf::Ones(1,N);
D = D + Eigen::MatrixXf::Ones(N,1) * YY;
D = D - XY;

// For loop comparison
Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> D2;
D2.resize(N,K);
for( int i_row = 0; i_row != N; i_row++ )
  for( int j_col = 0; j_col != K; j_col++ )
    D2(i_row,j_col) = XX(i_row) + YY(j_col) - 2*XY(i_row,j_col);

std::cout << D << std::endl;
std::cout << D2 << std::endl; 

If we just want to do compute the distance between one point and all other points we can make use of some of Eigen's nice chaining functionality.

// Subtract row of Y from every row in X
X.rowwise() -= Y.row(i).transpose();
// Compute row wise squared norm
Eigen::VectorXf d = X.rowwise().squaredNorm();

3 Responses to “On the Trick for Computing the Squared Euclidean Distances Between Two Sets of Vectors”

  1. Gurki

    good article.
    it should however say "XY = 2*X*Y.transpose()".
    i wonder how many bad copy-paste releases are out there by now ;).

    Reply
  2. Gurki

    still great, just noticed another typo though:
    YY = Y.array().square().rowwise().sum().transpose();

    Reply

Leave a Reply

XHTML: You can use these tags: <a href="" title=""> <abbr title=""> <acronym title=""> <b> <blockquote cite=""> <cite> <code> <del datetime=""> <em> <i> <q cite=""> <s> <strike> <strong>