summaryrefslogtreecommitdiff
path: root/seqmagick/transform.py
blob: 4478786ea0d0492189dc5f0d9d721fb4d60c1d8b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
"""
Functions to transform / filter sequences
"""
import collections
import contextlib
import pickle as pickle
import gzip
import itertools
import logging
import re
import string
import tempfile
import random

from Bio import SeqIO
from Bio.Data import CodonTable
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from Bio.SeqUtils.CheckSum import seguid
from functools import reduce

# Characters to be treated as gaps
GAP_CHARS = "-."
GAP_TABLE = {ord(c): None for c in GAP_CHARS}

# Size of temporary file buffer: default to 256MB
DEFAULT_BUFFER_SIZE = 268435456  # 256 * 2**20


@contextlib.contextmanager
def _record_buffer(records, buffer_size=DEFAULT_BUFFER_SIZE):
    """
    Buffer for transform functions which require multiple passes through data.

    Value returned by context manager is a function which returns an iterator
    through records.
    """
    with tempfile.SpooledTemporaryFile(buffer_size, mode='wb+') as tf:
        pickler = pickle.Pickler(tf)
        for record in records:
            pickler.dump(record)

        def record_iter():
            tf.seek(0)
            # _file is used below because it implements the necessary methods for pickle.Unpickler(), namely 'readinto' which is newly required in 3.8. See https://docs.python.org/3/library/tempfile.html#tempfile.SpooledTemporaryFile for details on the _file attribute of tempfile.SpooledTemporaryFile.
            unpickler = pickle.Unpickler(tf._file)
            while True:
                try:
                    yield unpickler.load()
                except EOFError:
                    break

        yield record_iter


def dashes_cleanup(records, prune_chars='.:?~'):
    """
    Take an alignment and convert any undesirable characters such as ? or ~ to
    -.
    """
    logging.info(
        "Applying _dashes_cleanup: converting any of '{}' to '-'.".format(prune_chars))
    translation_table = {ord(c): '-' for c in prune_chars}
    for record in records:
        record.seq = Seq(str(record.seq).translate(translation_table))
        yield record


def deduplicate_sequences(records, out_file):
    """
    Remove any duplicate records with identical sequences, keep the first
    instance seen and discard additional occurences.
    """

    logging.info('Applying _deduplicate_sequences generator: '
                 'removing any duplicate records with identical sequences.')
    checksum_sequences = collections.defaultdict(list)
    for record in records:
        checksum = seguid(record.seq)
        sequences = checksum_sequences[checksum]
        if not sequences:
            yield record
        sequences.append(record.id)

    if out_file is not None:
        with out_file:
            for sequences in checksum_sequences.values():
                out_file.write('%s\n' % (' '.join(sequences),))


def deduplicate_taxa(records):
    """
    Remove any duplicate records with identical IDs, keep the first
    instance seen and discard additional occurences.
    """
    logging.info('Applying _deduplicate_taxa generator: ' + \
                 'removing any duplicate records with identical IDs.')
    taxa = set()
    for record in records:
        # Default to full ID, split if | is found.
        taxid = record.id
        if '|' in record.id:
            try:
                taxid = int(record.id.split("|")[0])
            except:
                # If we couldn't parse an integer from the ID, just fall back
                # on the ID
                logging.warn("Unable to parse integer taxid from %s",
                        taxid)
        if taxid in taxa:
            continue
        taxa.add(taxid)
        yield record


def first_name_capture(records):
    """
    Take only the first whitespace-delimited word as the name of the sequence.
    Essentially removes any extra text from the sequence's description.
    """
    logging.info('Applying _first_name_capture generator: '
                 'making sure ID only contains the first whitespace-delimited '
                 'word.')
    whitespace = re.compile(r'\s+')
    for record in records:
        if whitespace.search(record.description):
            yield SeqRecord(record.seq, id=record.id,
                            description="")
        else:
            yield record


def include_from_file(records, handle):
    """
    Filter the records, keeping only sequences whose ID is contained in the
    handle.
    """
    ids = set(i.strip() for i in handle)

    for record in records:
        if record.id.strip() in ids:
            yield record


def exclude_from_file(records, handle):
    """
    Filter the records, keeping only sequences whose ID is not contained in the
    handle.
    """
    ids = set(i.strip() for i in handle)

    for record in records:
        if record.id.strip() not in ids:
            yield record


def isolate_region(sequences, start, end, gap_char='-'):
    """
    Replace regions before and after start:end with gap chars
    """
    # Check arguments
    if end <= start:
        raise ValueError("start of slice must precede end ({0} !> {1})".format(
            end, start))

    for sequence in sequences:
        seq = sequence.seq
        start_gap = gap_char * start
        end_gap = gap_char * (len(seq) - end)
        seq = Seq(start_gap + str(seq[start:end]) + end_gap)
        sequence.seq = seq
        yield sequence


def _cut_sequences(records, cut_slice):
    """
    Cut sequences given a slice.
    """
    for record in records:
        yield record[cut_slice]

def drop_columns(records, slices):
    """
    Drop all columns present in ``slices`` from records
    """
    for record in records:
        # Generate a set of indices to remove
        drop = set(i for slice in slices
                   for i in range(*slice.indices(len(record))))
        keep = [i not in drop for i in range(len(record))]
        record.seq = Seq(''.join(itertools.compress(record.seq, keep)))
        yield record

def multi_cut_sequences(records, slices):
    # If only a single slice is specified, use _cut_sequences,
    # since this preserves per-letter annotations
    if len(slices) == 1:
        for sequence in _cut_sequences(records, slices[0]):
            yield sequence
    else:
        # For multiple slices, concatenate the slice results
        for record in records:
            pieces = (record[s] for s in slices)
            # SeqRecords support addition as concatenation
            yield reduce(lambda x, y: x + y, pieces)

def _update_slices(record, slices):
    n = itertools.count().__next__
    # Generate a map from indexes in the specified sequence to those in the
    # alignment
    ungap_map = dict((n(), i) for i, base in enumerate(str(record.seq))
                     if base not in GAP_CHARS)
    def update_slice(s):
        """
        Maps a slice relative to ungapped record_id to a slice valid for the
        whole alignment.
        """
        start, end = s.start, s.stop
        if start is not None:
            try:
                start = ungap_map[start]
            except KeyError:
                raise KeyError("""No index {0} in {1}.""".format(
                    start, record.id))
        if end is not None:
            # We need the base in the slice identified by end, not the base
            # at end, otherwise insertions between end-1 and end will be
            # included.
            try:
                end = ungap_map[end - 1] + 1
            except KeyError:
                logging.warn("""No index %d in %s. Keeping columns to end
                    of alignment.""", end, record.id)
                end = None

        return slice(start, end)

    return [update_slice(s) for s in slices]

def cut_sequences_relative(records, slices, record_id):
    """
    Cuts records to slices, indexed by non-gap positions in record_id
    """
    with _record_buffer(records) as r:
        try:
            record = next(i for i in r() if i.id == record_id)
        except StopIteration:
            raise ValueError("Record with id {0} not found.".format(record_id))

        new_slices = _update_slices(record, slices)
        for record in multi_cut_sequences(r(), new_slices):
            yield record

def multi_mask_sequences(records, slices):
    """
    Replace characters sliced by slices with gap characters.
    """
    for record in records:
        record_indices = list(range(len(record)))
        keep_indices = reduce(lambda i, s: i - frozenset(record_indices[s]),
                              slices, frozenset(record_indices))
        seq = ''.join(b if i in keep_indices else '-'
                      for i, b in enumerate(str(record.seq)))
        record.seq = Seq(seq)
        yield record

def mask_sequences_relative(records, slices, record_id):
    with _record_buffer(records) as r:
        try:
            record = next(i for i in r() if i.id == record_id)
        except StopIteration:
            raise ValueError("Record with id {0} not found.".format(record_id))

        new_slices = _update_slices(record, slices)
        for record in multi_mask_sequences(r(), new_slices):
            yield record


def lower_sequences(records):
    """
    Convert sequences to all lowercase.
    """
    logging.info('Applying _lower_sequences generator: '
                 'converting sequences to all lowercase.')
    for record in records:
        yield record.lower()


def upper_sequences(records):
    """
    Convert sequences to all uppercase.
    """
    logging.info('Applying _upper_sequences generator: '
                 'converting sequences to all uppercase.')
    for record in records:
        yield record.upper()


def prune_empty(records):
    """
    Remove any sequences which are entirely gaps ('-')
    """
    for record in records:
        if not all(c == '-' for c in str(record.seq)):
            yield record


def _reverse_annotations(old_record, new_record):
    """
    Copy annotations form old_record to new_record, reversing any
    lists / tuples / strings.
    """
    # Copy the annotations over
    for k, v in list(old_record.annotations.items()):
        # Trim if appropriate
        if isinstance(v, (tuple, list)) and len(v) == len(old_record):
            assert len(v) == len(old_record)
            v = v[::-1]
        new_record.annotations[k] = v

    # Letter annotations must be lists / tuples / strings of the same
    # length as the sequence
    for k, v in list(old_record.letter_annotations.items()):
        assert len(v) == len(old_record)
        new_record.letter_annotations[k] = v[::-1]


def reverse_sequences(records):
    """
    Reverse the order of sites in sequences.
    """
    logging.info('Applying _reverse_sequences generator: '
                 'reversing the order of sites in sequences.')
    for record in records:
        rev_record = SeqRecord(record.seq[::-1], id=record.id,
                               name=record.name,
                               description=record.description)
        # Copy the annotations over
        _reverse_annotations(record, rev_record)

        yield rev_record


def reverse_complement_sequences(records):
    """
    Transform sequences into reverse complements.
    """
    logging.info('Applying _reverse_complement_sequences generator: '
                 'transforming sequences into reverse complements.')
    for record in records:
        rev_record = SeqRecord(record.seq.reverse_complement(),
                               id=record.id, name=record.name,
                               description=record.description)
        # Copy the annotations over
        _reverse_annotations(record, rev_record)

        yield rev_record


def ungap_sequences(records, gap_chars=GAP_TABLE):
    """
    Remove gaps from sequences, given an alignment.
    """
    logging.info('Applying _ungap_sequences generator: removing all gap characters')
    for record in records:
        yield ungap_all(record, gap_chars)


def ungap_all(record, gap_chars=GAP_TABLE):

    record = SeqRecord(
        Seq(str(record.seq).translate(gap_chars)),
        id=record.id, description=record.description
    )
    return record


def _update_id(record, new_id):
    """
    Update a record id to new_id, also modifying the ID in record.description
    """
    old_id = record.id
    record.id = new_id

    # At least for FASTA, record ID starts the description
    record.description = re.sub('^' + re.escape(old_id), new_id, record.description)
    return record


def name_append_suffix(records, suffix):
    """
    Given a set of sequences, append a suffix for each sequence's name.
    """
    logging.info('Applying _name_append_suffix generator: '
                 'Appending suffix ' + suffix + ' to all '
                 'sequence IDs.')
    for record in records:
        new_id = record.id + suffix
        _update_id(record, new_id)
        yield record


def name_insert_prefix(records, prefix):
    """
    Given a set of sequences, insert a prefix for each sequence's name.
    """
    logging.info('Applying _name_insert_prefix generator: '
                 'Inserting prefix ' + prefix + ' for all '
                 'sequence IDs.')
    for record in records:
        new_id = prefix + record.id
        _update_id(record, new_id)
        yield record



def name_include(records, filter_regex):
    """
    Given a set of sequences, filter out any sequences with names
    that do not match the specified regular expression.  Ignore case.
    """
    logging.info('Applying _name_include generator: '
                 'including only IDs matching ' + filter_regex +
                 ' in results.')
    regex = re.compile(filter_regex)
    for record in records:
        if regex.search(record.id) or regex.search(record.description):
            yield record


def name_exclude(records, filter_regex):
    """
    Given a set of sequences, filter out any sequences with names
    that match the specified regular expression.  Ignore case.
    """
    logging.info('Applying _name_exclude generator: '
                 'excluding IDs matching ' + filter_regex + ' in results.')
    regex = re.compile(filter_regex)
    for record in records:
        if not regex.search(record.id) and not regex.search(record.description):
            yield record


def name_replace(records, search_regex, replace_pattern):
    """
    Given a set of sequences, replace all occurrences of search_regex
    with replace_pattern. Ignore case.

    If the ID and the first word of the description match, assume the
    description is FASTA-like and apply the transform to the entire
    description, then set the ID from the first word. If the ID and
    the first word of the description do not match, apply the transform
    to each individually.
    """
    regex = re.compile(search_regex)
    for record in records:
        maybe_id = record.description.split(None, 1)[0]
        if maybe_id == record.id:
            record.description = regex.sub(replace_pattern, record.description)
            record.id = record.description.split(None, 1)[0]
        else:
            record.id = regex.sub(replace_pattern, record.id)
            record.description = regex.sub(replace_pattern, record.description)
        yield record


def seq_include(records, filter_regex):
    """
    Filter any sequences who's seq does not match the filter. Ignore case.
    """
    regex = re.compile(filter_regex)
    for record in records:
        if regex.search(str(record.seq)):
            yield record


def seq_exclude(records, filter_regex):
    """
    Filter any sequences whose seq matches the filter. Ignore case.
    """
    regex = re.compile(filter_regex)
    for record in records:
        if not regex.search(str(record.seq)):
            yield record


def sample(records, k, random_seed=None):
    """Choose a length-``k`` subset of ``records``, retaining the input
    order.  If k > len(records), all are returned. If an integer
    ``random_seed`` is provided, sets ``random.seed()``

    """

    if random_seed is not None:
        random.seed(random_seed)

    result = []
    for i, record in enumerate(records):
        if len(result) < k:
            result.append(record)
        else:
            r = random.randint(0, i)
            if r < k:
                result[r] = record
    return result


def head(records, head):
    """
    Limit results to the top N records.
    With the leading `-', print all but the last N records.
    """
    logging.info('Applying _head generator: '
                 'limiting results to top ' + head + ' records.')

    if head == '-0':
        for record in records:
            yield record
    elif '-' in head:
        with _record_buffer(records) as r:
            record_count = sum(1 for record in r())
            end_index = max(record_count + int(head), 0)
            for record in itertools.islice(r(), end_index):
                yield record
    else:
        for record in itertools.islice(records, int(head)):
            yield record

def tail(records, tail):
    """
    Limit results to the bottom N records.
    Use +N to output records starting with the Nth.
    """
    logging.info('Applying _tail generator: '
                 'limiting results to top ' + tail + ' records.')

    if tail == '+0':
        for record in records:
            yield record
    elif '+' in tail:
        tail = int(tail) - 1
        for record in itertools.islice(records, tail, None):
            yield record
    else:
        with _record_buffer(records) as r:
            record_count = sum(1 for record in r())
            start_index = max(record_count - int(tail), 0)
            for record in itertools.islice(r(), start_index, None):
                yield record

# Squeeze-related
def gap_proportion(sequences, gap_chars='-'):
    """
    Generates a list with the proportion of gaps by index in a set of
    sequences.
    """
    aln_len = None
    gaps = []
    for i, sequence in enumerate(sequences):
        if aln_len is None:
            aln_len = len(sequence)
            gaps = [0] * aln_len
        else:
            if not len(sequence) == aln_len:
                raise ValueError(("Unexpected sequence length {0}. Is this "
                                  "an alignment?").format(len(sequence)))

        # Update any gap positions in gap list
        for j, char in enumerate(sequence.seq):
            if char in gap_chars:
                gaps[j] += 1

    sequence_count = float(i + 1)
    gap_props = [i / sequence_count for i in gaps]
    return gap_props


def squeeze(records, gap_threshold=1.0):
    """
    Remove any gaps that are present in the same position across all sequences
    in an alignment.  Takes a second sequence iterator for determining gap
    positions.
    """
    with _record_buffer(records) as r:
        gap_proportions = gap_proportion(r())

        keep_columns = [g < gap_threshold for g in gap_proportions]

        for record in r():
            sequence = str(record.seq)
            # Trim
            squeezed = itertools.compress(sequence, keep_columns)
            yield SeqRecord(Seq(''.join(squeezed)), id=record.id,
                            description=record.description)

def strip_range(records):
    """
    Cut off trailing /<start>-<stop> ranges from IDs.  Ranges must be 1-indexed and
    the stop integer must not be less than the start integer.
    """
    logging.info('Applying _strip_range generator: '
                 'removing /<start>-<stop> ranges from IDs')
    # Split up and be greedy.
    cut_regex = re.compile(r"(?P<id>.*)\/(?P<start>\d+)\-(?P<stop>\d+)")
    for record in records:
        name = record.id
        match = cut_regex.match(str(record.id))
        if match:
            sequence_id = match.group('id')
            start = int(match.group('start'))
            stop = int(match.group('stop'))
            if start > 0 and start <= stop:
                name = sequence_id
        yield SeqRecord(record.seq, id=name,
                        description='')


def transcribe(records, transcribe):
    """
    Perform transcription or back-transcription.
    transcribe must be one of the following:
        dna2rna
        rna2dna
    """
    logging.info('Applying _transcribe generator: '
                 'operation to perform is ' + transcribe + '.')
    for record in records:
        sequence = str(record.seq)
        description = record.description
        name = record.id
        if transcribe == 'dna2rna':
            dna = Seq(sequence, IUPAC.ambiguous_dna)
            rna = dna.transcribe()
            yield SeqRecord(rna, id=name, description=description)
        elif transcribe == 'rna2dna':
            rna = Seq(sequence, IUPAC.ambiguous_rna)
            dna = rna.back_transcribe()
            yield SeqRecord(dna, id=name, description=description)

# Translate-related functions
class CodonWarningTable(object):
    """
    Translation table for codons tht prints a warning when an unknown
    codon is requested, then returns the value passed as missing_char
    """

    def __init__(self, wrapped, missing_char='X'):
        self.wrapped = wrapped
        self.missing_char = missing_char
        self.seen = set()

    def get(self, codon, missing=None):
        try:
            return self.__getitem__(codon)
        except KeyError:
            return missing

    def __getitem__(self, codon):
        if codon == '---':
            return '-'
        elif '-' in codon:
            if codon not in self.seen:
                logging.warning("Unknown Codon: %s", codon)
                self.seen.add(codon)
            return self.missing_char
        else:
            return self.wrapped.__getitem__(codon)

    def __contains__(self, value):
        return value in self.wrapped


    def __contains__(self, value):
        return value in self.wrapped


def translate(records, translate):
    """
    Perform translation from generic DNA/RNA to proteins.  Bio.Seq
    does not perform back-translation because the codons would
    more-or-less be arbitrary.  Option to translate only up until
    reaching a stop codon.  translate must be one of the following:
        dna2protein
        dna2proteinstop
        rna2protein
        rna2proteinstop
    """
    logging.info('Applying translation generator: '
                 'operation to perform is ' + translate + '.')

    to_stop = translate.endswith('stop')

    source_type = translate[:3]

    # Get a translation table
    table = {'dna': CodonTable.ambiguous_dna_by_name["Standard"],
             'rna': CodonTable.ambiguous_rna_by_name["Standard"]}[source_type]

    # Handle ambiguities by replacing ambiguous codons with 'X'
    # TODO: this copy operation causes infinite recursion with python3.6 -
    # not sure why it was here to begin with.
    # table = copy.deepcopy(table)
    table.forward_table = CodonWarningTable(table.forward_table)

    for record in records:
        sequence = str(record.seq)
        seq = Seq(sequence)
        protein = seq.translate(table, to_stop=to_stop)
        yield SeqRecord(protein, id=record.id, description=record.description)


def max_length_discard(records, max_length):
    """
    Discard any records that are longer than max_length.
    """
    logging.info('Applying _max_length_discard generator: '
                 'discarding records longer than '
                 '.')
    for record in records:
        if len(record) > max_length:
            # Discard
            logging.debug('Discarding long sequence: %s, length=%d',
                record.id, len(record))
        else:
            yield record


def min_length_discard(records, min_length):
    """
    Discard any records that are shorter than min_length.
    """
    logging.info('Applying _min_length_discard generator: '
                 'discarding records shorter than %d.', min_length)
    for record in records:
        if len(record) < min_length:
            logging.debug('Discarding short sequence: %s, length=%d',
                record.id, len(record))
        else:
            yield record


def min_ungap_length_discard(records, min_length):
    """
    Discard any records that are shorter than min_length after removing gaps.
    """
    for record in records:
        if len(ungap_all(record)) >= min_length:
            yield record


def sort_length(source_file, source_file_type, direction=1):
    """
    Sort sequences by length. 1 is ascending (default) and 0 is descending.
    """
    direction_text = 'ascending' if direction == 1 else 'descending'

    logging.info('Indexing sequences by length: %s', direction_text)

    # Adapted from the Biopython tutorial example.

    # Get the lengths and ids, and sort on length
    len_and_ids = sorted((len(rec), rec.id)
                         for rec in SeqIO.parse(source_file, source_file_type))

    if direction == 0:
        ids = reversed([seq_id for (length, seq_id) in len_and_ids])
    else:
        ids = [seq_id for (length, seq_id) in len_and_ids]
    del len_and_ids  # free this memory

    # SeqIO.index does not handle gzip instances
    if isinstance(source_file, gzip.GzipFile):
        tmpfile = tempfile.NamedTemporaryFile()
        source_file.seek(0)
        tmpfile.write(source_file.read())
        tmpfile.seek(0)
        source_file = tmpfile

    record_index = SeqIO.index(source_file.name, source_file_type)

    for seq_id in ids:
        yield record_index[seq_id]


def sort_name(source_file, source_file_type, direction=1):
    """
    Sort sequences by name. 1 is ascending (default) and 0 is descending.
    """

    direction_text = 'ascending' if direction == 1 else 'descending'

    logging.info("Indexing sequences by name: %s", direction_text)

    # Adapted from the Biopython tutorial example.

    # Sort on id
    ids = sorted((rec.id) for rec in SeqIO.parse(source_file,
                                                 source_file_type))

    if direction == 0:
        ids = reversed(ids)

    # SeqIO.index does not handle gzip instances
    if isinstance(source_file, gzip.GzipFile):
        tmpfile = tempfile.NamedTemporaryFile()
        source_file.seek(0)
        tmpfile.write(source_file.read())
        tmpfile.seek(0)
        source_file = tmpfile

    record_index = SeqIO.index(source_file.name, source_file_type)

    for id in ids:
        yield record_index[id]