Explaining Deep Network Classification of Matrices: A Case Study on Monotonicity
By: Leandro Farina, Sergey Korotov
Potential Business Impact:
Finds simple math rules for sorting numbers.
This work demonstrates a methodology for using deep learning to discover simple, practical criteria for classifying matrices based on abstract algebraic properties. By combining a high-performance neural network with explainable AI (XAI) techniques, we can distill a model's learned strategy into human-interpretable rules. We apply this approach to the challenging case of monotone matrices, defined by the condition that their inverses are entrywise nonnegative. Despite their simple definition, an easy characterization in terms of the matrix elements or the derived parameters is not known. Here, we present, to the best of our knowledge, the first systematic machine-learning approach for deriving a practical criterion that distinguishes monotone from non-monotone matrices. After establishing a labelled dataset by randomly generated monotone and non-monotone matrices uniformly on $(-1,1)$, we employ deep neural network algorithms for classifying the matrices as monotone or non-monotone, using both their entries and a comprehensive set of matrix features. By saliency methods, such as integrated gradients, we identify among all features, two matrix parameters which alone provide sufficient information for the matrix classification, with $95\%$ accuracy, namely the absolute values of the two lowest-order coefficients, $c_0$ and $c_1$ of the matrix's characteristic polynomial. A data-driven study of 18,000 random $7\times7$ matrices shows that the monotone class obeys $\lvert c_{0}/c_{1}\rvert\le0.18$ with probability $>99.98\%$; because $\lvert c_{0}/c_{1}\rvert = 1/\mathrm{tr}(A^{-1})$ for monotone $A$, this is equivalent to the simple bound $\mathrm{tr}(A^{-1})\ge5.7$.
Similar Papers
Robustly Learning Monotone Single-Index Models
Machine Learning (CS)
Teaches computers to learn from messy, tricky data.
The Nondecreasing Rank
Machine Learning (Stat)
Finds hidden patterns in data using math.
Logical Expressivity and Explanations for Monotonic GNNs with Scoring Functions
Machine Learning (CS)
Explains computer predictions by finding simple rules.