summaryrefslogtreecommitdiff
path: root/silx/gui/plot/test/testStats.py
diff options
context:
space:
mode:
Diffstat (limited to 'silx/gui/plot/test/testStats.py')
-rw-r--r--silx/gui/plot/test/testStats.py273
1 files changed, 255 insertions, 18 deletions
diff --git a/silx/gui/plot/test/testStats.py b/silx/gui/plot/test/testStats.py
index 8db8cc9..d5046ba 100644
--- a/silx/gui/plot/test/testStats.py
+++ b/silx/gui/plot/test/testStats.py
@@ -35,6 +35,11 @@ from silx.gui.plot import StatsWidget
from silx.gui.plot.stats import statshandler
from silx.gui.utils.testutils import TestCaseQt, SignalListener
from silx.gui.plot import Plot1D, Plot2D
+from silx.gui.plot3d.SceneWidget import SceneWidget
+from silx.gui.plot.items.roi import RectangleROI, PolygonROI
+from silx.gui.plot.tools.roi import RegionOfInterestManager
+from silx.gui.plot.stats.stats import Stats
+from silx.gui.plot.CurvesROIWidget import ROI
from silx.utils.testutils import ParametricTestCase
import unittest
import logging
@@ -43,12 +48,9 @@ import numpy
_logger = logging.getLogger(__name__)
-class TestStats(TestCaseQt):
- """
- Test :class:`BaseClass` class and inheriting classes
- """
+class TestStatsBase(object):
+ """Base class for stats TestCase"""
def setUp(self):
- TestCaseQt.setUp(self)
self.createCurveContext()
self.createImageContext()
self.createScatterContext()
@@ -63,7 +65,6 @@ class TestStats(TestCaseQt):
self.scatterPlot.setAttribute(qt.Qt.WA_DeleteOnClose)
self.scatterPlot.close()
del self.scatterPlot
- TestCaseQt.tearDown(self)
def createCurveContext(self):
self.plot1d = Plot1D()
@@ -74,12 +75,13 @@ class TestStats(TestCaseQt):
self.curveContext = stats._CurveContext(
item=self.plot1d.getCurve('curve0'),
plot=self.plot1d,
- onlimits=False)
+ onlimits=False,
+ roi=None)
def createScatterContext(self):
self.scatterPlot = Plot2D()
lgd = 'scatter plot'
- self.xScatterData = numpy.array([0, 1, 2, 20, 50, 60, 36])
+ self.xScatterData = numpy.array([0, 2, 3, 20, 50, 60, 36])
self.yScatterData = numpy.array([2, 3, 4, 26, 69, 6, 18])
self.valuesScatterData = numpy.array([5, 6, 7, 10, 90, 20, 5])
self.scatterPlot.addScatter(self.xScatterData, self.yScatterData,
@@ -87,7 +89,8 @@ class TestStats(TestCaseQt):
self.scatterContext = stats._ScatterContext(
item=self.scatterPlot.getScatter(lgd),
plot=self.scatterPlot,
- onlimits=False
+ onlimits=False,
+ roi=None
)
def createImageContext(self):
@@ -99,7 +102,8 @@ class TestStats(TestCaseQt):
self.imageContext = stats._ImageContext(
item=self.plot2d.getImage(self._imgLgd),
plot=self.plot2d,
- onlimits=False
+ onlimits=False,
+ roi=None
)
def getBasicStats(self):
@@ -113,6 +117,19 @@ class TestStats(TestCaseQt):
'com': stats.StatCOM()
}
+
+class TestStats(TestStatsBase, TestCaseQt):
+ """
+ Test :class:`BaseClass` class and inheriting classes
+ """
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ TestStatsBase.setUp(self)
+
+ def tearDown(self):
+ TestStatsBase.tearDown(self)
+ TestCaseQt.tearDown(self)
+
def testBasicStatsCurve(self):
"""Test result for simple stats on a curve"""
_stats = self.getBasicStats()
@@ -155,7 +172,8 @@ class TestStats(TestCaseQt):
image2Context = stats._ImageContext(
item=self.plot2d.getImage(self._imgLgd),
plot=self.plot2d,
- onlimits=False
+ onlimits=False,
+ roi=None,
)
_stats = self.getBasicStats()
self.assertEqual(_stats['min'].calculate(image2Context), 0)
@@ -225,21 +243,24 @@ class TestStats(TestCaseQt):
curveContextOnLimits = stats._CurveContext(
item=self.plot1d.getCurve('curve0'),
plot=self.plot1d,
- onlimits=True)
+ onlimits=True,
+ roi=None)
self.assertEqual(stat.calculate(curveContextOnLimits), 2)
self.plot2d.getXAxis().setLimitsConstraints(minPos=32)
imageContextOnLimits = stats._ImageContext(
item=self.plot2d.getImage('test image'),
plot=self.plot2d,
- onlimits=True)
+ onlimits=True,
+ roi=None)
self.assertEqual(stat.calculate(imageContextOnLimits), 32)
self.scatterPlot.getXAxis().setLimitsConstraints(minPos=40)
scatterContextOnLimits = stats._ScatterContext(
item=self.scatterPlot.getScatter('scatter plot'),
plot=self.scatterPlot,
- onlimits=True)
+ onlimits=True,
+ roi=None)
self.assertEqual(stat.calculate(scatterContextOnLimits), 20)
@@ -255,7 +276,8 @@ class TestStatsFormatter(TestCaseQt):
self.curveContext = stats._CurveContext(
item=self.plot1d.getCurve('curve0'),
plot=self.plot1d,
- onlimits=False)
+ onlimits=False,
+ roi=None)
self.stat = stats.StatMin()
@@ -295,6 +317,7 @@ class TestStatsHandler(TestCaseQt):
self.stat = stats.StatMin()
def tearDown(self):
+ Stats._getContext.cache_clear()
self.plot1d.setAttribute(qt.Qt.WA_DeleteOnClose)
self.plot1d.close()
self.plot1d = None
@@ -391,6 +414,7 @@ class TestStatsWidgetWithCurves(TestCaseQt, ParametricTestCase):
self.statsTable.setStats(mystats)
def tearDown(self):
+ Stats._getContext.cache_clear()
self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
self.plot.close()
self.statsTable = None
@@ -493,7 +517,6 @@ class TestStatsWidgetWithCurves(TestCaseQt, ParametricTestCase):
self.qapp.processEvents()
tableItems = self.statsTable._itemToTableItems(self.plot.getCurve('curve0'))
curve0_min = tableItems['min'].text()
- print(curve0_min)
self.assertTrue(float(curve0_min) == -1.)
self.plot.getCurve('curve0').setData(x=range(4), y=range(1, 5))
@@ -581,6 +604,7 @@ class TestStatsWidgetWithImages(TestCaseQt):
self.widget.setStats(mystats)
def tearDown(self):
+ Stats._getContext.cache_clear()
self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
self.plot.close()
self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
@@ -641,6 +665,7 @@ class TestStatsWidgetWithScatters(TestCaseQt):
self.widget.setStats(mystats)
def tearDown(self):
+ Stats._getContext.cache_clear()
self.scatterPlot.setAttribute(qt.Qt.WA_DeleteOnClose)
self.scatterPlot.close()
self.widget.setAttribute(qt.Qt.WA_DeleteOnClose)
@@ -694,6 +719,7 @@ class TestLineWidget(TestCaseQt):
stats=mystats)
def tearDown(self):
+ Stats._getContext.cache_clear()
self.qapp.processEvents()
self.plot.setAttribute(qt.Qt.WA_DeleteOnClose)
self.plot.close()
@@ -806,12 +832,223 @@ class TestUpdateModeWidget(TestCaseQt):
self.assertEqual(manualUpdateListener.callCount(), 2)
+class TestStatsROI(TestStatsBase, TestCaseQt):
+ """
+ Test stats based on ROI
+ """
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ self.createRois()
+ TestStatsBase.setUp(self)
+ self.createHistogramContext()
+
+ self.roiManager = RegionOfInterestManager(self.plot2d)
+ self.roiManager.addRoi(self._2Droi_rect)
+ self.roiManager.addRoi(self._2Droi_poly)
+
+ def tearDown(self):
+ self.roiManager.clear()
+ self.roiManager = None
+ self._1Droi = None
+ self._2Droi_rect = None
+ self._2Droi_poly = None
+ self.plotHisto.setAttribute(qt.Qt.WA_DeleteOnClose)
+ self.plotHisto.close()
+ self.plotHisto = None
+ TestStatsBase.tearDown(self)
+ TestCaseQt.tearDown(self)
+
+ def createRois(self):
+ self._1Droi = ROI(name='my1DRoi', fromdata=2.0, todata=5.0)
+ self._2Droi_rect = RectangleROI()
+ self._2Droi_rect.setGeometry(size=(10, 10), origin=(10, 0))
+ self._2Droi_poly = PolygonROI()
+ points = numpy.array(((0, 20), (0, 0), (10, 0)))
+ self._2Droi_poly.setPoints(points=points)
+
+ def createCurveContext(self):
+ TestStatsBase.createCurveContext(self)
+ self.curveContext = stats._CurveContext(
+ item=self.plot1d.getCurve('curve0'),
+ plot=self.plot1d,
+ onlimits=False,
+ roi=self._1Droi)
+
+ def createHistogramContext(self):
+ self.plotHisto = Plot1D()
+ x = range(20)
+ y = range(20)
+ self.plotHisto.addHistogram(x, y, legend='histo0')
+
+ self.histoContext = stats._HistogramContext(
+ item=self.plotHisto.getHistogram('histo0'),
+ plot=self.plotHisto,
+ onlimits=False,
+ roi=self._1Droi)
+
+ def createScatterContext(self):
+ TestStatsBase.createScatterContext(self)
+ self.scatterContext = stats._ScatterContext(
+ item=self.scatterPlot.getScatter('scatter plot'),
+ plot=self.scatterPlot,
+ onlimits=False,
+ roi=self._1Droi
+ )
+
+ def createImageContext(self):
+ TestStatsBase.createImageContext(self)
+
+ self.imageContext = stats._ImageContext(
+ item=self.plot2d.getImage(self._imgLgd),
+ plot=self.plot2d,
+ onlimits=False,
+ roi=self._2Droi_rect
+ )
+
+ self.imageContext_2 = stats._ImageContext(
+ item=self.plot2d.getImage(self._imgLgd),
+ plot=self.plot2d,
+ onlimits=False,
+ roi=self._2Droi_poly
+ )
+
+ def testErrors(self):
+ # test if onlimits is True and give also a roi
+ with self.assertRaises(ValueError):
+ stats._CurveContext(item=self.plot1d.getCurve('curve0'),
+ plot=self.plot1d,
+ onlimits=True,
+ roi=self._1Droi)
+
+ # test if is a curve context and give an invalid 2D roi
+ with self.assertRaises(TypeError):
+ stats._CurveContext(item=self.plot1d.getCurve('curve0'),
+ plot=self.plot1d,
+ onlimits=False,
+ roi=self._2Droi_rect)
+
+ def testBasicStatsCurve(self):
+ """Test result for simple stats on a curve"""
+ _stats = self.getBasicStats()
+ xData = yData = numpy.array(range(0, 10))
+ self.assertEqual(_stats['min'].calculate(self.curveContext), 2)
+ self.assertEqual(_stats['max'].calculate(self.curveContext), 5)
+ self.assertEqual(_stats['minCoords'].calculate(self.curveContext), (2,))
+ self.assertEqual(_stats['maxCoords'].calculate(self.curveContext), (5,))
+ self.assertEqual(_stats['std'].calculate(self.curveContext), numpy.std(yData[2:6]))
+ self.assertEqual(_stats['mean'].calculate(self.curveContext), numpy.mean(yData[2:6]))
+ com = numpy.sum(xData[2:6] * yData[2:6]) / numpy.sum(yData[2:6])
+ self.assertEqual(_stats['com'].calculate(self.curveContext), com)
+
+ def testBasicStatsImageRectRoi(self):
+ """Test result for simple stats on an image"""
+ self.assertEqual(self.imageContext.values.compressed().size, 121)
+ _stats = self.getBasicStats()
+ self.assertEqual(_stats['min'].calculate(self.imageContext), 10)
+ self.assertEqual(_stats['max'].calculate(self.imageContext), 1300)
+ self.assertEqual(_stats['minCoords'].calculate(self.imageContext), (10, 0))
+ self.assertEqual(_stats['maxCoords'].calculate(self.imageContext), (20.0, 10.0))
+ self.assertAlmostEqual(_stats['std'].calculate(self.imageContext),
+ numpy.std(self.imageData[0:11, 10:21]))
+ self.assertAlmostEqual(_stats['mean'].calculate(self.imageContext),
+ numpy.mean(self.imageData[0:11, 10:21]))
+
+ compressed_values = self.imageContext.values.compressed()
+ compressed_values = compressed_values.reshape(11, 11)
+ yData = numpy.sum(compressed_values.astype(numpy.float64), axis=1)
+ xData = numpy.sum(compressed_values.astype(numpy.float64), axis=0)
+
+ dataYRange = range(11)
+ dataXRange = range(10, 21)
+
+ ycom = numpy.sum(yData*dataYRange) / numpy.sum(yData)
+ xcom = numpy.sum(xData*dataXRange) / numpy.sum(xData)
+ self.assertEqual(_stats['com'].calculate(self.imageContext), (xcom, ycom))
+
+ def testBasicStatsImagePolyRoi(self):
+ """Test a simple rectangle ROI"""
+ _stats = self.getBasicStats()
+ self.assertEqual(_stats['min'].calculate(self.imageContext_2), 0)
+ self.assertEqual(_stats['max'].calculate(self.imageContext_2), 2432)
+ self.assertEqual(_stats['minCoords'].calculate(self.imageContext_2), (0.0, 0.0))
+ # not 0.0, 19.0 because not fully in. Should all pixel have a weight,
+ # on to manage them in stats. For now 0 if the center is not in, else 1
+ self.assertEqual(_stats['maxCoords'].calculate(self.imageContext_2), (0.0, 19.0))
+
+ def testBasicStatsScatter(self):
+ self.assertEqual(self.scatterContext.values.compressed().size, 2)
+ _stats = self.getBasicStats()
+ self.assertEqual(_stats['min'].calculate(self.scatterContext), 6)
+ self.assertEqual(_stats['max'].calculate(self.scatterContext), 7)
+ self.assertEqual(_stats['minCoords'].calculate(self.scatterContext), (2, 3))
+ self.assertEqual(_stats['maxCoords'].calculate(self.scatterContext), (3, 4))
+ self.assertEqual(_stats['std'].calculate(self.scatterContext), numpy.std([6, 7]))
+ self.assertEqual(_stats['mean'].calculate(self.scatterContext), numpy.mean([6, 7]))
+
+ def testBasicHistogram(self):
+ _stats = self.getBasicStats()
+ xData = yData = numpy.array(range(2, 6))
+ self.assertEqual(_stats['min'].calculate(self.histoContext), 2)
+ self.assertEqual(_stats['max'].calculate(self.histoContext), 5)
+ self.assertEqual(_stats['minCoords'].calculate(self.histoContext), (2,))
+ self.assertEqual(_stats['maxCoords'].calculate(self.histoContext), (5,))
+ self.assertEqual(_stats['std'].calculate(self.histoContext), numpy.std(yData))
+ self.assertEqual(_stats['mean'].calculate(self.histoContext), numpy.mean(yData))
+ com = numpy.sum(xData * yData) / numpy.sum(yData)
+ self.assertEqual(_stats['com'].calculate(self.histoContext), com)
+
+
+class TestAdvancedROIImageContext(TestCaseQt):
+ """Test stats result on an image context with different scale and
+ origins"""
+
+ def setUp(self):
+ TestCaseQt.setUp(self)
+ self.data_dims = (100, 100)
+ self.data = numpy.random.rand(*self.data_dims)
+ self.plot = Plot2D()
+
+ def test(self):
+ """Test stats result on an image context with different scale and
+ origins"""
+ roi_origins = [(0, 0), (2, 10), (14, 20)]
+ img_origins = [(0, 0), (14, 20), (2, 10)]
+ img_scales = [1.0, 0.5, 2.0]
+ _stats = {'sum': stats.Stat(name='sum', fct=numpy.sum), }
+ for roi_origin in roi_origins:
+ for img_origin in img_origins:
+ for img_scale in img_scales:
+ with self.subTest(roi_origin=roi_origin,
+ img_origin=img_origin,
+ img_scale=img_scale):
+ self.plot.addImage(self.data, legend='img',
+ origin=img_origin,
+ scale=img_scale)
+ roi = RectangleROI()
+ roi.setGeometry(origin=roi_origin, size=(20, 20))
+ context = stats._ImageContext(
+ item=self.plot.getImage('img'),
+ plot=self.plot,
+ onlimits=False,
+ roi=roi)
+ x_start = int((roi_origin[0] - img_origin[0]) / img_scale)
+ x_end = int(x_start + (20 / img_scale)) + 1
+ y_start = int((roi_origin[1] - img_origin[1])/ img_scale)
+ y_end = int(y_start + (20 / img_scale)) + 1
+ x_start = max(x_start, 0)
+ x_end = min(max(x_end, 0), self.data_dims[1])
+ y_start = max(y_start, 0)
+ y_end = min(max(y_end, 0), self.data_dims[0])
+ th_sum = numpy.sum(self.data[y_start:y_end, x_start:x_end])
+ self.assertAlmostEqual(_stats['sum'].calculate(context),
+ th_sum)
+
def suite():
test_suite = unittest.TestSuite()
for TestClass in (TestStats, TestStatsHandler, TestStatsWidgetWithScatters,
TestStatsWidgetWithImages, TestStatsWidgetWithCurves,
- TestStatsFormatter, TestEmptyStatsWidget,
- TestLineWidget, TestUpdateModeWidget):
+ TestStatsFormatter, TestEmptyStatsWidget, TestStatsROI,
+ TestLineWidget, TestUpdateModeWidget, ):
test_suite.addTest(
unittest.defaultTestLoader.loadTestsFromTestCase(TestClass))
return test_suite