diff options
author | Andrej Shadura <andrewsh@debian.org> | 2019-03-09 22:30:32 +0000 |
---|---|---|
committer | Andrej Shadura <andrewsh@debian.org> | 2019-03-09 22:30:32 +0000 |
commit | c36aa2a8fd31ca5e225ff30278e910070cd2c8c1 (patch) | |
tree | bdfe1a5ccb57999d4d664a2a44121a78c88b19d4 /src/de/lmu/ifi/dbs/elki/datasource/parser/TermFrequencyParser.java | |
parent | 89aa1958dbaf9052da0c24706308a2ef8cefa96e (diff) |
Import Upstream version 0.5.0~beta2
Diffstat (limited to 'src/de/lmu/ifi/dbs/elki/datasource/parser/TermFrequencyParser.java')
-rw-r--r-- | src/de/lmu/ifi/dbs/elki/datasource/parser/TermFrequencyParser.java | 118 |
1 files changed, 66 insertions, 52 deletions
diff --git a/src/de/lmu/ifi/dbs/elki/datasource/parser/TermFrequencyParser.java b/src/de/lmu/ifi/dbs/elki/datasource/parser/TermFrequencyParser.java index 8448e4c4..bb277b17 100644 --- a/src/de/lmu/ifi/dbs/elki/datasource/parser/TermFrequencyParser.java +++ b/src/de/lmu/ifi/dbs/elki/datasource/parser/TermFrequencyParser.java @@ -4,7 +4,7 @@ package de.lmu.ifi.dbs.elki.datasource.parser; This file is part of ELKI: Environment for Developing KDD-Applications Supported by Index-Structures - Copyright (C) 2011 + Copyright (C) 2012 Ludwig-Maximilians-Universität München Lehr- und Forschungseinheit für Datenbanksysteme ELKI Development Team @@ -23,28 +23,25 @@ package de.lmu.ifi.dbs.elki.datasource.parser; along with this program. If not, see <http://www.gnu.org/licenses/>. */ -import java.io.BufferedReader; -import java.io.IOException; -import java.io.InputStream; -import java.io.InputStreamReader; -import java.util.ArrayList; +import gnu.trove.iterator.TIntFloatIterator; +import gnu.trove.map.hash.TIntFloatHashMap; + import java.util.BitSet; -import java.util.Collections; import java.util.HashMap; import java.util.List; -import java.util.Map; -import java.util.TreeMap; import java.util.regex.Pattern; import de.lmu.ifi.dbs.elki.data.LabelList; import de.lmu.ifi.dbs.elki.data.SparseFloatVector; -import de.lmu.ifi.dbs.elki.data.type.TypeUtil; +import de.lmu.ifi.dbs.elki.data.type.SimpleTypeInformation; import de.lmu.ifi.dbs.elki.data.type.VectorFieldTypeInformation; -import de.lmu.ifi.dbs.elki.datasource.bundle.MultipleObjectsBundle; import de.lmu.ifi.dbs.elki.logging.Logging; import de.lmu.ifi.dbs.elki.utilities.documentation.Description; import de.lmu.ifi.dbs.elki.utilities.documentation.Title; -import de.lmu.ifi.dbs.elki.utilities.pairs.Pair; +import de.lmu.ifi.dbs.elki.utilities.exceptions.AbortException; +import de.lmu.ifi.dbs.elki.utilities.optionhandling.OptionID; +import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameterization.Parameterization; +import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.Flag; /** * A parser to load term frequency data, which essentially are sparse vectors @@ -54,7 +51,6 @@ import de.lmu.ifi.dbs.elki.utilities.pairs.Pair; * * @apiviz.has SparseFloatVector */ -// TODO: add a flag to perform TF normalization when using term counts @Title("Term frequency parser") @Description("Parse a file containing term frequencies. The expected format is 'label term1 <freq> term2 <freq> ...'. Terms must not contain the separator character!") public class TermFrequencyParser extends NumberVectorLabelParser<SparseFloatVector> { @@ -74,29 +70,32 @@ public class TermFrequencyParser extends NumberVectorLabelParser<SparseFloatVect HashMap<String, Integer> keymap; /** + * Normalize + */ + boolean normalize; + + /** * Constructor. * + * @param normalize Normalize * @param colSep * @param quoteChar * @param labelIndices */ - public TermFrequencyParser(Pattern colSep, char quoteChar, BitSet labelIndices) { - super(colSep, quoteChar, labelIndices); + public TermFrequencyParser(boolean normalize, Pattern colSep, char quoteChar, BitSet labelIndices) { + super(colSep, quoteChar, labelIndices, SparseFloatVector.STATIC); + this.normalize = normalize; this.maxdim = 0; this.keymap = new HashMap<String, Integer>(); } @Override - protected SparseFloatVector createDBObject(List<Double> attributes) { - throw new UnsupportedOperationException("This method should never be reached."); - } - - @Override - public Pair<SparseFloatVector, LabelList> parseLineInternal(String line) { + protected void parseLineInternal(String line) { List<String> entries = tokenize(line); - Map<Integer, Float> values = new TreeMap<Integer, Float>(); - LabelList labels = new LabelList(); + double len = 0; + TIntFloatHashMap values = new TIntFloatHashMap(); + LabelList labels = null; String curterm = null; for(int i = 0; i < entries.size(); i++) { @@ -105,7 +104,7 @@ public class TermFrequencyParser extends NumberVectorLabelParser<SparseFloatVect } else { try { - Float attribute = Float.valueOf(entries.get(i)); + float attribute = Float.valueOf(entries.get(i)); Integer curdim = keymap.get(curterm); if(curdim == null) { curdim = maxdim + 1; @@ -113,10 +112,14 @@ public class TermFrequencyParser extends NumberVectorLabelParser<SparseFloatVect maxdim += 1; } values.put(curdim, attribute); + len += attribute; curterm = null; } catch(NumberFormatException e) { if(curterm != null) { + if(labels == null) { + labels = new LabelList(1); + } labels.add(curterm); } curterm = entries.get(i); @@ -124,41 +127,33 @@ public class TermFrequencyParser extends NumberVectorLabelParser<SparseFloatVect } } if(curterm != null) { + if(labels == null) { + labels = new LabelList(1); + } labels.add(curterm); } - - return new Pair<SparseFloatVector, LabelList>(new SparseFloatVector(values, maxdim), labels); - } - - @Override - public MultipleObjectsBundle parse(InputStream in) { - BufferedReader reader = new BufferedReader(new InputStreamReader(in)); - int lineNumber = 1; - List<SparseFloatVector> vectors = new ArrayList<SparseFloatVector>(); - List<LabelList> lblc = new ArrayList<LabelList>(); - try { - for(String line; (line = reader.readLine()) != null; lineNumber++) { - if(!line.startsWith(COMMENT) && line.length() > 0) { - Pair<SparseFloatVector, LabelList> pair = parseLineInternal(line); - vectors.add(pair.first); - lblc.add(pair.second); + if(normalize) { + if(Math.abs(len - 1.0) > 1E-10 && len > 1E-10) { + for(TIntFloatIterator iter = values.iterator(); iter.hasNext();) { + iter.advance(); + iter.setValue((float) (iter.value() / len)); } } } - catch(IOException e) { - throw new IllegalArgumentException("Error while parsing line " + lineNumber + "."); - } - // Set maximum dimensionality - for(int i = 0; i < vectors.size(); i ++) { - vectors.get(i).setDimensionality(maxdim); - } - return MultipleObjectsBundle.makeSimple(getTypeInformation(maxdim), vectors, TypeUtil.LABELLIST, lblc); + + curvec = new SparseFloatVector(values, maxdim); + curlbl = labels; } @Override - protected VectorFieldTypeInformation<SparseFloatVector> getTypeInformation(int dimensionality) { - final Map<Integer, Float> emptyMap = Collections.emptyMap(); - return new VectorFieldTypeInformation<SparseFloatVector>(SparseFloatVector.class, dimensionality, new SparseFloatVector(emptyMap, dimensionality)); + protected SimpleTypeInformation<SparseFloatVector> getTypeInformation(int dimensionality) { + if(dimensionality > 0) { + return new VectorFieldTypeInformation<SparseFloatVector>(SparseFloatVector.class, dimensionality, new SparseFloatVector(SparseFloatVector.EMPTYMAP, dimensionality)); + } + if(dimensionality == DIMENSIONALITY_VARIABLE) { + return new SimpleTypeInformation<SparseFloatVector>(SparseFloatVector.class); + } + throw new AbortException("No vectors were read from the input file - cannot determine vector data type."); } @Override @@ -174,9 +169,28 @@ public class TermFrequencyParser extends NumberVectorLabelParser<SparseFloatVect * @apiviz.exclude */ public static class Parameterizer extends NumberVectorLabelParser.Parameterizer<SparseFloatVector> { + /** + * Option ID for normalization + */ + public static final OptionID NORMALIZE_FLAG = OptionID.getOrCreateOptionID("tf.normalize", "Normalize vectors to manhattan length 1 (convert term counts to term frequencies)"); + + /** + * Normalization flag + */ + boolean normalize = false; + + @Override + protected void makeOptions(Parameterization config) { + super.makeOptions(config); + Flag normF = new Flag(NORMALIZE_FLAG); + if(config.grab(normF)) { + normalize = normF.getValue(); + } + } + @Override protected TermFrequencyParser makeInstance() { - return new TermFrequencyParser(colSep, quoteChar, labelIndices); + return new TermFrequencyParser(normalize, colSep, quoteChar, labelIndices); } } }
\ No newline at end of file |