#!/usr/bin/env python import sys import os.path import itertools import h5py import matplotlib.figure import matplotlib.image import numpy from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg, NavigationToolbar2QT from matplotlib.pyplot import Rectangle from PyQt5.Qt import (Qt) from PyQt5.QtCore import (pyqtSignal) from PyQt5.QtWidgets import (QAction, QApplication, QSlider, QMenuBar, QTabWidget, QFileDialog, QStatusBar, QMessageBox, QRadioButton, QButtonGroup, QCheckBox, QPushButton, QHBoxLayout, QVBoxLayout, QSplitter, QTableWidgetItem, QTableWidget, QLabel, QLineEdit, QMainWindow, QWidget, QComboBox, QProgressDialog, QDoubleSpinBox) from scipy.interpolate import griddata def set_src(): import sys import os.path as osp dirpath = osp.join(osp.dirname(osp.abspath(__file__)), osp.pardir) sys.path.insert(0, osp.abspath(dirpath)) try: import binoculars.main import binoculars.space import binoculars.plot import binoculars.fit import binoculars.util except ImportError: # try to use code from src distribution set_src() import binoculars.main import binoculars.space import binoculars.plot import binoculars.fit import binoculars.util class Window(QMainWindow): def __init__(self, parent=None): super(Window, self).__init__(parent) newproject = QAction("New project", self) newproject.triggered.connect(self.newproject) loadproject = QAction("Open project", self) loadproject.triggered.connect(self.loadproject) addspace = QAction("Import space", self) addspace.triggered.connect(self.add_to_project) menu_bar = QMenuBar() file = menu_bar.addMenu("&File") file.addAction(newproject) file.addAction(loadproject) file.addAction(addspace) self.setMenuBar(menu_bar) self.statusbar = QStatusBar() self.tab_widget = QTabWidget(self) self.tab_widget.setTabsClosable(True) self.tab_widget.tabCloseRequested.connect(self.tab_widget.removeTab) self.setCentralWidget(self.tab_widget) self.setMenuBar(menu_bar) self.setStatusBar(self.statusbar) def newproject(self): dialog = QFileDialog(self, "project filename"); dialog.setNameFilters(['binoculars fit file (*.fit)']); dialog.setDefaultSuffix('fit'); dialog.setFileMode(QFileDialog.AnyFile); dialog.setAcceptMode(QFileDialog.AcceptSave); if not dialog.exec_(): return fname = dialog.selectedFiles()[0] if not fname: return try: widget = TopWidget(str(fname), parent=self) self.tab_widget.addTab(widget, short_filename(str(fname))) self.tab_widget.setCurrentWidget(widget) except Exception as e: QMessageBox.critical(self, 'New project', 'Unable to save project to {}: {}'.format(fname, e)) def loadproject(self, filename=None): if not filename: dialog = QFileDialog(self, "Load project"); dialog.setNameFilters(['binoculars fit file (*.fit)']); dialog.setFileMode(QFileDialog.ExistingFiles); dialog.setAcceptMode(QFileDialog.AcceptOpen); if not dialog.exec_(): return fname = dialog.selectedFiles()[0] if not fname: return try: widget = TopWidget(str(fname), parent=self) self.tab_widget.addTab(widget, short_filename(str(fname))) self.tab_widget.setCurrentWidget(widget) except Exception as e: QMessageBox.critical(self, 'Load project', 'Unable to load project from {}: {}'.format(fname, e)) else: widget = TopWidget(str(fname), parent=self) self.tab_widget.addTab(widget, 'fname') self.tab_widget.setCurrentWidget(widget) def add_to_project(self, filename=None): if self.tab_widget.count() == 0: QMessageBox.warning(self, 'Warning', 'First select a file to store data') self.newproject() if not filename: dialog = QFileDialog(self, "Import spaces"); dialog.setNameFilters(['binoculars space file (*.hdf5)']); dialog.setFileMode(QFileDialog.ExistingFiles); dialog.setAcceptMode(QFileDialog.AcceptOpen); if not dialog.exec_(): return fname = dialog.selectedFiles() if not fname: return for name in fname: try: widget = self.tab_widget.currentWidget() widget.addspace(str(name)) except Exception as e: QMessageBox.critical(self, 'Import spaces', 'Unable to import space {}: {}'.format(fname, e)) else: widget = self.tab_widget.currentWidget() widget.addspace(filename) class TopWidget(QWidget): def __init__(self, filename, parent=None): super(TopWidget, self).__init__(parent) hbox = QHBoxLayout() vbox = QVBoxLayout() minihbox = QHBoxLayout() minihbox2 = QHBoxLayout() self.database = FitData(filename) self.table = TableWidget(self.database) self.nav = ButtonedSlider() self.nav.slice_index.connect(self.index_change) self.table.trigger.connect(self.active_change) self.table.check_changed.connect(self.refresh_plot) self.tab_widget = QTabWidget() self.fitwidget = FitWidget(self.database, self) self.integratewidget = IntegrateWidget(self.database, self) self.plotwidget = OverviewWidget(self.database, self) self.peakwidget = PeakWidget(self.database, self) self.tab_widget.addTab(self.fitwidget, 'Fit') self.tab_widget.addTab(self.integratewidget, 'Integrate') self.tab_widget.addTab(self.plotwidget, 'plot') self.tab_widget.addTab(self.peakwidget, 'Peaktracker') self.emptywidget = QWidget() self.emptywidget.setLayout(vbox) vbox.addWidget(self.table) vbox.addWidget(self.nav) self.functions = list() self.function_box = QComboBox() for function in dir(binoculars.fit): cls = getattr(binoculars.fit, function) if isinstance(cls, type) and issubclass(cls, binoculars.fit.PeakFitBase): self.functions.append(cls) self.function_box.addItem(function) self.function_box.setCurrentIndex(self.function_box.findText('PolarLorentzian2D')) vbox.addWidget(self.function_box) vbox.addLayout(minihbox) vbox.addLayout(minihbox2) self.all_button = QPushButton('fit all') self.rod_button = QPushButton('fit rod') self.slice_button = QPushButton('fit slice') self.all_button.clicked.connect(self.fit_all) self.rod_button.clicked.connect(self.fit_rod) self.slice_button.clicked.connect(self.fit_slice) minihbox.addWidget(self.all_button) minihbox.addWidget(self.rod_button) minihbox.addWidget(self.slice_button) self.allint_button = QPushButton('int all') self.rodint_button = QPushButton('int rod') self.sliceint_button = QPushButton('int slice') self.allint_button.clicked.connect(self.int_all) self.rodint_button.clicked.connect(self.int_rod) self.sliceint_button.clicked.connect(self.int_slice) minihbox2.addWidget(self.allint_button) minihbox2.addWidget(self.rodint_button) minihbox2.addWidget(self.sliceint_button) splitter = QSplitter(Qt.Horizontal) splitter.addWidget(self.emptywidget) splitter.addWidget(self.tab_widget) self.tab_widget.currentChanged.connect(self.tab_change) hbox.addWidget(splitter) self.setLayout(hbox) def tab_change(self, index): if index == 2: self.refresh_plot() def addspace(self, filename=None): if filename == None: filename = str(QFileDialog.getOpenFileName(self, 'Open Project', '.', '*.hdf5')) self.table.addspace(filename) def active_change(self): rodkey, axis, resolution = self.table.currentkey() newdatabase = RodData(self.database.filename, rodkey, axis, resolution) self.integratewidget.database = newdatabase self.peakwidget.database = newdatabase self.integratewidget.set_axis() self.peakwidget.set_axis() self.fitwidget.database = newdatabase self.nav.set_length(newdatabase.rodlength()) index = newdatabase.load('index') if index == None: index = 0 self.nav.set_index(index) self.index_change(index) def index_change(self, index): if index == None: index = 0 self.fitwidget.database.save('index', self.nav.index()) self.fitwidget.plot(index) self.integratewidget.plot(index) def refresh_plot(self): self.plotwidget.refresh(list(RodData(self.database.filename, rodkey, axis, resolution) for rodkey, axis, resolution in self.table.checked())) @property def fitclass(self): return self.functions[self.function_box.currentIndex()] def fit_slice(self): index = self.nav.index() space = self.fitwidget.database.space_from_index(index) self.fitwidget.fit(index, space, self.fitclass) self.fit_loc(self.fitwidget.database) self.fitwidget.plot(index) def fit_rod(self): def function(index, space): self.fitwidget.fit(index, space, self.fitclass) self.progressbox(self.fitwidget.database.rodkey, function, enumerate(self.fitwidget.database), self.fitwidget.database.rodlength()) self.fit_loc(self.fitwidget.database) self.fitwidget.plot() def fit_all(self): def function(index, space): self.fitwidget.fit(index, space, self.fitclass) for rodkey, axis, resolution in self.table.checked(): self.fitwidget.database = RodData(self.database.filename, rodkey, axis, resolution) self.progressbox(self.fitwidget.database.rodkey, function, enumerate(self.fitwidget.database), self.fitwidget.database.rodlength()) self.fit_loc(self.fitwidget.database) self.fitwidget.plot() def int_slice(self): index = self.nav.index() space = self.fitwidget.database.space_from_index(index) self.integratewidget.integrate(index, space) self.integratewidget.plot(index) def int_rod(self): self.progressbox(self.integratewidget.database.rodkey, self.integratewidget.integrate, enumerate(self.integratewidget.database), self.integratewidget.database.rodlength()) self.integratewidget.plot() def int_all(self): for rodkey, axis, resolution in self.table.checked(): self.integratewidget.database = RodData(self.database.filename, rodkey, axis, resolution) self.progressbox(self.integratewidget.database.rodkey, self.integratewidget.integrate, enumerate(self.integratewidget.database), self.integratewidget.database.rodlength()) self.integratewidget.plot() def fit_loc(self, database): deg = 2 for param in database.all_attrkeys(): if param.startswith('loc'): x, y = database.all_from_key(param) x, yvar = database.all_from_key('var_{0}'.format(param)) cx = x[numpy.invert(y.mask)] y = y.compressed() yvar = yvar.compressed() w = numpy.log(1 / yvar) w[w == numpy.inf] = 0 w = numpy.nan_to_num(w) w[w < 0] = 0 w[w < numpy.median(w)] = 0 if len(x) > 0: c = numpy.polynomial.polynomial.polyfit(cx, y, deg, w=w) newy = numpy.polynomial.polynomial.polyval(x, c) for index, newval in enumerate(newy): database.save_sliceattr(index, 'guessloc{0}'.format(param.lstrip('loc')), newval) def progressbox(self, rodkey, function, iterator, length): pd = QProgressDialog('Processing {0}'.format(rodkey), 'Cancel', 0, length) pd.setWindowModality(Qt.WindowModal) pd.show() def progress(index, item): pd.setValue(index) if pd.wasCanceled(): raise KeyboardInterrupt QApplication.processEvents() function(*item) for index, item in enumerate(iterator): progress(index, item) pd.close() class TableWidget(QWidget): trigger = pyqtSignal() check_changed = pyqtSignal() def __init__(self, database, parent=None): super(TableWidget, self).__init__(parent) hbox = QHBoxLayout() self.database = database self.activeindex = 0 self.table = QTableWidget(0, 5) self.table.setHorizontalHeaderLabels(['', 'rod', 'axis', 'res', 'remove']) self.table.cellClicked.connect(self.setlength) for index, width in enumerate([25, 150, 40, 50, 70]): self.table.setColumnWidth(index, width) for filename, rodkey in zip(database.filelist, database.rods()): self.addspace(filename, rodkey) hbox.addWidget(self.table) self.setLayout(hbox) def addspace(self, filename, rodkey=None): def remove_callback(rodkey): return lambda: self.remove(rodkey) def activechange_callback(index): return lambda: self.setlength(index, 1) if rodkey == None: rodkey = short_filename(filename) if rodkey in self.database.rods(): newkey = find_unused_rodkey(rodkey, self.database.rods()) self.database.copy(rodkey, newkey) rodkey = newkey old_axis, old_resolution = self.database.load(rodkey, 'axis'), self.database.load(rodkey, 'resolution') self.database.create_rod(rodkey, filename) index = self.table.rowCount() self.table.insertRow(index) axes = binoculars.space.Axes.fromfile(filename) checkboxwidget = QCheckBox() checkboxwidget.rodkey = rodkey checkboxwidget.setChecked(0) self.table.setCellWidget(index, 0, checkboxwidget) checkboxwidget.clicked.connect(self.check_changed) item = QTableWidgetItem(rodkey) self.table.setItem(index, 1, item) axis = QComboBox() for ax in axes: axis.addItem(ax.label) self.table.setCellWidget(index, 2, axis) if not old_axis == None: self.table.cellWidget(index, 2).setCurrentIndex(axes.index(old_axis)) elif index > 0: self.table.cellWidget(0, 2).setCurrentIndex(self.table.cellWidget(0, 2).currentIndex()) resolution = QLineEdit() if not old_resolution == None: resolution.setText(str(old_resolution)) elif index > 0: resolution.setText(self.table.cellWidget(0, 3).text()) else: resolution.setText(str(axes[axes.index(str(axis.currentText()))].res)) resolution.editingFinished.connect(activechange_callback(index)) self.table.setCellWidget(index, 3, resolution) buttonwidget = QPushButton('remove') buttonwidget.clicked.connect(remove_callback(rodkey)) self.table.setCellWidget(index, 4, buttonwidget) def remove(self, rodkey): table_rodkeys = list(self.table.cellWidget(index, 0).rodkey for index in range(self.table.rowCount())) for index, label in enumerate(table_rodkeys): if rodkey == label: self.table.removeRow(index) self.database.delete_rod(rodkey) print('removed: {0}'.format(rodkey)) def setlength(self, y, x=1): if x == 1: self.activeindex = y rodkey, axis, resolution = self.currentkey() self.database.save(rodkey, 'axis', axis) self.database.save(rodkey, 'resolution', resolution) self.trigger.emit() def currentkey(self): rodkey = self.table.cellWidget(self.activeindex, 0).rodkey axis = str(self.table.cellWidget(self.activeindex, 2).currentText()) resolution = float(self.table.cellWidget(self.activeindex, 3).text()) return rodkey, axis, resolution def checked(self): selection = [] for index in range(self.table.rowCount()): checkbox = self.table.cellWidget(index, 0) if checkbox.checkState(): rodkey = self.table.cellWidget(index, 0).rodkey axis = str(self.table.cellWidget(index, 2).currentText()) resolution = float(self.table.cellWidget(index, 3).text()) selection.append((rodkey, axis, resolution)) return selection class FitData(object): def __init__(self, filename): self.filename = filename self.axdict = dict() with h5py.File(self.filename, 'a') as db: for rodkey in self.rods(): spacename = db[rodkey].attrs['filename'] if not os.path.exists(spacename): warningbox = QMessageBox(2, 'Warning', 'Cannot find space {0} at file {1}; locate proper space'.format(rodkey, spacename), buttons=QMessageBox.Open) warningbox.exec_() spacename = str(QFileDialog.getOpenFileName(caption='Open space {0}'.format(rodkey), directory='.', filter='*.hdf5')) if not spacename: raise IOError('Select proper input') db[rodkey].attrs['filename'] = spacename self.axdict[rodkey] = binoculars.space.Axes.fromfile(spacename) def create_rod(self, rodkey, spacename): with h5py.File(self.filename, 'a') as db: if rodkey not in list(db.keys()): db.create_group(rodkey) db[rodkey].attrs['filename'] = spacename self.axdict[rodkey] = binoculars.space.Axes.fromfile(spacename) def delete_rod(self, rodkey): with h5py.File(self.filename, 'a') as db: del db[rodkey] def rods(self): with h5py.File(self.filename, 'a') as db: rods = list(db.keys()) return rods def copy(self, oldkey, newkey): with h5py.File(self.filename, 'a') as db: if oldkey in list(db.keys()): db.copy(db[oldkey], db, name=newkey) @property def filelist(self): filelist = [] with h5py.File(self.filename, 'a') as db: for key in db.keys(): filelist.append(db[key].attrs['filename']) return filelist def save(self, rodkey, key, value): with h5py.File(self.filename, 'a') as db: db[rodkey].attrs[str(key)] = value def load(self, rodkey, key): with h5py.File(self.filename, 'a') as db: if rodkey in db: if key in db[rodkey].attrs: return db[rodkey].attrs[str(key)] else: return None class RodData(FitData): def __init__(self, filename, rodkey, axis, resolution): super(RodData, self).__init__(filename) self.rodkey = rodkey self.slicekey = '{0}_{1}'.format(axis, resolution) self.axis = axis self.resolution = resolution with h5py.File(self.filename, 'a') as db: if rodkey in db: if self.slicekey not in db[rodkey]: db[rodkey].create_group(self.slicekey) db[rodkey][self.slicekey].create_group('attrs') def save(self, key, value): super(RodData, self).save(self.rodkey, key, value) def load(self, key): return super(RodData, self).load(self.rodkey, key) def paxes(self): axes = self.axdict[self.rodkey] projected = list(axes) axindex = axes.index(self.axis) projected.pop(axindex) return projected def get_bins(self): axes = self.axdict[self.rodkey] axindex = axes.index(self.axis) ax = axes[axindex] bins = binoculars.space.get_bins(ax, self.resolution) return bins, ax, axindex def rodlength(self): bins, ax, axindex = self.get_bins() return numpy.alen(bins) - 1 def get_index_value(self, index): return binoculars.space.get_axis_values(self.axdict[self.rodkey], self.axis, self.resolution)[index] def get_key(self, index): axes = self.axdict[self.rodkey] bins, ax, axindex = self.get_bins() start, stop = bins[index], bins[index + 1] k = [slice(None) for i in axes] k[axindex] = slice(start, stop) return k def space_from_index(self, index): with h5py.File(self.filename, 'a') as db: filename = db[self.rodkey].attrs['filename'] return binoculars.space.Space.fromfile(filename, self.get_key(index)).project(self.axis) def save_data(self, index, key, data): with h5py.File(self.filename, 'a') as db: id = '{0}_{1}_data'.format(int(index), key) mid = '{0}_{1}_mask'.format(int(index), key) try: db[self.rodkey][self.slicekey].create_dataset(id, data.shape, dtype=data.dtype, compression='gzip').write_direct(data) db[self.rodkey][self.slicekey].create_dataset(mid, data.shape, dtype=data.mask.dtype, compression='gzip').write_direct(data.mask) except RuntimeError: del db[self.rodkey][self.slicekey][id] del db[self.rodkey][self.slicekey][mid] db[self.rodkey][self.slicekey].create_dataset(id, data.shape, dtype=data.dtype, compression='gzip').write_direct(data) db[self.rodkey][self.slicekey].create_dataset(mid, data.shape, dtype=data.mask.dtype, compression='gzip').write_direct(data.mask) def load_data(self, index, key): with h5py.File(self.filename, 'a') as db: id = '{0}_{1}_data'.format(int(index), key) mid = '{0}_{1}_mask'.format(int(index), key) try: return numpy.ma.array(db[self.rodkey][self.slicekey][id][...], mask=db[self.rodkey][self.slicekey][mid][...]) except KeyError: return None def save_sliceattr(self, index, key, value): mkey = 'mask{0}'.format(key) with h5py.File(self.filename, 'a') as db: try: group = db[self.rodkey][self.slicekey]['attrs'] # # else it breaks with the old fitaid except KeyError: db[self.rodkey][self.slicekey].create_group('attrs') group = db[self.rodkey][self.slicekey]['attrs'] if not key in group: group.create_dataset(key, (self.rodlength(),)) group.create_dataset(mkey, (self.rodlength(),), dtype=numpy.bool).write_direct(numpy.ones(self.rodlength(), dtype=numpy.bool)) group[key][index] = value group[mkey][index] = 0 def load_sliceattr(self, index, key): mkey = 'mask{0}'.format(key) with h5py.File(self.filename, 'a') as db: try: group = db[self.rodkey][self.slicekey]['attrs'] except KeyError: db[self.rodkey][self.slicekey].create_group('attrs') group = db[self.rodkey][self.slicekey]['attrs'] if key in list(group.keys()): return numpy.ma.array(group[key][index], mask=group[mkey][index]) else: return None def all_attrkeys(self): with h5py.File(self.filename, 'a') as db: group = db[self.rodkey][self.slicekey]['attrs'] return list(group.keys()) def all_from_key(self, key): mkey = 'mask{0}'.format(key) axes = self.axdict[self.rodkey] with h5py.File(self.filename, 'a') as db: group = db[self.rodkey][self.slicekey]['attrs'] if key in list(group.keys()): return binoculars.space.get_axis_values(axes, self.axis, self.resolution), numpy.ma.array(group[key], mask=numpy.array(group[mkey])) def load_loc(self, index): loc = list() count = itertools.count() key = 'guessloc{0}'.format(next(count)) while self.load_sliceattr(index, key) != None: loc.append(self.load_sliceattr(index, key)) key = 'guessloc{0}'.format(next(count)) if len(loc) > 0: return loc else: count = itertools.count() key = 'loc{0}'.format(next(count)) while self.load_sliceattr(index, key) != None: loc.append(self.load_sliceattr(index, key)) key = 'loc{0}'.format(next(count)) if len(loc) > 0: return loc else: return None def save_loc(self, index, loc): for i, value in enumerate(loc): self.save_sliceattr(index, 'guessloc{0}'.format(i), value) def save_segments(self, segments): with h5py.File(self.filename, 'a') as db: try: db[self.rodkey][self.slicekey].create_dataset('segment', segments.shape, dtype=segments.dtype, compression='gzip').write_direct(segments) except RuntimeError: del db[self.rodkey][self.slicekey]['segment'] db[self.rodkey][self.slicekey].create_dataset('segment', segments.shape, dtype=segments.dtype, compression='gzip').write_direct(segments) def load_segments(self): with h5py.File(self.filename, 'a') as db: try: return numpy.array(db[self.rodkey][self.slicekey]['segment'][:]) except KeyError: return None def __iter__(self): for index in range(self.rodlength()): yield self.space_from_index(index) def short_filename(filename): return filename.split('/')[-1].split('.')[0] class HiddenToolbar(NavigationToolbar2QT): def __init__(self, corner_callback, canvas): super(HiddenToolbar, self).__init__(canvas, None) self._corner_callback = corner_callback self.zoom() def press(self, event): self._corner_preclick = self._views() def release(self, event): if self._corner_preclick == self._views(): self._corner_callback(event.xdata, event.ydata) self._corner_preclick = None class FitWidget(QWidget): def __init__(self, database ,parent=None): super(FitWidget, self).__init__(parent) self.database = database vbox = QHBoxLayout() self.figure = matplotlib.figure.Figure() self.canvas = FigureCanvasQTAgg(self.figure) self.toolbar = HiddenToolbar(self.loc_callback, self.canvas) vbox.addWidget(self.canvas) self.setLayout(vbox) def loc_callback(self, x, y): if self.ax: self.database.save_loc(self.currentindex(), numpy.array([x, y])) def plot(self, index = None): if index == None: index = self.currentindex() space = self.database.space_from_index(index) fitdata = self.database.load_data(index, 'fit') self.figure.clear() self.figure.space_axes = space.axes info = self.database.get_index_value(index) label = self.database.axis if fitdata is not None: if space.dimension == 1: self.ax = self.figure.add_subplot(111) binoculars.plot.plot(space, self.figure, self.ax, fit = fitdata) elif space.dimension == 2: self.ax = self.figure.add_subplot(121) binoculars.plot.plot(space, self.figure, self.ax, fit = None) self.ax = self.figure.add_subplot(122) binoculars.plot.plot(space, self.figure, self.ax, fit = fitdata) else: self.ax = self.figure.add_subplot(111) binoculars.plot.plot(space, self.figure, self.ax) self.figure.suptitle('{0}, res = {1}, {2} = {3}'.format(self.database.rodkey, self.database.resolution, label, info)) self.canvas.draw() def fit(self, index, space, function): print(index) if not len(space.get_masked().compressed()) == 0: loc = self.get_loc() fit = function(space, loc = loc) fit.fitdata.mask = space.get_masked().mask self.database.save_data(index, 'fit', fit.fitdata) params = list(line.split(':')[0] for line in fit.summary.split('\n')) print(fit.result, fit.variance) for key, value in zip(params, fit.result): self.database.save_sliceattr(index, key, value) for key, value in zip(params, fit.variance): self.database.save_sliceattr(index, 'var_{0}'.format(key), value) def get_loc(self): return self.database.load_loc(self.currentindex()) def currentindex(self): index = self.database.load('index') if index == None: return 0 else: return index class IntegrateWidget(QWidget): def __init__(self, database, parent = None): super(IntegrateWidget, self).__init__(parent) self.parent = parent self.database = database self.figure = matplotlib.figure.Figure() self.canvas = FigureCanvasQTAgg(self.figure) self.toolbar = HiddenToolbar(self.loc_callback, self.canvas) hbox = QHBoxLayout() splitter = QSplitter(Qt.Vertical) self.make_controlwidget() splitter.addWidget(self.canvas) splitter.addWidget(self.control_widget) hbox.addWidget(splitter) self.setLayout(hbox) def make_controlwidget(self): self.control_widget = QWidget() integratebox = QVBoxLayout() intensitybox = QHBoxLayout() backgroundbox = QHBoxLayout() self.aroundroi = QCheckBox('background around roi') self.aroundroi.setChecked(1) self.aroundroi.clicked.connect(self.refresh_aroundroi) self.hsize = QDoubleSpinBox() self.vsize = QDoubleSpinBox() intensitybox.addWidget(QLabel('roi size:')) intensitybox.addWidget(self.hsize) intensitybox.addWidget(self.vsize) self.left = QDoubleSpinBox() self.right = QDoubleSpinBox() self.top = QDoubleSpinBox() self.bottom = QDoubleSpinBox() self.hsize.valueChanged.connect(self.send) self.vsize.valueChanged.connect(self.send) self.left.valueChanged.connect(self.send) self.right.valueChanged.connect(self.send) self.top.valueChanged.connect(self.send) self.bottom.valueChanged.connect(self.send) backgroundbox.addWidget(self.aroundroi) backgroundbox.addWidget(self.left) backgroundbox.addWidget(self.right) backgroundbox.addWidget(self.top) backgroundbox.addWidget(self.bottom) integratebox.addLayout(intensitybox) integratebox.addLayout(backgroundbox) self.fromfit = QRadioButton('peak from fit', self) self.fromfit.setChecked(True) self.fromfit.toggled.connect(self.plot_box) self.fromfit.toggled.connect(self.refresh_tracker) self.fromsegment = QRadioButton('peak from segment', self) self.fromsegment.setChecked(False) self.fromsegment.toggled.connect(self.plot_box) self.fromsegment.toggled.connect(self.refresh_tracker) self.trackergroup = QButtonGroup(self) self.trackergroup.addButton(self.fromfit) self.trackergroup.addButton(self.fromsegment) radiobox = QHBoxLayout() radiobox.addWidget(self.fromfit) radiobox.addWidget(self.fromsegment) integratebox.addLayout(radiobox) self.control_widget.setLayout(integratebox) def refresh_aroundroi(self): self.database.save('aroundroi', self.aroundroi.checkState()) axes = self.database.paxes() if not self.aroundroi.checkState(): self.left.setMinimum(axes[0].min) self.left.setMaximum(axes[0].max) self.right.setMinimum(axes[0].min) self.right.setMaximum(axes[0].max) self.top.setMinimum(axes[1].min) self.top.setMaximum(axes[1].max) self.bottom.setMinimum(axes[1].min) self.bottom.setMaximum(axes[1].max) else: self.left.setMinimum(0) self.left.setMaximum(axes[0].max - axes[0].min) self.right.setMinimum(0) self.right.setMaximum(axes[0].max - axes[0].min) self.top.setMinimum(0) self.top.setMaximum(axes[1].max - axes[1].min) self.bottom.setMinimum(0) self.bottom.setMaximum(axes[1].max - axes[1].min) def refresh_tracker(self): self.database.save('fromfit', self.fromfit.isChecked()) self.plot_box() def set_axis(self): roi = self.database.load('roi') aroundroi = self.database.load('aroundroi') if aroundroi != None: self.aroundroi.setChecked(aroundroi) else: self.aroundroi.setChecked(True) self.refresh_aroundroi() axes = self.database.paxes() self.hsize.setSingleStep(axes[1].res) self.hsize.setDecimals(len(str(axes[1].res)) - 2) self.vsize.setSingleStep(axes[0].res) self.vsize.setDecimals(len(str(axes[0].res)) - 2) self.left.setSingleStep(axes[1].res) self.left.setDecimals(len(str(axes[1].res)) - 2) self.right.setSingleStep(axes[1].res) self.right.setDecimals(len(str(axes[1].res)) - 2) self.top.setSingleStep(axes[0].res) self.top.setDecimals(len(str(axes[0].res)) - 2) self.bottom.setSingleStep(axes[0].res) self.bottom.setDecimals(len(str(axes[0].res)) - 2) tracker = self.database.load('fromfit') if tracker != None: if tracker: self.fromfit.setChecked(True) else: self.fromsegment.setChecked(True) if roi is not None: for box, value in zip([self.hsize, self.vsize, self.left, self.right, self.top, self.bottom], roi): box.setValue(value) def send(self): roi = [self.hsize.value(), self.vsize.value(), self.left.value() ,self.right.value() ,self.top.value(), self.bottom.value()] self.database.save('roi', roi) self.plot_box() def integrate(self, index, space): loc = self.get_loc() if loc != None: axes = space.axes key = space.get_key(self.intkey(loc, axes)) fitdata = self.database.load_data(index, 'fit') if fitdata != None: fitintensity = fitdata[key].data.flatten() fitbkg = numpy.hstack(fitdata[space.get_key(bkgkey)].data.flatten() for bkgkey in self.bkgkeys(loc, axes)) if numpy.alen(fitbkg) == 0: fitstructurefactor = fitintensity.sum() elif numpy.alen(fitintensity) == 0: fitstructurefactor = numpy.nan else: fitstructurefactor = numpy.sqrt(fitintensity.sum() - numpy.alen(fitintensity) * 1.0 / numpy.alen(fitbkg) * fitbkg.sum()) self.database.save_sliceattr(index, 'fitsf', fitstructurefactor) niintensity = space[self.intkey(loc, axes)].get_masked().compressed() try: intensity = interpolate(space[self.intkey(loc, axes)]).flatten() bkg = numpy.hstack(space[bkgkey].get_masked().compressed() for bkgkey in self.bkgkeys(loc, axes)) interdata = space.get_masked() interdata[key] = intensity.reshape(interdata[key].shape) interdata[key].mask = numpy.zeros_like(interdata[key]) self.database.save_data(index, 'inter', interdata) except Exception as e: print('Warning error interpolating slice {0}: {1}'.format(index, e.message)) intensity = numpy.array([]) bkg = numpy.array([]) if numpy.alen(intensity) == 0: structurefactor = numpy.nan nistructurefactor = numpy.nan elif numpy.alen(bkg) == 0: structurefactor = numpy.sqrt(intensity.sum()) nistructurefactor = numpy.sqrt(niintensity.sum()) else: structurefactor = numpy.sqrt(intensity.sum() - numpy.alen(intensity) * 1.0 / numpy.alen(bkg) * bkg.sum()) nistructurefactor = numpy.sqrt(niintensity.sum() - numpy.alen(niintensity) * 1.0 / numpy.alen(bkg) * bkg.sum()) self.database.save_sliceattr(index, 'sf', structurefactor) self.database.save_sliceattr(index, 'nisf', nistructurefactor) print('Structurefactor {0}: {1}'.format(index, structurefactor)) def intkey(self, coords, axes): vsize = self.vsize.value() / 2 hsize = self.hsize.value() / 2 return tuple(ax.restrict(slice(coord - size, coord + size)) for ax, coord, size in zip(axes, coords, [vsize, hsize])) def bkgkeys(self, coords, axes): aroundroi = self.database.load('aroundroi') if aroundroi: key = self.intkey(coords, axes) vsize = self.vsize.value() / 2 hsize = self.hsize.value() / 2 leftkey = (key[0], axes[1].restrict(slice(coords[1] - hsize - self.left.value(), coords[1] - hsize))) rightkey = (key[0], axes[1].restrict(slice(coords[1] + hsize, coords[1] + hsize + self.right.value()))) topkey = (axes[0].restrict(slice(coords[0] - vsize - self.top.value(), coords[0] - vsize)), key[1]) bottomkey = (axes[0].restrict(slice(coords[0] + vsize, coords[0] + vsize + self.bottom.value())), key[1]) return leftkey, rightkey, topkey, bottomkey else: return [(axes[0].restrict(slice(self.left.value(), self.right.value())), axes[1].restrict(slice(self.top.value(), self.bottom.value())))] def get_loc(self): if self.fromfit.isChecked(): return self.database.load_loc(self.currentindex()) else: index = self.currentindex() indexvalue = self.database.get_index_value(index) return self.parent.peakwidget.get_coords(indexvalue) def loc_callback(self, x, y): if self.ax: if self.fromfit.isChecked(): self.database.save_loc(self.currentindex(), numpy.array([x, y])) else: index = self.currentindex() indexvalue = self.database.get_index_value(index) self.parent.peakwidget.add_row(numpy.array([indexvalue, x, y])) self.plot_box() def plot(self, index = None): if index == None: index = self.currentindex() space = self.database.space_from_index(index) interdata = self.database.load_data(index, 'inter') info = self.database.get_index_value(index) label = self.database.axis self.figure.clear() self.figure.space_axes = space.axes if interdata is not None: if space.dimension == 1: self.ax = self.figure.add_subplot(111) binoculars.plot.plot(space, self.figure, self.ax, fit = interdata) elif space.dimension == 2: self.ax = self.figure.add_subplot(121) binoculars.plot.plot(space, self.figure, self.ax, fit = None) self.ax = self.figure.add_subplot(122) binoculars.plot.plot(space, self.figure, self.ax, fit = interdata) else: self.ax = self.figure.add_subplot(111) binoculars.plot.plot(space, self.figure, self.ax) self.figure.suptitle('{0}, res = {1}, {2} = {3}'.format(self.database.rodkey, self.database.resolution, label, info)) self.plot_box() self.canvas.draw() def plot_box(self): loc = self.get_loc() if len(self.figure.get_axes()) != 0 and loc != None: ax = self.figure.get_axes()[0] axes = self.figure.space_axes key = self.intkey(loc, axes) bkgkey = self.bkgkeys(loc, axes) ax.patches = [] rect = Rectangle((key[0].start, key[1].start), key[0].stop - key[0].start, key[1].stop - key[1].start, alpha = 0.2,color = 'k') ax.add_patch(rect) for k in bkgkey: bkg = Rectangle((k[0].start, k[1].start), k[0].stop - k[0].start, k[1].stop - k[1].start, alpha = 0.2,color = 'r') ax.add_patch(bkg) self.canvas.draw() def currentindex(self): index = self.database.load('index') if index == None: return 0 else: return index class ButtonedSlider(QWidget): slice_index = pyqtSignal(int) def __init__(self,parent=None): super(ButtonedSlider, self).__init__(parent) self.navigation_button_left_end = QPushButton('|<') self.navigation_button_left_one = QPushButton('<') self.navigation_slider = QSlider(Qt.Horizontal) self.navigation_slider.sliderReleased.connect(self.send) self.navigation_button_right_one = QPushButton('>') self.navigation_button_right_end = QPushButton('>|') self.navigation_button_left_end.setMaximumWidth(20) self.navigation_button_left_one.setMaximumWidth(20) self.navigation_button_right_end.setMaximumWidth(20) self.navigation_button_right_one.setMaximumWidth(20) self.navigation_button_left_end.clicked.connect(self.slider_change_left_end) self.navigation_button_left_one.clicked.connect(self.slider_change_left_one) self.navigation_button_right_end.clicked.connect(self.slider_change_right_end) self.navigation_button_right_one.clicked.connect(self.slider_change_right_one) box = QHBoxLayout() box.addWidget(self.navigation_button_left_end) box.addWidget(self.navigation_button_left_one) box.addWidget(self.navigation_slider) box.addWidget(self.navigation_button_right_one) box.addWidget(self.navigation_button_right_end) self.setDisabled(True) self.setLayout(box) def set_length(self,length): self.navigation_slider.setMinimum(0) self.navigation_slider.setMaximum(length - 1) self.navigation_slider.setTickPosition(QSlider.TicksBelow) self.navigation_slider.setValue(0) self.setEnabled(True) def send(self): self.slice_index.emit(self.navigation_slider.value()) def slider_change_left_one(self): self.navigation_slider.setValue(max(self.navigation_slider.value() - 1, 0)) self.send() def slider_change_left_end(self): self.navigation_slider.setValue(0) self.send() def slider_change_right_one(self): self.navigation_slider.setValue(min(self.navigation_slider.value() + 1, self.navigation_slider.maximum())) self.send() def slider_change_right_end(self): self.navigation_slider.setValue(self.navigation_slider.maximum()) self.send() def index(self): return self.navigation_slider.value() def set_index(self, index): self.navigation_slider.setValue(index) class HiddenToolbar2(NavigationToolbar2QT): def __init__(self, canvas): super(HiddenToolbar2, self).__init__(canvas, None) self.zoom() class OverviewWidget(QWidget): def __init__(self, database, parent = None): super(OverviewWidget, self).__init__(parent) self.databaselist = list() self.figure = matplotlib.figure.Figure() self.canvas = FigureCanvasQTAgg(self.figure) self.toolbar = HiddenToolbar2(self.canvas) self.table = QTableWidget(0,2) self.make_table() self.table.cellClicked.connect(self.plot) hbox = QHBoxLayout() splitter = QSplitter(Qt.Horizontal) splitter.addWidget(self.canvas) splitter.addWidget(self.control_widget) hbox.addWidget(splitter) self.setLayout(hbox) def select(self): selection = [] for index in range(self.table.rowCount()): checkbox = self.table.cellWidget(index, 0) if checkbox.checkState(): selection.append(str(self.table.cellWidget(index,1).text())) return selection def make_table(self): self.control_widget = QWidget() vbox = QVBoxLayout() minibox = QHBoxLayout() vbox.addWidget(self.table) self.table.setHorizontalHeaderLabels(['','param']) for index, width in enumerate([25,50]): self.table.setColumnWidth(index, width) self.log = QCheckBox('log') self.log.clicked.connect(self.plot) self.export_button = QPushButton('export curves') self.export_button.clicked.connect(self.export) minibox.addWidget(self.log) minibox.addWidget(self.export_button) vbox.addLayout(minibox) self.control_widget.setLayout(vbox) def export(self): folder = str(QFileDialog.getExistingDirectory(self, "Select directory to save curves")) params = self.select() for param in params: for database in self.databaselist: x, y = database.all_from_key(param) args = numpy.argsort(x) numpy.savetxt( os.path.join(folder,'{0}_{1}.txt'.format(param, database.rodkey)), numpy.vstack(arr[args] for arr in [x, y]).T) def refresh(self, databaselist): self.databaselist = databaselist params = self.select() while self.table.rowCount() > 0: self.table.removeRow(0) allparams = list(list(param for param in database.all_attrkeys() if not param.startswith('mask')) for database in databaselist) allparams.extend(list(['locx_s', 'locy_s']) for database in databaselist if database.load_segments() is not None) if len(allparams) > 0: uniqueparams = numpy.unique(numpy.hstack(params for params in allparams)) else: uniqueparams = [] for param in uniqueparams: index = self.table.rowCount() self.table.insertRow(index) checkboxwidget = QCheckBox() if param in params: checkboxwidget.setChecked(1) else: checkboxwidget.setChecked(0) self.table.setCellWidget(index,0, checkboxwidget) checkboxwidget.clicked.connect(self.plot) item = QLabel(param) self.table.setCellWidget(index, 1, item) self.plot() def plot(self): params = self.select() self.figure.clear() self.ax = self.figure.add_subplot(111) for param in params: for database in self.databaselist: if param == 'locx_s': segments = database.load_segments() if segments is not None: x = numpy.hstack(database.get_index_value(index) for index in range(database.rodlength())) y = numpy.vstack(get_coords(xvalue, segments) for xvalue in x) self.ax.plot(x, y[:,0], '+', label = '{0} - {1}'.format('locx_s', database.rodkey)) elif param == 'locy_s': segments = database.load_segments() if segments is not None: x = numpy.hstack(database.get_index_value(index) for index in range(database.rodlength())) y = numpy.vstack(get_coords(xvalue, segments) for xvalue in x) self.ax.plot(x, y[:,1], '+', label = '{0} - {1}'.format('locy_s', database.rodkey)) else: x, y = database.all_from_key(param) self.ax.plot(x, y, '+', label = '{0} - {1}'.format(param, database.rodkey)) self.ax.legend() if self.log.checkState(): self.ax.semilogy() self.canvas.draw() class PeakWidget(QWidget): def __init__(self, database, parent=None): super(PeakWidget, self).__init__(parent) self.database = database # create a QTableWidget self.table = QTableWidget(0, 3, self) self.table.horizontalHeader().setStretchLastSection(True) self.table.verticalHeader().setVisible(False) self.table.itemChanged.connect(self.save) self.btn_add_row = QPushButton('+', self) self.btn_add_row.clicked.connect(self.add_row) self.buttonRemove = QPushButton('-', self) self.buttonRemove.clicked.connect(self.remove) vbox = QVBoxLayout() hbox = QHBoxLayout() hbox.addWidget(self.btn_add_row) hbox.addWidget(self.buttonRemove) vbox.addLayout(hbox) vbox.addWidget(self.table) self.setLayout(vbox) def set_axis(self): self.axes = self.database.paxes() while self.table.rowCount() > 0: self.table.removeRow(0) segments = self.database.load_segments() if segments is not None: for index in range(segments.shape[0]): self.add_row(segments[index, :]) self.table.setHorizontalHeaderLabels(['{0}'.format(self.database.axis), '{0}'.format(self.axes[0].label), '{0}'.format(self.axes[1].label)]) def add_row(self, row = None): rowindex = self.table.rowCount() self.table.insertRow(rowindex) if row is not None: for index in range(3): newitem = QTableWidgetItem(str(row[index])) self.table.setItem(rowindex, index, newitem) def remove(self): self.table.removeRow(self.table.currentRow()) self.save() def axis_coords(self): a = numpy.zeros((self.table.rowCount(), self.table.columnCount())) for rowindex in range(a.shape[0]): for columnindex in range(a.shape[1]): item = self.table.item(rowindex, columnindex) if item is not None: a[rowindex, columnindex] = float(item.text()) return a def save(self): self.database.save_segments(self.axis_coords()) def get_coords(self, x): return get_coords(x, self.axis_coords()) def get_coords(x, coords): if coords.shape[0] == 0: return None if coords.shape[0] == 1: return coords[0,1:] args = numpy.argsort(coords[:,0]) x0 = coords[args,0] x1 = coords[args,1] x2 = coords[args,2] if x < x0.min(): first = 0 last = 1 elif x > x0.max(): first = -2 last = -1 else: first = numpy.searchsorted(x0, x) - 1 last = numpy.searchsorted(x0, x) a1 = (x1[last] - x1[first]) / (x0[last] - x0[first]) b1 = x1[first] - a1 * x0[first] a2 = (x2[last] - x2[first]) / (x0[last] - x0[first]) b2 = x2[first] - a2 * x0[first] return numpy.array([a1 * x + b1, a2 * x + b2]) def interpolate(space): data = space.get_masked() mask = data.mask grid = numpy.vstack(numpy.ma.array(g, mask=mask).compressed() for g in space.get_grid()).T open = numpy.vstack(numpy.ma.array(g, mask=numpy.invert(mask)).compressed() for g in space.get_grid()).T if open.shape[0] == 0: return data.compressed() elif grid.shape[0] == 0: return data.compressed() else: interpolated = griddata(grid, data.compressed(), open) values = data.data.copy() values[mask] = interpolated mask = numpy.isnan(values) if mask.sum() > 0: data = numpy.ma.array(values, mask = mask) grid = numpy.vstack(numpy.ma.array(g, mask=mask).compressed() for g in space.get_grid()).T open = numpy.vstack(numpy.ma.array(g, mask=numpy.invert(mask)).compressed() for g in space.get_grid()).T interpolated = griddata(grid, data.compressed(), open, method = 'nearest') values[mask] = interpolated return values def find_unused_rodkey(rodkey, rods): if not rodkey in rods: return rodkey for index in itertools.count(0): newkey = '{0}_{1}'.format(rodkey, index) if newkey not in rods: return newkey if __name__ == "__main__": app = QApplication(sys.argv) main = Window() main.resize(1000, 600) main.show() sys.exit(app.exec_())