Unit4.5.3Matrix-matrix multiplication for Machine Learning

Matrix-matrix multiplication turns out to be an operation that is frequently employed by algorithms in machine learning. In this unit, we discuss the all-nearest-neighbor problem. Knowing how to optimize matrix-matrix multiplication allows one to optimize a practical algorithm for computing it.

The k-nearest-neighbor problem (KNN) takes as input $m$ points in $\Rn \text{,}$ $\{ x_j \}_{j=0}^{m-1} \text{,}$ and a reference point, $x \text{,}$ and computes the $k$ nearest neighbors of $x$ among the $m$ points. The all-nearest-neighbor (ANN) problem computes the $k$ nearest neighbors of each of the points $x_j \text{.}$

The trick to computing ANN is to observe that we need to compute the distances between all points $x_i$ and $x_j \text{,}$ given by $\| x_i - x_j \|_2 \text{.}$ But,

\begin{equation*} \| x_i - x_j \|_2^2 = ( x_i - x_j )^T ( x_i - x_j ) = x_i^T x_i - 2 x_i^T x_j + x_j^T x_j . \end{equation*}

So, if one creates the matrix

\begin{equation*} X = \left( \begin{array}{c | c | c | c } x_0 \amp x_1 \amp \cdots \amp x_{m-1} \end{array} \right) \end{equation*}

and computes

\begin{equation*} C = X^T X = \left( \begin{array}{c | c | c | c} x_0^T x_0 \amp \star \amp \cdots \amp \star \\ \hline x_1^T x_0 \amp x_1^T x_1 \amp \cdots \amp \star \\ \hline \vdots \amp \vdots \amp \ddots \amp \vdots \\ \hline x_{m-1}^T x_0 \amp x_{m-1}^T x_1 \amp \cdots \amp x_{m-1}^T x_{m-1} \end{array} \right). \end{equation*}

Hence, if the lower triangular part of $C$ is computed, then

\begin{equation*} \| x_i - x_j \|_2^2 = \gamma_{i,i} - 2 \gamma_{i,j} + \gamma_{j,j}. \end{equation*}

By sorting this information, the nearest neighbors for each $x_i$ can be found.

There are three problems with this:

• Only the lower triangular part of $C$ needs to be computed. This operation is known as a symmetric rank-k update. (Here the $k$ refers to the number of rows in $X$ rather than the $k$ in k-nearest-neighbors.) How Goto's algorithm can be modified to compute the symmetric rank-k update is discussed in .

• Typically $m \gg n$ and hence the intermediate matrix $C$ takes a very large amount of space to store an intermediate result.

• Sorting the resulting information in order to determine the $k$ nearest neighbors for each point means reading and writing data multipliple times from and to memory.

You can read up on how to achieve a high performance implementation for solving the all-nearest-neighbor problem by exploiting what we know about implementing matrix-matrix multiplication in :