Skip to main content

Unit 4.5.3 Matrix-matrix multiplication for Machine Learning

Matrix-matrix multiplication turns out to be an operation that is frequently employed by algorithms in machine learning. In this unit, we discuss the all-nearest-neighbor problem. Knowing how to optimize matrix-matrix multiplication allows one to optimize a practical algorithm for computing it.

The k-nearest-neighbor problem (KNN) takes as input \(m \) points in \(\Rn \text{,}\) \(\{ x_j \}_{j=0}^{m-1} \text{,}\) and a reference point, \(x \text{,}\) and computes the \(k \) nearest neighbors of \(x \) among the \(m \) points. The all-nearest-neighbor (ANN) problem computes the \(k \) nearest neighbors of each of the points \(x_j \text{.}\)

The trick to computing ANN is to observe that we need to compute the distances between all points \(x_i \) and \(x_j \text{,}\) given by \(\| x_i - x_j \|_2 \text{.}\) But,

\begin{equation*} \| x_i - x_j \|_2^2 = ( x_i - x_j )^T ( x_i - x_j ) = x_i^T x_i - 2 x_i^T x_j + x_j^T x_j . \end{equation*}

So, if one creates the matrix

\begin{equation*} X = \left( \begin{array}{c | c | c | c } x_0 \amp x_1 \amp \cdots \amp x_{m-1} \end{array} \right) \end{equation*}

and computes

\begin{equation*} C = X^T X = \left( \begin{array}{c | c | c | c} x_0^T x_0 \amp \star \amp \cdots \amp \star \\ \hline x_1^T x_0 \amp x_1^T x_1 \amp \cdots \amp \star \\ \hline \vdots \amp \vdots \amp \ddots \amp \vdots \\ \hline x_{m-1}^T x_0 \amp x_{m-1}^T x_1 \amp \cdots \amp x_{m-1}^T x_{m-1} \end{array} \right). \end{equation*}

Hence, if the lower triangular part of \(C \) is computed, then

\begin{equation*} \| x_i - x_j \|_2^2 = \gamma_{i,i} - 2 \gamma_{i,j} + \gamma_{j,j}. \end{equation*}

By sorting this information, the nearest neighbors for each \(x_i \) can be found.

There are three problems with this:

  • Only the lower triangular part of \(C \) needs to be computed. This operation is known as a symmetric rank-k update. (Here the \(k \) refers to the number of rows in \(X \) rather than the \(k \) in k-nearest-neighbors.) How Goto's algorithm can be modified to compute the symmetric rank-k update is discussed in [11].

  • Typically \(m \gg n \) and hence the intermediate matrix \(C \) takes a very large amount of space to store an intermediate result.

  • Sorting the resulting information in order to determine the \(k \) nearest neighbors for each point means reading and writing data multipliple times from and to memory.

You can read up on how to achieve a high performance implementation for solving the all-nearest-neighbor problem by exploiting what we know about implementing matrix-matrix multiplication in [34]: