summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorÉtienne Mollier <emollier@debian.org>2023-08-18 20:29:40 +0200
committerÉtienne Mollier <emollier@debian.org>2023-08-18 20:29:40 +0200
commit44e15dc2752db3389f7a58973c881428fe3ee9b9 (patch)
treece4d5ea53115cdfa854feed0e158c5deb94c382e
parentb856bf0d18f08ea7a7aee699e714ef9787809d46 (diff)
parent5a1ef3f52ab8ce64cbf992aa364462abf584c596 (diff)
Update upstream source from tag 'upstream/2023.7.0'
Update to upstream version '2023.7.0' with Debian dir 2618fa743e02c15eb038e44b557f7fb8e48e78c3
-rw-r--r--.github/workflows/ci-dev.yaml12
-rw-r--r--.github/workflows/ci.yml55
-rw-r--r--.github/workflows/join-release.yaml6
-rw-r--r--.github/workflows/tag-release.yaml7
-rw-r--r--LICENSE2
-rw-r--r--ci/recipe/meta.yaml3
-rw-r--r--q2_sample_classifier/__init__.py2
-rw-r--r--q2_sample_classifier/_format.py2
-rw-r--r--q2_sample_classifier/_transformer.py2
-rw-r--r--q2_sample_classifier/_type.py2
-rw-r--r--q2_sample_classifier/_version.py6
-rw-r--r--q2_sample_classifier/classify.py2
-rw-r--r--q2_sample_classifier/plugin_setup.py9
-rw-r--r--q2_sample_classifier/tests/__init__.py2
-rw-r--r--q2_sample_classifier/tests/test_actions.py2
-rw-r--r--q2_sample_classifier/tests/test_base_class.py2
-rw-r--r--q2_sample_classifier/tests/test_classifier.py2
-rw-r--r--q2_sample_classifier/tests/test_estimators.py55
-rw-r--r--q2_sample_classifier/tests/test_types_formats_transformers.py2
-rw-r--r--q2_sample_classifier/tests/test_utilities.py2
-rw-r--r--q2_sample_classifier/tests/test_visualization.py2
-rw-r--r--q2_sample_classifier/utilities.py43
-rw-r--r--q2_sample_classifier/visuals.py2
-rw-r--r--setup.py2
24 files changed, 121 insertions, 105 deletions
diff --git a/.github/workflows/ci-dev.yaml b/.github/workflows/ci-dev.yaml
new file mode 100644
index 0000000..f66c713
--- /dev/null
+++ b/.github/workflows/ci-dev.yaml
@@ -0,0 +1,12 @@
+# Example of workflow trigger for calling workflow (the client).
+name: ci-dev
+on:
+ pull_request:
+ branches: ["dev"]
+ push:
+ branches: ["dev"]
+jobs:
+ ci:
+ uses: qiime2/distributions/.github/workflows/lib-ci-dev.yaml@dev
+ with:
+ distro: core \ No newline at end of file
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
deleted file mode 100644
index e994794..0000000
--- a/.github/workflows/ci.yml
+++ /dev/null
@@ -1,55 +0,0 @@
-# This file is automatically generated by busywork.qiime2.org and
-# template-repos - any manual edits made to this file will be erased when
-# busywork performs maintenance updates.
-
-name: ci
-
-on:
- pull_request:
- push:
- branches:
- - master
-
-jobs:
- lint:
- runs-on: ubuntu-latest
- steps:
- - name: checkout source
- uses: actions/checkout@v2
-
- - name: set up python 3.8
- uses: actions/setup-python@v1
- with:
- python-version: 3.8
-
- - name: install dependencies
- run: python -m pip install --upgrade pip
-
- - name: lint
- run: |
- pip install -q https://github.com/qiime2/q2lint/archive/master.zip
- q2lint
- pip install -q flake8
- flake8
-
- build-and-test:
- needs: lint
- strategy:
- matrix:
- os: [ubuntu-latest, macos-latest]
- runs-on: ${{ matrix.os }}
- steps:
- - name: checkout source
- uses: actions/checkout@v2
- with:
- fetch-depth: 0
-
- - name: set up git repo for versioneer
- run: git fetch --depth=1 origin +refs/tags/*:refs/tags/*
-
- - uses: qiime2/action-library-packaging@alpha1
- with:
- package-name: q2-sample-classifier
- build-target: dev
- additional-tests: py.test --pyargs q2_sample_classifier
- library-token: ${{ secrets.LIBRARY_TOKEN }}
diff --git a/.github/workflows/join-release.yaml b/.github/workflows/join-release.yaml
new file mode 100644
index 0000000..8669749
--- /dev/null
+++ b/.github/workflows/join-release.yaml
@@ -0,0 +1,6 @@
+name: join-release
+on:
+ workflow_dispatch: {}
+jobs:
+ release:
+ uses: qiime2/distributions/.github/workflows/lib-join-release.yaml@dev \ No newline at end of file
diff --git a/.github/workflows/tag-release.yaml b/.github/workflows/tag-release.yaml
new file mode 100644
index 0000000..8b0f228
--- /dev/null
+++ b/.github/workflows/tag-release.yaml
@@ -0,0 +1,7 @@
+name: tag-release
+on:
+ push:
+ branches: ["Release-*"]
+jobs:
+ tag:
+ uses: qiime2/distributions/.github/workflows/lib-tag-release.yaml@dev \ No newline at end of file
diff --git a/LICENSE b/LICENSE
index a77f678..71235b5 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,6 +1,6 @@
BSD 3-Clause License
-Copyright (c) 2017-2022, QIIME 2 development team.
+Copyright (c) 2017-2023, QIIME 2 development team.
All rights reserved.
Redistribution and use in source and binary forms, with or without
diff --git a/ci/recipe/meta.yaml b/ci/recipe/meta.yaml
index 94f805c..1a8de08 100644
--- a/ci/recipe/meta.yaml
+++ b/ci/recipe/meta.yaml
@@ -32,6 +32,9 @@ requirements:
- q2-feature-table {{ qiime2_epoch }}.*
test:
+ commands:
+ - py.test --pyargs q2_sample_classifier
+
requires:
- qiime2 >={{ qiime2 }}
- q2-types >={{ q2_types }}
diff --git a/q2_sample_classifier/__init__.py b/q2_sample_classifier/__init__.py
index e7a023f..70193e9 100644
--- a/q2_sample_classifier/__init__.py
+++ b/q2_sample_classifier/__init__.py
@@ -1,5 +1,5 @@
# ----------------------------------------------------------------------------
-# Copyright (c) 2017-2022, QIIME 2 development team.
+# Copyright (c) 2017-2023, QIIME 2 development team.
#
# Distributed under the terms of the Modified BSD License.
#
diff --git a/q2_sample_classifier/_format.py b/q2_sample_classifier/_format.py
index fe125f3..a08c129 100644
--- a/q2_sample_classifier/_format.py
+++ b/q2_sample_classifier/_format.py
@@ -1,5 +1,5 @@
# ----------------------------------------------------------------------------
-# Copyright (c) 2017-2022, QIIME 2 development team.
+# Copyright (c) 2017-2023, QIIME 2 development team.
#
# Distributed under the terms of the Modified BSD License.
#
diff --git a/q2_sample_classifier/_transformer.py b/q2_sample_classifier/_transformer.py
index 04299ed..f5b4e2a 100644
--- a/q2_sample_classifier/_transformer.py
+++ b/q2_sample_classifier/_transformer.py
@@ -1,5 +1,5 @@
# ----------------------------------------------------------------------------
-# Copyright (c) 2017-2022, QIIME 2 development team.
+# Copyright (c) 2017-2023, QIIME 2 development team.
#
# Distributed under the terms of the Modified BSD License.
#
diff --git a/q2_sample_classifier/_type.py b/q2_sample_classifier/_type.py
index 66a4e82..55251c1 100644
--- a/q2_sample_classifier/_type.py
+++ b/q2_sample_classifier/_type.py
@@ -1,5 +1,5 @@
# ----------------------------------------------------------------------------
-# Copyright (c) 2017-2022, QIIME 2 development team.
+# Copyright (c) 2017-2023, QIIME 2 development team.
#
# Distributed under the terms of the Modified BSD License.
#
diff --git a/q2_sample_classifier/_version.py b/q2_sample_classifier/_version.py
index 97323c1..8809afd 100644
--- a/q2_sample_classifier/_version.py
+++ b/q2_sample_classifier/_version.py
@@ -23,9 +23,9 @@ def get_keywords():
# setup.py/versioneer.py will grep for the variable names, so they must
# each be defined on a line of their own. _version.py will just call
# get_keywords().
- git_refnames = " (HEAD -> master, tag: 2022.11.1)"
- git_full = "f693e2087a65c868846472d3284c7ba9f00d4bfc"
- git_date = "2022-12-21 22:30:20 +0000"
+ git_refnames = " (tag: 2023.7.0, Release-2023.7)"
+ git_full = "28215705a9e042fe0ac5b5c1a9313ec13e5d1d07"
+ git_date = "2023-08-17 19:35:51 +0000"
keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
return keywords
diff --git a/q2_sample_classifier/classify.py b/q2_sample_classifier/classify.py
index 414881f..0e35436 100644
--- a/q2_sample_classifier/classify.py
+++ b/q2_sample_classifier/classify.py
@@ -1,5 +1,5 @@
# ----------------------------------------------------------------------------
-# Copyright (c) 2017-2022, QIIME 2 development team.
+# Copyright (c) 2017-2023, QIIME 2 development team.
#
# Distributed under the terms of the Modified BSD License.
#
diff --git a/q2_sample_classifier/plugin_setup.py b/q2_sample_classifier/plugin_setup.py
index 8ae4d3d..3a9b4b5 100644
--- a/q2_sample_classifier/plugin_setup.py
+++ b/q2_sample_classifier/plugin_setup.py
@@ -1,5 +1,5 @@
# ----------------------------------------------------------------------------
-# Copyright (c) 2017-2022, QIIME 2 development team.
+# Copyright (c) 2017-2023, QIIME 2 development team.
#
# Distributed under the terms of the Modified BSD License.
#
@@ -155,12 +155,15 @@ parameter_descriptions = {
classifiers = Str % Choices(
['RandomForestClassifier', 'ExtraTreesClassifier',
- 'GradientBoostingClassifier', 'AdaBoostClassifier',
+ 'GradientBoostingClassifier',
+ 'AdaBoostClassifier[DecisionTree]', 'AdaBoostClassifier[ExtraTrees]',
'KNeighborsClassifier', 'LinearSVC', 'SVC'])
regressors = Str % Choices(
['RandomForestRegressor', 'ExtraTreesRegressor',
- 'GradientBoostingRegressor', 'AdaBoostRegressor', 'ElasticNet',
+ 'GradientBoostingRegressor',
+ 'AdaBoostRegressor[DecisionTree]', 'AdaBoostRegressor[ExtraTrees]',
+ 'ElasticNet',
'Ridge', 'Lasso', 'KNeighborsRegressor', 'LinearSVR', 'SVR'])
output_descriptions = {
diff --git a/q2_sample_classifier/tests/__init__.py b/q2_sample_classifier/tests/__init__.py
index fed4ef6..357a7fc 100644
--- a/q2_sample_classifier/tests/__init__.py
+++ b/q2_sample_classifier/tests/__init__.py
@@ -1,5 +1,5 @@
# ----------------------------------------------------------------------------
-# Copyright (c) 2017-2022, QIIME 2 development team.
+# Copyright (c) 2017-2023, QIIME 2 development team.
#
# Distributed under the terms of the Modified BSD License.
#
diff --git a/q2_sample_classifier/tests/test_actions.py b/q2_sample_classifier/tests/test_actions.py
index 2e1bd2a..e13a86c 100644
--- a/q2_sample_classifier/tests/test_actions.py
+++ b/q2_sample_classifier/tests/test_actions.py
@@ -1,5 +1,5 @@
# ----------------------------------------------------------------------------
-# Copyright (c) 2017-2022, QIIME 2 development team.
+# Copyright (c) 2017-2023, QIIME 2 development team.
#
# Distributed under the terms of the Modified BSD License.
#
diff --git a/q2_sample_classifier/tests/test_base_class.py b/q2_sample_classifier/tests/test_base_class.py
index cdc0cc8..ed421aa 100644
--- a/q2_sample_classifier/tests/test_base_class.py
+++ b/q2_sample_classifier/tests/test_base_class.py
@@ -1,5 +1,5 @@
# ----------------------------------------------------------------------------
-# Copyright (c) 2017-2022, QIIME 2 development team.
+# Copyright (c) 2017-2023, QIIME 2 development team.
#
# Distributed under the terms of the Modified BSD License.
#
diff --git a/q2_sample_classifier/tests/test_classifier.py b/q2_sample_classifier/tests/test_classifier.py
index e17bbd9..6a21ddc 100644
--- a/q2_sample_classifier/tests/test_classifier.py
+++ b/q2_sample_classifier/tests/test_classifier.py
@@ -1,5 +1,5 @@
# ----------------------------------------------------------------------------
-# Copyright (c) 2017-2022, QIIME 2 development team.
+# Copyright (c) 2017-2023, QIIME 2 development team.
#
# Distributed under the terms of the Modified BSD License.
#
diff --git a/q2_sample_classifier/tests/test_estimators.py b/q2_sample_classifier/tests/test_estimators.py
index 277be3b..b42de3c 100644
--- a/q2_sample_classifier/tests/test_estimators.py
+++ b/q2_sample_classifier/tests/test_estimators.py
@@ -1,5 +1,5 @@
# ----------------------------------------------------------------------------
-# Copyright (c) 2017-2022, QIIME 2 development team.
+# Copyright (c) 2017-2023, QIIME 2 development team.
#
# Distributed under the terms of the Modified BSD License.
#
@@ -14,6 +14,7 @@ import json
import numpy as np
from sklearn.metrics import mean_squared_error, accuracy_score
from sklearn.ensemble import AdaBoostClassifier
+from sklearn.tree import DecisionTreeClassifier, ExtraTreeClassifier
from sklearn.feature_extraction import DictVectorizer
from sklearn.pipeline import Pipeline
import skbio
@@ -235,8 +236,10 @@ class EstimatorsTests(SampleClassifierTestPluginBase):
# when a random seed is set.
def test_classifiers(self):
for classifier in ['RandomForestClassifier', 'ExtraTreesClassifier',
- 'GradientBoostingClassifier', 'AdaBoostClassifier',
- 'LinearSVC', 'SVC', 'KNeighborsClassifier']:
+ 'GradientBoostingClassifier',
+ 'AdaBoostClassifier[DecisionTree]',
+ 'AdaBoostClassifier[ExtraTrees]', 'LinearSVC',
+ 'SVC', 'KNeighborsClassifier']:
table_fp = self.get_data_path('chardonnay.table.qza')
table = qiime2.Artifact.load(table_fp)
res = sample_classifier.actions.classify_samples(
@@ -356,9 +359,11 @@ class EstimatorsTests(SampleClassifierTestPluginBase):
# when a random seed is set.
def test_regressors(self):
for regressor in ['RandomForestRegressor', 'ExtraTreesRegressor',
- 'GradientBoostingRegressor', 'AdaBoostRegressor',
- 'Lasso', 'Ridge', 'ElasticNet',
- 'KNeighborsRegressor', 'LinearSVR', 'SVR']:
+ 'GradientBoostingRegressor',
+ 'AdaBoostRegressor[DecisionTree]',
+ 'AdaBoostRegressor[ExtraTrees]', 'Lasso', 'Ridge',
+ 'ElasticNet', 'KNeighborsRegressor', 'LinearSVR',
+ 'SVR']:
table_fp = self.get_data_path('ecam-table-maturity.qza')
table = qiime2.Artifact.load(table_fp)
res = sample_classifier.actions.regress_samples(
@@ -386,13 +391,25 @@ class EstimatorsTests(SampleClassifierTestPluginBase):
regressor, accuracy, seeded_results[regressor]))
# test adaboost base estimator trainer
- def test_train_adaboost_base_estimator(self):
+ def test_train_adaboost_decision_tree(self):
abe = _train_adaboost_base_estimator(
self.table_chard_fp, self.mdc_chard_fp, 'Region',
n_estimators=10, n_jobs=1, cv=3, random_state=None,
parameter_tuning=True, classification=True,
- missing_samples='ignore')
+ missing_samples='ignore', base_estimator="DecisionTree")
+ self.assertEqual(type(abe.named_steps.est), AdaBoostClassifier)
+ self.assertEqual(type(abe.named_steps.est.base_estimator),
+ DecisionTreeClassifier)
+
+ def test_train_adaboost_extra_trees(self):
+ abe = _train_adaboost_base_estimator(
+ self.table_chard_fp, self.mdc_chard_fp, 'Region',
+ n_estimators=10, n_jobs=1, cv=3, random_state=None,
+ parameter_tuning=True, classification=True,
+ missing_samples='ignore', base_estimator="ExtraTrees")
self.assertEqual(type(abe.named_steps.est), AdaBoostClassifier)
+ self.assertEqual(type(abe.named_steps.est.base_estimator),
+ ExtraTreeClassifier)
# test some invalid inputs/edge cases
def test_invalids(self):
@@ -458,8 +475,10 @@ class EstimatorsTests(SampleClassifierTestPluginBase):
# x, y, and z predicts the correct metadata values for those same samples.
def test_predict_classifications(self):
for classifier in ['RandomForestClassifier', 'ExtraTreesClassifier',
- 'GradientBoostingClassifier', 'AdaBoostClassifier',
- 'LinearSVC', 'SVC', 'KNeighborsClassifier']:
+ 'GradientBoostingClassifier',
+ 'AdaBoostClassifier[DecisionTree]',
+ 'AdaBoostClassifier[ExtraTrees]', 'LinearSVC',
+ 'SVC', 'KNeighborsClassifier']:
estimator, importances = fit_classifier(
self.table_chard_fp, self.mdc_chard_fp, random_state=123,
n_estimators=2, estimator=classifier, n_jobs=1,
@@ -494,7 +513,9 @@ class EstimatorsTests(SampleClassifierTestPluginBase):
def test_predict_regressions(self):
for regressor in ['RandomForestRegressor', 'ExtraTreesRegressor',
- 'GradientBoostingRegressor', 'AdaBoostRegressor',
+ 'GradientBoostingRegressor',
+ 'AdaBoostRegressor[DecisionTree]',
+ 'AdaBoostRegressor[ExtraTrees]',
'Lasso', 'Ridge', 'ElasticNet',
'KNeighborsRegressor', 'SVR', 'LinearSVR']:
estimator, importances = fit_regressor(
@@ -558,14 +579,16 @@ seeded_results = {
'RandomForestClassifier': 0.63636363636363635,
'ExtraTreesClassifier': 0.454545454545,
'GradientBoostingClassifier': 0.272727272727,
- 'AdaBoostClassifier': 0.272727272727,
+ 'AdaBoostClassifier[DecisionTree]': 0.272727272727,
+ 'AdaBoostClassifier[ExtraTrees]': 0.272727272727,
'LinearSVC': 0.818182,
'SVC': 0.36363636363636365,
'KNeighborsClassifier': 0.363636363636,
'RandomForestRegressor': 23.226508,
'ExtraTreesRegressor': 19.725397,
'GradientBoostingRegressor': 34.157100,
- 'AdaBoostRegressor': 30.920635,
+ 'AdaBoostRegressor[DecisionTree]': 30.920635,
+ 'AdaBoostRegressor[ExtraTrees]': 21.746031,
'Lasso': 722.827623,
'Ridge': 521.195194222418,
'ElasticNet': 618.532273,
@@ -577,14 +600,16 @@ seeded_predict_results = {
'RandomForestClassifier': 18,
'ExtraTreesClassifier': 21,
'GradientBoostingClassifier': 21,
- 'AdaBoostClassifier': 21,
+ 'AdaBoostClassifier[DecisionTree]': 21,
+ 'AdaBoostClassifier[ExtraTrees]': 21,
'LinearSVC': 21,
'SVC': 12,
'KNeighborsClassifier': 14,
'RandomForestRegressor': 7.4246031746,
'ExtraTreesRegressor': 0.,
'GradientBoostingRegressor': 50.1955883469,
- 'AdaBoostRegressor': 9.7857142857142865,
+ 'AdaBoostRegressor[DecisionTree]': 9.7857142857142865,
+ 'AdaBoostRegressor[ExtraTrees]': 33.95238095238095,
'Lasso': 0.173138653701,
'Ridge': 2.694020055323081e-05,
'ElasticNet': 0.0614243397637,
diff --git a/q2_sample_classifier/tests/test_types_formats_transformers.py b/q2_sample_classifier/tests/test_types_formats_transformers.py
index 4fc95f8..1aa9ffd 100644
--- a/q2_sample_classifier/tests/test_types_formats_transformers.py
+++ b/q2_sample_classifier/tests/test_types_formats_transformers.py
@@ -1,5 +1,5 @@
# ----------------------------------------------------------------------------
-# Copyright (c) 2017-2022, QIIME 2 development team.
+# Copyright (c) 2017-2023, QIIME 2 development team.
#
# Distributed under the terms of the Modified BSD License.
#
diff --git a/q2_sample_classifier/tests/test_utilities.py b/q2_sample_classifier/tests/test_utilities.py
index 2aef9f5..0a38aa9 100644
--- a/q2_sample_classifier/tests/test_utilities.py
+++ b/q2_sample_classifier/tests/test_utilities.py
@@ -1,5 +1,5 @@
# ----------------------------------------------------------------------------
-# Copyright (c) 2017-2022, QIIME 2 development team.
+# Copyright (c) 2017-2023, QIIME 2 development team.
#
# Distributed under the terms of the Modified BSD License.
#
diff --git a/q2_sample_classifier/tests/test_visualization.py b/q2_sample_classifier/tests/test_visualization.py
index eb09e89..33eb4da 100644
--- a/q2_sample_classifier/tests/test_visualization.py
+++ b/q2_sample_classifier/tests/test_visualization.py
@@ -1,5 +1,5 @@
# ----------------------------------------------------------------------------
-# Copyright (c) 2017-2022, QIIME 2 development team.
+# Copyright (c) 2017-2023, QIIME 2 development team.
#
# Distributed under the terms of the Modified BSD License.
#
diff --git a/q2_sample_classifier/utilities.py b/q2_sample_classifier/utilities.py
index 6f57456..334044c 100644
--- a/q2_sample_classifier/utilities.py
+++ b/q2_sample_classifier/utilities.py
@@ -1,5 +1,5 @@
# ----------------------------------------------------------------------------
-# Copyright (c) 2017-2022, QIIME 2 development team.
+# Copyright (c) 2017-2023, QIIME 2 development team.
#
# Distributed under the terms of the Modified BSD License.
#
@@ -21,7 +21,10 @@ from sklearn.ensemble import (RandomForestRegressor, RandomForestClassifier,
from sklearn.svm import SVR, SVC
from sklearn.linear_model import Ridge, Lasso, ElasticNet
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
-from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
+from sklearn.tree import (
+ DecisionTreeClassifier, DecisionTreeRegressor,
+ ExtraTreeClassifier, ExtraTreeRegressor
+)
from sklearn.pipeline import Pipeline
import q2templates
@@ -32,6 +35,7 @@ import pkg_resources
from scipy.sparse import issparse
from scipy.stats import randint
import biom
+import re
from .visuals import (_linear_regress, _plot_confusion_matrix, _plot_RFE,
_regplot_from_dataframe, _generate_roc_plots)
@@ -781,32 +785,42 @@ def _select_estimator(estimator, n_jobs, n_estimators, random_state=None):
return param_dist, estimator
-def _train_adaboost_base_estimator(table, metadata, column, n_estimators,
- n_jobs, cv, random_state=None,
+def _train_adaboost_base_estimator(table, metadata, column, base_estimator,
+ n_estimators, n_jobs, cv, random_state=None,
parameter_tuning=False,
classification=True,
missing_samples='error'):
param_dist = parameters['ensemble']
+
if classification:
- base_estimator = DecisionTreeClassifier()
+ base_est = {
+ 'DecisionTree': DecisionTreeClassifier(),
+ 'ExtraTrees': ExtraTreeClassifier()
+ }
+ pipe_base_estimator = base_est[base_estimator]
adaboost_estimator = AdaBoostClassifier
else:
- base_estimator = DecisionTreeRegressor()
+ base_est = {
+ 'DecisionTree': DecisionTreeRegressor(),
+ 'ExtraTrees': ExtraTreeRegressor()
+ }
+ pipe_base_estimator = base_est[base_estimator]
adaboost_estimator = AdaBoostRegressor
- base_estimator = Pipeline(
- [('dv', DictVectorizer()), ('est', base_estimator)])
+
+ estimator = Pipeline(
+ [('dv', DictVectorizer()), ('est', pipe_base_estimator)])
if parameter_tuning:
features, targets = _load_data(
table, metadata, missing_samples=missing_samples)
param_dist = _map_params_to_pipeline(param_dist)
base_estimator = _tune_parameters(
- features, targets[column], base_estimator, param_dist,
+ features, targets[column], estimator, param_dist,
n_jobs=n_jobs, cv=cv, random_state=random_state).best_estimator_
return Pipeline(
- [('dv', base_estimator.named_steps.dv),
- ('est', adaboost_estimator(base_estimator.named_steps.est,
+ [('dv', estimator.named_steps.dv),
+ ('est', adaboost_estimator(estimator.named_steps.est,
n_estimators, random_state=random_state))])
@@ -830,10 +844,11 @@ def _set_parameters_and_estimator(estimator, table, metadata, column,
parameter_tuning, classification=True,
missing_samples='error'):
# specify parameters and distributions to sample from for parameter tuning
- if estimator in ['AdaBoostClassifier', 'AdaBoostRegressor']:
+ if estimator.startswith("AdaBoost"):
+ base_estimator = re.search(r"\[([A-Za-z]+)\]", estimator).group(1)
estimator = _train_adaboost_base_estimator(
- table, metadata, column, n_estimators, n_jobs, cv, random_state,
- parameter_tuning, classification=classification,
+ table, metadata, column, base_estimator, n_estimators, n_jobs, cv,
+ random_state, parameter_tuning, classification=classification,
missing_samples=missing_samples)
parameter_tuning = False
param_dist = None
diff --git a/q2_sample_classifier/visuals.py b/q2_sample_classifier/visuals.py
index 6a7f0b1..1cfced6 100644
--- a/q2_sample_classifier/visuals.py
+++ b/q2_sample_classifier/visuals.py
@@ -1,5 +1,5 @@
# ----------------------------------------------------------------------------
-# Copyright (c) 2017-2022, QIIME 2 development team.
+# Copyright (c) 2017-2023, QIIME 2 development team.
#
# Distributed under the terms of the Modified BSD License.
#
diff --git a/setup.py b/setup.py
index 925f1d4..7244737 100644
--- a/setup.py
+++ b/setup.py
@@ -1,5 +1,5 @@
# ----------------------------------------------------------------------------
-# Copyright (c) 2017-2022, QIIME 2 development team.
+# Copyright (c) 2017-2023, QIIME 2 development team.
#
# Distributed under the terms of the Modified BSD License.
#