Text Classification
Our focus is to solve text classification problem using deep learning. To reiterate the problem of NLP (Natural Language Processing) based text classification below:
Problem
In today’s world, websites have to deal with toxic and divisive content. Especially major websites like Quora which cater to large traffic and their purpose is to provide a platform to people for asking and answering questions. A key challenge is to weed out insincere questions, those founded upon false premises or questions that intend to make a statement rather than look for helpful answers.
A question is classified as insincere if:
- Non-neutral tone directed at someone
- Discriminatory or contains abusive language
- Contains false information
For more information regarding the challenge you can use the following link.
Code
The full deep learning code used is available here.
Deep Learning Approach
In part 1, we saw how machine learning algorithms could be applied to text classification. We had to identify and create a variety of features to reduce the complexity in the data. This required considerable effort but was essential for the learning algorithms to be able to detect patterns in the data.
In this section, we will approach the same problem using deep learning techniques. Deep learning has become the state of the art method for a variety of classification tasks. But first thing we need to understand is the motivation for the same, which has been outlined here:
Advantage:
- Save Efforts: No need to spend time exploring the intricacies of the text. We can start with a plug and play neural network and further, evaluate if any pre-processing or data cleaning will be necessary. During our experimentation we found that there wasn’t much incremental value obtained from cleaning text.
- Feature Extraction: Deep learning tries to learn high level features from data in an incremental manner. This removes the need to perform extensive feature engineering.
- Superior Performance: Assuming sufficient data, deep learning outperforms other techniques.
Issues:
That all sounds great, but what is the difficult part of this task. Well the major problem is defining the right architecture. But how do we get started, our primary task should be to understand the type of neural network we wish to use and further the architecture of the neural network. In other words we need to make a choice based on the following parameters:
- Type of Neural Network: Should we use an RNN or CNN or some form of a hybrid of both. As we are dealing with textual data, we need a sequence model. In other words, our model should be able to remember the past words used in a sentence in order to draw some value from the context of the sentence. This is where an RNN shines.
- Size, Width & Depth: We need to decide on the total number of nodes in the model, the number of layers and the number of neurons in the respective layers.
- Elements of a Neuron: That is the activation function, scaling, limiting and so on.
- Components/Layers: The decision of all the components or layers involved such as the embedding layer, using dropout or max-pooling and so on. We also need to decide how these layers will be ordered and connected.
Let’s understand some key concepts before we proceed further:
Word embeddings
Words represented as real valued vectors are what we call word embeddings. The value associated with each word is either learned by using a neural network on a large dataset with a predefined task (like document classification) or by using an unsupervised process such as using document statistics. These weights can be used as a part of transfer learning to take advantage of the reduced training time and better performance.
Recurrent neural networks
Consider how we make sense of a sentence, we not only look at a word but also how it fits with the preceding words. Recurrent neural networks (RNN) take into consideration both the current inputs as well as the preceding inputs. They are suitable for sequence problems because their internal memory stores important details about the inputs that they received in the past which helps them precisely predict the output of the next time step.
GRU and LSTM are improved variations of the vanilla RNN, which tackle the problem of vanishing gradients and handling of long term dependencies.
Consider a simple use case that we are trying to infer the weather based on the conversations between two people. Let’s take the following text as an example
“We were walking on the road when it started to pour, luckily my friend was carrying an umbrella.”
The target variable here has a classification of “rain”. In such a case RNN has to keep “pour” in memory while considering “umbrella” to correctly predict rain. As the occurrence of umbrella alone is not definitive proof of rain. Similarly the occurrence of “didn’t” before “pour” would have changed everything.
LSTM:
- Long Short Term Memory or LSTM has been an improvement over vanilla RNN as they are able to capture long term dependencies by introducing input, forget and output gates, which control what previous information needs to be stored and updated. In other words, if we need to store word 11 along with word 50 of a conversation to truly derive the context, LSTM is better equipped to do the same. LSTM is based on the following three major ideas:
- Introducing a word to allow my neural network to learn from it.
- Does a previously learned word continue to make sense or should I forget about it?
- Creating the final Output that is getting the predicted value.
GRU:
- Gated Recurrent Units also deal with long term dependencies in a similar fashion to LSTM. However, they combine the input and forget gates into a single update gate, resulting in a simpler design compared to LSTM. The update gate essentially decides what past information to hold on to and the reset gate as the name suggests decides what should be discarded or forgotten.
Bidirectional RNN
A bidirectional RNN first forward propagates left to right, starting with the initial time step. Then starting at the final time step, it moves right to left until it reaches the initial time step. This learning of representations by incorporating future time steps helps understand context better.
In other words let’s go back to our example of deciphering the weather based on conversations. It makes sense to make a leap from “pour” to “umbrella” starting reading from left to right. But what if we went right to left, that will just add to the power of the model as may be in another conversation we have a different occurrence pattern of words for example:
“I took out an umbrella as it started to pour.”.
Attention Layer
I think the best way to describe attention is by having a look at a basic CNN use case of image classification (Dogs vs Cats). If you are given an image of a dog what is the most defining characteristic that helps you differentiate. Is it the dog’s nose or ears? The attention mechanism blurs certain pixels and focuses on only a portion of the image. Thus, it assigns a weight and tells the model what to focus on.
In our context, attention model takes into account input from several time steps back and assigns a weight, signifying interest, to each of them. Attention is used to pay attention to specific words in the text sequence for example over a large dataset the attention layer will give more weightage to words like “rain”, “pour”, “umbrella” and so on.
Stochastic weight averaging
We used Stochastic weight averaging to update the weights of our network for the following reasons-
- SWA can be applied to any architecture, dataset and shows good results. It is essentially an ensemble technique where you are deciding to start storing the weights at subsequent epochs and average them.
- But wouldn’t that add to the computational load? SWA only requires the weights of one pre-trained model, to initialize and post that we store a running average to update weights at each epoch.
The image below (Illustrations of SWA and SGD with a Pre-activation ResNet-164 on CIFAR-100) shows how well SWA generalizes and results in better performance. On the left we have W1, W2 and W3 as weights of three independently trained networks and Wswa is the average of the three. This holds even though on the right we see a greater train loss by SWA as compared to SGD.
Cyclical Learning Rate (CLR)
The essence of this learning rate policy comes from the observation that increasing the learning rate might have a short term negative effect and yet achieve a longer term beneficial effect. This observation leads to the idea of letting the learning rate vary within a range of values rather than adopting a stepwise fixed or exponentially decreasing value. That is, one sets minimum and maximum boundaries and the learning rate cyclically varies between these bounds based on a predefined function. It is possible that during gradient descent we are stuck at the local minimum but a cyclical learning rate can help jump out to a different location moving towards the global optimum.
Dropout
This a technique specifically used to prevent over fitting. It basically involves dropping some percentage of the units during the training process. By dropping a unit out, we mean temporarily removing it from the network, along with all its incoming and outgoing connections. The units are selected randomly. Dropout randomly zeros out activations in order to force the network to generalize better, do less overfitting, and build in redundancies as a regularization technique.
There are different ways to drop values. Think of pixels in an image, these pixels will be very correlated to its neighbours, in such a case randomly dropping a pixels will not accomplish anything. That is where techniques like spatial dropout come into the picture. Spatial dropout involves dropping entire feature maps. For explanation purposes consider an image cropped into smaller segments and each being mapped using a function. Out of all these mapped values we randomly delete some of them.
If we take textual data you can think of it as dropping entire phrases and forcing the model to generalize better.
Batch normalization
Rather than having varying ranges in your data we often normalize the data set to allow faster convergence. The same principle is used in neural networks for the input of the hidden layer. It involves a covariance shift which helps the network generalize better. That means even if the value in the train set and test set are vastly different, by normalizing it we reduce overfitting and help get better results.
Pooling
Global average and global max pooling reduce the spatial size of the feature map/ representation to one feature map for each category (classification task).
Let’s say that we have an image of a dog. Global average pooling will take the average of all the activation values and tell us about the overall strength of the image, i.e. whether the image is of a dog or not.
Global max pooling, on the other hand, will take the maximum of the activation values. This will help identify the strongest trait of the image, say, ears of the dog. Similarly, in the case of textual data, global max pooling highlights the phrase with the most information while global average pooling indicates the overall value of the sentence.
Model Architecture
Let’s dive deeper into the choices made and connect the dots between understanding the components of a neural network to actually forming one. Well, one of the most important factors that comes into play while deciding on the architecture is past experience and experimentation.
We first started with a very basic model (GRU) and plotted the accuracy and loss. We noticed during our experimentation that given the nature of our dataset it was very easy to overfit. Observe the graph below obtained on using a bidirectional GRU with 64 units/neurons and a single hidden layer (16 units):
Clearly, we start to overfit very quickly and we cannot keep the number of epochs high. Similarly, it also suggests that a very complex model (One with multiple layers and several nodes) will also be prone to overfitting. This also suggests the importance of regularization techniques like dropout in the architecture.
We also carried out error analysis to further get an idea of the performance of the baseline model.
Error analysis on the baseline model
We intended to understand the following:
What are insincere topics where the network strongly believes to be sincere ?
class 0 [0.94] Should we kick out Muslims from India? class 0 [0.94] Why does Quora keep posting this leftist propaganda? Are they owned by a liberal media conglomerate? class 0 [0.94] Why don’t the Christian and Muslim quarter be under Palestinian control and the Jewish quarter be under Israeli control?
Our baseline model has shown an F1 score of 0.63 yet these sentences are being tagged as sincere. What is the issue?
Post a deep dive into the dataset we noticed that there are several misclassified cases in the dataset. In the competition details they have specified that the method used for classification is a combination of manual and some ML/rule based efforts. This also brings our attention to a real world scenario that can exist in any ML/DL project. Ideally, we would like to reclassify such cases with the wrong tag as the model is currently learning from misclassified examples as well. However, that can not be done given the constraints of the competition.
What are sincere topics where the network strongly believes to be insincere ?
class 1 [0.00000001] What astrological combinations are needed to obtain a scholarship for studies? class 1 [0.00000001] How many karats needed for a karat farm? class 1 [0.0000001] How many SITA devices are present in Ramayana?
What are insincere/sincere topics where the network are most uncertain how to classify ?
Keep in mind our chosen threshold for converting probabilities into a binary is: 0.4
class 0 [0.40] Why do most people not consider the US a two party system, even though any party besides Democrat/Republican realistically has no chance for a win? class 0 [0.40] Feigning ignorance of who wrote it, could you have mistaken "Something" for Paul McCartney number if he had sung it instead of George? class 1 [0.399] Why can't we make radiation death rays with huge spread and nearly infinite range to nullify Western air power once and for all and put a stop their invasions of our sovereignty? class 1 [0.399] What country is more tolerant to gays, Saudi Arabia or USA?
What are insincere topics where the network strongly believes correctly ?
class 1 [0.99] Why are the Quora moderators f*cking ***holes? class 1 [0.98] Why is it that Quora is so f*cking liberal and stupid?
This gives us a general idea of what is happening and help us develop some intuition of where the model shines and where it goes wrong. In particular we also get to see some words that might be clearly tipping the classification scales in a particular direction.
Model Design
Let’s start with the input data, we have used an average of the pre-trained GloVe and Paragram word embeddings. The reason for this choice was that Glove and Paragram were found to contain above 80% of our corpus. Further, taking an average of these embeddings was giving better results on our baseline model.
Spatial Dropout is being used immediately after the embeddings. This makes sure are model is more robust during training and prevents over-fitting.
Following which we have used a bidirectional LSTM layer, with 40 units. We decided to split this model into two pathways:
- An attention layer
- Another bidirectional LSTM layer, with 40 units
The best way to look at this is like we have made two branches, the prior (LSTM – Attention) maintains the simplicity and the latter branch allows the model to learn more by having an additional layer. The reason we selected 40 units is mostly based of experimentation and intuition, we noticed that by having a larger number of units the model started to over fit almost immediately.
In the latter branch, the output of the 2nd bidirectional LSTM layer was being used for three operations, namely, an attention layer, global average pooling and global max pooling. Now, each of these layers bring forth diverse features from the data and contain 80 units each.
All these outputs (from both prior and latter branches) are concatenated (as in the concatenation of 4 outputs we have 320 units) and fed into a layer of 32 units, with a RELU activation. This is followed by batch normalization and dropout to speed the computation and help reduce over-fitting. After this, we have the final output later with a sigmoid activation function.
Performance
Kaggle had set the evaluation metric to be the F1 score. This was a suitable choice, instead of accuracy, because of the class imbalance present in the dataset. Moreover, due to some of the questions being labelled incorrectly, techniques used to handle class imbalance, such has undersampling and oversampling, might actually increase the incorrectly labelled questions or decrease the correctly labelled ones. Further, computational constraints were another important factor to keep in mind while making any decision.
Our approach for model validation included creating a train and validation data set. We ran 10 epochs which is the maximum we could run with this model as post this we would start overfitting or violate computational constraints. We started SWA from the 4th epoch as at that point the F1 score had already reached close to 0.65. Thus, it was good point to start the process.
Once we had the final predictions from the model we used a threshold to binarize the probabilities which was obtained on the basis of the validation dataset.
Kaggle’s score calculations process involved only 15% data for the public leaderboard and remaining for the private leaderboard. Our final model returned a score of 0.68 on the public leaderboard and around 0.68875 in the private leaderboard. This stability in the score was a good demonstration of a good generalized model.
Here, is a look at the confusion matrix:
Conclusion
The traditional ML approach yielded a score of 0.583 as compared to the deep learning model’s score of 0.68.
While the deep learning model clearly outperformed the traditional ML stacking, there are a few points to consider before you set the course of your text classification problem:
- The computational effort was larger while building the deep learning model. We had used a standard kaggle kernel (14GB RAM, 2GB GPU) and further as per competition rules there was a 2 hour time limit on GPU usage.
- The journey of taking all the decisions associated with the model architecture, though an interesting process, took significant time and effort. A lot of experimentation was involved before nailing down on the final model’s design. Further, there can even be more room for improvement in the architecture.
- Data Sufficiency, that is for any DL approach, a prerequisite is the availability of a large dataset. In our use case the train set had 1.3M records.
- If performance is compared, then DL is definitely the victor, with even our baseline model outperforming the stacked ML model.
- The deep learning model did not require the extensive text pre-processing and the feature engineering involved while training.
- The model was able to take advantage of the pre-trained word embeddings and learn a lot of the intricate patterns found in the text.
References
- Embeddings Average
- Groupout
- Spatial Dropout
- Empirical Evaluation of Gated Neural Networks on Sequence Modelling
- Text Classification Improved by Integrating Bidirectional LSTM with Two-dimensional Max Pooling
- LSTM
- Neural Machine Translation by Jointly Learning to Align and Translate
- Cyclical Learning Rates for Training Neural Networks
- Averaging Weights Leads to Wider Optima and Better Generalization
- Stochastic Weight Averaging with Keras callback function
- Kaggle – Embeddings Exploration
- Attention
- Kaggle LSTM
- Kaggle – Word2Vec Embeddings
- Models on Kaggle
- Deep Learning NLP
About Us
Data science discovery is a step on the path of your data science journey. Please follow us on LinkedIn to stay updated.
About the writers:
- Ujjayant Sinha: Data science enthusiast with interest in natural language problems.
- Ankit Gadi: Driven by a knack and passion for data science coupled with a strong foundation in Operations Research and Statistics has helped me embark on my data science journey.