#!/usr/bin/env python3

import os, sys
import argparse
import time

#######################################
############# FUNCTIONS ###############
#######################################

def ucsc_download(input_var,output):
	print (time.strftime("%H:%M:%S")+" Downloading from UCSC: "+input_var)
	chromFa_url = "http://hgdownload.soe.ucsc.edu/goldenPath/"+input_var+"/bigZips/chromFa.tar.gz"
	assembly_url = "http://hgdownload.soe.ucsc.edu/goldenPath/"+input_var+"/bigZips/"+input_var+".fa.gz" 	

	chromFa_file = output+"/"+"chromFa.tar.gz"
	assembly_file = output+"/"+assembly_url.split("/")[-1]

	os.system ("wget -c "+chromFa_url+" -O"+chromFa_file+" 2>"+output+"/log.tmp;wait")

	log_file = open(output+"/log.tmp",'r').read()
	if "ERROR" in log_file:
		os.system("rm "+chromFa_file)
		os.system ("wget -c "+assembly_url+" -O "+assembly_file+" 2>/dev/null;wait")
		return assembly_file
	else:
		return chromFa_file

def file_download (input_var,output):
	print (time.strftime("%H:%M:%S")+" Downloading: "+input_var)

	assembly_file = output+"/"+(input_var.split("/")[-1])
	os.system("wget -c "+input_var+" -O "+assembly_file+" 2>/dev/null")

	return assembly_file

def uncompress_assembly(assembly_file,output):
	print (time.strftime("%H:%M:%S")+" Uncompressing: "+assembly_file) 
	
	if ".gz" in assembly_file:
		os.system("gzip -d "+assembly_file+" 2>/dev/null")
		assembly_file = assembly_file.split(".gz")[0]
	if ".bz2" in assembly_file:
		os.system("bzip2 -d "+assembly_file+" 2>/dev/null")
		assembly_file = assembly_file.split(".bz2")[0]
	if ".tar" in assembly_file:
		os.system("tar -xf "+assembly_file+" 2>/dev/null")
		os.system("cat *.fa >"+output+"/tmp.fa")
		os.system("rm ch*fa")
		os.system("rm "+output+"/*.tar")
	else:
		os.system("mv "+assembly_file+" "+output+"/tmp.fa")

	assembly_file = output+"/tmp.fa"
	return assembly_file

def fasta_split(assembly_file,output,regex):
	print (time.strftime("%H:%M:%S")+" Splitting into canonical and non canonical")
	file_fasta = open(assembly_file,'r').read()
	file_fasta = file_fasta.split(">")
	
	canonical = []
	alternative = []
	for chrom in file_fasta:
		if "^" in regex:
			regex_v2 = regex.split("^")[-1]
			if regex_v2 in chrom:
				canonical.append(chrom)
			elif chrom!="":
                        	alternative.append(chrom)
		else:
			if regex in chrom:
				alternative.append(chrom)
			elif chrom!="":
				canonical.append(chrom)
	if len(alternative)>1:
		os.system("mkdir "+output+"/"+label+"_nonCanonical")
	os.system("mkdir "+output+"/"+label+"_canonical")
	os.system("samtools faidx "+output+"/tmp.fa")

	for chrom in canonical:
		canonical_file = open(output+"/"+label+"_canonical.fa",'a')
		canonical_file.write(">"+chrom)
		canonical_file.close()

		chrom_name=chrom.split("\n")[0]			
		chrom_file = open(output+"/"+label+"_canonical/"+chrom_name,'a')
		chrom_file.write(">"+chrom)
		chrom_file.close()

	for chrom in alternative:
		alternative_file = open(output+"/"+label+"_nonCanonical.fa",'a')
		alternative_file.write(">"+chrom)
		alternative_file.close()

		chrom_name=chrom.split("\n")[0]
		chrom_file = open(output+"/"+label+"_nonCanonical/"+chrom_name,'a')
		chrom_file.write(">"+chrom)
		chrom_file.close()


#######################################
############# ARGUMENTS ###############
#######################################

parser = argparse.ArgumentParser(description='PREPARE ASSEMBLY v1.0.0')
parser.add_argument ('--input','-i',type=str,help='Path or URL to multi-FASTA file (.fa, .fa.gz, .fa.tar.gz, .bz2)', default="")
parser.add_argument ('--ucsc','-u',type=str,help='UCSC assembly ID (e.g. hg38)', default="")
parser.add_argument ('--label','-l',type=str,help='Label for output files', default='assembly')
parser.add_argument ('--outdir','-o',type=str,help='Path to output directory, default current directory',default='.') 
parser.add_argument ('--regex','-r',type=str,help='REGEX to filter canonical or non-canonical chromosomes. REGEX to filter canonical chromosomes must be preceded by ^. (Default: "_" for UCSC assemblies, for Ensembl use ^dna:chromosome).', default="_")
args = parser.parse_args()

#####################################
######### 1. CHECK INPUT ############
#####################################

if len(sys.argv)==1:
	parser.print_help()
	sys.exit(1)

if args.input!="":
	if args.ucsc!="":
		print ("###########################")
		print (" Choose -u or -i, not both ")
		print ("###########################\n")
		parser.print_help()
		sys.exit(1)
	else:
		input_var = args.input
		ucsc = False
elif args.ucsc!="":
	input_var = args.ucsc
	label = args.ucsc
	ucsc = True
else:
	print ("###########################")
	print ("  Please choose -u or -i   ")
	print ("###########################\n")
	parser.print_help()
	sys.exit(1)

label = args.label
outdir = args.outdir

if not os.path.isdir(outdir):
	os.system("mkdir "+outdir)

regex = args.regex

#####################################
######## 2. DOWNLOAD INPUT ##########
#####################################

if ucsc==True:
	assembly_file = ucsc_download(input_var,outdir)
	print (time.strftime("%H:%M:%S")+" Downloaded: "+assembly_file)
else:
	if "http:/" in input_var or "ftp:/" in input_var:
		assembly_file = file_download(input_var,outdir)
	else:
		assembly_file=outdir+"/"+input_var


#####################################
########## 3. UNCOMPRESS ############
#####################################

assembly_file = uncompress_assembly(assembly_file,outdir)
print (time.strftime("%H:%M:%S")+" Uncompressed: "+assembly_file)

#####################################
########## 4. FASTA SPLIT ###########
#####################################

fasta_split(assembly_file,outdir,regex)
print (time.strftime("%H:%M:%S")+" Finished")

#####################################
############# 5. RM TMP #############
#####################################

os.system("rm "+outdir+"/*tmp*")
