Note

This tutorial is intended to be run in an IPython notebook. It is also available as a notebook file here.

Debugging scikit-learn text classification pipeline

scikit-learn docs provide a nice text classification tutorial. Make sure to read it first. We’ll be doing something similar to it, while taking more detailed look at classifier weights and predictions.

1. Baseline model

First, we need some data. Let’s load 20 Newsgroups data, keeping only 4 categories:

from sklearn.datasets import fetch_20newsgroups

categories = ['alt.atheism', 'soc.religion.christian',
              'comp.graphics', 'sci.med']
twenty_train = fetch_20newsgroups(
    subset='train',
    categories=categories,
    shuffle=True,
    random_state=42
)
twenty_test = fetch_20newsgroups(
    subset='test',
    categories=categories,
    shuffle=True,
    random_state=42
)

A basic text processing pipeline - bag of words features and Logistic Regression as a classifier:

from sklearn.feature_extraction.text import CountVectorizer
from sklearn.linear_model import LogisticRegressionCV
from sklearn.pipeline import make_pipeline

vec = CountVectorizer()
clf = LogisticRegressionCV()
pipe = make_pipeline(vec, clf)
pipe.fit(twenty_train.data, twenty_train.target);

We’re using LogisticRegressionCV here to adjust regularization parameter C automatically. It allows to compare different vectorizers - optimal C value could be different for different input features (e.g. for bigrams or for character-level input). An alternative would be to use GridSearchCV or RandomizedSearchCV.

Let’s check quality of this pipeline:

from sklearn import metrics

def print_report(pipe):
    y_test = twenty_test.target
    y_pred = pipe.predict(twenty_test.data)
    report = metrics.classification_report(y_test, y_pred,
        target_names=twenty_test.target_names)
    print(report)
    print("accuracy: {:0.3f}".format(metrics.accuracy_score(y_test, y_pred)))

print_report(pipe)
                        precision    recall  f1-score   support

           alt.atheism       0.93      0.80      0.86       319
         comp.graphics       0.87      0.96      0.91       389
               sci.med       0.94      0.81      0.87       396
soc.religion.christian       0.85      0.98      0.91       398

           avg / total       0.90      0.89      0.89      1502

accuracy: 0.891

Not bad. We can try other classifiers and preprocessing methods, but let’s check first what the model learned using eli5.show_weights() function:

import eli5
eli5.show_weights(clf, top=10)
y=0 top features y=1 top features y=2 top features y=3 top features
Weight? Feature
+1.991 x21167
+1.925 x19218
+1.834 x5714
+1.813 x23677
+1.697 x15511
+1.696 x26415
+1.617 x6440
+1.594 x26412
… 10174 more positive …
… 25605 more negative …
-1.686 x28473
-10.453 <BIAS>
Weight? Feature
+1.702 x15699
+0.825 x17366
+0.798 x14281
+0.786 x30117
+0.779 x14277
+0.773 x17356
+0.729 x24267
+0.724 x7874
+0.702 x2148
… 11710 more positive …
… 24069 more negative …
-1.379 <BIAS>
Weight? Feature
+2.016 x25234
+1.951 x12026
+1.758 x17854
+1.697 x11729
+1.655 x32847
+1.522 x22379
+1.518 x16328
… 15007 more positive …
… 20772 more negative …
-1.764 x15521
-2.171 x15699
-5.013 <BIAS>
Weight? Feature
+1.193 x28473
+1.030 x8609
+1.021 x8559
+0.946 x8798
+0.899 x8544
+0.797 x8553
… 11122 more positive …
… 24657 more negative …
-0.852 x15699
-0.894 x25663
-1.181 x23122
-1.243 x16881

The table above doesn’t make any sense; the problem is that eli5 was not able to get feature and class names from the classifier object alone. We can provide feature and target names explicitly:

# eli5.show_weights(clf,
#                   feature_names=vec.get_feature_names(),
#                   target_names=twenty_test.target_names)

The code above works, but a better way is to provide vectorizer instead and let eli5 figure out the details automatically:

eli5.show_weights(clf, vec=vec, top=10,
                  target_names=twenty_test.target_names)
y=alt.atheism top features y=comp.graphics top features y=sci.med top features y=soc.religion.christian top features
Weight? Feature
+1.991 mathew
+1.925 keith
+1.834 atheism
+1.813 okcforum
+1.697 go
+1.696 psuvm
+1.617 believing
+1.594 psu
… 10174 more positive …
… 25605 more negative …
-1.686 rutgers
-10.453 <BIAS>
Weight? Feature
+1.702 graphics
+0.825 images
+0.798 files
+0.786 software
+0.779 file
+0.773 image
+0.729 package
+0.724 card
+0.702 3d
… 11710 more positive …
… 24069 more negative …
-1.379 <BIAS>
Weight? Feature
+2.016 pitt
+1.951 doctor
+1.758 information
+1.697 disease
+1.655 treatment
+1.522 msg
+1.518 health
… 15007 more positive …
… 20772 more negative …
-1.764 god
-2.171 graphics
-5.013 <BIAS>
Weight? Feature
+1.193 rutgers
+1.030 church
+1.021 christians
+0.946 clh
+0.899 christ
+0.797 christian
… 11122 more positive …
… 24657 more negative …
-0.852 graphics
-0.894 posting
-1.181 nntp
-1.243 host

This starts to make more sense. Columns are target classes. In each column there are features and their weights. Intercept (bias) feature is shown as <BIAS> in the same table. We can inspect features and weights because we’re using a bag-of-words vectorizer and a linear classifier (so there is a direct mapping between individual words and classifier coefficients). For other classifiers features can be harder to inspect.

Some features look good, but some don’t. It seems model learned some names specific to a dataset (email parts, etc.) though, instead of learning topic-specific words. Let’s check prediction results on an example:

eli5.show_prediction(clf, twenty_test.data[0], vec=vec,
                     target_names=twenty_test.target_names)

y=alt.atheism (probability 0.000, score -8.709) top features

Contribution? Feature
+1.743 Highlighted in text (sum)
-10.453 <BIAS>

from: brian@ucsd.edu (brian kantor) subject: re: help for kidney stones .............. organization: the avant-garde of the now, ltd. lines: 12 nntp-posting-host: ucsd.edu as i recall from my bout with kidney stones, there isn't any medication that can do anything about them except relieve the pain. either they pass, or they have to be broken up with sound, or they have to be extracted surgically. when i was in, the x-ray tech happened to mention that she'd had kidney stones and children, and the childbirth hurt less. demerol worked, although i nearly got arrested on my way home when i barfed all over the police car parked just outside the er. - brian

y=comp.graphics (probability 0.010, score -4.592) top features

Contribution? Feature
-1.379 <BIAS>
-3.213 Highlighted in text (sum)

from: brian@ucsd.edu (brian kantor) subject: re: help for kidney stones .............. organization: the avant-garde of the now, ltd. lines: 12 nntp-posting-host: ucsd.edu as i recall from my bout with kidney stones, there isn't any medication that can do anything about them except relieve the pain. either they pass, or they have to be broken up with sound, or they have to be extracted surgically. when i was in, the x-ray tech happened to mention that she'd had kidney stones and children, and the childbirth hurt less. demerol worked, although i nearly got arrested on my way home when i barfed all over the police car parked just outside the er. - brian

y=sci.med (probability 0.989, score 3.945) top features

Contribution? Feature
+8.958 Highlighted in text (sum)
-5.013 <BIAS>

from: brian@ucsd.edu (brian kantor) subject: re: help for kidney stones .............. organization: the avant-garde of the now, ltd. lines: 12 nntp-posting-host: ucsd.edu as i recall from my bout with kidney stones, there isn't any medication that can do anything about them except relieve the pain. either they pass, or they have to be broken up with sound, or they have to be extracted surgically. when i was in, the x-ray tech happened to mention that she'd had kidney stones and children, and the childbirth hurt less. demerol worked, although i nearly got arrested on my way home when i barfed all over the police car parked just outside the er. - brian

y=soc.religion.christian (probability 0.001, score -7.157) top features

Contribution? Feature
-0.258 <BIAS>
-6.899 Highlighted in text (sum)

from: brian@ucsd.edu (brian kantor) subject: re: help for kidney stones .............. organization: the avant-garde of the now, ltd. lines: 12 nntp-posting-host: ucsd.edu as i recall from my bout with kidney stones, there isn't any medication that can do anything about them except relieve the pain. either they pass, or they have to be broken up with sound, or they have to be extracted surgically. when i was in, the x-ray tech happened to mention that she'd had kidney stones and children, and the childbirth hurt less. demerol worked, although i nearly got arrested on my way home when i barfed all over the police car parked just outside the er. - brian

What can be highlighted in text is highlighted in text. There is also a separate table for features which can’t be highlighted in text - <BIAS> in this case. If you hover mouse on a highlighted word it shows you a weight of this word in a title. Words are colored according to their weights.

2. Baseline model, improved data

Aha, from the highlighting above it can be seen that a classifier learned some non-interesting stuff indeed, e.g. it remembered parts of email addresses. We should probably clean the data first to make it more interesting; improving model (trying different classifiers, etc.) doesn’t make sense at this point - it may just learn to leverage these email addresses better.

In practice we’d have to do cleaning yourselves; in this example 20 newsgroups dataset provides an option to remove footers and headers from the messages. Nice. Let’s clean up the data and re-train a classifier.

twenty_train = fetch_20newsgroups(
    subset='train',
    categories=categories,
    shuffle=True,
    random_state=42,
    remove=['headers', 'footers'],
)
twenty_test = fetch_20newsgroups(
    subset='test',
    categories=categories,
    shuffle=True,
    random_state=42,
    remove=['headers', 'footers'],
)

vec = CountVectorizer()
clf = LogisticRegressionCV()
pipe = make_pipeline(vec, clf)
pipe.fit(twenty_train.data, twenty_train.target);

We just made the task harder and more realistic for a classifier.

print_report(pipe)
                        precision    recall  f1-score   support

           alt.atheism       0.83      0.78      0.80       319
         comp.graphics       0.82      0.96      0.88       389
               sci.med       0.89      0.80      0.84       396
soc.religion.christian       0.88      0.86      0.87       398

           avg / total       0.85      0.85      0.85      1502

accuracy: 0.852

A great result - we just made quality worse! Does it mean pipeline is worse now? No, likely it has a better quality on unseen messages. It is evaluation which is more fair now. Inspecting features used by classifier allowed us to notice a problem with the data and made a good change, despite of numbers which told us not to do that.

Instead of removing headers and footers we could have improved evaluation setup directly, using e.g. GroupKFold from scikit-learn. Then quality of old model would have dropped, we could have removed headers/footers and see increased accuracy, so the numbers would have told us to remove headers and footers. It is not obvious how to split data though, what groups to use with GroupKFold.

So, what have the updated classifier learned? (output is less verbose because only a subset of classes is shown - see “targets” argument):

eli5.show_prediction(clf, twenty_test.data[0], vec=vec,
                     target_names=twenty_test.target_names,
                     targets=['sci.med'])

y=sci.med (probability 0.732, score 0.031) top features

Contribution? Feature
+1.747 Highlighted in text (sum)
-1.716 <BIAS>

as i recall from my bout with kidney stones, there isn't any medication that can do anything about them except relieve the pain. either they pass, or they have to be broken up with sound, or they have to be extracted surgically. when i was in, the x-ray tech happened to mention that she'd had kidney stones and children, and the childbirth hurt less.

Hm, it no longer uses email addresses, but it still doesn’t look good: classifier assigns high weights to seemingly unrelated words like ‘do’ or ‘my’. These words appear in many texts, so maybe classifier uses them as a proxy for bias. Or maybe some of them are more common in some of classes.

3. Pipeline improvements

To help classifier we may filter out stop words:

vec = CountVectorizer(stop_words='english')
clf = LogisticRegressionCV()
pipe = make_pipeline(vec, clf)
pipe.fit(twenty_train.data, twenty_train.target)

print_report(pipe)
                        precision    recall  f1-score   support

           alt.atheism       0.87      0.76      0.81       319
         comp.graphics       0.85      0.95      0.90       389
               sci.med       0.93      0.85      0.89       396
soc.religion.christian       0.85      0.89      0.87       398

           avg / total       0.87      0.87      0.87      1502

accuracy: 0.871
eli5.show_prediction(clf, twenty_test.data[0], vec=vec,
                     target_names=twenty_test.target_names,
                     targets=['sci.med'])

y=sci.med (probability 0.714, score 0.510) top features

Contribution? Feature
+2.184 Highlighted in text (sum)
-1.674 <BIAS>

as i recall from my bout with kidney stones, there isn't any medication that can do anything about them except relieve the pain. either they pass, or they have to be broken up with sound, or they have to be extracted surgically. when i was in, the x-ray tech happened to mention that she'd had kidney stones and children, and the childbirth hurt less.

Looks better, isn’t it?

Alternatively, we can use TF*IDF scheme; it should give a somewhat similar effect.

Note that we’re cross-validating LogisticRegression regularisation parameter here, like in other examples (LogisticRegressionCV, not LogisticRegression). TF*IDF values are different from word count values, so optimal C value can be different. We could draw a wrong conclusion if a classifier with fixed regularization strength is used - the chosen C value could have worked better for one kind of data.

from sklearn.feature_extraction.text import TfidfVectorizer

vec = TfidfVectorizer()
clf = LogisticRegressionCV()
pipe = make_pipeline(vec, clf)
pipe.fit(twenty_train.data, twenty_train.target)

print_report(pipe)
                        precision    recall  f1-score   support

           alt.atheism       0.91      0.79      0.85       319
         comp.graphics       0.83      0.97      0.90       389
               sci.med       0.95      0.87      0.91       396
soc.religion.christian       0.90      0.91      0.91       398

           avg / total       0.90      0.89      0.89      1502

accuracy: 0.892
eli5.show_prediction(clf, twenty_test.data[0], vec=vec,
                     target_names=twenty_test.target_names,
                     targets=['sci.med'])

y=sci.med (probability 0.987, score 1.585) top features

Contribution? Feature
+6.788 Highlighted in text (sum)
-5.203 <BIAS>

as i recall from my bout with kidney stones, there isn't any medication that can do anything about them except relieve the pain. either they pass, or they have to be broken up with sound, or they have to be extracted surgically. when i was in, the x-ray tech happened to mention that she'd had kidney stones and children, and the childbirth hurt less.

It helped, but didn’t have quite the same effect. Why not do both?

vec = TfidfVectorizer(stop_words='english')
clf = LogisticRegressionCV()
pipe = make_pipeline(vec, clf)
pipe.fit(twenty_train.data, twenty_train.target)

print_report(pipe)
                        precision    recall  f1-score   support

           alt.atheism       0.93      0.77      0.84       319
         comp.graphics       0.84      0.97      0.90       389
               sci.med       0.95      0.89      0.92       396
soc.religion.christian       0.88      0.92      0.90       398

           avg / total       0.90      0.89      0.89      1502

accuracy: 0.893
eli5.show_prediction(clf, twenty_test.data[0], vec=vec,
                     target_names=twenty_test.target_names,
                     targets=['sci.med'])

y=sci.med (probability 0.939, score 1.910) top features

Contribution? Feature
+5.488 Highlighted in text (sum)
-3.578 <BIAS>

as i recall from my bout with kidney stones, there isn't any medication that can do anything about them except relieve the pain. either they pass, or they have to be broken up with sound, or they have to be extracted surgically. when i was in, the x-ray tech happened to mention that she'd had kidney stones and children, and the childbirth hurt less.

This starts to look good!

4. Char-based pipeline

Maybe we can get somewhat better quality by choosing a different classifier, but let’s skip it for now. Let’s try other analysers instead - use char n-grams instead of words:

vec = TfidfVectorizer(stop_words='english', analyzer='char',
                      ngram_range=(3,5))
clf = LogisticRegressionCV()
pipe = make_pipeline(vec, clf)
pipe.fit(twenty_train.data, twenty_train.target)

print_report(pipe)
                        precision    recall  f1-score   support

           alt.atheism       0.93      0.79      0.85       319
         comp.graphics       0.81      0.97      0.89       389
               sci.med       0.95      0.86      0.90       396
soc.religion.christian       0.89      0.91      0.90       398

           avg / total       0.89      0.89      0.89      1502

accuracy: 0.888
eli5.show_prediction(clf, twenty_test.data[0], vec=vec,
                     target_names=twenty_test.target_names)

y=alt.atheism (probability 0.002, score -7.318) top features

Contribution? Feature
-0.838 Highlighted in text (sum)
-6.480 <BIAS>

as i recall from my bout with kidney stones, there isn't any medication that can do anything about them except relieve the pain. either they pass, or they have to be broken up with sound, or they have to be extracted surgically. when i was in, the x-ray tech happened to mention that she'd had kidney stones and children, and the childbirth hurt less.

y=comp.graphics (probability 0.017, score -5.118) top features

Contribution? Feature
+0.934 <BIAS>
-6.052 Highlighted in text (sum)

as i recall from my bout with kidney stones, there isn't any medication that can do anything about them except relieve the pain. either they pass, or they have to be broken up with sound, or they have to be extracted surgically. when i was in, the x-ray tech happened to mention that she'd had kidney stones and children, and the childbirth hurt less.

y=sci.med (probability 0.963, score -0.656) top features

Contribution? Feature
+4.493 Highlighted in text (sum)
-5.149 <BIAS>

as i recall from my bout with kidney stones, there isn't any medication that can do anything about them except relieve the pain. either they pass, or they have to be broken up with sound, or they have to be extracted surgically. when i was in, the x-ray tech happened to mention that she'd had kidney stones and children, and the childbirth hurt less.

y=soc.religion.christian (probability 0.018, score -5.048) top features

Contribution? Feature
+0.600 Highlighted in text (sum)
-5.648 <BIAS>

as i recall from my bout with kidney stones, there isn't any medication that can do anything about them except relieve the pain. either they pass, or they have to be broken up with sound, or they have to be extracted surgically. when i was in, the x-ray tech happened to mention that she'd had kidney stones and children, and the childbirth hurt less.

It works, but quality is a bit worse. Also, it takes ages to train.

It looks like stop_words have no effect now - in fact, this is documented in scikit-learn docs, so our stop_words=‘english’ was useless. But at least it is now more obvious how the text looks like for a char ngram-based classifier. Grab a cup of tea and see how char_wb looks like:

vec = TfidfVectorizer(analyzer='char_wb', ngram_range=(3,5))
clf = LogisticRegressionCV()
pipe = make_pipeline(vec, clf)
pipe.fit(twenty_train.data, twenty_train.target)

print_report(pipe)
                        precision    recall  f1-score   support

           alt.atheism       0.93      0.79      0.85       319
         comp.graphics       0.87      0.96      0.91       389
               sci.med       0.91      0.90      0.90       396
soc.religion.christian       0.89      0.91      0.90       398

           avg / total       0.90      0.89      0.89      1502

accuracy: 0.894
eli5.show_prediction(clf, twenty_test.data[0], vec=vec,
                     target_names=twenty_test.target_names)

y=alt.atheism (probability 0.000, score -8.878) top features

Contribution? Feature
-2.560 Highlighted in text (sum)
-6.318 <BIAS>

as i recall from my bout with kidney stones, there isn't any medication that can do anything about them except relieve the pain. either they pass, or they have to be broken up with sound, or they have to be extracted surgically. when i was in, the x-ray tech happened to mention that she'd had kidney stones and children, and the childbirth hurt less.

y=comp.graphics (probability 0.005, score -6.007) top features

Contribution? Feature
+0.974 <BIAS>
-6.981 Highlighted in text (sum)

as i recall from my bout with kidney stones, there isn't any medication that can do anything about them except relieve the pain. either they pass, or they have to be broken up with sound, or they have to be extracted surgically. when i was in, the x-ray tech happened to mention that she'd had kidney stones and children, and the childbirth hurt less.

y=sci.med (probability 0.834, score -0.440) top features

Contribution? Feature
+2.134 Highlighted in text (sum)
-2.573 <BIAS>

as i recall from my bout with kidney stones, there isn't any medication that can do anything about them except relieve the pain. either they pass, or they have to be broken up with sound, or they have to be extracted surgically. when i was in, the x-ray tech happened to mention that she'd had kidney stones and children, and the childbirth hurt less.

y=soc.religion.christian (probability 0.160, score -2.510) top features

Contribution? Feature
+3.263 Highlighted in text (sum)
-5.773 <BIAS>

as i recall from my bout with kidney stones, there isn't any medication that can do anything about them except relieve the pain. either they pass, or they have to be broken up with sound, or they have to be extracted surgically. when i was in, the x-ray tech happened to mention that she'd had kidney stones and children, and the childbirth hurt less.

The result is similar, with some minor changes. Quality is better for unknown reason; maybe cross-word dependencies are not that important.

5. Debugging HashingVectorizer

To check that we can try fitting word n-grams instead of char n-grams. But let’s deal with efficiency first. To handle large vocabularies we can use HashingVectorizer from scikit-learn; to make training faster we can employ SGDCLassifier:

from sklearn.feature_extraction.text import HashingVectorizer
from sklearn.linear_model import SGDClassifier

vec = HashingVectorizer(stop_words='english', ngram_range=(1,2))
clf = SGDClassifier(n_iter=10, random_state=42)
pipe = make_pipeline(vec, clf)
pipe.fit(twenty_train.data, twenty_train.target)

print_report(pipe)
                        precision    recall  f1-score   support

           alt.atheism       0.90      0.80      0.85       319
         comp.graphics       0.88      0.96      0.92       389
               sci.med       0.93      0.90      0.92       396
soc.religion.christian       0.89      0.91      0.90       398

           avg / total       0.90      0.90      0.90      1502

accuracy: 0.899

It was super-fast! We’re not choosing regularization parameter using cross-validation though. Let’s check what model learned:

eli5.show_prediction(clf, twenty_test.data[0], vec=vec,
                     target_names=twenty_test.target_names,
                     targets=['sci.med'])

y=sci.med (score 0.097) top features

Contribution? Feature
+0.678 Highlighted in text (sum)
-0.581 <BIAS>

as i recall from my bout with kidney stones, there isn't any medication that can do anything about them except relieve the pain. either they pass, or they have to be broken up with sound, or they have to be extracted surgically. when i was in, the x-ray tech happened to mention that she'd had kidney stones and children, and the childbirth hurt less.

Result looks similar to CountVectorizer. But with HashingVectorizer we don’t even have a vocabulary! Why does it work?

eli5.show_weights(clf, vec=vec, top=10,
                  target_names=twenty_test.target_names)
y=alt.atheism top features y=comp.graphics top features y=sci.med top features y=soc.religion.christian top features
Weight? Feature
+2.836 x199378
+2.378 x938889
+1.776 x718537
+1.625 x349126
+1.554 x242643
+1.509 x71928
… 50341 more positive …
… 50567 more negative …
-1.634 x683213
-1.795 x741207
-1.872 x199709
-2.132 x641063
Weight? Feature
+3.737 x580586
+2.056 x342790
+1.956 x771885
+1.787 x363686
+1.717 x111283
… 32081 more positive …
… 31710 more negative …
-1.760 x857427
-1.779 x85557
-1.813 x693269
-2.021 x120354
-2.447 x814572
Weight? Feature
+2.209 x988761
+2.194 x337555
+2.162 x154565
+1.818 x806262
… 44124 more positive …
… 43892 more negative …
-1.704 x790864
-1.750 x580586
-1.851 x34701
-2.085 x85557
-2.147 x365313
-2.150 x494508
Weight? Feature
+3.034 x641063
+3.016 x199709
+2.977 x741207
+2.092 x396081
+1.901 x274863
… 51475 more positive …
… 51717 more negative …
-1.963 x672777
-2.096 x199378
-2.143 x443433
-2.963 x718537
-3.245 x970058

Ok, we don’t have a vocabulary, so we don’t have feature names. Are we out of luck? Nope, eli5 has an answer for that: InvertableHashingVectorizer. It can be used to get feature names for HahshingVectorizer without fitiing a huge vocabulary. It still needs some data to learn words -> hashes mapping though; we can use a random subset of data to fit it.

from eli5.sklearn import InvertableHashingVectorizer
import numpy as np
ivec = InvertableHashingVectorizer(vec)
sample_size = len(twenty_train.data) // 10
X_sample = np.random.choice(twenty_train.data, size=sample_size)
ivec.fit(X_sample);
eli5.show_weights(clf, vec=ivec, top=20,
                  target_names=twenty_test.target_names)
y=alt.atheism top features y=comp.graphics top features y=sci.med top features y=soc.religion.christian top features
Weight? Feature
+2.836 atheism
+2.378 writes
+1.634 morality
+1.625 motto
+1.554 religion
+1.509 islam
+1.489 keith
+1.476 religious
+1.439 objective
+1.414 wrote
+1.405 said
+1.361 punishment
+1.335 livesey
+1.332 mathew
+1.324 atheist
+1.320 agree
… 47696 more positive …
… 53202 more negative …
-1.776 rutgers edu
-1.795 rutgers
-1.872 christ
-2.132 christians
Weight? Feature
+3.737 graphics
+2.447 image
+2.056 code
+2.021 files
+1.956 images
+1.813 3d
+1.787 software
+1.717 file
+1.701 ftp
+1.587 video
+1.572 keywords
+1.572 card
+1.509 points
+1.500 line
+1.494 need
+1.483 computer
+1.470 hi
… 30146 more positive …
… 33635 more negative …
-1.654 people
-1.760 keyboard
-1.779 god
Weight? Feature
+2.209 health
+2.194 msg
+2.162 doctor
+2.150 disease
+2.147 treatment
+1.851 medical
+1.818 com
+1.704 pain
+1.663 effects
+1.616 cancer
+1.513 case
+1.453 diet
+1.447 blood
+1.439 information
+1.435 keyboard
+1.407 pitt
… 42291 more positive …
… 45715 more negative …
-1.462 church
-1.697 FEATURE[354651]
-1.750 graphics
-2.085 god
Weight? Feature
+3.245 church
+3.034 christians
+3.016 christ
+2.977 rutgers
+2.963 rutgers edu
+2.143 christian
+2.092 heaven
+1.963 love
+1.901 athos rutgers
+1.901 athos
+1.741 satan
+1.714 authority
+1.653 faith
+1.644 1993
+1.643 article apr
+1.633 understanding
+1.541 sin
+1.509 god
… 49948 more positive …
… 53234 more negative …
-1.525 graphics
-2.096 atheism

There are collisions (hover mouse over features with “…”), and there are important features which were not seen in the random sample (FEATURE[…]), but overall it looks fine.

“rutgers edu” bigram feature is suspicious though, it looks like a part of URL.

rutgers_example = [x for x in twenty_train.data if 'rutgers' in x.lower()][0]
print(rutgers_example)
In article <Apr.8.00.57.41.1993.28246@athos.rutgers.edu> REXLEX@fnal.gov writes:
>In article <Apr.7.01.56.56.1993.22824@athos.rutgers.edu> shrum@hpfcso.fc.hp.com
>Matt. 22:9-14 'Go therefore to the main highways, and as many as you find
>there, invite to the wedding feast.'...

>hmmmmmm.  Sounds like your theology and Christ's are at odds. Which one am I
>to believe?

Yep, it looks like model learned this address instead of learning something useful.

eli5.show_prediction(clf, rutgers_example, vec=vec,
                     target_names=twenty_test.target_names,
                     targets=['soc.religion.christian'])

y=soc.religion.christian (score 2.044) top features

Contribution? Feature
+2.706 Highlighted in text (sum)
-0.662 <BIAS>

in article <apr.8.00.57.41.1993.28246@athos.rutgers.edu> rexlex@fnal.gov writes: >in article <apr.7.01.56.56.1993.22824@athos.rutgers.edu> shrum@hpfcso.fc.hp.com >matt. 22:9-14 'go therefore to the main highways, and as many as you find >there, invite to the wedding feast.'... >hmmmmmm. sounds like your theology and christ's are at odds. which one am i >to believe?

Quoted text makes it too easy for model to classify some of the messages; that won’t generalize to new messages. So to improve the model next step could be to process the data further, e.g. remove quoted text or replace email addresses with a special token.

You get the idea: looking at features helps to understand how classifier works. Maybe even more importantly, it helps to notice preprocessing bugs, data leaks, issues with task specification - all these nasty problems you get in a real world.