
def quick(l):
    l = list(l)
    exchanges = []
    n = len(l)
    i = 0
    while i < n:
        if (i>0) and (l[i]<l[i-1]):
            new_place = min([u for u in range(i) if l[u]>l[i]  ])
            tt = l[new_place]
            l[new_place] = l[i]
            l[i] = tt
            exchanges.append( ( new_place,i ) )
        else:
            i += 1
    return [l,exchanges]

class Sum(list):
    '''A sum is a list of terms with integer coefficients.  Terms are arbitrary object with the only restriction that they must be hashable.'''

    def __init__(self, elts, coeffs=None):
        if coeffs == None:
            coeffs = [1 for e in elts]
        else:
            coeffs = coeffs
        self.coeffs = []
        for i,e in enumerate(elts):
            if isinstance(e,Sum):
                for j,p in enumerate(e):
                    self.append(p)
                    self.coeffs.append( coeffs[i] * e.coeffs[j] )
            else:
                self.append(e)
                self.coeffs.append(coeffs[i])


    def diff(self,i,n=1):
        if n == 1:
            return Sum(  [ e.diff(i) for e in self ] , self.coeffs ).collect()
        else:
            return self.diff(i,n-1).diff(i)

    def collect(self):
        new_elements = []
        new_coeffs = []
        for i,e in enumerate(self):
            if e in new_elements:
                k = new_elements.index(e)
                new_coeffs[k] += self.coeffs[i]
            else:
                new_elements.append(e)
                new_coeffs.append(self.coeffs[i])
        return Sum(new_elements,new_coeffs)



    def _latex_(self):
        terms = []
        for i,e in enumerate(self):
            if self.coeffs[i] == 1:
                terms.append( e._latex_() )
            else:
                terms.append( str(self.coeffs[i]) + ' ' + e._latex_() )
        s = str.join(' + ', terms)
        return s

class Bracket(list):
    '''This class represent a multilinear function. Its content is stored as a list of tuples.
    For instance, "[ ('a','e'),('e',) ]" would represent "f^{2}[V_ae,V_e]".
    Derivatives with respect to any variable 'x' is computed by added successively 'x' to each of the tuples, and to add a new tuple containing only 'x'. Hence the result is a list of brackets. Each bracket of the resulting sum is also reordered in lexical order w.r.t. to its tuples.
    '''

    def diff(self,i):
        new_brackets = [ ]
        for k in range(len(self)):
            n = Bracket(self)
            n[k] = n[k] + (i,)
            n.reorder()
            new_brackets.append(n)
        b = Bracket( tuple( self + [(i,)] ) )
        b.reorder()
        new_brackets.append( b )
        return Sum(new_brackets).collect()

    def reorder(self):
        for i,e in enumerate(self):
            self[i] = tuple( sorted(e) )
        self.sort()
        return None

    def is_ascending(self):
        l = []
        [l.extend(list(a)) for a in self]
        for i in range(1,len(l)):
            if l[i] < l[i-1]:
                return False
        return True

    def order_swaps(self):
        l = []
        [l.extend(list(a)) for a in self]
        [junk,swaps] = quick(l)
        return swaps



    def __hash__(self):
        return tuple(self).__hash__()

    def flat_indices(self):
        l = []
        for e in self:
            l.extend( list(e) )
        return l

    def _latex_(self):
        if True:
            n = len(self)
            s = "f^{{({n})}}[{content}]"
            content = []
            for el in self:
                index = str.join(' ',  [ str(e) for e in el ]   )
                content.append( 'V_{{{index}}}'.format(index=index) )
            content = str.join(',',content)
            s = s.format(n=n,content=content)
            s = s.replace('s','\\sigma')
            return s

class BracketG(Bracket):

    def diff(self,i):
        new_brackets = [ ]
        for k in range(len(self)):
            if self[k] not in ( ('t',), ('y',), ('z',) ):
                n = BracketG(self)
                n[k] = n[k] + (i,)
                n.reorder()
                new_brackets.append(n)
        if i not in ('p','s'):
            b = BracketG( tuple( self + [(i,)] ) )
            b.reorder()
            new_brackets.append( b )
        elif i == 'p':
            b = BracketG( tuple( self + [(i,)] ) )
            b1 = BracketG( tuple( self + [('t',)] ) )
            b.reorder()
            b1.reorder()
            new_brackets.extend( [b,b1] )
        elif i == 's':
            b = BracketG( tuple( self + [(i,)] ) )
            b1 = BracketG( tuple( self + [('y',)] ) )
            b2 = BracketG( tuple( self + [('z',)] ) )
            b.reorder()
            b1.reorder()
            b2.reorder()
            new_brackets.append( b )
            new_brackets.append( b1 )
            new_brackets.append( b2 )
        return Sum(new_brackets)

    def _latex_(self,type='f'):
        if True:
            n = len(self)
            s = 'g^{{( {n} )}}_{{ {ind} }}[ {content} ]'
            ind = []
            for i in self:
                if i == ('t',):
                    ind.append('p')
                elif i == ('y',):
                    ind.append('u')
                elif i == ('z',):
                    ind.append('\\sigma')
                else:
                    ind.append('a')
            ind = str.join('',ind)
            content = []
            for i in self:
                if i == ('t',):
                    content.append('I_p')
                elif i == ('y',):
                    content.append('q')
                elif i == ('z',):
                    content.append('I_s')
                else:
                    content.append( 'g_{{ {0} }}'.format( str.join('', i) ) )
            content = str.join(',',content)
            content = content.replace('s','\\sigma ').replace('q','\\epsilon')
            return s.format(n=n,ind=ind,content=content)
        
    def stoch_order(self):
        '''Counts the number of epsilons'''
        return sum( [ el.count('y') for el in self] )

def kill_s(sum_expr):
    '''Simplifies expressions using g_s=0.'''
    bras = []
    coeffs = []
    for i,expr in enumerate(sum_expr):
        coeff = sum_expr.coeffs[i]
        valid = True
        for e in expr:
            if e.count('s') == 1:
                valid = False
                pass
        if expr.count(('z',)) == 1:
            valid = False
        if valid:
            bras.append(expr)
            coeffs.append(coeff)
    return Sum(bras,coeffs)

def get_nondecreasing_indices(args,order):
    indices = [[(i,) for i in args]]
    for o in range(2,order+1):
        pindices = indices[-1]
        cindices = []
        for e in args:
            for p in pindices:
                if p[-1] <= e:
                    cindices += [ p + (e,) ]
        indices.append(cindices)
    return indices

def get_K(expr):
    resp =  Sum([ h for h in expr if len(h) > 1])
    other = Sum([ h for h in expr if len(h) == 1])
    return [resp,other]

def get_L(expr):
    n = sum([len(el) for el in expr[0]])
    special = BracketG( [('a',)] ) * n
    resp =  Sum([ h for h in expr if (len(h)>1) and (h!=special) ])
    other = Sum([ h for h in expr if (len(h) == 1) or h==special])
    return [resp,other]

def compute_derivatives(n,vars,):
    eqs_f = [ Sum([Bracket([])]) ]
    eqs_f.append( [ eqs_f[-1].diff(v) for v in vars ] )
    eqs_g = [ Sum([BracketG([])]) ]
    eqs_g.append( [ eqs_g[-1].diff(v) for v in vars ] )
    ndt = [[],[(v,) for v in vars]]
    for i in range(n-1):
        new_eqs_f = []
        new_eqs_g = []
        new_ndt = []
        o_eqs_f = eqs_f[-1]
        o_eqs_g = eqs_g[-1]
        o_ndt = ndt[-1]
        for i,e in enumerate(o_eqs_f):
            ind = o_ndt[i]
            for v in vars:
                if v>= ind[-1]:
                    new_eqs_f.append( e.diff(v).collect() )
        for i,e in enumerate(o_eqs_g):
            ind = o_ndt[i]
            for v in vars:
                if v>= ind[-1]:
                    new_eqs_g.append( e.diff(v).collect() )
                    new_ndt.append( ind + (v,) )
        eqs_f.append(new_eqs_f)
        eqs_g.append(new_eqs_g)
        ndt.append(new_ndt)
    return [eqs_f,eqs_g,ndt]

###############################################
###############################################
#                                             #
# this part of the code is specific to python #
#                                             #
###############################################
###############################################

def compile_bracket(bra):
    n = len(bra.flat_indices())    
    if isinstance(bra,BracketG):
        content = []
        for el in bra:
            if el == ('t',):
                content += ['I_p']
            elif el == ('y',):
                content  += ['I_e']
            elif el == ('z',):
                content += ['I_s']
            else:
                content +=  ["g_{0}".format( str.join('',el) )]
        indices = []
        for el in bra:
            if el == ('t',):
                indices += 'p'
            elif el == ('y',):
                indices += 'e'
            elif el == ('z',):
                indices += 's'
            else:
                indices += 'a'
        indices = str.join('',indices)
        txt = 'mdot(g_{0},[{1}])'.format( indices, str.join(",",content))
    elif isinstance(bra,Bracket):
        content = [ "V_{0}".format( str.join("",e) ) for e in bra]
        txt = 'mdot(f_{0},[{1}])'.format(n, str.join(",",content))
    if not bra.is_ascending():
        l = []
        for t in bra:
            for e in t:
                l.append(e)
        [nl, exchanges] = quick(l)
        swp = [".swapaxes({0},{1})".format(e[0],e[1]) for e in exchanges]
        txt += str.join('',swp)
    return txt

def compile_expectation(expr):
    if isinstance(expr, BracketG):
        bra = expr
        ns = bra.stoch_order()
        if ns == 0:
            return compile_bracket(bra)
        elif ns == 1:
            return 'np.zeros((n_v,{0}))'.format(str.join(',', ['n_'+i for i in bra.flat_indices() if i != 'y'] ))
        inds = bra.flat_indices()
        eps_inds = [str(n) for n,i in enumerate(inds) if inds[n] == 'y' ]
        sig_indices = [str(i) for i in tuple(range(ns))]
        resp = compile_bracket(bra)        
        resp = 'np.tensordot( {0}, Sigma_{1}, axes=( ({2}), ({3}) ) )'.format(resp, n+1, str.join(',', eps_inds), str.join(',',sig_indices) )
        return resp
    if isinstance(expr,Sum):
        resp = ''
        for n,term in enumerate(expr):
            c = expr.coeffs[n]
            if c == 1:
                resp += ' + {0}'.format( compile_expectation(term) )
            elif c >= 1:
                resp += ' + {0}*{1}'.format(c,compile_expectation(term) )
        resp = resp.lstrip(' +')

        return resp
        
def compile_sum(s):

    resp = ''
    for n,bra in enumerate(s):
        c = s.coeffs[n]
        if c == 1:
            resp += ' + {0}'.format(compile_bracket(bra))
        elif c >= 1:
            resp += ' + {0}*{1}'.format(c,compile_bracket(bra))
    resp = resp.lstrip(' +')
    return resp

def write_order(order,vars,eqs_f,eqs_g,ndt):
    code = "{br}Computing order {order}\n\n".format(br='\n#' + 10*'-' ,order=order)
    code += 'order = {0}\n'.format(order)
    for k in range(len(eqs_f[order])):
        eq_f = eqs_f[order][k]
        eq_g = eqs_g[order][k]
        inds = ndt[order][k]
        code += '\n#--- Computing derivatives {0}\n\n'.format(inds)
        [K_eq,K_other] = get_K(eq_f)
        K_name = 'K_'+str.join('',inds)
        lhs = K_name
        code += '{lhs} = {rhs}\n'.format(lhs=lhs,rhs=compile_sum(K_eq))
        [L_eq,L_other] = get_L(eq_g)
        L_name = 'L_'+str.join('',inds)
        lhs = L_name
        code += '{lhs} = {rhs}\n\n'.format(lhs=lhs,rhs=compile_sum(L_eq))
        if k == 0:
            code += '#We need to solve the infamous sylvester equation\n'
            code += '#A = f_a + sdot(f_d,g_a)\n'
            code += '#B = f_d\n'
            code += '#C = g_a\n'
            code += 'D = {K} + sdot(f_d,{L})\n'.format(K=K_name,L=L_name)
            code += 'g_{sub} = solve_sylvester(A,B,C,D)\n'.format(sub = str.join('',inds))
        else:
            code += '#We solve A*X + const = 0\n'
            code += 'const = sdot(f_d,{L}) + {K}\n'.format(L=L_name,K=K_name)
            code += 'g_{sub} = - sdot(A_inv, const)\n'.format(sub=str.join('',inds))
        if True:
            sizes = str.join(',', ['n_'+v for v in vars] )
            code += '\nif order < max_order:\n'
            code += '    Y = {L}{other}\n'.format(L = L_name, other = compile_sum(L_other) )
            code += '    Z = g_{sub}\n'.format(sub=str.join('',inds))
            code += '    V_{sub} = build_V(Y,Z,({sizes}))\n'.format(sub = str.join('',inds), sizes=sizes)
#            code += 'end\n'
    return code

def gen_code(max_order,vars):
    [eqs_f,eqs_g,ndt] = compute_derivatives(max_order,vars)

    code = ''
#    code = 'max_order = {0}\n'.format(max_order)
    for o in range(2,max_order+1):
        code += write_order(o,vars,eqs_f,eqs_g,ndt)
    return code

def sort_order(expr):
    orders = [term.stoch_order() for term in expr]
    print orders
    if len(orders) == 0:
        return []
    resp = []
    for ord in range(0,max(orders)+1):
        resp.append( [term for n,term in enumerate(expr) if orders[n] == ord] )
    return resp


if __name__ == '__main__':
    #code = gen_code(4,('a','e'))

    [eqs_f,eqs_g,ndt] = compute_derivatives(5,('a','e','s'))

    for order in range(len(eqs_g)):
        for eq in eqs_g[order]:
            print ''
            print eq._latex_()
            print compile_expectation( kill_s(eq) )
            #for a in ( sort_order( eq ) ):
            #    print [compile_expectation( el ) for el in a]
    
    #print code