diff options
Diffstat (limited to 'silx/gui/plot/test/testStats.py')
-rw-r--r-- | silx/gui/plot/test/testStats.py | 273 |
1 files changed, 255 insertions, 18 deletions
diff --git a/silx/gui/plot/test/testStats.py b/silx/gui/plot/test/testStats.py index 8db8cc9..d5046ba 100644 --- a/silx/gui/plot/test/testStats.py +++ b/silx/gui/plot/test/testStats.py @@ -35,6 +35,11 @@ from silx.gui.plot import StatsWidget from silx.gui.plot.stats import statshandler from silx.gui.utils.testutils import TestCaseQt, SignalListener from silx.gui.plot import Plot1D, Plot2D +from silx.gui.plot3d.SceneWidget import SceneWidget +from silx.gui.plot.items.roi import RectangleROI, PolygonROI +from silx.gui.plot.tools.roi import RegionOfInterestManager +from silx.gui.plot.stats.stats import Stats +from silx.gui.plot.CurvesROIWidget import ROI from silx.utils.testutils import ParametricTestCase import unittest import logging @@ -43,12 +48,9 @@ import numpy _logger = logging.getLogger(__name__) -class TestStats(TestCaseQt): - """ - Test :class:`BaseClass` class and inheriting classes - """ +class TestStatsBase(object): + """Base class for stats TestCase""" def setUp(self): - TestCaseQt.setUp(self) self.createCurveContext() self.createImageContext() self.createScatterContext() @@ -63,7 +65,6 @@ class TestStats(TestCaseQt): self.scatterPlot.setAttribute(qt.Qt.WA_DeleteOnClose) self.scatterPlot.close() del self.scatterPlot - TestCaseQt.tearDown(self) def createCurveContext(self): self.plot1d = Plot1D() @@ -74,12 +75,13 @@ class TestStats(TestCaseQt): self.curveContext = stats._CurveContext( item=self.plot1d.getCurve('curve0'), plot=self.plot1d, - onlimits=False) + onlimits=False, + roi=None) def createScatterContext(self): self.scatterPlot = Plot2D() lgd = 'scatter plot' - self.xScatterData = numpy.array([0, 1, 2, 20, 50, 60, 36]) + self.xScatterData = numpy.array([0, 2, 3, 20, 50, 60, 36]) self.yScatterData = numpy.array([2, 3, 4, 26, 69, 6, 18]) self.valuesScatterData = numpy.array([5, 6, 7, 10, 90, 20, 5]) self.scatterPlot.addScatter(self.xScatterData, self.yScatterData, @@ -87,7 +89,8 @@ class TestStats(TestCaseQt): self.scatterContext = stats._ScatterContext( item=self.scatterPlot.getScatter(lgd), plot=self.scatterPlot, - onlimits=False + onlimits=False, + roi=None ) def createImageContext(self): @@ -99,7 +102,8 @@ class TestStats(TestCaseQt): self.imageContext = stats._ImageContext( item=self.plot2d.getImage(self._imgLgd), plot=self.plot2d, - onlimits=False + onlimits=False, + roi=None ) def getBasicStats(self): @@ -113,6 +117,19 @@ class TestStats(TestCaseQt): 'com': stats.StatCOM() } + +class TestStats(TestStatsBase, TestCaseQt): + """ + Test :class:`BaseClass` class and inheriting classes + """ + def setUp(self): + TestCaseQt.setUp(self) + TestStatsBase.setUp(self) + + def tearDown(self): + TestStatsBase.tearDown(self) + TestCaseQt.tearDown(self) + def testBasicStatsCurve(self): """Test result for simple stats on a curve""" _stats = self.getBasicStats() @@ -155,7 +172,8 @@ class TestStats(TestCaseQt): image2Context = stats._ImageContext( item=self.plot2d.getImage(self._imgLgd), plot=self.plot2d, - onlimits=False + onlimits=False, + roi=None, ) _stats = self.getBasicStats() self.assertEqual(_stats['min'].calculate(image2Context), 0) @@ -225,21 +243,24 @@ class TestStats(TestCaseQt): curveContextOnLimits = stats._CurveContext( item=self.plot1d.getCurve('curve0'), plot=self.plot1d, - onlimits=True) + onlimits=True, + roi=None) self.assertEqual(stat.calculate(curveContextOnLimits), 2) self.plot2d.getXAxis().setLimitsConstraints(minPos=32) imageContextOnLimits = stats._ImageContext( item=self.plot2d.getImage('test image'), plot=self.plot2d, - onlimits=True) + onlimits=True, + roi=None) self.assertEqual(stat.calculate(imageContextOnLimits), 32) self.scatterPlot.getXAxis().setLimitsConstraints(minPos=40) scatterContextOnLimits = stats._ScatterContext( item=self.scatterPlot.getScatter('scatter plot'), plot=self.scatterPlot, - onlimits=True) + onlimits=True, + roi=None) self.assertEqual(stat.calculate(scatterContextOnLimits), 20) @@ -255,7 +276,8 @@ class TestStatsFormatter(TestCaseQt): self.curveContext = stats._CurveContext( item=self.plot1d.getCurve('curve0'), plot=self.plot1d, - onlimits=False) + onlimits=False, + roi=None) self.stat = stats.StatMin() @@ -295,6 +317,7 @@ class TestStatsHandler(TestCaseQt): self.stat = stats.StatMin() def tearDown(self): + Stats._getContext.cache_clear() self.plot1d.setAttribute(qt.Qt.WA_DeleteOnClose) self.plot1d.close() self.plot1d = None @@ -391,6 +414,7 @@ class TestStatsWidgetWithCurves(TestCaseQt, ParametricTestCase): self.statsTable.setStats(mystats) def tearDown(self): + Stats._getContext.cache_clear() self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) self.plot.close() self.statsTable = None @@ -493,7 +517,6 @@ class TestStatsWidgetWithCurves(TestCaseQt, ParametricTestCase): self.qapp.processEvents() tableItems = self.statsTable._itemToTableItems(self.plot.getCurve('curve0')) curve0_min = tableItems['min'].text() - print(curve0_min) self.assertTrue(float(curve0_min) == -1.) self.plot.getCurve('curve0').setData(x=range(4), y=range(1, 5)) @@ -581,6 +604,7 @@ class TestStatsWidgetWithImages(TestCaseQt): self.widget.setStats(mystats) def tearDown(self): + Stats._getContext.cache_clear() self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) self.plot.close() self.widget.setAttribute(qt.Qt.WA_DeleteOnClose) @@ -641,6 +665,7 @@ class TestStatsWidgetWithScatters(TestCaseQt): self.widget.setStats(mystats) def tearDown(self): + Stats._getContext.cache_clear() self.scatterPlot.setAttribute(qt.Qt.WA_DeleteOnClose) self.scatterPlot.close() self.widget.setAttribute(qt.Qt.WA_DeleteOnClose) @@ -694,6 +719,7 @@ class TestLineWidget(TestCaseQt): stats=mystats) def tearDown(self): + Stats._getContext.cache_clear() self.qapp.processEvents() self.plot.setAttribute(qt.Qt.WA_DeleteOnClose) self.plot.close() @@ -806,12 +832,223 @@ class TestUpdateModeWidget(TestCaseQt): self.assertEqual(manualUpdateListener.callCount(), 2) +class TestStatsROI(TestStatsBase, TestCaseQt): + """ + Test stats based on ROI + """ + def setUp(self): + TestCaseQt.setUp(self) + self.createRois() + TestStatsBase.setUp(self) + self.createHistogramContext() + + self.roiManager = RegionOfInterestManager(self.plot2d) + self.roiManager.addRoi(self._2Droi_rect) + self.roiManager.addRoi(self._2Droi_poly) + + def tearDown(self): + self.roiManager.clear() + self.roiManager = None + self._1Droi = None + self._2Droi_rect = None + self._2Droi_poly = None + self.plotHisto.setAttribute(qt.Qt.WA_DeleteOnClose) + self.plotHisto.close() + self.plotHisto = None + TestStatsBase.tearDown(self) + TestCaseQt.tearDown(self) + + def createRois(self): + self._1Droi = ROI(name='my1DRoi', fromdata=2.0, todata=5.0) + self._2Droi_rect = RectangleROI() + self._2Droi_rect.setGeometry(size=(10, 10), origin=(10, 0)) + self._2Droi_poly = PolygonROI() + points = numpy.array(((0, 20), (0, 0), (10, 0))) + self._2Droi_poly.setPoints(points=points) + + def createCurveContext(self): + TestStatsBase.createCurveContext(self) + self.curveContext = stats._CurveContext( + item=self.plot1d.getCurve('curve0'), + plot=self.plot1d, + onlimits=False, + roi=self._1Droi) + + def createHistogramContext(self): + self.plotHisto = Plot1D() + x = range(20) + y = range(20) + self.plotHisto.addHistogram(x, y, legend='histo0') + + self.histoContext = stats._HistogramContext( + item=self.plotHisto.getHistogram('histo0'), + plot=self.plotHisto, + onlimits=False, + roi=self._1Droi) + + def createScatterContext(self): + TestStatsBase.createScatterContext(self) + self.scatterContext = stats._ScatterContext( + item=self.scatterPlot.getScatter('scatter plot'), + plot=self.scatterPlot, + onlimits=False, + roi=self._1Droi + ) + + def createImageContext(self): + TestStatsBase.createImageContext(self) + + self.imageContext = stats._ImageContext( + item=self.plot2d.getImage(self._imgLgd), + plot=self.plot2d, + onlimits=False, + roi=self._2Droi_rect + ) + + self.imageContext_2 = stats._ImageContext( + item=self.plot2d.getImage(self._imgLgd), + plot=self.plot2d, + onlimits=False, + roi=self._2Droi_poly + ) + + def testErrors(self): + # test if onlimits is True and give also a roi + with self.assertRaises(ValueError): + stats._CurveContext(item=self.plot1d.getCurve('curve0'), + plot=self.plot1d, + onlimits=True, + roi=self._1Droi) + + # test if is a curve context and give an invalid 2D roi + with self.assertRaises(TypeError): + stats._CurveContext(item=self.plot1d.getCurve('curve0'), + plot=self.plot1d, + onlimits=False, + roi=self._2Droi_rect) + + def testBasicStatsCurve(self): + """Test result for simple stats on a curve""" + _stats = self.getBasicStats() + xData = yData = numpy.array(range(0, 10)) + self.assertEqual(_stats['min'].calculate(self.curveContext), 2) + self.assertEqual(_stats['max'].calculate(self.curveContext), 5) + self.assertEqual(_stats['minCoords'].calculate(self.curveContext), (2,)) + self.assertEqual(_stats['maxCoords'].calculate(self.curveContext), (5,)) + self.assertEqual(_stats['std'].calculate(self.curveContext), numpy.std(yData[2:6])) + self.assertEqual(_stats['mean'].calculate(self.curveContext), numpy.mean(yData[2:6])) + com = numpy.sum(xData[2:6] * yData[2:6]) / numpy.sum(yData[2:6]) + self.assertEqual(_stats['com'].calculate(self.curveContext), com) + + def testBasicStatsImageRectRoi(self): + """Test result for simple stats on an image""" + self.assertEqual(self.imageContext.values.compressed().size, 121) + _stats = self.getBasicStats() + self.assertEqual(_stats['min'].calculate(self.imageContext), 10) + self.assertEqual(_stats['max'].calculate(self.imageContext), 1300) + self.assertEqual(_stats['minCoords'].calculate(self.imageContext), (10, 0)) + self.assertEqual(_stats['maxCoords'].calculate(self.imageContext), (20.0, 10.0)) + self.assertAlmostEqual(_stats['std'].calculate(self.imageContext), + numpy.std(self.imageData[0:11, 10:21])) + self.assertAlmostEqual(_stats['mean'].calculate(self.imageContext), + numpy.mean(self.imageData[0:11, 10:21])) + + compressed_values = self.imageContext.values.compressed() + compressed_values = compressed_values.reshape(11, 11) + yData = numpy.sum(compressed_values.astype(numpy.float64), axis=1) + xData = numpy.sum(compressed_values.astype(numpy.float64), axis=0) + + dataYRange = range(11) + dataXRange = range(10, 21) + + ycom = numpy.sum(yData*dataYRange) / numpy.sum(yData) + xcom = numpy.sum(xData*dataXRange) / numpy.sum(xData) + self.assertEqual(_stats['com'].calculate(self.imageContext), (xcom, ycom)) + + def testBasicStatsImagePolyRoi(self): + """Test a simple rectangle ROI""" + _stats = self.getBasicStats() + self.assertEqual(_stats['min'].calculate(self.imageContext_2), 0) + self.assertEqual(_stats['max'].calculate(self.imageContext_2), 2432) + self.assertEqual(_stats['minCoords'].calculate(self.imageContext_2), (0.0, 0.0)) + # not 0.0, 19.0 because not fully in. Should all pixel have a weight, + # on to manage them in stats. For now 0 if the center is not in, else 1 + self.assertEqual(_stats['maxCoords'].calculate(self.imageContext_2), (0.0, 19.0)) + + def testBasicStatsScatter(self): + self.assertEqual(self.scatterContext.values.compressed().size, 2) + _stats = self.getBasicStats() + self.assertEqual(_stats['min'].calculate(self.scatterContext), 6) + self.assertEqual(_stats['max'].calculate(self.scatterContext), 7) + self.assertEqual(_stats['minCoords'].calculate(self.scatterContext), (2, 3)) + self.assertEqual(_stats['maxCoords'].calculate(self.scatterContext), (3, 4)) + self.assertEqual(_stats['std'].calculate(self.scatterContext), numpy.std([6, 7])) + self.assertEqual(_stats['mean'].calculate(self.scatterContext), numpy.mean([6, 7])) + + def testBasicHistogram(self): + _stats = self.getBasicStats() + xData = yData = numpy.array(range(2, 6)) + self.assertEqual(_stats['min'].calculate(self.histoContext), 2) + self.assertEqual(_stats['max'].calculate(self.histoContext), 5) + self.assertEqual(_stats['minCoords'].calculate(self.histoContext), (2,)) + self.assertEqual(_stats['maxCoords'].calculate(self.histoContext), (5,)) + self.assertEqual(_stats['std'].calculate(self.histoContext), numpy.std(yData)) + self.assertEqual(_stats['mean'].calculate(self.histoContext), numpy.mean(yData)) + com = numpy.sum(xData * yData) / numpy.sum(yData) + self.assertEqual(_stats['com'].calculate(self.histoContext), com) + + +class TestAdvancedROIImageContext(TestCaseQt): + """Test stats result on an image context with different scale and + origins""" + + def setUp(self): + TestCaseQt.setUp(self) + self.data_dims = (100, 100) + self.data = numpy.random.rand(*self.data_dims) + self.plot = Plot2D() + + def test(self): + """Test stats result on an image context with different scale and + origins""" + roi_origins = [(0, 0), (2, 10), (14, 20)] + img_origins = [(0, 0), (14, 20), (2, 10)] + img_scales = [1.0, 0.5, 2.0] + _stats = {'sum': stats.Stat(name='sum', fct=numpy.sum), } + for roi_origin in roi_origins: + for img_origin in img_origins: + for img_scale in img_scales: + with self.subTest(roi_origin=roi_origin, + img_origin=img_origin, + img_scale=img_scale): + self.plot.addImage(self.data, legend='img', + origin=img_origin, + scale=img_scale) + roi = RectangleROI() + roi.setGeometry(origin=roi_origin, size=(20, 20)) + context = stats._ImageContext( + item=self.plot.getImage('img'), + plot=self.plot, + onlimits=False, + roi=roi) + x_start = int((roi_origin[0] - img_origin[0]) / img_scale) + x_end = int(x_start + (20 / img_scale)) + 1 + y_start = int((roi_origin[1] - img_origin[1])/ img_scale) + y_end = int(y_start + (20 / img_scale)) + 1 + x_start = max(x_start, 0) + x_end = min(max(x_end, 0), self.data_dims[1]) + y_start = max(y_start, 0) + y_end = min(max(y_end, 0), self.data_dims[0]) + th_sum = numpy.sum(self.data[y_start:y_end, x_start:x_end]) + self.assertAlmostEqual(_stats['sum'].calculate(context), + th_sum) + def suite(): test_suite = unittest.TestSuite() for TestClass in (TestStats, TestStatsHandler, TestStatsWidgetWithScatters, TestStatsWidgetWithImages, TestStatsWidgetWithCurves, - TestStatsFormatter, TestEmptyStatsWidget, - TestLineWidget, TestUpdateModeWidget): + TestStatsFormatter, TestEmptyStatsWidget, TestStatsROI, + TestLineWidget, TestUpdateModeWidget, ): test_suite.addTest( unittest.defaultTestLoader.loadTestsFromTestCase(TestClass)) return test_suite |