Learning the Positions in CountSketch
This work addresses a bottleneck in learning-based sketching for researchers and practitioners in machine learning and optimization, offering a novel extension to previous methods that only learned values, but it is incremental as it builds on existing paradigms.
The paper tackles the problem of learning both the values and positions of non-zero entries in sketching matrices, specifically for CountSketch, to improve optimization tasks like low-rank approximation and Hessian approximation. The result includes algorithms that achieve good accuracy with fast running times and can significantly reduce error even with limited training data.
We consider sketching algorithms which first compress data by multiplication with a random sketch matrix, and then apply the sketch to quickly solve an optimization problem, e.g., low-rank approximation and regression. In the learning-based sketching paradigm proposed by~\cite{indyk2019learning}, the sketch matrix is found by choosing a random sparse matrix, e.g., CountSketch, and then the values of its non-zero entries are updated by running gradient descent on a training data set. Despite the growing body of work on this paradigm, a noticeable omission is that the locations of the non-zero entries of previous algorithms were fixed, and only their values were learned. In this work, we propose the first learning-based algorithms that also optimize the locations of the non-zero entries. Our first proposed algorithm is based on a greedy algorithm. However, one drawback of the greedy algorithm is its slower training time. We fix this issue and propose approaches for learning a sketching matrix for both low-rank approximation and Hessian approximation for second order optimization. The latter is helpful for a range of constrained optimization problems, such as LASSO and matrix estimation with a nuclear norm constraint. Both approaches achieve good accuracy with a fast running time. Moreover, our experiments suggest that our algorithm can still reduce the error significantly even if we only have a very limited number of training matrices.