summaryrefslogtreecommitdiff
path: root/kaptive.py
blob: 3cf607c3c15fc3d8f92758c3ae2f26e1761b0db8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
#!/usr/bin/env python3
"""
Copyright 2021 Ryan Wick (rrwick@gmail.com)
https://github.com/katholt/Kaptive

Kaptive is a tool which reports information about the K and O types for Klebsiella genome
assemblies. It will help a user to decide whether their Klebsiella sample has a known or novel
locus type, and if novel, how similar it is to a known type.

This script needs the following input files to run:
* A Genbank file with known locus types
* One or more assemblies in FASTA format

Example command:
kaptive.py -a path/to/assemblies/*.fasta -k k_loci_refs.gbk -o output_directory

For each input assembly file, Kaptive will identify the closest known locus type and report
information about the corresponding locus genes.

It generates the following output files:
* A FASTA file for each input assembly with the nucleotide sequences matching the closest locus type
* A table summarising the results for all input assemblies

Character codes indicate problems with the locus match:
* `?` indicates that the match was not in a single piece, possible due to a poor match or
      discontiguous assembly
* `-` indicates that genes expected in the locus were not found
* `+` indicates that extra genes were found in the locus
* `*` indicates that one or more expected genes was found but with low identity

This file is part of Kaptive. Kaptive is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by the Free Software Foundation,
either version 3 of the License, or (at your option) any later version. Kaptive is distributed in
the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
details. You should have received a copy of the GNU General Public License along with Kaptive. If
not, see <http://www.gnu.org/licenses/>.
"""

import argparse
import sys
import os
import multiprocessing
import subprocess
import json
import fcntl
import gzip
import copy
import random
import re
from collections import OrderedDict
from Bio import SeqIO

__version__ = '2.0.7'


def main():
    """Script execution starts here."""
    args = get_argument_parser().parse_args()

    check_for_blast()
    check_files_exist(args.assembly + [args.k_refs] + [args.allelic_typing])
    check_assembly_format(args.assembly)
    fix_paths(args)

    output_table = not args.no_table
    output_json = not args.no_json

    temp_dir = make_temp_dir(args)
    ref_seqs, gene_seqs, ref_genes, ref_types = \
        parse_genbank(args.k_refs, temp_dir, args.locus_label, args.type_label)
    special_logic = load_special_logic(args.k_refs, ref_types)

    all_gene_dict = {}
    for gene_list in ref_genes.values():
        for gene in gene_list:
            all_gene_dict[gene.full_name] = gene

    refs = load_locus_references(ref_seqs, ref_genes, ref_types)
    type_gene_names = get_type_gene_names(args.allelic_typing)

    if output_table:
        create_table_file(args.out, type_gene_names)
    json_list = []

    for fasta_file in args.assembly:
        assembly = Assembly(fasta_file)
        best = get_best_locus_match(assembly, ref_seqs, refs, args.threads)
        if best is None:
            best = Locus('None', '', '', [])

        find_assembly_pieces(assembly, best, args)
        if not best.assembly_pieces:
            best = Locus('None', '', '', [])

        assembly_pieces_fasta = save_assembly_pieces_to_file(best, assembly, args.out)
        type_gene_results = type_gene_search(assembly_pieces_fasta, type_gene_names, args)
        if args.no_seq_out and assembly_pieces_fasta is not None:
            os.remove(assembly_pieces_fasta)
        protein_blast(assembly, best, gene_seqs, args)
        apply_special_logic(best, special_logic, ref_genes)
        if best.type == 'unknown':
            best.type = 'unknown (' + best.name + ')'

        output(args.out, assembly, best, args, type_gene_names, type_gene_results,
               json_list, output_table, output_json, all_gene_dict)

    if output_json:
        write_json_file(args.out, json_list)

    clean_up(ref_seqs, gene_seqs, temp_dir)
    sys.exit(0)


def get_argument_parser():
    """Specifies the command line arguments required by the script."""
    parser = argparse.ArgumentParser(description='Kaptive',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    add_arguments_to_parser(parser)
    return parser


def add_arguments_to_parser(parser):
    parser.add_argument('--version', action='version', version='Kaptive v' + __version__,
                        help="Show Kaptive's version number and exit")
    parser.add_argument('-a', '--assembly', nargs='+', type=str, required=True,
                        help='FASTA file(s) for assemblies')
    parser.add_argument('-k', '--k_refs', type=str, required=True,
                        help='GenBank file with reference loci')
    parser.add_argument('-g', '--allelic_typing', type=str, required=False,
                        help='SRST2-formatted FASTA file of allelic typing genes to include in '
                             'results')
    parser.add_argument('-o', '--out', type=str, required=False, default='./kaptive_results',
                        help='Output directory/prefix')
    parser.add_argument('-v', '--verbose', action='store_true',
                        help='Display detailed information about each assembly in stdout')
    parser.add_argument('-t', '--threads', type=int, required=False,
                        default=min(multiprocessing.cpu_count(), 4),
                        help='The number of threads to use for the BLAST searches')
    parser.add_argument('--no_seq_out', action='store_true',
                        help='Suppress output files of sequences matching locus')
    parser.add_argument('--no_table', action='store_true',
                        help='Suppress output of tab-delimited table')
    parser.add_argument('--no_json', action='store_true',
                        help='Suppress output of JSON file')
    parser.add_argument('--start_end_margin', type=int, required=False, default=10,
                        help='Missing bases at the ends of locus allowed in a perfect match.')
    parser.add_argument('--min_gene_cov', type=float, required=False, default=90.0,
                        help='minimum required %% coverage for genes')
    parser.add_argument('--min_gene_id', type=float, required=False, default=80.0,
                        help='minimum required %% identity for genes')
    parser.add_argument('--low_gene_id', type=float, required=False, default=95.0,
                        help='genes with a %% identity below this value will be flagged as low '
                             'identity')
    parser.add_argument('--min_assembly_piece', type=int, required=False, default=100,
                        help='minimum locus matching assembly piece to return')
    parser.add_argument('--gap_fill_size', type=int, required=False, default=100,
                        help='when separate parts of the assembly are found within this distance, '
                             'they will be merged')
    parser.add_argument('--locus_label', type=str, required=False,
                        default='automatically determined',
                        help='In the Genbank file, the source feature must have a note '
                             'identifying the locus name, starting with this label followed by '
                             'a colon (e.g. /note="K locus: KL1")')
    parser.add_argument('--type_label', type=str, required=False,
                        default='automatically determined',
                        help='In the Genbank file, the source feature must have a note '
                             'identifying the type name, starting with this label followed by '
                             'a colon (e.g. /note="K type: K1")')


def check_for_blast():
    """Checks to make sure the required BLAST+ tools are available."""
    if not find_program('makeblastdb'):
        quit_with_error('could not find makeblastdb tool (part of BLAST+)')
    if not find_program('blastn'):
        quit_with_error('could not find blastn tool (part of BLAST+)')
    if not find_program('tblastn'):
        quit_with_error('could not find tblastn tool (part of BLAST+)')


def find_program(name):
    """Checks to see if a program exists."""
    try:
        subprocess.check_call(['which', name], stdout=subprocess.DEVNULL)
    except subprocess.CalledProcessError:
        return False
    return True


def fix_paths(args):
    """
    Changes the paths given by the user to absolute paths, which are easier to work with later.
    Also creates the output directory, if necessary.
    """
    args.assembly = [os.path.abspath(x) for x in args.assembly]
    args.k_refs = os.path.abspath(args.k_refs)
    if args.out[-1] == '/':
        args.out += 'kaptive_results'
    args.out = os.path.abspath(args.out)
    out_dir = os.path.dirname(args.out)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)


def make_temp_dir(args):
    """Makes the temporary directory, if necessary. Returns the temp directory path."""
    temp_dir_name = 'kaptive_temp_' + str(os.getpid()) + '_' + str(random.randint(0, 999999))
    temp_dir = os.path.join(os.path.dirname(args.out), temp_dir_name)
    if not os.path.exists(temp_dir):
        os.makedirs(temp_dir)
    return temp_dir


def clean_up(ref_seqs, gene_seqs, temp_dir):
    """
    Deletes the temporary FASTA files. If the temp directory is then empty, it is deleted too.
    """
    try:
        os.remove(ref_seqs)
    except OSError:
        pass
    try:
        os.remove(gene_seqs)
    except OSError:
        pass
    try:
        if not os.listdir(temp_dir):
            os.rmdir(temp_dir)
    except FileNotFoundError:
        pass


def parse_genbank(genbank, temp_dir, locus_label, type_label):
    """
    This function reads the input Genbank file and produces two temporary FASTA files: one with the
    loci nucleotide sequences and one with the gene sequences.
    It returns the file paths for these two FASTA files along with a dictionary that links genes to
    loci.
    """
    ref_genes, ref_types = {}, {}
    ref_seqs_filename = os.path.join(temp_dir, 'temp_ref_seqs.fasta')
    gene_seqs_filename = os.path.join(temp_dir, 'temp_gene_seqs.fasta')
    ref_seqs = open(ref_seqs_filename, 'wt')
    gene_seqs = open(gene_seqs_filename, 'wt')

    if locus_label == 'automatically determined':
        locus_label = find_label(genbank, 'locus')
    else:
        check_label(genbank, locus_label)
    if type_label == 'automatically determined':
        type_label = find_label(genbank, 'type', required=False)
    else:
        check_label(genbank, type_label)

    for record in SeqIO.parse(genbank, 'genbank'):
        locus_name, type_name = '', ''
        for feature in record.features:
            if feature.type == 'source' and 'note' in feature.qualifiers:
                for note in feature.qualifiers['note']:
                    if note.startswith(locus_label):
                        locus_name = get_name_from_note(note, locus_label)
                    elif note.startswith('Extra genes'):
                        locus_name = note.replace(':', '').replace(' ', '_')
                    elif type_label is not None and note.startswith(type_label):
                        type_name = get_name_from_note(note, type_label)
        if locus_name in ref_genes:
            quit_with_error('Duplicate reference locus name: ' + locus_name)
        ref_genes[locus_name] = []

        # Extra genes are only used for the gene search, not the nucleotide search.
        if not locus_name.startswith('Extra_genes'):
            ref_seqs.write('>' + locus_name + '\n')
            ref_seqs.write(add_line_breaks_to_sequence(str(record.seq), 60))
            ref_types[locus_name] = type_name

        gene_num = 1
        for feature in record.features:
            if feature.type == 'CDS':
                gene = Gene(locus_name, gene_num, feature, record.seq)
                ref_genes[locus_name].append(gene)
                gene_num += 1
                gene_seqs.write(gene.get_fasta())
    ref_seqs.close()
    gene_seqs.close()
    return ref_seqs_filename, gene_seqs_filename, ref_genes, ref_types


def rreplace(s, old, new):
    """
    https://stackoverflow.com/questions/2556108
    """
    li = s.rsplit(old, 1)
    return new.join(li)


def find_label(genbank, text, required=True):
    """
    Automatically finds the label in the Genbank file which contains the specified text. For
    example, if the text is 'locus', then the Genbank file must have exactly one possible label
    containing 'locus' that is present in a note qualifier in the source feature for every record.
    If not, Kaptive will quit with an error.
    """
    possible_locus_labels = set()
    for record in SeqIO.parse(genbank, 'genbank'):
        for feature in record.features:
            if feature.type == 'source' and 'note' in feature.qualifiers:
                for note in feature.qualifiers['note']:
                    if ':' in note:
                        note = note.split(':')[0].strip()
                        if text in note:
                            possible_locus_labels.add(note)
    if not possible_locus_labels:
        if required:
            quit_with_error('None of the records contain a valid ' + text + ' label')
        else:
            return None
    available_locus_labels = possible_locus_labels.copy()
    for record in SeqIO.parse(genbank, 'genbank'):
        locus_labels = set()
        for feature in record.features:
            if feature.type == 'source' and 'note' in feature.qualifiers:
                for note in feature.qualifiers['note']:
                    if ':' in note:
                        locus_labels.add(note.split(':')[0].strip())
        if any(x == 'Extra genes' for x in locus_labels):
            continue
        if not locus_labels:
            quit_with_error('no possible ' + text + ' labels were found for ' + record.name)
        previous_labels = available_locus_labels.copy()
        available_locus_labels = available_locus_labels.intersection(locus_labels)
        if not available_locus_labels:
            error_message = record.name + ' does not have a ' + text + ' label matching the ' \
                            'previous records\n'
            error_message += 'Previous record labels: ' + ', '.join(list(previous_labels)) + '\n'
            error_message += 'Labels in ' + record.name + ': ' + ', '.join(list(locus_labels))
            quit_with_error(error_message)
    if len(available_locus_labels) > 1:
        error_message = 'multiple possible ' + text + ' labels were found: ' + \
                        ', '.join(list(available_locus_labels)) + '\n'
        error_message += 'Please use the --' + text + '_label option to specify which to use'
        quit_with_error(error_message)
    return list(available_locus_labels)[0]


def check_label(genbank, label):
    """
    Makes sure that every record in the Genbank file contains a note in the source feature
    beginning with the given label.
    """
    for record in SeqIO.parse(genbank, 'genbank'):
        found_label = False
        for feature in record.features:
            if feature.type == 'source' and 'note' in feature.qualifiers:
                for note in feature.qualifiers['note']:
                    if note.startswith(label):
                        locus_name = get_name_from_note(note, label)
                        if locus_name:
                            found_label = True
        if not found_label:
            error_message = record.name + ' is missing a label\n'
            error_message += 'The source feature must have a note qualifier beginning with "' + \
                             label + ':" followed by the relevant info'
            quit_with_error(error_message)


def get_name_from_note(full_note, locus_label):
    """
    Extracts the part of the note following the label (and any colons, spaces or equals signs).
    """
    locus_name = full_note[len(locus_label):].strip()
    while locus_name.startswith(':') or locus_name.startswith(' ') or \
            locus_name.startswith('='):
        locus_name = locus_name[1:]
    return locus_name


def check_files_exist(filenames):
    """Checks to make sure each file in the given list exists."""
    for filename in filenames:
        if filename is not None:
            check_file_exists(filename)


def check_file_exists(filename):
    """Checks to make sure the single given file exists."""
    if not os.path.isfile(filename):
        quit_with_error('could not find ' + filename)


def check_assembly_format(filenames):
    """Tries to load each assembly and shows an error if it did not successfully load."""
    for assembly in filenames:
        fasta = load_fasta(assembly)
        if len(fasta) < 1:
            quit_with_error('invalid FASTA file: ' + assembly)
        for record in fasta:
            header, seq = record
            if len(seq) == 0:
                quit_with_error('invalid FASTA file (contains a zero-length sequence): ' +
                                assembly)


def quit_with_error(message):
    """Displays the given message and ends the program's execution."""
    print('Error:', message, file=sys.stderr)
    sys.exit(1)


def get_best_locus_match(assembly, refs_fasta, refs, threads):
    """
    Searches for all known locus types in the given assembly and returns the best match.
    Best match is defined as the locus type for which the largest fraction of the locus has a BLAST
    hit to the assembly. In cases of a tie, the number of BLAST hits (fewer is better) and the mean
    identity of the BLAST hits (higher is better) are used to determine the best.
    """
    for ref in refs.values():
        ref.clear()
    blast_hits = get_blast_hits(assembly.fasta, refs_fasta, threads)

    for hit in blast_hits:
        if hit.qseqid not in refs:
            quit_with_error('BLAST hit (' + hit.qseqid + ') not found in locus references')
        refs[hit.qseqid].add_blast_hit(hit)
    for ref in refs.values():
        ref.clean_up_blast_hits()

    sorted_refs = sorted(refs.values(), reverse=True,
                         key=lambda ref: (ref.get_coverage(),
                                          -len(ref.blast_hits),
                                          ref.get_mean_blast_hit_identity()))
    best_ref = sorted_refs[0]
    if best_ref.get_coverage() == 0.0:
        return None
    else:
        return copy.copy(best_ref)


def type_gene_search(assembly_pieces_fasta, type_gene_names, args):
    if not type_gene_names or not args.allelic_typing:
        return {}
    if not assembly_pieces_fasta:
        return {x: None for x in type_gene_names}

    makeblastdb(assembly_pieces_fasta)
    all_gene_blast_hits = get_blast_hits(assembly_pieces_fasta, args.allelic_typing, args.threads,
                                         type_genes=True)
    clean_blast_db(assembly_pieces_fasta)

    # Filter out small hits.
    all_gene_blast_hits = [x for x in all_gene_blast_hits
                           if x.query_cov >= args.min_gene_cov and x.pident >= args.min_gene_id]

    type_gene_results = {}
    for gene_name in type_gene_names:
        blast_hits = sorted([x for x in all_gene_blast_hits if x.gene_name == gene_name],
                            reverse=True, key=lambda z: z.bitscore)
        if not blast_hits:
            hit = None
        else:
            perfect_match = None
            for hit in blast_hits:
                if hit.pident == 100.0 and hit.query_cov == 100.0:
                    perfect_match = hit
                    break
            if perfect_match:
                hit = perfect_match
                hit.result = str(perfect_match.allele_number)
            else:
                hit = blast_hits[0]
                hit.result = str(blast_hits[0].allele_number) + '*'
        type_gene_results[gene_name] = hit

    return type_gene_results


def find_assembly_pieces(assembly, locus, args):
    """
    This function uses the BLAST hits in the given locus type to find the corresponding pieces of
    the given assembly. It saves its results in the Locus object.
    """
    if not locus.blast_hits:
        return
    assembly_pieces = [x.get_assembly_piece(assembly) for x in locus.blast_hits]
    merged_pieces = merge_assembly_pieces(assembly_pieces)
    length_filtered_pieces = [x for x in merged_pieces if x.get_length() >= args.min_assembly_piece]
    if not length_filtered_pieces:
        return
    locus.assembly_pieces = fill_assembly_piece_gaps(length_filtered_pieces, args.gap_fill_size)

    # Now check to see if the biggest assembly piece seems to capture the whole locus. If so, this
    # is an ideal match.
    biggest_piece = sorted(locus.assembly_pieces, key=lambda z: z.get_length(), reverse=True)[0]
    start = biggest_piece.earliest_hit_coordinate()
    end = biggest_piece.latest_hit_coordinate()
    if good_start_and_end(start, end, locus.get_length(), args.start_end_margin):
        locus.assembly_pieces = [biggest_piece]

    # If it isn't the ideal case, we still want to check if the start and end of the locus were
    # found in the same contig. If so, fill all gaps in between so we include the entire
    # intervening sequence.
    else:
        earliest, latest, same_contig_and_strand = locus.get_earliest_and_latest_pieces()
        start = earliest.earliest_hit_coordinate()
        end = latest.latest_hit_coordinate()
        if good_start_and_end(start, end, locus.get_length(), args.start_end_margin) and \
           same_contig_and_strand:
            gap_filling_piece = AssemblyPiece(assembly, earliest.contig_name, earliest.start,
                                              latest.end, earliest.strand)
            locus.assembly_pieces = merge_assembly_pieces(locus.assembly_pieces +
                                                          [gap_filling_piece])
    locus.identity = get_mean_identity(locus.assembly_pieces)


def protein_blast(assembly, locus, gene_seqs, args):
    """
    Conducts a BLAST search of all known locus proteins. Stores the results in the Locus
    object.
    """
    hits = get_blast_hits(assembly.fasta, gene_seqs, args.threads, genes=True)
    hits = [x for x in hits if x.query_cov >= args.min_gene_cov and x.pident >= args.min_gene_id]

    best_hits = []
    for expected_gene in locus.gene_names:
        best_hit = get_best_hit_for_query(hits, expected_gene, locus)
        if best_hit is not None:
            best_hits.append(best_hit)
    best_hits = sorted(best_hits, key=lambda x: x.bitscore, reverse=True)
    for best_hit in best_hits:
        if best_hit in hits:
            hits = cull_conflicting_hits(best_hit, hits)

    expected_hits = []
    for expected_gene in locus.gene_names:
        best_hit = get_best_hit_for_query(hits, expected_gene, locus)
        if not best_hit:
            locus.missing_expected_genes.append(expected_gene)
        else:
            best_hit.over_identity_threshold = best_hit.pident >= args.low_gene_id
            expected_hits.append(best_hit)
            hits = [x for x in hits if x is not best_hit]
            hits = cull_conflicting_hits(best_hit, hits)
    other_hits = cull_all_conflicting_hits(hits)

    locus.expected_hits_inside_locus = [x for x in expected_hits
                                        if x.in_assembly_pieces(locus.assembly_pieces)]
    locus.expected_hits_outside_locus = [x for x in expected_hits
                                         if not x.in_assembly_pieces(locus.assembly_pieces)]
    locus.other_hits_inside_locus = [x for x in other_hits
                                     if x.in_assembly_pieces(locus.assembly_pieces)]
    locus.other_hits_outside_locus = [x for x in other_hits
                                      if not x.in_assembly_pieces(locus.assembly_pieces)]


def create_table_file(output_prefix, type_gene_names):
    """
    Creates the table file and writes a header line if necessary.
    If the file already exists and the header line is correct, then it does nothing (to allow
    multiple independent processes to append to the file).
    """
    table_path = output_prefix + '_table.txt'

    # If the table already exists, we don't need to do anything.
    if os.path.isfile(table_path):
        with open(table_path, 'r') as existing_table:
            first_line = existing_table.readline().strip()
            if first_line.startswith('Assembly\tBest match locus'):
                return

    headers = ['Assembly',
               'Best match locus',
               'Best match type',
               'Match confidence',
               'Problems',
               'Coverage',
               'Identity',
               'Length discrepancy',
               'Expected genes in locus',
               'Expected genes in locus, details',
               'Missing expected genes',
               'Other genes in locus',
               'Other genes in locus, details',
               'Expected genes outside locus',
               'Expected genes outside locus, details',
               'Other genes outside locus',
               'Other genes outside locus, details']

    if type_gene_names:
        headers += type_gene_names

    with open(table_path, 'w') as table:
        table.write('\t'.join(headers))
        table.write('\n')


def get_type_gene_names(type_genes_fasta):
    gene_names = []
    if type_genes_fasta:
        gene_names = set()
        with open(type_genes_fasta, 'rt') as type_genes_db:
            for line in type_genes_db:
                if line.startswith('>'):
                    try:
                        gene_names.add(line.split('>')[1].split('__')[1])
                    except IndexError:
                        quit_with_error(type_genes_fasta + ' not formatted as an SRST2 database '
                                                           'FASTA file')
        if not gene_names:
            quit_with_error(type_genes_fasta + ' not formatted as an SRST2 database FASTA file')
        gene_names = sorted(list(gene_names))
    return gene_names


def load_special_logic(ref_filename, ref_types):
    """
    If any of the reference loci have a type of 'special logic', that implies that a corresponding
    file exists to describe that logic. This function loads that special logic file if needed.
    """
    if not any(t == 'special logic' for t in ref_types.values()):
        return []
    special_logic = []
    assert ref_filename.endswith('.gbk')
    special_logic_filename = rreplace(ref_filename, '.gbk', '.logic')
    check_file_exists(special_logic_filename)
    with open(special_logic_filename, 'rt') as special_logic_file:
        for line in special_logic_file:
            parts = line.strip().split('\t')
            assert len(parts) == 3
            locus, extra_loci, new_type = parts
            if locus == 'locus':  # header line
                continue
            if extra_loci.lower() == 'none':
                extra_loci = []
            else:
                extra_loci = sorted(extra_loci.split(','))
            special_logic.append((locus, extra_loci, new_type))
    return special_logic


def apply_special_logic(locus, special_logic, ref_genes):
    """
    This function has special logic for dealing with the locus -> type situations that depend on
    other genes in the genome.
    """
    if not locus.type == 'special logic':
        return

    other_gene_names = [x.qseqid for x in locus.other_hits_outside_locus]
    extra_gene_names = sorted(n for n in other_gene_names if n.startswith('Extra_genes_'))

    # Look for any 'Extra genes' loci for which all of their genes have been found in this genome.
    found_loci = []
    for ref_locus, genes in ref_genes.items():
        if ref_locus.startswith('Extra_genes_'):
            short_locus_name = ref_locus.replace('Extra_genes_', '')
            locus_gene_names = [g.full_name for g in genes]
            if all(g in extra_gene_names for g in locus_gene_names):
                found_loci.append(short_locus_name)
        found_loci = sorted(found_loci)

    # See if the combination of best-match-locus and extra-loci is represented in the special logic,
    # and if so, change the type.
    new_types = []
    for locus_name, extra_loci, new_type in special_logic:
        if locus.name == locus_name and found_loci == extra_loci:
            new_types.append(new_type)
    if len(new_types) == 0:
        locus.type = 'unknown'
    elif len(new_types) == 1:
        locus.type = new_types[0]
    else:  # multiple matches - shouldn't happen!
        quit_with_error('redundancy in special logic file')


def output(output_prefix, assembly, locus, args, type_gene_names, type_gene_results,
           json_list, output_table, output_json, all_gene_dict):
    """
    Writes a line to the output table describing all that we've learned about the given locus and
    writes to stdout as well.
    """
    uncertainty_chars = locus.get_match_uncertainty_chars()

    try:
        expected_in_locus_per = 100.0 * len(locus.expected_hits_inside_locus) / \
                                len(locus.gene_names)
        expected_out_locus_per = 100.0 * len(locus.expected_hits_outside_locus) / \
            len(locus.gene_names)
        expected_genes_in_locus_str = str(len(locus.expected_hits_inside_locus)) + ' / ' + \
            str(len(locus.gene_names)) + ' (' + float_to_str(expected_in_locus_per) + '%)'
        expected_genes_out_locus_str = str(len(locus.expected_hits_outside_locus)) + ' / ' + \
            str(len(locus.gene_names)) + ' (' + float_to_str(expected_out_locus_per) + '%)'
        missing_per = 100.0 * len(locus.missing_expected_genes) / len(locus.gene_names)
        missing_genes_str = str(len(locus.missing_expected_genes)) + ' / ' + \
            str(len(locus.gene_names)) + ' (' + float_to_str(missing_per) + '%)'
    except ZeroDivisionError:
        expected_genes_in_locus_str, expected_genes_out_locus_str, missing_genes_str = '', '', ''

    output_to_stdout(assembly, locus, args.verbose, type_gene_names, type_gene_results,
                     uncertainty_chars, expected_genes_in_locus_str, expected_genes_out_locus_str,
                     missing_genes_str)
    if output_table:
        output_to_table(output_prefix, assembly, locus, type_gene_names, type_gene_results,
                        uncertainty_chars, expected_genes_in_locus_str,
                        expected_genes_out_locus_str)
    if output_json:
        add_to_json(assembly, locus, type_gene_names, type_gene_results, json_list,
                    uncertainty_chars, all_gene_dict)


def output_to_table(output_prefix, assembly, locus, type_gene_names, type_gene_results,
                    uncertainty_chars, expected_genes_in_locus_str, expected_genes_out_locus_str):
    line = [assembly.name,
            locus.name,
            locus.type,
            locus.get_match_confidence(),
            uncertainty_chars,
            locus.get_coverage_string(),
            locus.get_identity_string(),
            locus.get_length_discrepancy_string(),
            expected_genes_in_locus_str,
            get_gene_info_string(locus.expected_hits_inside_locus),
            ';'.join(locus.missing_expected_genes),
            str(len(locus.other_hits_inside_locus)),
            get_gene_info_string(locus.other_hits_inside_locus),
            expected_genes_out_locus_str,
            get_gene_info_string(locus.expected_hits_outside_locus),
            str(len(locus.other_hits_outside_locus)),
            get_gene_info_string(locus.other_hits_outside_locus)]

    for gene_name in type_gene_names:
        hit = type_gene_results[gene_name]
        line.append('-' if not hit else hit.result)

    table_path = output_prefix + '_table.txt'
    table = open(table_path, 'at')
    table.write('\t'.join(line))
    table.write('\n')
    table.close()


def add_to_json(assembly, locus, type_gene_names, type_gene_results, json_list,
                uncertainty_chars, all_gene_dict):
    json_record = OrderedDict()
    json_record['Assembly name'] = assembly.name

    match_dict = OrderedDict()
    match_dict['Locus name'] = locus.name
    match_dict['Type'] = locus.type
    match_dict['Match confidence'] = locus.get_match_confidence()

    reference_dict = OrderedDict()
    reference_dict['Length'] = len(locus.seq)
    reference_dict['Sequence'] = locus.seq
    match_dict['Reference'] = reference_dict
    json_record['Best match'] = match_dict

    problems = OrderedDict()
    problems['Locus assembled in multiple pieces'] = str('?' in uncertainty_chars)
    problems['Missing genes in locus'] = str('-' in uncertainty_chars)
    problems['Extra genes in locus'] = str('+' in uncertainty_chars)
    problems['At least one low identity gene'] = str('*' in uncertainty_chars)
    json_record['Problems'] = problems

    blast_results = OrderedDict()
    blast_results['Coverage'] = locus.get_coverage_string()
    blast_results['Identity'] = locus.get_identity_string()
    blast_results['Length discrepancy'] = locus.get_length_discrepancy_string()
    assembly_pieces = []
    for i, piece in enumerate(locus.assembly_pieces):
        assembly_piece = OrderedDict()
        assembly_piece['Contig name'] = piece.contig_name
        assembly_piece['Contig start position'] = piece.start + 1
        assembly_piece['Contig end position'] = piece.end
        assembly_piece['Contig strand'] = piece.strand
        piece_seq = piece.get_sequence()
        assembly_piece['Length'] = len(piece_seq)
        assembly_piece['Sequence'] = piece_seq
        assembly_pieces.append(assembly_piece)
    blast_results['Locus assembly pieces'] = assembly_pieces
    json_record['blastn result'] = blast_results

    expected_genes_in_locus = {x.qseqid: x for x in locus.expected_hits_inside_locus}
    expected_hits_outside_locus = {x.qseqid: x for x in locus.expected_hits_outside_locus}
    other_hits_inside_locus = {x.qseqid: x for x in locus.other_hits_inside_locus}
    other_hits_outside_locus = {x.qseqid: x for x in locus.other_hits_outside_locus}

    locus_genes = []
    for gene in locus.genes:
        gene_dict = OrderedDict()
        gene_name = gene.full_name
        gene_dict['Name'] = gene_name
        if gene_name in expected_genes_in_locus:
            gene_dict['Result'] = 'Found in locus'
        elif gene_name in expected_hits_outside_locus:
            gene_dict['Result'] = 'Found outside locus'
        else:
            gene_dict['Result'] = 'Not found'
        gene_dict['Reference'] = gene.get_reference_info_json_dict()

        if gene_name in expected_genes_in_locus or gene_name in expected_hits_outside_locus:
            if gene_name in expected_genes_in_locus:
                hit = expected_genes_in_locus[gene_name]
            else:
                hit = expected_hits_outside_locus[gene_name]
            gene_dict['tblastn result'] = hit.get_blast_result_json_dict(assembly)
            gene_dict['Match confidence'] = hit.get_match_confidence()
        else:
            gene_dict['Match confidence'] = 'Not found'

        locus_genes.append(gene_dict)
    json_record['Locus genes'] = locus_genes

    extra_genes = OrderedDict()
    for gene_name, hit in other_hits_inside_locus.items():
        gene_dict = OrderedDict()
        gene = all_gene_dict[gene_name]
        gene_dict['Reference'] = gene.get_reference_info_json_dict()
        gene_dict['tblastn result'] = hit.get_blast_result_json_dict(assembly)
        extra_genes[gene_name] = gene_dict
    json_record['Other genes in locus'] = extra_genes

    other_genes = OrderedDict()
    for gene_name, hit in other_hits_outside_locus.items():
        gene_dict = OrderedDict()
        gene = all_gene_dict[gene_name]
        gene_dict['Reference'] = gene.get_reference_info_json_dict()
        gene_dict['tblastn result'] = hit.get_blast_result_json_dict(assembly)
        other_genes[gene_name] = gene_dict
    json_record['Other genes outside locus'] = other_genes

    if type_gene_names:
        allelic_typing = OrderedDict()
        for gene_name in type_gene_names:
            allelic_type = OrderedDict()
            if not type_gene_results[gene_name]:
                allelic_type['Allele'] = 'Not found'
            else:
                blast_hit = type_gene_results[gene_name]
                allele = blast_hit.result
                if allele.endswith('*'):
                    perfect_match = False
                    allele = allele[:-1]
                else:
                    perfect_match = True
                try:
                    allele = int(allele)
                except ValueError:
                    pass
                allelic_type['Allele'] = allele
                allelic_type['Perfect match'] = str(perfect_match)
                allelic_type['blastn result'] = blast_hit.get_blast_result_json_dict(assembly)
            allelic_typing[gene_name] = allelic_type
        json_record['Allelic_typing'] = allelic_typing

    json_list.append(json_record)


def write_json_file(output_prefix, json_list):
    json_filename = output_prefix + '.json'
    if not os.path.isfile(json_filename):
        with open(output_prefix + '.json', 'wt') as json_out:
            fcntl.flock(json_out, fcntl.LOCK_EX)
            json_out.write(json.dumps(json_list, indent=4))
            json_out.write('\n')
            fcntl.flock(json_out, fcntl.LOCK_UN)
    else:
        with open(output_prefix + '.json', 'r+t') as json_out:
            fcntl.flock(json_out, fcntl.LOCK_EX)
            file_data = json_out.read()
            try:
                existing_json_list = json.loads(file_data, object_pairs_hook=OrderedDict)
                json_list = existing_json_list + json_list
            except ValueError:
                pass
            json_out.seek(0)
            json_out.write(json.dumps(json_list, indent=4))
            json_out.write('\n')
            json_out.truncate()
            fcntl.flock(json_out, fcntl.LOCK_UN)


def output_to_stdout(assembly, locus, verbose, type_gene_names, type_gene_results,
                     uncertainty_chars, expected_genes_in_locus_str, expected_genes_out_locus_str,
                     missing_genes_str):
    if verbose:
        print()
        assembly_name_line = 'Assembly: ' + assembly.name
        print(assembly_name_line)
        print('-' * len(assembly_name_line))
        print('    Best match locus: ' + locus.name)
        print('    Best match type: ' + locus.type)
        print('    Match confidence: ' + locus.get_match_confidence())
        print('    Problems: ' + (uncertainty_chars if uncertainty_chars else 'None'))
        print('    Coverage: ' + locus.get_coverage_string())
        print('    Identity: ' + locus.get_identity_string())
        print('    Length discrepancy: ' + locus.get_length_discrepancy_string())
        print()
        print_assembly_pieces(locus.assembly_pieces)
        print_gene_hits('Expected genes in locus: ' + expected_genes_in_locus_str,
                        locus.expected_hits_inside_locus)
        print_gene_hits('Expected genes outside locus: ' + expected_genes_out_locus_str,
                        locus.expected_hits_outside_locus)
        print('    Missing expected genes: ' + missing_genes_str)
        for missing_gene in locus.missing_expected_genes:
            print('        ' + missing_gene)
        print()
        print_gene_hits('Other genes in locus: ' + str(len(locus.other_hits_inside_locus)),
                        locus.other_hits_inside_locus)
        print_gene_hits('Other genes outside locus: ' + str(len(locus.other_hits_outside_locus)),
                        locus.other_hits_outside_locus)

        for gene_name in type_gene_names:
            result = 'Not found' if not type_gene_results[gene_name] \
                else type_gene_results[gene_name].result
            print('    ' + gene_name + ' allele: ' + result)
        print()

    else:  # not verbose
        simple_output = assembly.name + ': ' + locus.name + uncertainty_chars
        for gene_name in type_gene_names:
            result = 'Not found' if not type_gene_results[gene_name] \
                else type_gene_results[gene_name].result
            simple_output += ', ' + gene_name + '=' + result
        print(simple_output)


def print_assembly_pieces(pieces):
    """This function prints assembly pieces nicely for verbose output."""
    print('    Locus assembly pieces:')
    if pieces:
        longest_header = max([len(x.get_nice_header()) for x in pieces])
        for piece in pieces:
            first_part = piece.get_nice_header()
            first_part = first_part.ljust(longest_header)
            second_part = piece.get_sequence_short()
            print('        ' + first_part + '  ' + second_part)
    print()


def print_gene_hits(title, hits):
    """This function prints gene hits nicely for verbose output."""
    print('    ' + title)
    if hits:
        longest_gene_name = max([len(x.qseqid) for x in hits])
        longest_contig_details = max([len(x.get_contig_details_string()) for x in hits])
        longest_coverage_details = max([len(x.get_coverage_details_string()) for x in hits])
        cov_space = max([x.query_cov for x in hits]) == 100.0
        id_space = max([x.pident for x in hits]) == 100.0
        spacing_1 = longest_gene_name + 2
        spacing_2 = spacing_1 + longest_contig_details + 2
        spacing_3 = spacing_2 + longest_coverage_details + 2
        for hit in hits:
            print('        ' + hit.get_aligned_string(spacing_1, spacing_2, spacing_3,
                                                      cov_space, id_space))
    print()


def float_to_str(float_in):
    """
    This function converts a float to a string in a special manner: if the float is an integer,
    the resulting string has no decimal point. Otherwise, one decimal point is used.
    """
    if float_in == int(float_in):
        return str(int(float_in))
    else:
        return '%.1f' % float_in


def get_blast_hits(database, query, threads, genes=False, type_genes=False):
    """Returns a list BlastHit objects for a search of the given query in the given database."""
    if genes:
        command = ['tblastn',
                   '-db_gencode', '11',  # bacterial translation table
                   '-seg', 'no']         # don't filter out low complexity regions
    else:
        command = ['blastn', '-task', 'blastn']
    command += ['-db', database, '-query', query, '-num_threads', str(threads), '-outfmt',
                '6 qseqid sseqid qstart qend sstart send evalue bitscore length pident qlen qseq '
                'sseq']
    process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    out, err = process.communicate()
    out = convert_bytes_to_str(out)
    err = convert_bytes_to_str(err)
    if err or process.returncode != 0:
        msg = command[0] + ' crashed!\n'

        # A known crash can occur with tblastn and recent versions of BLAST+ when multiple threads
        # are used. Check for this case and display an informative error message if so.
        version = get_blast_version(command[0])
        bad_version = re.match(r'2\.(?:[4-9]|1[01])\.\d+$', version)
        if threads > 1 and bad_version and (not err or err.startswith('terminate called')):
            msg += '\nYou are using BLAST+ v' + version + ' which may crash when running with '
            msg += 'multiple threads.\n\n'
            msg += 'To avoid this issue, try one of the following:\n'
            msg += '  1) Use an unaffected version of BLAST+ (v2.3.0 or earlier should work)\n'
            msg += '  2) Run Kaptive with "--threads 1" (will probably be slower)\n'
            if err:
                msg += '\nRaw error:\n' + err
            quit_with_error(msg)
        elif err:
            quit_with_error(command[0] + ' encountered an error:\n' + err)
        else:
            quit_with_error(msg)

    if genes:
        blast_hits = [GeneBlastHit(line) for line in line_iterator(out)]
    elif type_genes:
        blast_hits = [TypeGeneBlastHit(line) for line in line_iterator(out)]
    else:
        blast_hits = [BlastHit(line) for line in line_iterator(out)]
    return blast_hits


def get_blast_version(program):
    command = [program, '-version']
    process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    out, err = process.communicate()
    out = convert_bytes_to_str(out)
    try:
        return out.split(': ')[1].split()[0].split('+')[0]
    except IndexError:
        return ''


def get_best_hit_for_query(blast_hits, query_name, locus):
    """
    Given a list of BlastHits, this function returns the best hit for the given query, based first
    on whether or not the hit is in the assembly pieces, then on bit score.
    It returns None if no BLAST hits match that query.
    """
    matching_hits = [x for x in blast_hits if x.qseqid == query_name]
    if matching_hits:
        return sorted(matching_hits,
                      key=lambda z: (z.in_assembly_pieces(locus.assembly_pieces), z.bitscore),
                      reverse=True)[0]
    else:
        return None


def cull_conflicting_hits(hit_to_keep, blast_hits):
    """
    This function returns a (potentially) reduced set of BLAST hits which excludes BLAST hits that
    overlap too much (same part of assembly) with the hit to keep.
    """
    return [x for x in blast_hits if not x.conflicts(hit_to_keep)]


def cull_all_conflicting_hits(blast_hits):
    """
    This function returns a (potentially) reduced set of BLAST hits where none of the remaining
    hits conflict.
    """
    blast_hits.sort(key=lambda x: x.bitscore, reverse=True)
    kept_hits = []
    while blast_hits:
        kept_hits.append(blast_hits.pop(0))
        blast_hits = cull_conflicting_hits(kept_hits[-1], blast_hits)
    return kept_hits


def merge_assembly_pieces(pieces):
    """
    Takes a list of AssemblyPiece objects and returns another list of AssemblyPiece objects where
    the overlapping pieces have been merged.
    """
    while True:
        merged_pieces = []
        merge_count = 0
        while pieces:
            merged_piece = pieces[0]
            unmerged = []
            for other_piece in pieces[1:]:
                combined = merged_piece.combine(other_piece)
                if not combined:
                    unmerged.append(other_piece)
                else:
                    merged_piece = combined
                    merge_count += 1
            merged_pieces.append(merged_piece)
            pieces = unmerged
        if merge_count == 0:
            break
        else:
            pieces = merged_pieces
    return merged_pieces


def fill_assembly_piece_gaps(pieces, max_gap_fill_size):
    """
    This function takes a list of assembly pieces, and if any of them are close enough to each
    other, the gap will be merged in.
    It assumes that all given pieces are from the same assembly.
    """
    pieces_by_contig_and_strand = {}
    fixed_pieces = []
    for piece in pieces:
        contig = piece.contig_name
        strand = piece.strand
        if (contig, strand) not in pieces_by_contig_and_strand:
            pieces_by_contig_and_strand[(contig, strand)] = []
        pieces_by_contig_and_strand[(contig, strand)].append(piece)
    for (contig, strand), pieces_in_contig_and_strand in pieces_by_contig_and_strand.items():
        gap_filling_pieces = []
        sorted_pieces = sorted(pieces_in_contig_and_strand, key=lambda x: x.start)
        max_end = sorted_pieces[0].end
        gaps = []
        for piece in sorted_pieces[1:]:
            if piece.start > max_end and piece.start - max_end <= max_gap_fill_size:
                gaps.append((max_end, piece.start))
            max_end = max(max_end, piece.end)
        assembly = sorted_pieces[0].assembly
        for gap in gaps:
            gap_filling_pieces.append(AssemblyPiece(assembly, contig, gap[0], gap[1], strand))
        before_merge = pieces_in_contig_and_strand + gap_filling_pieces
        filled_pieces = merge_assembly_pieces(before_merge)
        fixed_pieces += filled_pieces
    return fixed_pieces


def get_mean_identity(pieces):
    """Returns the mean identity (weighted by sequence length) for a list of assembly pieces."""
    identity_sum = 0.0
    length_sum = 0
    for piece in pieces:
        for hit in piece.blast_hits:
            length_sum += hit.length
            identity_sum += hit.length * hit.pident
    if identity_sum == 0.0:
        return 0.0
    else:
        return identity_sum / length_sum


def reverse_complement(seq):
    """Given a DNA sequences, this function returns the reverse complement sequence."""
    rev_comp = ''
    for i in reversed(range(len(seq))):
        rev_comp += complement_base(seq[i])
    return rev_comp


def complement_base(base):
    """Given a DNA base, this returns the complement."""
    forward = 'ATGCatgcRYSWKMryswkmBDHVbdhvNn.-?'
    reverse = 'TACGtacgYRSWMKyrswmkVHDBvhdbNn.-?N'
    return reverse[forward.find(base)]


def save_assembly_pieces_to_file(locus, assembly, output_prefix):
    """
    Creates a single FASTA file for all of the assembly pieces.
    Assumes all assembly pieces are from the same assembly.
    """
    if not locus.assembly_pieces:
        return None
    fasta_file_name = output_prefix + '_' + assembly.name + '.fasta'
    with open(fasta_file_name, 'w') as fasta_file:
        for piece in locus.assembly_pieces:
            fasta_file.write('>' + assembly.name + '_' + piece.get_header() + '\n')
            fasta_file.write(add_line_breaks_to_sequence(piece.get_sequence(), 60))
    return fasta_file_name


def add_line_breaks_to_sequence(sequence, length):
    """Wraps sequences to the defined length. All resulting sequences end in a line break."""
    seq_with_breaks = ''
    while len(sequence) > length:
        seq_with_breaks += sequence[:length] + '\n'
        sequence = sequence[length:]
    if sequence:
        seq_with_breaks += sequence
        seq_with_breaks += '\n'
    return seq_with_breaks


def line_iterator(string_with_line_breaks):
    """Iterates over a string containing line breaks, one line at a time."""
    prev_newline = -1
    while True:
        next_newline = string_with_line_breaks.find('\n', prev_newline + 1)
        if next_newline < 0:
            break
        yield string_with_line_breaks[prev_newline + 1:next_newline]
        prev_newline = next_newline


def load_locus_references(fasta, ref_genes, ref_types):
    """Returns a dictionary of: key = locus name, value = Locus object"""
    return {seq[0]: Locus(seq[0], ref_types[seq[0]], seq[1], ref_genes[seq[0]])
            for seq in load_fasta(fasta)}


def load_fasta(filename):
    """Returns the names and sequences for the given fasta file."""
    fasta_seqs = []
    if get_compression_type(filename) == 'gz':
        open_func = gzip.open
    else:  # plain text
        open_func = open
    with open_func(filename, 'rt') as fasta_file:
        name = ''
        sequence = ''
        for line in fasta_file:
            line = line.strip()
            if not line:
                continue
            if line[0] == '>':  # Header line = start of new contig
                if name:
                    fasta_seqs.append((name.split()[0], sequence))
                    sequence = ''
                name = line[1:]
            else:
                sequence += line
        if name:
            fasta_seqs.append((name.split()[0], sequence))
    return fasta_seqs


def good_start_and_end(start, end, length, allowed_margin):
    """
    Checks whether the given start and end coordinates are within the accepted margin of error.
    """
    good_start = start <= allowed_margin
    good_end = end >= length - allowed_margin
    start_before_end = start < end
    return good_start and good_end and start_before_end


def get_gene_info_string(gene_hit_list):
    """Returns a single comma-delimited string summarising the gene hits in the given list."""
    return ';'.join([x.qseqid + ',' + str(x.pident) + '%' for x in gene_hit_list])


def is_contig_name_spades_format(contig_name):
    """
    Returns whether or not the contig name appears to be in the SPAdes/Velvet format.
    Example: NODE_5_length_150905_cov_4.42519
    """
    contig_name_parts = contig_name.split('_')
    return len(contig_name_parts) > 5 and contig_name_parts[0] == 'NODE' and \
        contig_name_parts[2] == 'length' and contig_name_parts[4] == 'cov'


def get_nice_contig_name(contig_name):
    """
    For a contig with a SPAdes/Velvet format, this function returns a simplified string that is
    just NODE_XX where XX is the contig number.
    For any other format, this function trims off everything following the first whitespace.
    """
    if is_contig_name_spades_format(contig_name):
        return 'NODE_' + contig_name.split('_')[1]
    else:
        return contig_name.split()[0]


class BlastHit(object):
    """
    Stores the BLAST hit output mostly verbatim. However, it does convert the BLAST ranges
    (1-based, inclusive end) to Python ranges (0-based, exclusive end).
    """
    def __init__(self, hit_string):
        parts = hit_string.split('\t')
        self.qseqid = parts[0]
        self.sseqid = parts[1]
        self.qstart = int(parts[2]) - 1
        self.qend = int(parts[3])
        self.sstart = int(parts[4])
        self.send = int(parts[5])
        if self.sstart <= self.send:
            self.strand = '+'
        else:
            self.sstart, self.send = self.send, self.sstart
            self.strand = '-'
        self.sstart -= 1
        self.evalue = float(parts[6])
        self.bitscore = float(parts[7])
        self.length = int(parts[8])
        self.pident = float(parts[9])
        self.query_cov = 100.0 * len(parts[11]) / float(parts[10])
        self.sseq = parts[12]

    def __repr__(self):
        return self.qseqid + ', ' + self.get_contig_details_string() + ', ' + \
               self.get_coverage_details_string() + ', ' + self.get_identity_details_string()

    def get_contig_details_string(self):
        """Returns a string describing the hit's range and strand in the contig."""
        return 'Contig: ' + get_nice_contig_name(self.sseqid) + ' (' + str(self.sstart) + '-' + \
               str(self.send) + ', ' + self.strand + ' strand)'

    def get_coverage_string(self):
        return '%.2f' % self.query_cov + '%'

    def get_coverage_details_string(self, extra_space=False):
        first_part = 'Cov: '
        second_part = self.get_coverage_string()
        if len(second_part) == 6 and extra_space:
            first_part += ' '
        return first_part + second_part

    def get_identity_string(self):
        return '%.2f' % self.pident + '%'

    def get_identity_details_string(self, extra_space=False):
        first_part = 'ID: '
        second_part = self.get_identity_string()
        if len(second_part) == 6 and extra_space:
            first_part += ' '
        return first_part + second_part

    def get_aligned_string(self, spacing_1, spacing_2, spacing_3, cov_space, id_space):
        """Returns a string describing the hit with spaced parts for alignment."""
        aligned_string = self.qseqid + '  '
        aligned_string = aligned_string.ljust(spacing_1)
        aligned_string += self.get_contig_details_string() + '  '
        aligned_string = aligned_string.ljust(spacing_2)
        aligned_string += self.get_coverage_details_string(cov_space) + '  '
        aligned_string = aligned_string.ljust(spacing_3)
        aligned_string += self.get_identity_details_string(id_space)
        return aligned_string

    def get_assembly_piece(self, assembly):
        """Returns the piece of the assembly which corresponds to this BLAST hit."""
        return AssemblyPiece(assembly, self.sseqid, self.sstart, self.send, self.strand, [self])

    def get_query_range(self):
        """Produces an IntRange object for the hit query."""
        return IntRange([(self.qstart, self.qend)])

    def in_assembly_pieces(self, assembly_pieces):
        """
        Returns True if the hit is in (or at least overlaps with) any of the given assembly pieces.
        """
        for piece in assembly_pieces:
            if piece.overlaps(self.sseqid, self.sstart, self.send):
                return True
        return False

    def get_blast_result_json_dict(self, assembly):
        blast_results = OrderedDict()
        blast_results['Coverage'] = self.get_coverage_string()
        blast_results['Identity'] = self.get_identity_string()
        blast_results['Contig name'] = self.sseqid
        blast_results['Contig start position'] = self.sstart
        blast_results['Contig end position'] = self.send
        blast_results['Contig strand'] = self.strand
        blast_results['Bit score'] = self.bitscore
        blast_results['E-value'] = self.evalue
        return blast_results


class GeneBlastHit(BlastHit):
    """This class adds a few gene-specific things to the BlastHit class."""
    def __init__(self, hit_string):
        BlastHit.__init__(self, hit_string)
        self.over_identity_threshold = False

    def conflicts(self, other):
        """
        Returns whether or not this hit conflicts with the other hit.
        A conflict is defined as the hits overlapping by 50% or more of the shortest hit's length.
        A hit is not considered to conflict with itself.
        """
        if self is other:
            return False
        if self.sseqid != other.sseqid:
            return False
        max_start = max(self.sstart, other.sstart)
        min_end = min(self.send, other.send)
        if max_start < min_end:
            overlap = min_end - max_start
        else:
            overlap = 0
        min_length = min(self.send - self.sstart, other.send - other.sstart)
        frac_overlap = overlap / min_length
        return frac_overlap > 0.5

    def get_blast_result_json_dict(self, assembly):
        blast_results = super(GeneBlastHit, self).get_blast_result_json_dict(assembly)
        nuc_seq = assembly.contigs[self.sseqid][self.sstart:self.send]
        if self.strand == '-':
            nuc_seq = reverse_complement(nuc_seq)
        blast_results['Nucleotide length'] = len(nuc_seq)
        blast_results['Protein length'] = len(self.sseq)
        blast_results['Nucleotide sequence'] = nuc_seq
        blast_results['Protein sequence'] = self.sseq
        return blast_results

    def get_match_confidence(self):
        cov = self.query_cov
        ident = self.pident
        if cov == 100.0 and ident >= 99.0:
            confidence = 'Very high'
        elif cov >= 99.0 and ident >= 95.0:
            confidence = 'High'
        elif cov >= 97.0 and ident >= 95.0:
            confidence = 'Good'
        elif cov >= 95.0 and ident >= 85.0:
            confidence = 'Low'
        else:
            confidence = 'None'
        return confidence


class TypeGeneBlastHit(BlastHit):
    """This class adds a couple type gene-specific things to the BlastHit class."""
    def __init__(self, hit_string):
        BlastHit.__init__(self, hit_string)
        try:
            name_parts = self.qseqid.split('__')
            self.gene_name = name_parts[1]
            self.allele_number = int(name_parts[2])
        except (IndexError, ValueError):
            self.gene_name = ''
            self.allele_number = 0

    def get_blast_result_json_dict(self, assembly):
        blast_results = OrderedDict()
        blast_results['Coverage'] = self.get_coverage_string()
        blast_results['Identity'] = self.get_identity_string()
        blast_results['Assembly piece name'] = self.sseqid
        blast_results['Assembly piece start position'] = self.sstart
        blast_results['Assembly piece end position'] = self.send
        blast_results['Assembly piece strand'] = self.strand
        blast_results['Bit score'] = self.bitscore
        blast_results['E-value'] = self.evalue
        blast_results['Length'] = len(self.sseq)
        blast_results['Sequence'] = self.sseq
        return blast_results


class Locus(object):
    def __init__(self, name, type_name, seq, genes):
        self.name = name
        self.type = type_name
        self.seq = seq
        self.genes = genes
        self.gene_names = [x.full_name for x in genes]
        self.blast_hits = []
        self.hit_ranges = IntRange()
        self.assembly_pieces = []
        self.identity = 0.0
        self.expected_hits_inside_locus = []
        self.missing_expected_genes = []
        self.expected_hits_outside_locus = []
        self.other_hits_inside_locus = []
        self.other_hits_outside_locus = []

    def __repr__(self):
        return 'Locus ' + self.name

    def get_length(self):
        """Returns the locus sequence length."""
        return len(self.seq)

    def add_blast_hit(self, hit):
        """Adds a BLAST hit and updates the hit ranges."""
        self.blast_hits.append(hit)
        self.hit_ranges.add_range(hit.qstart, hit.qend)

    def get_mean_blast_hit_identity(self):
        """Returns the mean identity (weighted by hit length) for all BLAST hits in the locus."""
        identity_sum = 0.0
        length_sum = 0
        for hit in self.blast_hits:
            length_sum += hit.length
            identity_sum += hit.length * hit.pident
        if identity_sum == 0.0:
            return 0.0
        else:
            return identity_sum / length_sum

    def clear(self):
        """
        Clears everything in the Locus object relevant to a particular assembly - gets it ready
        for the next assembly.
        """
        self.blast_hits = []
        self.hit_ranges = IntRange()
        self.assembly_pieces = []
        self.identity = 0.0
        self.expected_hits_inside_locus = []
        self.missing_expected_genes = []
        self.expected_hits_outside_locus = []
        self.other_hits_inside_locus = []
        self.other_hits_outside_locus = []

    def get_coverage(self):
        """Returns the % of this locus which is covered by BLAST hits in the given assembly."""
        try:
            return 100.0 * self.hit_ranges.get_total_length() / len(self.seq)
        except ZeroDivisionError:
            return 0.0
    
    def get_coverage_string(self):
        return '%.2f' % self.get_coverage() + '%'

    def get_identity_string(self):
        return '%.2f' % self.identity + '%'
    
    def clean_up_blast_hits(self):
        """
        This function removes unnecessary BLAST hits from self.blast_hits.
        For each BLAST hit, we keep it if it offers new parts of the locus. If, on the other
        hand, it lies entirely within an existing hit (in locus positions), we ignore it. Since
        we first sort the BLAST hits longest to shortest, this strategy will prioritise long hits
        over short ones.
        """
        self.blast_hits.sort(key=lambda x: x.length, reverse=True)
        kept_hits = []
        range_so_far = IntRange()
        for hit in self.blast_hits:
            hit_range = hit.get_query_range()
            if not range_so_far.contains(hit_range):
                range_so_far.merge_in_range(hit_range)
                kept_hits.append(hit)
        self.blast_hits = kept_hits

    def get_match_uncertainty_chars(self):
        """
        Returns the character code which indicates uncertainty with how this locus was found in
        the current assembly.
        '?' means the locus was found in multiple discontinuous assembly pieces.
        '-' means that one or more expected genes were missing.
        '+' means that one or more additional genes were found in the locus assembly parts.
        '*' means that at least one of the expected genes in the locus is low identity.
        """
        uncertainty_chars = ''
        if len(self.assembly_pieces) > 1:
            uncertainty_chars += '?'
        if self.missing_expected_genes:
            uncertainty_chars += '-'
        if self.other_hits_inside_locus:
            uncertainty_chars += '+'
        if not all([x.over_identity_threshold for x in self.expected_hits_inside_locus]):
            uncertainty_chars += '*'
        return uncertainty_chars

    def get_length_discrepancy(self):
        """
        Returns an integer of the base discrepancy between the locus in the assembly and the
        reference locus sequence.
        E.g. if the assembly match was 5 bases shorter than the reference, this returns -5.
        This function only applies to cases where the locus was found in a single piece. In
        other cases, it returns None.
        """
        if len(self.assembly_pieces) != 1:
            return None
        only_piece = self.assembly_pieces[0]
        a_start = only_piece.start
        a_end = only_piece.end
        start = only_piece.earliest_hit_coordinate()
        end = only_piece.latest_hit_coordinate()
        expected_length = end - start
        actual_length = a_end - a_start
        return actual_length - expected_length

    def get_length_discrepancy_string(self):
        """
        Returns the length discrepancy, not as an integer but as a string with a sign and units.
        """
        length_discrepancy = self.get_length_discrepancy()
        if length_discrepancy is None:
            return 'n/a'
        length_discrepancy_string = str(length_discrepancy) + ' bp'
        if length_discrepancy > 0:
            length_discrepancy_string = '+' + length_discrepancy_string
        return length_discrepancy_string

    def get_earliest_and_latest_pieces(self):
        """
        Returns the AssemblyPiece with the earliest coordinate (closest to the locus start) and
        the AssemblyPiece with the latest coordinate (closest to the locus end)
        """
        earliest_piece = sorted(self.assembly_pieces, key=lambda x: x.earliest_hit_coordinate())[0]
        latest_piece = sorted(self.assembly_pieces, key=lambda x: x.latest_hit_coordinate())[-1]
        same_contig_and_strand = earliest_piece.contig_name == latest_piece.contig_name and \
            earliest_piece.strand == latest_piece.strand

        # Even though the pieces are on the same contig and strand, we still need to check whether
        # the earliest piece comes before the latest piece in that contig.
        if same_contig_and_strand:
            if earliest_piece.strand == '+' and earliest_piece.start > latest_piece.end:
                same_contig_and_strand = False
            elif earliest_piece.strand == '-' and earliest_piece.start < latest_piece.end:
                same_contig_and_strand = False
        return earliest_piece, latest_piece, same_contig_and_strand

    def get_match_confidence(self):
        """
        These confidence thresholds match those specified in the paper supp. text, with the
        addition of two new top-level categories: perfect and very high
        """
        single_piece = len(self.assembly_pieces) == 1
        cov = self.get_coverage()
        ident = self.identity
        missing = len(self.missing_expected_genes)
        extra = len(self.other_hits_inside_locus)
        if single_piece and cov == 100.0 and ident == 100.0 and missing == 0 and extra == 0 and \
                self.get_length_discrepancy() == 0:
            confidence = 'Perfect'
        elif single_piece and cov >= 99.0 and ident >= 95.0 and missing == 0 and extra == 0:
            confidence = 'Very high'
        elif single_piece and cov >= 99.0 and missing <= 3 and extra == 0:
            confidence = 'High'
        elif (single_piece or cov >= 95.0) and missing <= 3 and extra <= 1:
            confidence = 'Good'
        elif (single_piece or cov >= 90.0) and missing <= 3 and extra <= 2:
            confidence = 'Low'
        else:
            confidence = 'None'
        return confidence


class Assembly(object):
    def __init__(self, fasta_file):
        """Loads in an assembly and builds a BLAST database for it (if necessary)."""
        self.fasta = fasta_file
        self.name = fasta_file
        self.name = strip_extensions(fasta_file)
        self.contigs = {x[0]: x[1] for x in load_fasta(fasta_file)}  # key = name, value = sequence
        self.blast_db_already_present = self.blast_database_exists()
        if not self.blast_db_already_present:
            makeblastdb(self.fasta)

    def __del__(self):
        if not self.blast_db_already_present:
            clean_blast_db(self.fasta)

    def __repr__(self):
        return self.name

    def blast_database_exists(self):
        """Returns whether or not a BLAST database already exists for this assembly."""
        return os.path.isfile(self.fasta + '.nin') and os.path.isfile(self.fasta + '.nhr') and \
            os.path.isfile(self.fasta + '.nsq')


class AssemblyPiece(object):
    """
    This class describes a piece of an assembly: which contig the piece is on and what the range
    is.
    """
    def __init__(self, assembly, contig_name, contig_start, contig_end, strand, blast_hits=None):
        self.assembly = assembly
        self.contig_name = contig_name
        self.start = contig_start
        self.end = contig_end
        self.strand = strand
        if not blast_hits:
            blast_hits = []
        self.blast_hits = blast_hits

    def __repr__(self):
        return self.assembly.name + '_' + self.get_header()

    def get_header(self):
        """Returns a descriptive string for the FASTA header when saving this piece to file."""
        nice_contig_name = get_nice_contig_name(self.contig_name)
        return nice_contig_name + '_' + str(self.start + 1) + '_to_' + str(self.end) + \
            '_' + self.strand + '_strand'

    def get_nice_header(self):
        """Like get_header, but uses spaces/parentheses instead of underscores for readability."""
        nice_contig_name = get_nice_contig_name(self.contig_name)
        return nice_contig_name + ' (' + str(self.start + 1) + '-' + str(self.end) + \
            ', ' + self.strand + ' strand)'

    def get_bandage_range(self):
        """Returns the assembly piece in a Bandage path format."""
        if is_contig_name_spades_format(self.contig_name):
            name = self.contig_name.split('_')[1]
        else:
            name = self.contig_name.split()[0]
        return '(' + str(self.start + 1) + ') ' + name + '+ (' + str(self.end) + ')'

    def get_sequence(self):
        """Returns the DNA sequence for this piece of the assembly."""
        seq = self.assembly.contigs[self.contig_name][self.start:self.end]
        if self.strand == '+':
            return seq
        else:
            return reverse_complement(seq)

    def get_length(self):
        """Returns the sequence length for this piece."""
        return self.end - self.start

    def get_sequence_short(self):
        """Returns a shortened format of the sequence"""
        seq = self.get_sequence()
        length = len(seq)
        if len(seq) > 9:
            seq = seq[0:6] + '...' + seq[-6:]
        return seq + ' (' + str(length) + ' bp)'

    def combine(self, other):
        """
        If this assembly piece and the other can be combined, this function returns the combined
        piece. If they can't, it returns None.
        To be able to combine, pieces must be overlapping or directly adjacent and on the same
        strand.
        """
        if self.contig_name != other.contig_name or self.strand != other.strand:
            return None
        combined = IntRange([(self.start, self.end)])
        combined.add_range(other.start, other.end)
        if len(combined.ranges) == 1:
            new_start, new_end = combined.ranges[0]
            combined_hits = self.blast_hits + other.blast_hits
            return AssemblyPiece(self.assembly, self.contig_name, new_start, new_end, self.strand,
                                 combined_hits)
        else:
            return None

    def overlaps(self, contig_name, start, end):
        """Returns whether this assembly piece overlaps with the given parameters."""
        return self.contig_name == contig_name and self.start < end and start < self.end

    def earliest_hit_coordinate(self):
        """Returns the lowest query start coordinate in the BLAST hits."""
        if not self.blast_hits:
            return None
        return sorted([x.qstart for x in self.blast_hits])[0]

    def latest_hit_coordinate(self):
        """Returns the highest query end coordinate in the BLAST hits."""
        if not self.blast_hits:
            return None
        return sorted([x.qend for x in self.blast_hits])[-1]


class IntRange(object):
    """
    This class contains one or more integer ranges. Overlapping ranges will be merged together.
    It stores its ranges in a Python-like fashion where the last value in each range is
    exclusive.
    """
    def __init__(self, ranges=None):
        if not ranges:
            ranges = []
        self.ranges = []
        self.add_ranges(ranges)
        self.simplify()

    def __repr__(self):
        return str(self.ranges)

    def add_range(self, start, end):
        """Adds a single range."""
        self.add_ranges([(start, end)])

    def add_ranges(self, ranges):
        """Adds multiple ranges (list of tuples)."""
        self.ranges += ranges
        self.simplify()

    def merge_in_range(self, other):
        """Merges the other IntRange object into this one."""
        self.add_ranges(other.ranges)

    def get_total_length(self):
        """Returns the number of integers in the ranges."""
        return sum([x[1] - x[0] for x in self.ranges])

    def simplify(self):
        """Collapses overlapping ranges together."""
        fixed_ranges = []
        for int_range in self.ranges:
            if int_range[0] > int_range[1]:
                fixed_ranges.append((int_range[1], int_range[0]))
            elif int_range[0] < int_range[1]:
                fixed_ranges.append(int_range)
        starts_ends = [(x[0], 1) for x in fixed_ranges]
        starts_ends += [(x[1], -1) for x in fixed_ranges]
        starts_ends.sort(key=lambda z: z[0])
        current_sum = 0
        cumulative_sum = []
        for start_end in starts_ends:
            current_sum += start_end[1]
            cumulative_sum.append((start_end[0], current_sum))
        prev_depth = 0
        start = 0
        combined = []
        for pos, depth in cumulative_sum:
            if prev_depth == 0:
                start = pos
            elif depth == 0:
                combined.append((start, pos))
            prev_depth = depth
        self.ranges = combined

    def contains(self, other):
        """Returns True if the other IntRange is entirely contained within this IntRange."""
        for other_range in other.ranges:
            other_start, other_end = other_range
            contained = False
            for this_range in self.ranges:
                this_start, this_end = this_range
                if other_start >= this_start and other_end <= this_end:
                    contained = True
                    break
            if not contained:
                return False
        return True


class Gene(object):
    """This class prepares and stores a gene taken from the input Genbank file."""
    def __init__(self, locus_name, num, feature, k_locus_seq):
        self.locus_name = locus_name
        self.feature = feature
        gene_num_string = str(num).zfill(2)
        self.full_name = locus_name + '_' + gene_num_string
        if 'gene' in feature.qualifiers:
            self.gene_name = feature.qualifiers['gene'][0]
            self.full_name += '_' + self.gene_name
        else:
            self.gene_name = None
        if 'product' in feature.qualifiers:
            self.product = feature.qualifiers['product'][0]
        else:
            self.product = None
        if 'EC_number' in feature.qualifiers:
            self.ec_number = feature.qualifiers['EC_number'][0]
        else:
            self.ec_number = None
        self.nuc_seq = feature.extract(k_locus_seq)
        self.prot_seq = str(self.nuc_seq.translate(table=11))
        self.nuc_seq = str(self.nuc_seq)

    def get_fasta(self):
        """
        Returns the FASTA version of this gene: a header line followed by sequence lines (of
        protein sequence) ending in a line break.
        """
        return '>' + self.full_name + '\n' + \
               add_line_breaks_to_sequence(self.prot_seq, 60)

    def get_reference_info_json_dict(self):
        reference_dict = OrderedDict()
        if self.gene_name:
            reference_dict['Gene'] = self.gene_name
        if self.product:
            reference_dict['Product'] = self.product
        if self.ec_number:
            reference_dict['EC number'] = self.ec_number
        reference_dict['Nucleotide length'] = len(self.nuc_seq)
        reference_dict['Protein length'] = len(self.prot_seq)
        reference_dict['Nucleotide sequence'] = self.nuc_seq
        reference_dict['Protein sequence'] = self.prot_seq
        return reference_dict


def convert_bytes_to_str(bytes_or_str):
    """
    This function is for both Python2 and Python3. If the input is a str, it just returns that
    same str. If not, it assumes its bytes (and we're in Python3) and it returns it as a str.
    """
    if isinstance(bytes_or_str, str):
        return bytes_or_str
    else:
        return bytes_or_str.decode()


def makeblastdb(fasta):
    """
    If the FASTA file is not compressed, this just runs makeblastdb. If it is compressed,
    it runs gunzip and pipes into makeblastdb.
    """
    if ' ' in fasta:
        print('WARNING: spaces in file paths may not work in BLAST', file=sys.stderr)
    if get_compression_type(fasta) == 'gz':
        gunzip_command = ['gunzip', '-c', fasta]
        makeblastdb_command = ['makeblastdb', '-dbtype', 'nucl', '-in', '-', '-out', fasta,
                               '-title', fasta]
        gunzip = subprocess.Popen(gunzip_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        makeblastdb_process = subprocess.Popen(makeblastdb_command, stdin=gunzip.stdout,
                                               stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        gunzip.stdout.close()
        _, err = makeblastdb_process.communicate()
    else:  # plain text
        makeblastdb_command = ['makeblastdb', '-dbtype', 'nucl', '-in', fasta]
        makeblastdb_process = subprocess.Popen(makeblastdb_command, stdout=subprocess.PIPE,
                                               stderr=subprocess.PIPE)
        _, err = makeblastdb_process.communicate()
    if err:
        quit_with_error('makeblastdb encountered an error:\n' + convert_bytes_to_str(err))


def remove_if_exists(filename):
    try:
        os.remove(filename)
    except OSError:
        pass


def clean_blast_db(fasta):
    remove_if_exists(fasta + '.nsq')
    remove_if_exists(fasta + '.nhr')
    remove_if_exists(fasta + '.nin')


def get_compression_type(filename):
    """
    Attempts to guess the compression (if any) on a file using the first few bytes.
    http://stackoverflow.com/questions/13044562
    """
    magic_dict = {'gz': (b'\x1f', b'\x8b', b'\x08'),
                  'bz2': (b'\x42', b'\x5a', b'\x68'),
                  'zip': (b'\x50', b'\x4b', b'\x03', b'\x04')}
    max_len = max(len(x) for x in magic_dict)

    unknown_file = open(filename, 'rb')
    file_start = unknown_file.read(max_len)
    unknown_file.close()
    compression_type = 'plain'
    for filetype, magic_bytes in magic_dict.items():
        if file_start.startswith(magic_bytes):
            compression_type = filetype
    if compression_type == 'bz2':
        quit_with_error('cannot use bzip2 format - use gzip instead')
    if compression_type == 'zip':
        quit_with_error('cannot use zip format - use gzip instead')
    return compression_type


def strip_extensions(filename):
    """
    This function removes extensions from a file name. Examples:
      assembly.fasta -> assembly
      assembly.fa.gz -> assembly
      genome.assembly.fa.gz -> genome.assembly
    """
    name = os.path.basename(filename)
    if name.lower().endswith('.gz'):
        name = name[:-3]
    if name.lower().endswith('.fa'):
        name = name[:-3]
    elif name.lower().endswith('.fna'):
        name = name[:-4]
    elif name.lower().endswith('.fas'):
        name = name[:-4]
    elif name.lower().endswith('.fasta'):
        name = name[:-6]
    return name


if __name__ == '__main__':
    main()