.. note::
    :class: sphx-glr-download-link-note

    Click :ref:`here <sphx_glr_download_auto_examples_compose_plot_column_transformer.py>` to download the full example code
.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_compose_plot_column_transformer.py:


==================================================
Column Transformer with Heterogeneous Data Sources
==================================================

Datasets can often contain components of that require different feature
extraction and processing pipelines.  This scenario might occur when:

1. Your dataset consists of heterogeneous data types (e.g. raster images and
   text captions)
2. Your dataset is stored in a Pandas DataFrame and different columns
   require different processing pipelines.

This example demonstrates how to use
:class:`sklearn.compose.ColumnTransformer` on a dataset containing
different types of features.  We use the 20-newsgroups dataset and compute
standard bag-of-words features for the subject line and body in separate
pipelines as well as ad hoc features on the body. We combine them (with
weights) using a ColumnTransformer and finally train a classifier on the
combined set of features.

The choice of features is not particularly helpful, but serves to illustrate
the technique.



.. code-block:: pytb

    Traceback (most recent call last):
      File "/usr/lib/python3/dist-packages/sphinx_gallery/gen_rst.py", line 440, in _memory_usage
        out = func()
      File "/usr/lib/python3/dist-packages/sphinx_gallery/gen_rst.py", line 425, in __call__
        exec(self.code, self.globals)
      File "/build/scikit-learn-WWYjTV/scikit-learn-0.22.2.post1+dfsg/examples/compose/plot_column_transformer.py", line 120, in <module>
        X_train, y_train = fetch_20newsgroups(random_state=1,
      File "/build/scikit-learn-WWYjTV/scikit-learn-0.22.2.post1+dfsg/.pybuild/cpython3_3.8/build/sklearn/datasets/_twenty_newsgroups.py", line 250, in fetch_20newsgroups
        cache = _download_20newsgroups(target_dir=twenty_home,
      File "/build/scikit-learn-WWYjTV/scikit-learn-0.22.2.post1+dfsg/.pybuild/cpython3_3.8/build/sklearn/datasets/_twenty_newsgroups.py", line 73, in _download_20newsgroups
        archive_path = _fetch_remote(ARCHIVE, dirname=target_dir)
      File "/build/scikit-learn-WWYjTV/scikit-learn-0.22.2.post1+dfsg/.pybuild/cpython3_3.8/build/sklearn/datasets/_base.py", line 902, in _fetch_remote
        urlretrieve(remote.url, file_path)
      File "/usr/lib/python3.8/urllib/request.py", line 247, in urlretrieve
        with contextlib.closing(urlopen(url, data)) as fp:
      File "/usr/lib/python3.8/urllib/request.py", line 222, in urlopen
        return opener.open(url, data, timeout)
      File "/usr/lib/python3.8/urllib/request.py", line 525, in open
        response = self._open(req, data)
      File "/usr/lib/python3.8/urllib/request.py", line 542, in _open
        result = self._call_chain(self.handle_open, protocol, protocol +
      File "/usr/lib/python3.8/urllib/request.py", line 502, in _call_chain
        result = func(*args)
      File "/usr/lib/python3.8/urllib/request.py", line 1393, in https_open
        return self.do_open(http.client.HTTPSConnection, req,
      File "/usr/lib/python3.8/urllib/request.py", line 1353, in do_open
        raise URLError(err)
    urllib.error.URLError: <urlopen error [Errno -2] Name or service not known>





.. code-block:: default


    # Author: Matt Terry <matt.terry@gmail.com>
    #
    # License: BSD 3 clause

    import numpy as np

    from sklearn.base import BaseEstimator, TransformerMixin
    from sklearn.datasets import fetch_20newsgroups
    from sklearn.decomposition import TruncatedSVD
    from sklearn.feature_extraction import DictVectorizer
    from sklearn.feature_extraction.text import TfidfVectorizer
    from sklearn.metrics import classification_report
    from sklearn.pipeline import Pipeline
    from sklearn.compose import ColumnTransformer
    from sklearn.svm import LinearSVC


    class TextStats(TransformerMixin, BaseEstimator):
        """Extract features from each document for DictVectorizer"""

        def fit(self, x, y=None):
            return self

        def transform(self, posts):
            return [{'length': len(text),
                     'num_sentences': text.count('.')}
                    for text in posts]


    class SubjectBodyExtractor(TransformerMixin, BaseEstimator):
        """Extract the subject & body from a usenet post in a single pass.

        Takes a sequence of strings and produces a dict of sequences.  Keys are
        `subject` and `body`.
        """
        def fit(self, x, y=None):
            return self

        def transform(self, posts):
            # construct object dtype array with two columns
            # first column = 'subject' and second column = 'body'
            features = np.empty(shape=(len(posts), 2), dtype=object)
            for i, text in enumerate(posts):
                headers, _, bod = text.partition('\n\n')
                features[i, 1] = bod

                prefix = 'Subject:'
                sub = ''
                for line in headers.split('\n'):
                    if line.startswith(prefix):
                        sub = line[len(prefix):]
                        break
                features[i, 0] = sub

            return features


    pipeline = Pipeline([
        # Extract the subject & body
        ('subjectbody', SubjectBodyExtractor()),

        # Use ColumnTransformer to combine the features from subject and body
        ('union', ColumnTransformer(
            [
                # Pulling features from the post's subject line (first column)
                ('subject', TfidfVectorizer(min_df=50), 0),

                # Pipeline for standard bag-of-words model for body (second column)
                ('body_bow', Pipeline([
                    ('tfidf', TfidfVectorizer()),
                    ('best', TruncatedSVD(n_components=50)),
                ]), 1),

                # Pipeline for pulling ad hoc features from post's body
                ('body_stats', Pipeline([
                    ('stats', TextStats()),  # returns a list of dicts
                    ('vect', DictVectorizer()),  # list of dicts -> feature matrix
                ]), 1),
            ],

            # weight components in ColumnTransformer
            transformer_weights={
                'subject': 0.8,
                'body_bow': 0.5,
                'body_stats': 1.0,
            }
        )),

        # Use a SVC classifier on the combined features
        ('svc', LinearSVC(dual=False)),
    ], verbose=True)

    # limit the list of categories to make running this example faster.
    categories = ['alt.atheism', 'talk.religion.misc']
    X_train, y_train = fetch_20newsgroups(random_state=1,
                                          subset='train',
                                          categories=categories,
                                          remove=('footers', 'quotes'),
                                          return_X_y=True)
    X_test, y_test = fetch_20newsgroups(random_state=1,
                                        subset='test',
                                        categories=categories,
                                        remove=('footers', 'quotes'),
                                        return_X_y=True)

    pipeline.fit(X_train, y_train)
    y_pred = pipeline.predict(X_test)
    print(classification_report(y_test, y_pred))


.. rst-class:: sphx-glr-timing

   **Total running time of the script:** ( 0 minutes  0.007 seconds)


.. _sphx_glr_download_auto_examples_compose_plot_column_transformer.py:


.. only :: html

 .. container:: sphx-glr-footer
    :class: sphx-glr-footer-example



  .. container:: sphx-glr-download

     :download:`Download Python source code: plot_column_transformer.py <plot_column_transformer.py>`



  .. container:: sphx-glr-download

     :download:`Download Jupyter notebook: plot_column_transformer.ipynb <plot_column_transformer.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
