#!/usr/bin/python

import re
import math

def loadTree(filename):
#This function reads a file in Newick format and returns our simple
#dictionary-based data structure for trees.
#Uses parseTree() to interpret input string.
	f = open(filename,'r')
	exp = f.read()
	f.close

	exp = exp.replace(';','') #ignore trailing (or other) semi-colons
	exp = re.sub(r'\s+','',exp) #ignore whitespace
	exp = re.sub(r'\n','',exp)
	exp = re.sub(r'\[.*\]','',exp) #ignore bracketed clauses

	return parseTree(exp)


def makeLeaf(name,length):
#This function returns a tree structure corresponding to a single leaf
	return { 'left':None, 'right':None, 'name':name, 'length':length }


def parseTree(exp):
#This function takes a string in Newick format and parses it recursively.
#Each clause is expected to be of the general form (a:x,b:y):z
#where a and b may be subtrees in the same format.

	if ',' not in exp: #if this is a leaf
		name, length = exp.split(':')
		length = float(length)
		return makeLeaf(name,length)

	#uses the regular expression features of Python
	distPattern = re.compile(r'(?P<tree>\(.+\))\:(?P<length>[e\-\d\.]+)$')
	m = distPattern.search(exp)
	length = 0
	if m:			
		if m.group('length'): length = float( m.group('length') )
		exp = m.group('tree')
	if length == '': length = 0

	#Use the parseExp function to return the left and right hand sides
	#of the expression (e.g., a & b from (a,b))
	lhs, rhs = parseExp(exp)

	#Now package into a tree data structure
	return { "name":"internal",
			 "left":parseTree(lhs), #recursively set up subtrees
			 "right":parseTree(rhs),
			 "length":length }


def parseExp(exp):
	#Parse expression of type "a,b" into a & b where a and b can be
	#Newick formatted strings.
	chars = list(exp[1:-1]) #get rid of surrounding parens, and split to list
	count = 0
	lhs = True #boolean to distinguish left and right side of the comma
	left = '' #store the left substring
	right = '' #store the right substring

	#a little tricky to deal with nested parens correctly
	for c in chars:
		if c == '(':
			count += 1
		if c == ')':
			count -= 1
		if (c == ',') and (count == 0) and (lhs) :
			lhs = False
			continue

		if lhs: left += c
		else: right += c

	#Now return the left and right substrings
	return [ left, right ]


def readAlignment(filename):
#read an alignment in Phylip (sort of) format
#and return a dictionary of sequences
	f = open(filename,"r")
	taxa = None
	columns = None
	sequences = {}
	
	for line in f:
		if taxa == None:
			#first line tells how many seqs and cols
			taxa, columns = line.split()
		else:
			words = line.split() #lines can have whitespace
			name = words[0] #we'll require to start with taxon name
			seq = ''.join(words[1:])
			if name in sequences.keys():
				sequences[name] += seq
			else:
				sequences[name] = seq

	return sequences


def initTree(tree,aln):
#insert the sequences from an alignment into the 'data' field of tree
#stores a list of chars at each node, instead of a string
#(only changes leaves)
	if (tree['name'] != 'internal'):
		chars = aln[tree['name']]
		tree['data'] = [ [chars[x]] for x in range(0,len(chars)) ]
		return
	initTree(tree['left'],aln)
	initTree(tree['right'],aln)


def downPass(tree):
#this function returns the number of mutations necessary
#to explain the sequence data, given the tree topology
#and assigns sequences to the internal nodes
#
#...insert code here...
#

tree = loadTree('tree4.txt')
aln = readAlignment('seqs.aln')

initTree(tree,aln)
print 'Number of mutations: %d' % (downPass(tree))
print 'Sequence at root: %s' % ','.join( [ '/'.join(x) for x in tree['data'] ] )
