#! /usr/bin/python

#################################################################################
# This program will do the following:
#
# 1. Read in fasta formatted contig file, created using IDBA-UD assembler (Peng et al. 2012, Bioinformatics v. 28)
# 2. Read in pileup-formatted alignment (via mpileup in samtools) and calculate coverage for each sequence
# 3. Read in phymmbl output
# 4. For each sequence in the fasta file, merge information from all three
# 
# Outputs are: 
# 1. Fasta-formatted file with coverage and phymmbl information contained in the header
# 2. Tab delimited file with coverage and phymmbl information for each contig
#
# Usage: python merge_phymmbl_coverage.py <fasta contig file> <mpileup output> <phymmbl output>
#
# **note** This script will currently only with those specific outputs from mpilup and idba_ud 
#
# Created 26aug2013 by Daniel S. Jones
# Please contact me at dsjones <at> umn.edu with any questions
# Disclaimer: This code was hastily written, and while not terribly elegant, it gets the job done
#################################################################################



import sys
import string
import re

####################
# FUNCTIONS
####################	

def fasta2dictionary(fastafile):
    fastadic = {} #empty dictionary for fasta file
    fastanameinfo = {} #empty dictionary for info from fasta description line

    n = 0     # number of sequences in file
    i = 0     # lines in file
    m = 0     # lines after each '>'
    
    for line in fastafile:

        line = line.rstrip() 
        if not line.rstrip():
            continue

        # The first line in the file gets its own 'if' statement
        if line[0] == '>' and i == 0:
            m = 0
            fullnameline = line[1:]
            namespl = fullnameline.split(' ')
            name = namespl[0]
            n += 1

        elif line[0] == '':
            pass
     
        elif line[0] != '>':
            if m == 0:
                seq = line[0:]
            else:
                seq2 = line[0:]
                seq = seq + seq2
            m += 1            # Count of lines that do not start with '>'

        elif line[0] == '>' and i > 0:
            fastadic[name]=seq     # fasta dictionary, link name and sequence
            fastanameinfo[name]=fullnameline     # fasta dictionary, link name and full description
            
            fullnameline = line[1:]
            namespl = fullnameline.split(' ')
            name = namespl[0]

            m = 0        # reset
            n += 1

        i += 1

    fastadic[name]=seq     
    fastanameinfo[name]=fullnameline
    
    print "number of fasta-formatted sequences in input:", n

    return fastadic, fastanameinfo
    # Creates dictionaries from fasta-formatted file input file, fasta sequences and full description linked to short names 


def mpile_coverage(mpile):
    mpile_nuccounts_dic = {} #raw nucleotide counts
    mpile_length_dic = {} #sequence length
    mpile_cov_dic = {}
    
    c = 0    #new sequence ID
    i = 0    #lines in file
    n = 0    #new seq IDs
    
    for line in mpile:
        linespl = line.split()
            #splits on all whitespace

        name = linespl[0]
        nuc1 = int(linespl[3])
        len1 = int(linespl[1])

        mpile_cov_dic[name] = 0
            #store names in a dic (with zeros)
        
        mpile_length_dic[name] = len1
            #record length by recording nucleotide position in the file, overwriting it each time

        if name not in mpile_nuccounts_dic:
            mpile_nuccounts_dic[name] = nuc1
            n += 1
        elif name in mpile_nuccounts_dic:
            mpile_nuccounts_dic[name] = mpile_nuccounts_dic[name] + nuc1
        #successively add the number of nucleotides at each position 
 
        i += 1
            
    print i, 'lines in mpileup file'
    print n, 'number of seqs in mpileup file'

    for item in mpile_cov_dic:
        a = float(mpile_nuccounts_dic[item])
        b = float(mpile_length_dic[item])   
        cov = a/b 
        mpile_cov_dic[item] = cov
        print 'item=', item, 'mpile_cov_dic', mpile_cov_dic[item]  
        # calculate coverage based on the nuc count and length 
            
    return mpile_cov_dic


def read_phymmbl(phymmbl):
    phymmbl_dic = {}
    for line in phymmbl:
        phymmblspl = line.split()
        name = phymmblspl[0]
        phymmbl_dic[name] = phymmblspl[1:]
            #associates split list with id
            
    return phymmbl_dic

def create_attribute_line(id, fastanameinfo, mpiledic, phymmbldic):
    att_line = id
    f = fastanameinfo[id]
    #f2 = f.split([' ','_'])
    f2 = re.split('\s|_',f) 
    att_line = att_line + '\t' + f2[3] + '\t' + f2[6]
    
    try:
        att_line = att_line + '\t' + str(mpiledic[id])
    except:
        att_line = att_line + '\t' + '0'
        print "WARNING: sequence",id,"missing from mpileup output"
        #some reads are missing from the mpileup file
    
    for item in phymmbldic[id]:
        att_line = att_line + '\t' + item
    
    return att_line



################################
# MAIN PROGRAM
################################	


try:
    inFile = sys.argv[1]
except:
    print "Error: input/output files not properly designated/n"
    print "Usage: python merge_phymmbl_coverage.py <fasta contig file> <mpileup output> <phymmbl output>"
    sys.exit()

try:
    inFile2 = sys.argv[2]
    no_mpile = 0
except:
    print "WARNING: no mpileup file provided"
    print "proceeding without..."
    print "warning: note that if you want to include phymmbl but no mpileup, you must still provide a bogus file name in place of the mpileup"    
    print "(I may fix this up at some point)"    
    print "Usage: python merge_phymmbl_coverage.py <fasta contig file> <mpileup output> <phymmbl output>"
    no_mpile = 1    

try:
    inFile3 = sys.argv[3]
    no_phymmbl = 0
except:
    print "WARNING: no phymmbl file provided"
    print "proceeding without..."
    print "warning: note that if you want to include phymmbl but no mpileup, you must still provide a bogus file name in place of the mpileup"    
    print "(I may clean this up at some point)"    
    print "Usage: python merge_phymmbl_coverage.py <fasta contig file> <mpileup output> <phymmbl output>"
    no_phymmbl = 1    

outFile = '%s.coverage_phymmbl.fna' % (sys.argv[1])
# fasta-formatted outfile 

outFile2 = '%s.coverage_phymmbl_tsv.out' % (sys.argv[1])
#  tab delimited outfile 

fastaout = open(outFile, 'w')
tabout = open(outFile2, 'w')
fastafile = open(inFile, 'r')

if no_mpile == 0:
    mpileupfile = open(inFile2, 'r')
if no_phymmbl == 0:   
    phymmblfile = open(inFile3, 'r')


######
# READING FASTA FORMATTED CONTIGS INTO A DICTIONARY
######

fdic = fasta2dictionary(fastafile)
fastaseq = fdic[0]
fastanameinfo = fdic[1]

    
######
# READ IN MPILEUP AND CALCULATE COVERAGE
######

if no_mpile == 0:
    mpile_cov_dic = mpile_coverage(mpileupfile)


######
# READ IN PHYMMBL
######

if no_phymmbl == 0:
    phymmbldic = read_phymmbl(phymmblfile)


######
# COMBINE ALL AND CREATE OUTPUT
######

def cmpbynum(a, b):
    x = a.split('_')[1]
    y = b.split('_')[1]
    return int(x) - int(y)
    # define custom 'compare' function to sort strings numerically

# writing  header line to the tab-deliminited output file
print >>tabout, "Sequence_name", '\t', "Length_bp", '\t', "NumReads_fromassembler", '\t', "coverage_BWAmapping", \
    '\t', "best_phymmBL_match", '\t', "score", '\t', "genus",'\t', "genus_conf", '\t', "family",'\t', "family_conf", \
    '\t', "order",'\t', "order_conf",'\t', "class",'\t', "class_conf",'\t', "phylum",'\t', "phylum_conf"

for item in sorted(fastaseq, cmp=cmpbynum):
    #use custom compare function
    line = create_attribute_line(item, fastanameinfo, mpile_cov_dic, phymmbldic)
    print >>fastaout, '>'+line, '\n', fastaseq[item]
    print >>tabout, line



