Soft Decision Tree classifier: explainable and extendable PyTorch implementation
📝 Abstract
We implemented a Soft Decision Tree (SDT) and a Short-term Memory Soft Decision Tree (SM-SDT) using PyTorch. The methods were extensively tested on simulated and clinical datasets. The SDT was visualized to demonstrate the potential for its explainability. SDT, SM-SDT, and XGBoost demonstrated similar area under the curve (AUC) values. These methods were better than Random Forest, Logistic Regression, and Decision Tree. The results on clinical datasets suggest that, aside from a decision tree, all tested classification methods yield comparable results. The code and datasets are available online on GitHub: https://github.com/KI-Research-Institute/Soft-Decision-Tree
💡 Analysis
We implemented a Soft Decision Tree (SDT) and a Short-term Memory Soft Decision Tree (SM-SDT) using PyTorch. The methods were extensively tested on simulated and clinical datasets. The SDT was visualized to demonstrate the potential for its explainability. SDT, SM-SDT, and XGBoost demonstrated similar area under the curve (AUC) values. These methods were better than Random Forest, Logistic Regression, and Decision Tree. The results on clinical datasets suggest that, aside from a decision tree, all tested classification methods yield comparable results. The code and datasets are available online on GitHub: https://github.com/KI-Research-Institute/Soft-Decision-Tree
📄 Content
A soft decision tree (SDT) is a variant of the traditional decision tree where the splits at internal nodes are probabilistic rather than deterministic [1]. A single SDT that is trained for classification can approximate multiple hard decision trees. Therefore, it forms a compact representation for the trained tree-based classifier. It can be visualized and analysed to provide insights regarding the classifier results and the trained dataset (e.g. feature importance). Moreover, the SDT can be incorporated as a backbone for the development of novel tree-based machine-learning architectures. For a detailed description of SDT, the reader is referred to [1], [2].
We have implemented a SDT using PyTorch following Frosst et al. [2]. We developed a method to track the internal SDT parameters and to visualize them (Figure 1). Moreover, we introduce short-term memory SDT (SM-SDT) to demonstrate the extendibility of our implementation. SM-SDT is implemented by incorporating short-term memory capabilities into the nodes of the SDT. Specifically, each node is aware of the output of its immediate parent and grandparent in the tree. In addition, a neural linear layer was added at the input level of each node. Last, we have conducted simulation and clinical data experiments to validate and evaluate our implementation of SDT and SM-SDT. Accuracy results are reported below, the typical runtime of the SDT was 30-60 seconds on NVIDIA T4 GPU with 16 GB memory.
Figure 1: A visualization of a soft decision tree (SDT) classifier trained on heart failure dataset [3], [4]. The features with largest weights are presented at each internal node. The splits probabilities and leaf logits are presented as well. This information can be analysed and provide insights regarding the classifier results and the trained dataset (e.g. feature importance, classifier reasoning, etc.).
Method: We have incorporated the scikit-learn ‘make_classification’ module to generate random binary-outcome classification datasets with varying numbers of independent samples (1K, 100K, and 1M samples) and with varying numbers of features (50, 250, and 500 features). The number of features that contribute to the outcome remained 30 (constant) in all experiments. We have compared six classification algorithms for each simulated dataset: 1. Decision tree (DT); 2. logistic regression (LR); 3. a random forest with 1000 trees; 4. XGBoost; 5. SDT, and; 6. SM-SDT. All tree-based algorithms had a max depth of three.
Method: We have downloaded, cleaned, and normalized seven clinical datasets. The outcomes were binarized in all cases. The seven datasets are: 1. Diagnostic Wisconsin breast cancer database [5]; 2. Heart disease (Cleveland subset) dataset [6]; 3. Heart failure dataset [3]; 4. Indian liver patient dataset (ILPD) [7]; 5. Pima Indians diabetes database (PIMA) [8]; 6. Stroke prediction dataset [9], and; 7. Thyroid disease dataset [10]. We have compared seven classification algorithms for each simulated dataset: 1. decision tree (DT); 2. logistic regression (LR); 3. a random forest with 100 trees; 4. a random forest with 1000 trees; 5. XGBoost; 6. SDT, and; 7. SM-SDT. All tree-based algorithms had a max depth of three. The train/test splits were 80/20% of the data in all cases. We have repeated this experiment five times with various random train/test splits.
The average AUC values of all classification methods and clinical datasets are presented in Table 1. SDT and SM-SDT achieve state-of-the-art results, as the observed AUC is comparable to that of XGBoost, random forest, and logistic regression.
We have developed and tested a PyTorch implementation of a soft-decision tree. We have developed a method to visualize the tree and expanded it with short-term memory capabilities. Our results suggest that the SDT achieves state-of-the-art AUC values. The drawbacks of the suggested method are that it requires a GPU and that the computation time is typically 30-60s in comparison to less than few (3-4) seconds observed for other methods. We have shared the code online on KI Research Institute GitHub webpage [11] and hope that others will be interested in further unveiling the practical potential of SDT.
RF100: Random Forest with 100 trees; RF1000: Random Forest with 1000 trees; SDT: Soft Decision Treel; XGB: XGBoost.
This content is AI-processed based on ArXiv data.