diff options
author | Onderwaater <onderwaa@esrf.fr> | 2016-01-15 14:35:36 +0100 |
---|---|---|
committer | Onderwaater <onderwaa@esrf.fr> | 2016-01-15 14:35:36 +0100 |
commit | 66847a9ab58b66008450efdfafafe51aefa942d7 (patch) | |
tree | 7f65571b0422166b9e41105f4cf1f36ebd274819 | |
parent | 14e59b85fc43d2367f1c38b09aa494ccd9091138 (diff) |
bugfix + fitaid mod + id03 projection
-rw-r--r-- | binoculars/backends/id03.py | 30 | ||||
-rw-r--r-- | binoculars/backends/id03_xu.py | 2 | ||||
-rw-r--r-- | binoculars/backends/io7.py | 4 | ||||
-rw-r--r-- | binoculars/fit.py | 52 | ||||
-rwxr-xr-x | binoculars/main.py | 2 | ||||
-rwxr-xr-x | binoculars/space.py | 4 | ||||
-rwxr-xr-x | scripts/binoculars-fitaid | 241 |
7 files changed, 188 insertions, 147 deletions
diff --git a/binoculars/backends/id03.py b/binoculars/backends/id03.py index a388348..5b9cc8c 100644 --- a/binoculars/backends/id03.py +++ b/binoculars/backends/id03.py @@ -148,6 +148,7 @@ class QProjection(backend.ProjectionBase): return 'qx', 'qy', 'qz' + class SphericalQProjection(QProjection): def project(self, wavelength, UB, gamma, delta, theta, mu, chi, phi): qz, qy, qx = super(SphericalQProjection, self).project(wavelength, UB, gamma, delta, theta, mu, chi, phi) @@ -171,17 +172,9 @@ class CylindricalQProjection(QProjection): return 'qpar', 'qz', 'Phi' -class nrQProjection(backend.ProjectionBase): +class nrQProjection(QProjection): def project(self, wavelength, UB, gamma, delta, theta, mu, chi, phi): - k0 = 2 * numpy.pi / wavelength - delta, gamma = numpy.meshgrid(delta, gamma) - mu *= numpy.pi/180 - delta *= numpy.pi/180 - gamma *= numpy.pi/180 - - qy = k0 * (numpy.cos(gamma) * numpy.cos(delta) - numpy.cos(mu)) # definition of qx, and qy same as spec at theta = 0 - qx = k0 * (numpy.cos(gamma) * numpy.sin(delta)) - qz = k0 * (numpy.sin(gamma) + numpy.sin(mu)) + qx, qy, qz = super(nrQProjection, self).project(wavelength, UB, gamma, delta, 0, mu, chi, phi) return (qx, qy, qz) def get_axis_labels(self): @@ -236,6 +229,23 @@ class GammaDeltaMu(HKLProjection): # just passing on the coordinates, makes it def get_axis_labels(self): return 'Gamma', 'Delta', 'Mu' +class QTransformation(QProjection): + def project(self, wavelength, UB, gamma, delta, theta, mu, chi, phi): + qx, qy, qz = super(QTransformation, self).project(wavelength, UB, gamma, delta, theta, mu, chi, phi) + + M = self.config.matrix + q1 = qx * M[0] + qy * M[1] + qz * M[2] + q2 = qx * M[3] + qy * M[4] + qz * M[5] + q3 = qx * M[6] + qy * M[7] + qz * M[8] + + return (q1, q2, q3) + + def get_axis_labels(self): + return 'q1', 'q2', 'q3' + + def parse_config(self, config): + super(QTransformation, self).parse_config(config) + self.config.matrix = util.parse_tuple(config.pop('matrix'), length=9, type=float) class ID03Input(backend.InputBase): # OFFICIAL API diff --git a/binoculars/backends/id03_xu.py b/binoculars/backends/id03_xu.py index a10f617..61d1062 100644 --- a/binoculars/backends/id03_xu.py +++ b/binoculars/backends/id03_xu.py @@ -24,7 +24,7 @@ PY3 = sys.version_info > (3,) if PY3: pass else: - import itertools import izip as zip + from itertools import izip as zip try: from PyMca import specfilewrapper, EdfFile diff --git a/binoculars/backends/io7.py b/binoculars/backends/io7.py index 34a4c9d..e239712 100644 --- a/binoculars/backends/io7.py +++ b/binoculars/backends/io7.py @@ -329,10 +329,6 @@ class EH1(IO7Input): dy = self.apply_mask(dy, self.config.xmask, self.config.ymask) dz = self.apply_mask(dz, self.config.xmask, self.config.ymask) - - #X,Y = numpy.meshgrid(x,y) - #Z = numpy.ones(X.shape) * sdd - pixels = dx,dy,dz return intensity, weights, (energy, UB, pixels, gamma, delta, omega, alpha, nu) diff --git a/binoculars/fit.py b/binoculars/fit.py index a9333ad..c7262f0 100644 --- a/binoculars/fit.py +++ b/binoculars/fit.py @@ -4,7 +4,6 @@ import scipy.special import inspect import re - class FitBase(object): parameters = None guess = None @@ -14,7 +13,10 @@ class FitBase(object): def __init__(self, space, guess=None): self.space = space - args = inspect.getargspec(self.func).args + code = inspect.getsource(self.func) + + args = tuple( re.findall('\((.*?)\)', line)[0].split(',') for line in code.split('\n')[2:4]) + if space.dimension != len(args[0]): raise ValueError('dimension mismatch: space has {0}, {1.__class__.__name__} expects {2}'.format(space.dimension, self, len(args[0]))) self.parameters = args[1] @@ -145,9 +147,9 @@ def get_class_by_name(name): # fitting functions class Lorentzian1D(PeakFitBase): @staticmethod - def func(xxx_todo_changeme, xxx_todo_changeme1): - (x, ) = xxx_todo_changeme - (I, loc, gamma, slope, offset) = xxx_todo_changeme1 + def func(grid, params): + (x, ) = grid + (I, loc, gamma, slope, offset) = params return I / ((x - loc)**2 + gamma**2) + offset + x * slope def set_guess(self, maximum, argmax, linparams): @@ -157,8 +159,8 @@ class Lorentzian1D(PeakFitBase): class Lorentzian1DNoBkg(PeakFitBase): @staticmethod - def func(xxx_todo_changeme2, xxx_todo_changeme3): - (x, ) = xxx_todo_changeme2 + def func(grid, params): + (x, ) = grid (I, loc, gamma) = xxx_todo_changeme3 return I / ((x - loc)**2 + gamma**2) @@ -169,9 +171,9 @@ class Lorentzian1DNoBkg(PeakFitBase): class PolarLorentzian2Dnobkg(PeakFitBase): @staticmethod - def func(xxx_todo_changeme4, xxx_todo_changeme5): - (x, y) = xxx_todo_changeme4 - (I, loc0, loc1, gamma0, gamma1, th) = xxx_todo_changeme5 + def func(grid, params): + (x, y) = grid + (I, loc0, loc1, gamma0, gamma1, th) = params a, b = tuple(grid - center for grid, center in zip(rot2d(x, y, th), rot2d(loc0, loc1, th))) return (I / (1 + (a / gamma0)**2 + (b / gamma1)**2)) @@ -183,9 +185,9 @@ class PolarLorentzian2Dnobkg(PeakFitBase): class PolarLorentzian2D(PeakFitBase): @staticmethod - def func(xxx_todo_changeme6, xxx_todo_changeme7): - (x, y) = xxx_todo_changeme6 - (I, loc0, loc1, gamma0, gamma1, th, slope1, slope2, offset) = xxx_todo_changeme7 + def func(grid, params): + (x, y) = grid + (I, loc0, loc1, gamma0, gamma1, th, slope1, slope2, offset) = params a, b = tuple(grid - center for grid, center in zip(rot2d(x, y, th), rot2d(loc0, loc1, th))) return (I / (1 + (a / gamma0)**2 + (b / gamma1)**2) + x * slope1 + y * slope2 + offset) @@ -200,9 +202,9 @@ class PolarLorentzian2D(PeakFitBase): class Lorentzian2D(PeakFitBase): @staticmethod - def func(xxx_todo_changeme8, xxx_todo_changeme9): - (x, y) = xxx_todo_changeme8 - (I, loc0, loc1, gamma0, gamma1, th, slope1, slope2, offset) = xxx_todo_changeme9 + def func(grid, params): + (x, y) = grid + (I, loc0, loc1, gamma0, gamma1, th, slope1, slope2, offset) = params a, b = tuple(grid - center for grid, center in zip(rot2d(x, y, th), rot2d(loc0, loc1, th))) return (I / (1 + (a/gamma0)**2) * 1 / (1 + (b/gamma1)**2) + x * slope1 + y * slope2 + offset) @@ -214,9 +216,9 @@ class Lorentzian2D(PeakFitBase): class Lorentzian2Dnobkg(PeakFitBase): @staticmethod - def func(xxx_todo_changeme10, xxx_todo_changeme11): - (x, y) = xxx_todo_changeme10 - (I, loc0, loc1, gamma0, gamma1, th) = xxx_todo_changeme11 + def func(grid, params): + (x, y) = grid + (I, loc0, loc1, gamma0, gamma1, th) = params a, b = tuple(grid - center for grid, center in zip(rot2d(x, y, th), rot2d(loc0, loc1, th))) return (I / (1 + (a/gamma0)**2) * 1 / (1 + (b/gamma1)**2)) @@ -232,17 +234,17 @@ class Lorentzian(AutoDimensionFit): class Gaussian1D(PeakFitBase): @staticmethod - def func(xxx_todo_changeme12, xxx_todo_changeme13): - (x,) = xxx_todo_changeme12 - (loc, I, sigma, offset, slope) = xxx_todo_changeme13 + def func(grid, params): + (x,) = grid + (loc, I, sigma, offset, slope) = params return I * numpy.exp(-((x-loc)/sigma)**2/2) + offset + x * slope class Voigt1D(PeakFitBase): @staticmethod - def func(xxx_todo_changeme14, xxx_todo_changeme15): - (x, ) = xxx_todo_changeme14 - (I, loc, sigma, gamma, slope, offset) = xxx_todo_changeme15 + def func(grid, params): + (x, ) = grid + (I, loc, sigma, gamma, slope, offset) = params z = (x - loc + numpy.complex(0, gamma)) / (sigma * numpy.sqrt(2)) return I * numpy.real(scipy.special.wofz(z))/(sigma * numpy.sqrt(2 * numpy.pi)) + offset + x * slope diff --git a/binoculars/main.py b/binoculars/main.py index 6c80f1b..68ac7cc 100755 --- a/binoculars/main.py +++ b/binoculars/main.py @@ -136,7 +136,7 @@ class Split(Main): # completely ignores the dispatcher, just yields a space per for intensity, weights, params in self.input.process_job(job): coords = self.projection.project(*params) if self.projection.config.limits == None: - yield space.Multiverse(space.Space.from_image(res, labels, coords, intensity, weights=weights)) + yield space.Space.from_image(res, labels, coords, intensity, weights=weights) else: yield space.Multiverse(space.Space.from_image(res, labels, coords, intensity, weights=weights, limits=limits) for limits in self.projection.config.limits) diff --git a/binoculars/space.py b/binoculars/space.py index 73fe1cc..3ff1dd2 100755 --- a/binoculars/space.py +++ b/binoculars/space.py @@ -571,14 +571,13 @@ class Space(object): def transform_coordinates(self, resolutions, labels, transformation): # gather data and transform - coords = self.get_grid() transcoords = transformation(*coords) intensity = self.get() weights = self.contributions # get rid of invalid coords - valid = reduce(numpy.bitwise_and, chain((numpy.isfinite(t) for t in transcoords)), (weights > 0, )) + valid = reduce(numpy.bitwise_and, chain((numpy.isfinite(t) for t in transcoords)), (weights > 0)) transcoords = tuple(t[valid] for t in transcoords) return self.from_image(resolutions, labels, transcoords, intensity[valid], weights[valid]) @@ -879,7 +878,6 @@ def dstack(spaces, dindices, dlabel, dresolution): return space.transform_coordinates(resolutions, labels, transformation) return sum(transform(space, dindex) for space, dindex in zip(spaces, dindices)) - def axis_offset(space, label, offset): exprs = list(ax.label for ax in space.axes) index = space.axes.index(label) diff --git a/scripts/binoculars-fitaid b/scripts/binoculars-fitaid index 69f15b0..08ca823 100755 --- a/scripts/binoculars-fitaid +++ b/scripts/binoculars-fitaid @@ -150,10 +150,10 @@ class TopWidget(QtGui.QWidget): self.table.check_changed.connect(self.refresh_plot) self.tab_widget = QtGui.QTabWidget() - self.fitwidget = FitWidget(self.database) - self.integratewidget = IntegrateWidget(self.database) - self.plotwidget = OverviewWidget(self.database) - self.peakwidget = PeakWidget(self.database) + self.fitwidget = FitWidget(self.database, self) + self.integratewidget = IntegrateWidget(self.database, self) + self.plotwidget = OverviewWidget(self.database, self) + self.peakwidget = PeakWidget(self.database, self) self.tab_widget.addTab(self.fitwidget, 'Fit') self.tab_widget.addTab(self.integratewidget, 'Integrate') @@ -643,11 +643,21 @@ class RodData(FitData): for i, value in enumerate(loc): self.save_sliceattr(index, 'guessloc{0}'.format(i), value) - def save_segments(self, index, loc): - db[self.rodkey][self.slicekey].create_dataset('segments', data.shape, dtype=data.dtype, compression='gzip').write_direct(data) + def save_segments(self, segments): + with h5py.File(self.filename, 'a') as db: + try: + db[self.rodkey][self.slicekey].create_dataset('segment', segments.shape, dtype=segments.dtype, compression='gzip').write_direct(segments) + except RuntimeError: + del db[self.rodkey][self.slicekey]['segment'] + db[self.rodkey][self.slicekey].create_dataset('segment', segments.shape, dtype=segments.dtype, compression='gzip').write_direct(segments) + + def load_segments(self): + with h5py.File(self.filename, 'a') as db: + try: + return numpy.array(db[self.rodkey][self.slicekey]['segment'][:]) + except KeyError: + return None - def load_segments(self, index, loc): - return db[self.rodkey][self.slicekey]['segments'] def __iter__(self): for index in range(self.rodlength()): @@ -716,10 +726,7 @@ class FitWidget(QtGui.QWidget): def fit(self, index, space, function): print(index) if not len(space.get_masked().compressed()) == 0: - loc = self.database.load_loc(index) - if loc: - if loc == list(float(0) for i in loc): - loc = None + loc = get_loc(index) fit = function(space, loc = loc) fit.fitdata.mask = space.get_masked().mask self.database.save_data(index, 'fit', fit.fitdata) @@ -730,6 +737,9 @@ class FitWidget(QtGui.QWidget): for key, value in zip(params, fit.variance): self.database.save_sliceattr(index, 'var_{0}'.format(key), value) + def get_loc(self): + return self.database.load_loc(self.currentindex()) + def currentindex(self): index = self.database.load('index') if index == None: @@ -740,6 +750,7 @@ class FitWidget(QtGui.QWidget): class IntegrateWidget(QtGui.QWidget): def __init__(self, database, parent = None): super(IntegrateWidget, self).__init__(parent) + self.parent = parent self.database = database self.figure = matplotlib.figure.Figure() @@ -764,14 +775,6 @@ class IntegrateWidget(QtGui.QWidget): intensitybox = QtGui.QHBoxLayout() backgroundbox = QtGui.QHBoxLayout() - self.tracker = QtGui.QCheckBox('peak tracker') - self.tracker.setChecked(1) - self.locx = QtGui.QDoubleSpinBox() - self.locy = QtGui.QDoubleSpinBox() - self.locx.setDisabled(True) - self.locy.setDisabled(True) - self.tracker.clicked.connect(self.refresh_tracker) - self.aroundroi = QtGui.QCheckBox('background around roi') self.aroundroi.setChecked(1) self.aroundroi.clicked.connect(self.refresh_aroundroi) @@ -804,11 +807,25 @@ class IntegrateWidget(QtGui.QWidget): integratebox.addLayout(intensitybox) integratebox.addLayout(backgroundbox) - minibox = QtGui.QHBoxLayout() - minibox.addWidget(self.tracker) - minibox.addWidget(self.locx) - minibox.addWidget(self.locy) - integratebox.addLayout(minibox) + self.fromfit = QtGui.QRadioButton('peak from fit', self) + self.fromfit.setChecked(True) + self.fromfit.toggled.connect(self.plot_box) + self.fromfit.toggled.connect(self.refresh_tracker) + + self.fromsegment = QtGui.QRadioButton('peak from segment', self) + self.fromsegment.setChecked(False) + self.fromsegment.toggled.connect(self.plot_box) + self.fromsegment.toggled.connect(self.refresh_tracker) + + self.trackergroup = QtGui.QButtonGroup(self) + self.trackergroup.addButton(self.fromfit) + self.trackergroup.addButton(self.fromsegment) + + radiobox = QtGui.QHBoxLayout() + radiobox.addWidget(self.fromfit) + radiobox.addWidget(self.fromsegment) + + integratebox.addLayout(radiobox) self.control_widget.setLayout(integratebox) @@ -835,13 +852,7 @@ class IntegrateWidget(QtGui.QWidget): self.bottom.setMaximum(axes[1].max - axes[1].min) def refresh_tracker(self): - self.database.save('tracker', self.tracker.checkState()) - if self.tracker.checkState(): - self.locx.setDisabled(True) - self.locy.setDisabled(True) - else: - self.locx.setDisabled(False) - self.locy.setDisabled(False) + self.database.save('fromfit', self.fromfit.isChecked()) self.plot_box() def set_axis(self): @@ -869,42 +880,24 @@ class IntegrateWidget(QtGui.QWidget): self.bottom.setSingleStep(axes[0].res) self.bottom.setDecimals(len(str(axes[0].res)) - 2) - self.locx.setSingleStep(axes[0].res) - self.locx.setDecimals(len(str(axes[0].res)) - 2) - self.locx.setMinimum(axes[0].min) - self.locx.setMaximum(axes[0].max) - - self.locy.setSingleStep(axes[1].res) - self.locy.setDecimals(len(str(axes[1].res)) - 2) - self.locy.setMinimum(axes[1].min) - self.locy.setMaximum(axes[1].max) - - tracker = self.database.load('tracker') + tracker = self.database.load('fromfit') if tracker != None: - self.tracker.setChecked(tracker) - else: - self.tracker.setChecked(True) + if tracker: + self.fromfit.setChecked(True) + else: + self.fromsegment.setChecked(True) if roi is not None: for box, value in zip([self.hsize, self.vsize, self.left, self.right, self.top, self.bottom], roi): box.setValue(value) - if self.fixed_loc() != None: - x,y = self.fixed_loc() - self.locx.setValue(x) - self.locy.setValue(y) - def send(self): roi = [self.hsize.value(), self.vsize.value(), self.left.value() ,self.right.value() ,self.top.value(), self.bottom.value()] self.database.save('roi', roi) self.plot_box() def integrate(self, index, space): - if self.tracker.checkState(): - loc = self.database.load_loc(index) - else: - loc = self.fixed_loc() - + loc = self.get_loc() if loc != None: axes = space.axes @@ -952,7 +945,6 @@ class IntegrateWidget(QtGui.QWidget): print('Structurefactor {0}: {1}'.format(index, structurefactor)) def intkey(self, coords, axes): - vsize = self.vsize.value() / 2 hsize = self.hsize.value() / 2 return tuple(ax.restrict(slice(coord - size, coord + size)) for ax, coord, size in zip(axes, coords, [vsize, hsize])) @@ -974,23 +966,22 @@ class IntegrateWidget(QtGui.QWidget): else: return [(axes[0].restrict(slice(self.left.value(), self.right.value())), axes[1].restrict(slice(self.top.value(), self.bottom.value())))] - def fixed_loc(self): - x = self.database.load('fixed_locx') - y = self.database.load('fixed_locy') - - if x != None and y != None: - return numpy.array([x, y]) + def get_loc(self): + if self.fromfit.isChecked(): + return self.database.load_loc(self.currentindex()) else: - return None + index = self.currentindex() + indexvalue = self.database.get_index_value(index) + return self.parent.peakwidget.get_coords(indexvalue) def loc_callback(self, x, y): if self.ax: - self.database.save_loc(self.currentindex(), numpy.array([x, y])) - if not self.tracker.checkState(): - self.database.save('fixed_locx', x) - self.database.save('fixed_locy', y) - self.locx.setValue(x) - self.locy.setValue(y) + if self.fromfit.isChecked(): + self.database.save_loc(self.currentindex(), numpy.array([x, y])) + else: + index = self.currentindex() + indexvalue = self.database.get_index_value(index) + self.parent.peakwidget.add_row(numpy.array([indexvalue, x, y])) self.plot_box() def plot(self, index = None): @@ -1023,10 +1014,7 @@ class IntegrateWidget(QtGui.QWidget): self.canvas.draw() def plot_box(self): - if not self.tracker.checkState(): - loc = self.fixed_loc() - else: - loc = self.database.load_loc(self.currentindex()) + loc = self.get_loc() if len(self.figure.get_axes()) != 0 and loc != None: ax = self.figure.get_axes()[0] axes = self.figure.space_axes @@ -1187,6 +1175,9 @@ class OverviewWidget(QtGui.QWidget): self.table.removeRow(0) allparams = list(list(param for param in database.all_attrkeys() if not param.startswith('mask')) for database in databaselist) + + allparams.extend(list(['locx_s', 'locy_s']) for database in databaselist if database.load_segments() is not None) + if len(allparams) > 0: uniqueparams = numpy.unique(numpy.hstack(params for params in allparams)) else: @@ -1216,8 +1207,22 @@ class OverviewWidget(QtGui.QWidget): self.ax = self.figure.add_subplot(111) for param in params: for database in self.databaselist: - x, y = database.all_from_key(param) - self.ax.plot(x, y, '+', label = '{0} - {1}'.format(param, database.rodkey)) + if param == 'locx_s': + segments = database.load_segments() + if segments is not None: + x = numpy.hstack(database.get_index_value(index) for index in range(database.rodlength())) + y = numpy.vstack(get_coords(xvalue, segments) for xvalue in x) + self.ax.plot(x, y[:,0], '+', label = '{0} - {1}'.format('locx_s', database.rodkey)) + elif param == 'locy_s': + segments = database.load_segments() + if segments is not None: + x = numpy.hstack(database.get_index_value(index) for index in range(database.rodlength())) + y = numpy.vstack(get_coords(xvalue, segments) for xvalue in x) + self.ax.plot(x, y[:,1], '+', label = '{0} - {1}'.format('locy_s', database.rodkey)) + else: + x, y = database.all_from_key(param) + self.ax.plot(x, y, '+', label = '{0} - {1}'.format(param, database.rodkey)) + self.ax.legend() if self.log.checkState(): self.ax.semilogy() @@ -1229,9 +1234,10 @@ class PeakWidget(QtGui.QWidget): self.database = database # create a QTableWidget - self.table = QtGui.QTableWidget(1, 3, self) + self.table = QtGui.QTableWidget(0, 3, self) self.table.horizontalHeader().setStretchLastSection(True) self.table.verticalHeader().setVisible(False) + self.table.itemChanged.connect(self.save) self.btn_add_row = QtGui.QPushButton('+', self) self.btn_add_row.clicked.connect(self.add_row) @@ -1251,44 +1257,73 @@ class PeakWidget(QtGui.QWidget): def set_axis(self): self.axes = self.database.paxes() + while self.table.rowCount() > 0: + self.table.removeRow(0) + segments = self.database.load_segments() + if segments is not None: + for index in range(segments.shape[0]): + self.add_row(segments[index, :]) self.table.setHorizontalHeaderLabels(['{0}'.format(self.database.axis), '{0}'.format(self.axes[0].label), '{0}'.format(self.axes[1].label)]) - def add_row(self): - self.table.insertRow(self.table.rowCount()) + def add_row(self, row = None): + rowindex = self.table.rowCount() + self.table.insertRow(rowindex) + if row is not None: + for index in range(3): + newitem = QtGui.QTableWidgetItem(str(row[index])) + self.table.setItem(rowindex, index, newitem) def remove(self): self.table.removeRow(self.table.currentRow()) + self.save() def axis_coords(self): - return numpy.vstack( numpy.array([float(self.table.item(index, 0)), float(self.table.item(index, 1)), float(self.table.item(index, 2))]) for index in range(self.table.rowCount())) + a = numpy.zeros((self.table.rowCount(), self.table.columnCount())) + for rowindex in range(a.shape[0]): + for columnindex in range(a.shape[1]): + item = self.table.item(rowindex, columnindex) + if item is not None: + a[rowindex, columnindex] = float(item.text()) + return a + + def save(self): + self.database.save_segments(self.axis_coords()) def get_coords(self, x): - coords = self.axis_coords() - if coords.shape[0] == 1: - return coords[0,1:] + return get_coords(x, self.axis_coords()) - args = numpy.argsort(coords[:,0]) - - x0 = coords[args,0] - x1 = coords[args,1] - x2 = coords[args,2] - - if x < x0.min(): - first = 0 - last = 1 - elif x > x0.max(): - first = -2 - last = -1 - else: - first = numpy.searchsorted(x, x0) - 1 - last = numpy.searchsorted(x, x0) +def get_coords(x, coords): + + if coords.shape[0] == 0: + return None + + if coords.shape[0] == 1: + return coords[0,1:] - a1 = (x1[last] - x1[first]) / (x0[last] - x0[first]) - b1 = x1[first] - a * x0[first] - a2 = (x2[last] - x2[first]) / (x0[last] - x0[first]) - b2 = x2[first] - a * x0[first] + args = numpy.argsort(coords[:,0]) + + x0 = coords[args,0] + x1 = coords[args,1] + x2 = coords[args,2] + + if x < x0.min(): + first = 0 + last = 1 + elif x > x0.max(): + first = -2 + last = -1 + else: + first = numpy.searchsorted(x0, x) - 1 + last = numpy.searchsorted(x0, x) - return numpy.array([a1 * x + b1, a2 * x + b2]) + a1 = (x1[last] - x1[first]) / (x0[last] - x0[first]) + b1 = x1[first] - a1 * x0[first] + a2 = (x2[last] - x2[first]) / (x0[last] - x0[first]) + b2 = x2[first] - a2 * x0[first] + + return numpy.array([a1 * x + b1, a2 * x + b2]) + + def interpolate(space): data = space.get_masked() |