"""Caching loader for the 20 newsgroups text classification datasetThe description of the dataset is available on the official website at:http://people.csail.mit.edu/jrennie/20Newsgroups/Quoting the introduction:The 20 Newsgroups data set is a collection of approximately 20,000newsgroup documents, partitioned (nearly) evenly across 20 differentnewsgroups. To the best of my knowledge, it was originally collectedby Ken Lang, probably for his Newsweeder: Learning to filter netnewspaper, though he does not explicitly mention this collection. The 20newsgroups collection has become a popular data set for experimentsin text applications of machine learning techniques, such as textclassification and text clustering.This dataset loader will download the recommended "by date" variant of the
dataset and which features a point in time split between the train and
test sets. The compressed dataset size is around 14 Mb compressed. Once
uncompressed the train set is 52 MB and the test set is 34 MB.The data is downloaded, extracted and cached in the '~/scikit_learn_data'
folder.The `fetch_20newsgroups` function will not vectorize the data into numpy
arrays but the dataset lists the filenames of the posts and their categories
as target labels.The `fetch_20newsgroups_tfidf` function will in addition do a simple tf-idf
vectorization step."""
# Copyright (c) 2011 Olivier Grisel <olivier.grisel@ensta.org>
# License: BSD 3 clauseimport os
import logging
import tarfile
import pickle
import shutil
import re
import codecsimport numpy as np
import scipy.sparse as spfrom .base import get_data_home
from .base import Bunch
from .base import load_files
from ..utils import check_random_state
from ..feature_extraction.text import CountVectorizer
from ..preprocessing import normalize
from ..externals import joblib, sixif six.PY3:from urllib.request import urlopen
else:from urllib2 import urlopenlogger = logging.getLogger(__name__)URL = ("http://people.csail.mit.edu/jrennie/""20Newsgroups/20news-bydate.tar.gz")
ARCHIVE_NAME = "20news-bydate.tar.gz"
CACHE_NAME = "20news-bydate.pkz"
TRAIN_FOLDER = "20news-bydate-train"
TEST_FOLDER = "20news-bydate-test"def download_20newsgroups(target_dir, cache_path):"""Download the 20 newsgroups data and stored it as a zipped pickle."""archive_path = os.path.join(target_dir, ARCHIVE_NAME)train_path = os.path.join(target_dir, TRAIN_FOLDER)test_path = os.path.join(target_dir, TEST_FOLDER)# if not os.path.exists(target_dir):# os.makedirs(target_dir)## if os.path.exists(archive_path):# # Download is not complete as the .tar.gz file is removed after# # download.# logger.warn("Download was incomplete, downloading again.")# os.remove(archive_path)# logger.warn("Downloading dataset from %s (14 MB)", URL)# opener = urlopen(URL)# open(archive_path, 'wb').write(opener.read())logger.info("Decompressing %s", archive_path)tarfile.open(archive_path, "r:gz").extractall(path=target_dir)os.remove(archive_path)# Store a zipped picklecache = dict(train=load_files(train_path, encoding='latin1'),test=load_files(test_path, encoding='latin1'))compressed_content = codecs.encode(pickle.dumps(cache), 'zlib_codec')open(cache_path, 'wb').write(compressed_content)shutil.rmtree(target_dir)return cachedef strip_newsgroup_header(text):"""Given text in "news" format, strip the headers, by removing everythingbefore the first blank line."""_before, _blankline, after = text.partition('\n\n')return after_QUOTE_RE = re.compile(r'(writes in|writes:|wrote:|says:|said:'r'|^In article|^Quoted from|^\||^>)')def strip_newsgroup_quoting(text):"""Given text in "news" format, strip lines beginning with the quotecharacters > or |, plus lines that often introduce a quoted section(for example, because they contain the string 'writes:'.)"""good_lines = [line for line in text.split('\n')if not _QUOTE_RE.search(line)]return '\n'.join(good_lines)def strip_newsgroup_footer(text):"""Given text in "news" format, attempt to remove a signature block.As a rough heuristic, we assume that signatures are set apart by eithera blank line or a line made of hyphens, and that it is the last such linein the file (disregarding blank lines at the end)."""lines = text.strip().split('\n')for line_num in range(len(lines) - 1, -1, -1):line = lines[line_num]if line.strip().strip('-') == '':breakif line_num > 0:return '\n'.join(lines[:line_num])else:return textdef fetch_20newsgroups(data_home=None, subset='train', categories=None,shuffle=True, random_state=42,remove=(),download_if_missing=True):"""Load the filenames and data from the 20 newsgroups dataset.Parameters----------subset: 'train' or 'test', 'all', optionalSelect the dataset to load: 'train' for the training set, 'test'for the test set, 'all' for both, with shuffled ordering.data_home: optional, default: NoneSpecify an download and cache folder for the datasets. If None,all scikit-learn data is stored in '~/scikit_learn_data' subfolders.categories: None or collection of string or unicodeIf None (default), load all the categories.If not None, list of category names to load (other categoriesignored).shuffle: bool, optionalWhether or not to shuffle the data: might be important for models thatmake the assumption that the samples are independent and identicallydistributed (i.i.d.), such as stochastic gradient descent.random_state: numpy random number generator or seed integerUsed to shuffle the dataset.download_if_missing: optional, True by defaultIf False, raise an IOError if the data is not locally availableinstead of trying to download the data from the source site.remove: tupleMay contain any subset of ('headers', 'footers', 'quotes'). Each ofthese are kinds of text that will be detected and removed from thenewsgroup posts, preventing classifiers from overfitting onmetadata.'headers' removes newsgroup headers, 'footers' removes blocks at theends of posts that look like signatures, and 'quotes' removes linesthat appear to be quoting another post.'headers' follows an exact standard; the other filters are not alwayscorrect."""data_home = get_data_home(data_home=data_home)cache_path = os.path.join(data_home, CACHE_NAME)twenty_home = os.path.join(data_home, "20news_home")cache = Noneif os.path.exists(cache_path):try:with open(cache_path, 'rb') as f:compressed_content = f.read()uncompressed_content = codecs.decode(compressed_content, 'zlib_codec')cache = pickle.loads(uncompressed_content)except Exception as e:print(80 * '_')print('Cache loading failed')print(80 * '_')print(e)if cache is None:if download_if_missing:cache = download_20newsgroups(target_dir=twenty_home,cache_path=cache_path)else:raise IOError('20Newsgroups dataset not found')if subset in ('train', 'test'):data = cache[subset]elif subset == 'all':data_lst = list()target = list()filenames = list()for subset in ('train', 'test'):data = cache[subset]data_lst.extend(data.data)target.extend(data.target)filenames.extend(data.filenames)data.data = data_lstdata.target = np.array(target)data.filenames = np.array(filenames)data.description = 'the 20 newsgroups by date dataset'else:raise ValueError("subset can only be 'train', 'test' or 'all', got '%s'" % subset)if 'headers' in remove:data.data = [strip_newsgroup_header(text) for text in data.data]if 'footers' in remove:data.data = [strip_newsgroup_footer(text) for text in data.data]if 'quotes' in remove:data.data = [strip_newsgroup_quoting(text) for text in data.data]if categories is not None:labels = [(data.target_names.index(cat), cat) for cat in categories]# Sort the categories to have the ordering of the labelslabels.sort()labels, categories = zip(*labels)mask = np.in1d(data.target, labels)data.filenames = data.filenames[mask]data.target = data.target[mask]# searchsorted to have continuous labelsdata.target = np.searchsorted(labels, data.target)data.target_names = list(categories)# Use an object array to shuffle: avoids memory copydata_lst = np.array(data.data, dtype=object)data_lst = data_lst[mask]data.data = data_lst.tolist()if shuffle:random_state = check_random_state(random_state)indices = np.arange(data.target.shape[0])random_state.shuffle(indices)data.filenames = data.filenames[indices]data.target = data.target[indices]# Use an object array to shuffle: avoids memory copydata_lst = np.array(data.data, dtype=object)data_lst = data_lst[indices]data.data = data_lst.tolist()return datadef fetch_20newsgroups_vectorized(subset="train", remove=(), data_home=None):"""Load the 20 newsgroups dataset and transform it into tf-idf vectors.This is a convenience function; the tf-idf transformation is done using thedefault settings for `sklearn.feature_extraction.text.Vectorizer`. For moreadvanced usage (stopword filtering, n-gram extraction, etc.), combinefetch_20newsgroups with a custom `Vectorizer` or `CountVectorizer`.Parameters----------subset: 'train' or 'test', 'all', optionalSelect the dataset to load: 'train' for the training set, 'test'for the test set, 'all' for both, with shuffled ordering.data_home: optional, default: NoneSpecify an download and cache folder for the datasets. If None,all scikit-learn data is stored in '~/scikit_learn_data' subfolders.remove: tupleMay contain any subset of ('headers', 'footers', 'quotes'). Each ofthese are kinds of text that will be detected and removed from thenewsgroup posts, preventing classifiers from overfitting onmetadata.'headers' removes newsgroup headers, 'footers' removes blocks at theends of posts that look like signatures, and 'quotes' removes linesthat appear to be quoting another post.Returns-------bunch : Bunch objectbunch.data: sparse matrix, shape [n_samples, n_features]bunch.target: array, shape [n_samples]bunch.target_names: list, length [n_classes]"""data_home = get_data_home(data_home=data_home)filebase = '20newsgroup_vectorized'if remove:filebase += 'remove-' + ('-'.join(remove))target_file = os.path.join(data_home, filebase + ".pk")# we shuffle but use a fixed seed for the memoizationdata_train = fetch_20newsgroups(data_home=data_home,subset='train',categories=None,shuffle=True,random_state=12,remove=remove)data_test = fetch_20newsgroups(data_home=data_home,subset='test',categories=None,shuffle=True,random_state=12,remove=remove)if os.path.exists(target_file):X_train, X_test = joblib.load(target_file)else:vectorizer = CountVectorizer(dtype=np.int16)X_train = vectorizer.fit_transform(data_train.data).tocsr()X_test = vectorizer.transform(data_test.data).tocsr()joblib.dump((X_train, X_test), target_file, compress=9)# the data is stored as int16 for compactness# but normalize needs floatsX_train = X_train.astype(np.float64)X_test = X_test.astype(np.float64)normalize(X_train, copy=False)normalize(X_test, copy=False)target_names = data_train.target_namesif subset == "train":data = X_traintarget = data_train.targetelif subset == "test":data = X_testtarget = data_test.targetelif subset == "all":data = sp.vstack((X_train, X_test)).tocsr()target = np.concatenate((data_train.target, data_test.target))else:raise ValueError("%r is not a valid subset: should be one of ""['train', 'test', 'all']" % subset)return Bunch(data=data, target=target, target_names=target_names)