from sympy import *
from scipy import *
from model.symbolic import *
import extern.qz

def second_order_taylorisation(expr,shocks,covariances):
    h = hessian(expr,shocks)
    print(h)
    resp = expr
    for i in range(covariances.shape[0]):
        for j in range(covariances.shape[1]):
            resp = resp + h[i,j] * covariances[i,j] / 2
    for s in shocks:
        resp = resp.subs(s,0)
    return(resp) 

def solve_triangular_system(sdict,key=None,level=0):
    if level > len(sdict)+1:
        raise "system is not triangular !"
    if key==None:
        for k in sdict:
            sdict[k] = sympify(sdict[k])
            atoms = (sdict[k].atoms())
            unknown = []
            for a in atoms:
                if (a.__class__ in [Parameter,Variable]):
                    unknown.append(a)
            if len(unknown)>0 :
                for a in unknown:
                    solve_triangular_system(sdict,a,level=1)
                    sdict[k] = sdict[k].subs(a,sdict[a])
    else:
        atoms = sympify(sdict[key]).atoms()
        unknown = []
        for a in atoms:
            if a.__class__ == Parameter:
                unknown.append(a)
        if len(unknown)>0:
            for a in unknown:
                solve_triangular_system(sdict,a,level=level+1)
                sdict[key] = sdict[key].subs(a,sdict[a])
    return sdict
    
    
def construct_4_blocks_matrix(blocks):
    '''construct block matrix line by line
    input : [[A1,A2],[A3,A4]] 
    '''
    A1 = blocks[0][0]
    A2 = blocks[0][1]
    A3 = blocks[1][0]
    A4 = blocks[1][1]

    [p1,q1] = (blocks[0][0]).shape
    [p2,q2] = (blocks[0][1]).shape
    [p3,q3] = (blocks[1][0]).shape
    [p4,q4] = (blocks[1][1]).shape 
    if p1<>p2 or p3<>p4 or q1<>q3 or q2<>q4:
        raise('dimension mismatch')
    m = zeros((p1+p3,q1+q2))
    print(m)
    print('A1',A1)
    print(m[0:p1,0:q1])
    m[0:p1,0:q1] = A1
    m[0:p1,q1:(q1+q2)] = A2
    m[p1:p1+p3,0:q1] = A3
    m[p1:p1+p3,q1:(q1+q2)] = A4

    return(m)
    
def sympy_to_dynare_string(sexpr):
    s = str(sexpr)
    s = s.replace("==","=")
    s = s.replace("**","^")
    return(s)