Training, validation, and test data sets

Last updated

In machine learning, a common task is the study and construction of algorithms that can learn from and make predictions on data. [1] Such algorithms function by making data-driven predictions or decisions, [2] through building a mathematical model from input data. These input data used to build the model are usually divided into multiple data sets. In particular, three data sets are commonly used in different stages of the creation of the model: training, validation, and test sets.

Contents

The model is initially fit on a training data set, [3] which is a set of examples used to fit the parameters (e.g. weights of connections between neurons in artificial neural networks) of the model. [4] The model (e.g. a naive Bayes classifier) is trained on the training data set using a supervised learning method, for example using optimization methods such as gradient descent or stochastic gradient descent. In practice, the training data set often consists of pairs of an input vector (or scalar) and the corresponding output vector (or scalar), where the answer key is commonly denoted as the target (or label). The current model is run with the training data set and produces a result, which is then compared with the target, for each input vector in the training data set. Based on the result of the comparison and the specific learning algorithm being used, the parameters of the model are adjusted. The model fitting can include both variable selection and parameter estimation.

Successively, the fitted model is used to predict the responses for the observations in a second data set called the validation data set. [3] The validation data set provides an unbiased evaluation of a model fit on the training data set while tuning the model's hyperparameters [5] (e.g. the number of hidden units—layers and layer widths—in a neural network [4] ). Validation data sets can be used for regularization by early stopping (stopping training when the error on the validation data set increases, as this is a sign of over-fitting to the training data set). [6] This simple procedure is complicated in practice by the fact that the validation data set's error may fluctuate during training, producing multiple local minima. This complication has led to the creation of many ad-hoc rules for deciding when over-fitting has truly begun. [6]

Finally, the test data set is a data set used to provide an unbiased evaluation of a final model fit on the training data set. [5] If the data in the test data set has never been used in training (for example in cross-validation), the test data set is also called a holdout data set. The term "validation set" is sometimes used instead of "test set" in some literature (e.g., if the original data set was partitioned into only two subsets, the test set might be referred to as the validation set). [5]

Deciding the sizes and strategies for data set division in training, test and validation sets is very dependent on the problem and data available. [7]

Training data set

Simplified neural network training example.svg
Simplified example of training a neural network in object detection: The network is trained by multiple images that are known to depict starfish and sea urchins, which are correlated with "nodes" that represent visual features. The starfish match with a ringed texture and a star outline, whereas most sea urchins match with a striped texture and oval shape. However, the instance of a ring textured sea urchin creates a weakly weighted association between them.
Simplified neural network example.svg
Subsequent run of the network on an input image (left): [8] The network correctly detects the starfish. However, the weakly weighted association between ringed texture and sea urchin also confers a weak signal to the latter from one of two intermediate nodes. In addition, a shell that was not included in the training gives a weak signal for the oval shape, also resulting in a weak signal for the sea urchin output. These weak signals may result in a false positive result for sea urchin.
In reality, textures and outlines would not be represented by single nodes, but rather by associated weight patterns of multiple nodes.

A training data set is a data set of examples used during the learning process and is used to fit the parameters (e.g., weights) of, for example, a classifier. [9] [10]

For classification tasks, a supervised learning algorithm looks at the training data set to determine, or learn, the optimal combinations of variables that will generate a good predictive model. [11] The goal is to produce a trained (fitted) model that generalizes well to new, unknown data. [12] The fitted model is evaluated using “new” examples from the held-out data sets (validation and test data sets) to estimate the model’s accuracy in classifying new data. [5] To reduce the risk of issues such as over-fitting, the examples in the validation and test data sets should not be used to train the model. [5]

Most approaches that search through training data for empirical relationships tend to overfit the data, meaning that they can identify and exploit apparent relationships in the training data that do not hold in general.

When a training set is continuously expanded with new data, then this is incremental learning.

Validation data set

A validation data set is a data set of examples used to tune the hyperparameters (i.e. the architecture) of a model. It is sometimes also called the development set or the "dev set". [13] An example of a hyperparameter for artificial neural networks includes the number of hidden units in each layer. [9] [10] It, as well as the testing set (as mentioned below), should follow the same probability distribution as the training data set.

In order to avoid overfitting, when any classification parameter needs to be adjusted, it is necessary to have a validation data set in addition to the training and test data sets. For example, if the most suitable classifier for the problem is sought, the training data set is used to train the different candidate classifiers, the validation data set is used to compare their performances and decide which one to take and, finally, the test data set is used to obtain the performance characteristics such as accuracy, sensitivity, specificity, F-measure, and so on. The validation data set functions as a hybrid: it is training data used for testing, but neither as part of the low-level training nor as part of the final testing.

The basic process of using a validation data set for model selection (as part of training data set, validation data set, and test data set) is: [10] [14]

Since our goal is to find the network having the best performance on new data, the simplest approach to the comparison of different networks is to evaluate the error function using data which is independent of that used for training. Various networks are trained by minimization of an appropriate error function defined with respect to a training data set. The performance of the networks is then compared by evaluating the error function using an independent validation set, and the network having the smallest error with respect to the validation set is selected. This approach is called the hold out method. Since this procedure can itself lead to some overfitting to the validation set, the performance of the selected network should be confirmed by measuring its performance on a third independent set of data called a test set.

An application of this process is in early stopping, where the candidate models are successive iterations of the same network, and training stops when the error on the validation set grows, choosing the previous model (the one with minimum error).

Test data set

A test data set is a data set that is independent of the training data set, but that follows the same probability distribution as the training data set. If a model fit to the training data set also fits the test data set well, minimal overfitting has taken place (see figure below). A better fitting of the training data set as opposed to the test data set usually points to over-fitting.

A test set is therefore a set of examples used only to assess the performance (i.e. generalization) of a fully specified classifier. [9] [10] To do this, the final model is used to predict classifications of examples in the test set. Those predictions are compared to the examples' true classifications to assess the model's accuracy. [11]

In a scenario where both validation and test data sets are used, the test data set is typically used to assess the final model that is selected during the validation process. In the case where the original data set is partitioned into two subsets (training and test data sets), the test data set might assess the model only once (e.g., in the holdout method). [15] Note that some sources advise against such a method. [12] However, when using a method such as cross-validation, two partitions can be sufficient and effective since results are averaged after repeated rounds of model training and testing to help reduce bias and variability. [5] [12]


A training set (left) and a test set (right) from the same statistical population are shown as blue points. Two predictive models are fit to the training data. Both fitted models are plotted with both the training and test sets. In the training set, the MSE of the fit shown in orange is 4 whereas the MSE for the fit shown in green is 9. In the test set, the MSE for the fit shown in orange is 15 and the MSE for the fit shown in green is 13. The orange curve severely overfits the training data, since its MSE increases by almost a factor of four when comparing the test set to the training set. The green curve overfits the training data much less, as its MSE increases by less than a factor of 2. Traintest.svg
A training set (left) and a test set (right) from the same statistical population are shown as blue points. Two predictive models are fit to the training data. Both fitted models are plotted with both the training and test sets. In the training set, the MSE of the fit shown in orange is 4 whereas the MSE for the fit shown in green is 9. In the test set, the MSE for the fit shown in orange is 15 and the MSE for the fit shown in green is 13. The orange curve severely overfits the training data, since its MSE increases by almost a factor of four when comparing the test set to the training set. The green curve overfits the training data much less, as its MSE increases by less than a factor of 2.

Confusion in terminology

Testing is trying something to find out about it ("To put to the proof; to prove the truth, genuineness, or quality of by experiment" according to the Collaborative International Dictionary of English) and to validate is to prove that something is valid ("To confirm; to render valid" Collaborative International Dictionary of English). With this perspective, the most common use of the terms test set and validation set is the one here described. However, in both industry and academia, they are sometimes used interchanged, by considering that the internal process is testing different models to improve (test set as a development set) and the final model is the one that needs to be validated before real use with an unseen data (validation set). "The literature on machine learning often reverses the meaning of 'validation' and 'test' sets. This is the most blatant example of the terminological confusion that pervades artificial intelligence research." [16] Nevertheless, the important concept that must be kept is that the final set, whether called test or validation, should only be used in the final experiment.

Cross-validation

In order to get more stable results and use all valuable data for training, a data set can be repeatedly split into several training and a validation data sets. This is known as cross-validation. To confirm the model's performance, an additional test data set held out from cross-validation is normally used.

It is possible to use cross-validation on training and validation sets, and within each training set have further cross-validation for a test set for hyperparameter tuning. This is known as nested cross-validation.

Causes of error

Comic strip demonstrating a fictional erroneous computer output (making a coffee 5 million degrees, from a previous definition of "extra hot"). This can be classified as both a failure in logic and a failure to include various relevant environmental conditions. Computer definition of extra hot.png
Comic strip demonstrating a fictional erroneous computer output (making a coffee 5 million degrees, from a previous definition of "extra hot"). This can be classified as both a failure in logic and a failure to include various relevant environmental conditions.

Omissions in the training of algorithms are a major cause of erroneous outputs. [17] Types of such omissions include: [17]

An example of an omission of particular circumstances is a case where a boy was able to unlock the phone because his mother registered her face under indoor, nighttime lighting, a condition which was not appropriately included in the training of the system. [17] [18]

Usage of relatively irrelevant input can include situations where algorithms use the background rather than the object of interest for object detection, such as being trained by pictures of sheep on grasslands, leading to a risk that a different object will be interpreted as a sheep if located on a grassland. [17]

See also

Related Research Articles

<span class="mw-page-title-main">Supervised learning</span> Machine learning paradigm

In machine learning, supervised learning (SL) is a paradigm where a model is trained using input objects and desired output values, which are often human-made labels. The training process builds a function that maps new data to expected output values. An optimal scenario will allow for the algorithm to accurately determine output values for unseen instances. This requires the learning algorithm to generalize from the training data to unseen situations in a "reasonable" way. This statistical quality of an algorithm is measured via a generalization error.

Pattern recognition is the task of assigning a class to an observation based on patterns extracted from data. While similar, pattern recognition (PR) is not to be confused with pattern machines (PM) which may possess (PR) capabilities but their primary function is to distinguish and create emergent patterns. PR has applications in statistical data analysis, signal processing, image analysis, information retrieval, bioinformatics, data compression, computer graphics and machine learning. Pattern recognition has its origins in statistics and engineering; some modern approaches to pattern recognition include the use of machine learning, due to the increased availability of big data and a new abundance of processing power.

<span class="mw-page-title-main">Overfitting</span> Flaw in mathematical modelling

In mathematical modeling, overfitting is "the production of an analysis that corresponds too closely or exactly to a particular set of data, and may therefore fail to fit to additional data or predict future observations reliably". An overfitted model is a mathematical model that contains more parameters than can be justified by the data. In a mathematical sense, these parameters represent the degree of a polynomial. The essence of overfitting is to have unknowingly extracted some of the residual variation as if that variation represented underlying model structure.

The inductive bias of a learning algorithm is the set of assumptions that the learner uses to predict outputs of given inputs that it has not encountered. Inductive bias is anything which makes the algorithm learn one pattern instead of another pattern. Learning involves searching a space of solutions for a solution that provides a good explanation of the data. However, in many cases, there may be multiple equally appropriate solutions. An inductive bias allows a learning algorithm to prioritize one solution over another, independently of the observed data.

In machine learning, early stopping is a form of regularization used to avoid overfitting when training a model with an iterative method, such as gradient descent. Such methods update the model to make it better fit the training data with each iteration. Up to a point, this improves the model's performance on data outside of the training set. Past that point, however, improving the model's fit to the training data comes at the expense of increased generalization error. Early stopping rules provide guidance as to how many iterations can be run before the learner begins to over-fit. Early stopping rules have been employed in many different machine learning methods, with varying amounts of theoretical foundation.

Machine learning (ML) is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions. Advances in the field of deep learning have allowed neural networks to surpass many previous approaches in performance.

In Vapnik–Chervonenkis theory, the Vapnik–Chervonenkis (VC) dimension is a measure of the size of a class of sets. The notion can be extended to classes of binary functions. It is defined as the cardinality of the largest set of points that the algorithm can shatter, which means the algorithm can always learn a perfect classifier for any labeling of at least one configuration of those data points. It was originally defined by Vladimir Vapnik and Alexey Chervonenkis.

<span class="mw-page-title-main">Cross-validation (statistics)</span> Statistical model validation technique

Cross-validation, sometimes called rotation estimation or out-of-sample testing, is any of various similar model validation techniques for assessing how the results of a statistical analysis will generalize to an independent data set. Cross-validation includes resampling and sample splitting methods that use different portions of the data to test and train a model on different iterations. It is often used in settings where the goal is prediction, and one wants to estimate how accurately a predictive model will perform in practice. It can also be used to assess the quality of a fitted model and the stability of its parameters.

Neural gas is an artificial neural network, inspired by the self-organizing map and introduced in 1991 by Thomas Martinetz and Klaus Schulten. The neural gas is a simple algorithm for finding optimal data representations based on feature vectors. The algorithm was coined "neural gas" because of the dynamics of the feature vectors during the adaptation process, which distribute themselves like a gas within the data space. It is applied where data compression or vector quantization is an issue, for example speech recognition, image processing or pattern recognition. As a robustly converging alternative to the k-means clustering it is also used for cluster analysis.

<span class="mw-page-title-main">Decision tree pruning</span> Data compression technique

Pruning is a data compression technique in machine learning and search algorithms that reduces the size of decision trees by removing sections of the tree that are non-critical and redundant to classify instances. Pruning reduces the complexity of the final classifier, and hence improves predictive accuracy by the reduction of overfitting.

Group method of data handling (GMDH) is a family of inductive algorithms for computer-based mathematical modeling of multi-parametric datasets that features fully automatic structural and parametric optimization of models.

In statistics and machine learning, ensemble methods use multiple learning algorithms to obtain better predictive performance than could be obtained from any of the constituent learning algorithms alone. Unlike a statistical ensemble in statistical mechanics, which is usually infinite, a machine learning ensemble consists of only a concrete finite set of alternative models, but typically allows for much more flexible structure to exist among those alternatives.

In statistics, model validation is the task of evaluating whether a chosen statistical model is appropriate or not. Oftentimes in statistical inference, inferences from models that appear to fit their data may be flukes, resulting in a misunderstanding by researchers of the actual relevance of their model. To combat this, model validation is used to test whether a statistical model can hold up to permutations in the data. This topic is not to be confused with the closely related task of model selection, the process of discriminating between multiple candidate models: model validation does not concern so much the conceptual design of models as it tests only the consistency between a chosen model and its stated outputs.

In reinforcement learning, error-driven learning is a method for adjusting a model's parameters based on the difference between its output results and the ground truth. These models stand out as they depend on environmental feedback, rather than explicit labels or categories. They are based on the idea that language acquisition involves the minimization of the prediction error (MPSE). By leveraging these prediction errors, the models consistently refine expectations and decrease computational complexity. Typically, these algorithms are operated by the GeneRec algorithm.

In machine learning, a hyperparameter is a parameter that can be set in order to define any configurable part of a model's learning process. Hyperparameters can be classified as either model hyperparameters or algorithm hyperparameters. These are named hyperparameters in contrast to parameters, which are characteristics that the model learns from the data.

A convolutional neural network (CNN) is a regularized type of feed-forward neural network that learns features by itself via filter optimization. This type of deep learning network has been applied to process and make predictions from many different types of data including text, images and audio. Convolution-based networks are the de-facto standard in deep learning-based approaches to computer vision and image processing, and have only recently have been replaced -- in some cases -- by newer deep learning architectures such as the transformer. Vanishing gradients and exploding gradients, seen during backpropagation in earlier neural networks, are prevented by using regularized weights over fewer connections. For example, for each neuron in the fully-connected layer, 10,000 weights would be required for processing an image sized 100 × 100 pixels. However, applying cascaded convolution kernels, only 25 neurons are required to process 5x5-sized tiles. Higher-layer features are extracted from wider context windows, compared to lower-layer features.

In machine learning, Platt scaling or Platt calibration is a way of transforming the outputs of a classification model into a probability distribution over classes. The method was invented by John Platt in the context of support vector machines, replacing an earlier method by Vapnik, but can be applied to other classification models. Platt scaling works by fitting a logistic regression model to a classifier's scores.

In machine learning, hyperparameter optimization or tuning is the problem of choosing a set of optimal hyperparameters for a learning algorithm. A hyperparameter is a parameter whose value is used to control the learning process, which must be configured before the process starts.

Neural architecture search (NAS) is a technique for automating the design of artificial neural networks (ANN), a widely used model in the field of machine learning. NAS has been used to design networks that are on par with or outperform hand-designed architectures. Methods for NAS can be categorized according to the search space, search strategy and performance estimation strategy used:

<span class="mw-page-title-main">Learning curve (machine learning)</span> Plot of machine learning model performance over time or experience

In machine learning (ML), a learning curve is a graphical representation that shows how a model's performance on a training set changes with the number of training iterations (epochs) or the amount of training data. Typically, the number of training epochs or training set size is plotted on the x-axis, and the value of the loss function on the y-axis.

References

  1. Ron Kohavi; Foster Provost (1998). "Glossary of terms". Machine Learning . 30: 271–274. doi: 10.1023/A:1007411609915 .
  2. Bishop, Christopher M. (2006). Pattern Recognition and Machine Learning. New York: Springer. p. vii. ISBN   0-387-31073-8. Pattern recognition has its origins in engineering, whereas machine learning grew out of computer science. However, these activities can be viewed as two facets of the same field, and together they have undergone substantial development over the past ten years.
  3. 1 2 James, Gareth (2013). An Introduction to Statistical Learning: with Applications in R. Springer. p. 176. ISBN   978-1461471370.
  4. 1 2 Ripley, Brian (1996). Pattern Recognition and Neural Networks . Cambridge University Press. p.  354. ISBN   978-0521717700.
  5. 1 2 3 4 5 6 Brownlee, Jason (2017-07-13). "What is the Difference Between Test and Validation Datasets?" . Retrieved 2017-10-12.
  6. 1 2 Prechelt, Lutz; Geneviève B. Orr (2012-01-01). "Early Stopping — But When?". In Grégoire Montavon; Klaus-Robert Müller (eds.). Neural Networks: Tricks of the Trade . Lecture Notes in Computer Science. Springer Berlin Heidelberg. pp.  53–67. doi:10.1007/978-3-642-35289-8_5. ISBN   978-3-642-35289-8.
  7. "Machine learning - Is there a rule-of-thumb for how to divide a dataset into training and validation sets?". Stack Overflow. Retrieved 2021-08-12.
  8. Ferrie, C., & Kaiser, S. (2019). Neural Networks for Babies. Sourcebooks. ISBN   978-1492671206.{{cite book}}: CS1 maint: multiple names: authors list (link)
  9. 1 2 3 Ripley, B.D. (1996) Pattern Recognition and Neural Networks, Cambridge: Cambridge University Press, p. 354
  10. 1 2 3 4 "Subject: What are the population, sample, training set, design set, validation set, and test set?", Neural Network FAQ, part 1 of 7: Introduction (txt), comp.ai.neural-nets, Sarle, W.S., ed. (1997, last modified 2002-05-17)
  11. 1 2 Larose, D. T.; Larose, C. D. (2014). Discovering knowledge in data : an introduction to data mining. Hoboken: Wiley. doi:10.1002/9781118874059. ISBN   978-0-470-90874-7. OCLC   869460667.
  12. 1 2 3 Xu, Yun; Goodacre, Royston (2018). "On Splitting Training and Validation Set: A Comparative Study of Cross-Validation, Bootstrap and Systematic Sampling for Estimating the Generalization Performance of Supervised Learning". Journal of Analysis and Testing. 2 (3). Springer Science and Business Media LLC: 249–262. doi: 10.1007/s41664-018-0068-2 . ISSN   2096-241X. PMC   6373628 . PMID   30842888.
  13. "Deep Learning". Coursera. Retrieved 2021-05-18.
  14. Bishop, C.M. (1995), Neural Networks for Pattern Recognition , Oxford: Oxford University Press, p. 372
  15. Kohavi, Ron (2001-03-03). "A Study of Cross-Validation and Bootstrap for Accuracy Estimation and Model Selection". 14.{{cite journal}}: Cite journal requires |journal= (help)
  16. Ripley, Brian D. (2008-01-10). "Glossary". Pattern recognition and neural networks. Cambridge University Press. ISBN   9780521717700. OCLC   601063414.
  17. 1 2 3 4 5 Chanda SS, Banerjee DN (2022). "Omission and commission errors underlying AI failures". AI Soc. 39 (3): 1–24. doi:10.1007/s00146-022-01585-x. PMC   9669536 . PMID   36415822.
  18. Greenberg A (2017-11-14). "Watch a 10-Year-Old's Face Unlock His Mom's iPhone X". Wired.