import scipy
from compilator.compilator import *
from misc.calculus import *
from copy import *
from sympy import *

class Solver:
    
    model = None
    init_values = dict()
    parameters_values = dict()
    compilator = None
    
    def __init__(self,model,parameters_values={},init_values={},covariances=None,lookup=False):
        self.model = model
        self.covariances = covariances
        
        if lookup:
            frame = inspect.currentframe().f_back
            try:
                self.init_values = frame.f_globals["init_values"]
                self.parameters_values = frame.f_globals["parameters_values"]
            except:
                print("Parameters values or init values were not defined correctly")
            finally:
                del frame
        else:
            self.init_values = init_values
            self.parameters_values = parameters_values
        
        for v in model.variables:
            if not (v in self.init_values):
                self.init_values[v] = 0
        for v in model.exovariables:
            if not (v in self.init_values):
                self.init_values[v] = 0
        for p in model.parameters:
            if not (p in self.parameters_values):
                self.parameters_values[p] = 0
        self.parm_dict = None
        self.init_dict = None            
        self.compilator = Compilator(model)
        
        return None
    
    def get_parameters_dict(self):
        if self.parm_dict == None:
            parm_dict = self.parameters_values.copy()
            self.parm_dict = solve_triangular_system(parm_dict)
        return(self.parm_dict)   
    
    def get_init_dict(self):
        if self.init_dict == None:
            parm_dict = self.get_parameters_dict().copy() 
            init_dict = self.init_values.copy()
            init_dict.update(parm_dict)
            self.init_dict = solve_triangular_system(init_dict)
        #.update(self.init_values)
        return(self.init_dict)
    
    
    def find_steady_state(self,return_dict=True):
        init_dict = self.get_init_dict()

        parm = zeros(len(self.model.parameters))
        for i in range(len(self.model.parameters)):
            parm[i] = init_dict[self.model.parameters[i]]
        
        x = zeros(len(self.model.variables))
        for i in range(len(self.model.variables)):
            x[i] = init_dict[self.model.variables[i]]
        #x = mat(x).T
        px = x.shape[0]
        
        y = zeros(len(self.model.exovariables))
        for i in range(len(self.model.exovariables)):
            y[i] = init_dict[self.model.exovariables[i]]
        
        #s = zeros(len(self.model.shocks))
        #ps = s.shape[0]
        
        z0 = array(scipy.r_[x,y])
        ce = self.model.current_equations()
        f = self.compilator.get_gaps_static_python() # return C code for static steady state
     
        l = lambda z: f(z[:px],z[px:],parm)
        test = l(z0)
        if (dot(test.T,test) < 0.00001):
            res = z0
        else:
            res = scipy.optimize.nonlin.anderson(l,z0,alpha=0.001)
        if return_dict:
            r_dict = dict()
            for i in range(len(self.model.variables)):
                r_dict[ self.model.variables[i] ] = res[i]
            for i in range(len(self.model.exovariables)):
                r_dict[ self.model.exovariables[i] ] = res[i + px]
            return r_dict
        else: 
            return(array(res))
        
    def toxml(self):
        xmlsolver = ET.Element("solver")
        xmlparmvalues = ET.SubElement(xmlsolver, "initparams")
        xmlinitvalues = ET.SubElement(xmlsolver, "initvalues")
        for p in self.parameters_values:
            xmlp = ET.Element('parameter',{'name':str(p),'value':sympy_to_dynare_string(self.parameters_values[p])})
            xmlparmvalues.append(xmlp)
        for v in self.init_values:
            xmlv = ET.Element('variable',{'name':str(p),'value':sympy_to_dynare_string(self.init_values[v])})
            xmlinitvalues.append(xmlv)
        return xmlsolver
        
    def toxmlmodfile(self):
        xmlmodfile = ET.Element("xmlmodfile")
        xmlmodfile.append(self.model.toxml())
        xmlmodfile.append(self.toxml())
        return(ET.tostring(xmlmodfile))
        
    
    def export_to_modfile(self,options={},append=""):
        
        default = {'steady':False,'check':False,'dest':'dynare','order':1}#default options
        for o in default:
            if not o in options:
                options[o] = default[o]

        init_block = ["// Generated by DareDare"]
        init_block.append( "// Model basename : " + self.model.fname)  
        init_block.append( "" )
        init_block.append( "var " + str.join(",", map(str,self.model.variables) ) + ";")
        init_block.append( "" )
        init_block.append( "varexo " + str.join(",", map(str,self.model.shocks) ) + ";" )
        init_block.append( "" )
        init_block.append( "parameters " + str.join(",",map(str,self.model.parameters)) + ";" )
        for p in self.model.parameters:
            init_block.append(p.name + " = " + str(self.parameters_values[p]) + ";")
        
        model_block = []
        model_block.append( "model;" )
        for eq in self.model.equations:
            s = str(eq).replace("==","=")
            s = s.replace("**","^")
            s = s.replace("_b1","(-1)") # this should allow lags more than 1
            s = s.replace("_f1","(+1)") # this should allow lags more than 1
            model_block.append(s + ";")
        model_block.append( "end;" )
        
        if options['dest'] == 'dynare':        
            shocks_block = []
            shocks_block.append("shocks;")
            cov_shape = self.covariances.shape
            for i in range(cov_shape[0]):
                for j in range(i,cov_shape[1]):
                    cov = self.covariances[i,j] 
                    if cov != 0:
                        if i == j:
                            shocks_block.append("var " + str(self.model.shocks[i]) + " = " + str(cov) + " ;")
                        else:
                            shocks_block.append("var " + str(self.model.shocks[i]) + "," + str(self.model.shocks[j]) + " = " + str(cov) + " ;")
            shocks_block.append("end;")
        elif options['dest'] == 'dynare++':
            shocks_block = []
            shocks_block.append("vcov = [")
            cov_shape = self.covariances.shape
            shocks_matrix = []
            for i in range(cov_shape[0]):
                shocks_matrix.append(" ".join(map(lambda s: str(s), self.covariances[i,:])))
            shocks_block.append( ";\n".join(shocks_matrix) + "\n" )                
            shocks_block.append( "];" )


        initval_block = []
        initval_block.append("initval;")
        for v in self.model.variables:
            if v in self.init_values:
                initval_block.append(str(v) + " = " + str(self.init_values[v]).replace('**','^') + ";")
        initval_block.append("end;")

        model_text = ""
        model_text += "\n".join(init_block)
        model_text += "\n\n" + "\n".join(model_block)
        model_text += "\n\n" + "\n".join(shocks_block)
        model_text += "\n\n" + "\n".join(initval_block)

        if options['steady'] and options['dest']=='dynare':
            model_text += "\n\nsteady;\n"
            
        if options['check'] and options['dest']=='dynare':
            model_text += "\n\ncheck;\n"
            
        if options['dest'] == 'dynare++':
            model_text += "\n\norder = " + str(options['order']) + ";\n"

        model_text += "\n" + append

        return(model_text)
    
    def get_jacobian(self,ss_dict):
        parm_dict = self.get_parameters_dict() 
        mjac = (self.model.get_jacobian())
        
        for v in (self.model.variables + self.model.exovariables):
            mjac = mjac.subs(v,ss_dict[v])
            mjac = mjac.subs(v(+1),ss_dict[v])
            mjac = mjac.subs(v(-1),ss_dict[v])
        for d in self.model.parameters:
            mjac = mjac.subs(d,parm_dict[d])

        nv = len(self.model.variables)
        return(mjac)
    
    def solve_with_uhligs_toolkit(self,print_matrices=False):
        from extern.uhlig import UhligToolkit
        
        r_dict = self.find_steady_state()
        ss_dict = dict()
        for v in r_dict:
            ss_dict[v] = r_dict[v]
            ss_dict[v(+1)] = r_dict[v]
            ss_dict[v(-1)] = r_dict[v]
        ss_dict.update(self.get_parameters_dict())
        
        endo_eq = []
        exo_eq =[]
        for eq in self.model.equations:
            is_endo = False
            for v in eq.atoms():
                if isinstance(v,Variable) and (v(-v.lag) in self.model.variables):
                    is_endo = True
                    break
            if is_endo:
                endo_eq.append(eq)
            else:
                exo_eq.append(eq)
         
        fall = Matrix( map(lambda eq: eq.gap(),self.model.equations) ).T
        f = Matrix( map(lambda eq: eq.gap(),endo_eq) ).T
        if len(exo_eq)>0:
            g = Matrix( map(lambda eq: eq.gap(),exo_eq) ).T
        else:
            g = Matrix([[]])
        
        variables_f = map(lambda v: v(+1), self.model.variables)
        variables_p = self.model.variables
        variables_b = map(lambda v: v(-1), self.model.variables)
        shocks_f = map(lambda v: v(+1), self.model.shocks)
        shocks_p = self.model.shocks
        shocks_b = map(lambda v: v(-1), self.model.exovariables)
        exovariables_f = map(lambda v: v(+1), self.model.exovariables)
        exovariables_p = self.model.exovariables
        exovariables_b = map(lambda v: v(-1), self.model.exovariables)
        
        F = f.jacobian(variables_f)
        for v in ss_dict:
            F = F.subs(v,ss_dict[v])
        F = mat(F)
        F = scipy.matrix(F,dtype='float64')
        
        
        G = f.jacobian(variables_p)
        for v in ss_dict:
            G = G.subs(v,ss_dict[v])
        G = mat(G)
        G = scipy.matrix(G,dtype='float64')
        
        H = f.jacobian(variables_b)
        for v in ss_dict:
            H = H.subs(v,ss_dict[v])
        H = mat(H)
        H = scipy.matrix(H,dtype='float64')
        
        L = f.jacobian(exovariables_f)
        for v in ss_dict:
            L = L.subs(v,ss_dict[v])
        L = mat(L)
        L = scipy.matrix(L,dtype='float64')
        
        M = f.jacobian(exovariables_p)
        for v in ss_dict:
            M = M.subs(v,ss_dict[v])
        M = mat(M)
        M = scipy.matrix(M,dtype='float64')

        Ntest = g.jacobian(exovariables_f)
        for v in ss_dict:
            Ntest = Ntest.subs(v,ss_dict[v])
        Ntest = mat(Ntest)
        Ntest = scipy.matrix(Ntest,dtype='float64')
        
        # tests whether exogenous equation is correctly specified
        try:
            assert((Ntest == -1).all())
        except:
            raise "incorrect exogenous equation specification : coefficient of future exogenous must be -1"
        
        N = g.jacobian(exovariables_p)
        for v in ss_dict:
            N = N.subs(v,ss_dict[v])
        N = mat(N)
        N = scipy.matrix(N,dtype='float64')
        
        Stest = g.jacobian(shocks_f)
        for v in ss_dict:
            Stest = Stest.subs(v,ss_dict[v])
        Stest = mat(Stest)
        Stest = scipy.matrix(Stest,dtype='float64')
        try:
            assert((Stest == 1).all())
        except:
            raise "incorrect exogenous equation specification : coefficient of future shocks must be 1"
        
        solver = UhligToolkit(F,G,H,L,M,N,None,print_matrices=print_matrices)
        return(solver)
    
    def print_uhlig_solution(self,solver):
        rows = map(str, self.model.variables)
        columnsPP = map(lambda v: str(v(-1)), self.model.variables)
        columnsQQ = map(str, self.model.exovariables)
        output = "\t| " +  str.join("\t\t\t",columnsPP) + "\t| " + str.join("\t\t",columnsQQ) + "\n"
        for i in range(len(rows)):
            line = rows[i] + "\t|" + str.join("\t\t",map(str,solver.PP[i,:].tolist()[0])) + "\t|" + str.join("\t\t",map(str,solver.QQ[i,:].tolist()[0])) + "\n"
            output = output + line
        
        rows = map(lambda v: str(v(+1)), self.model.exovariables)
        columnsNN = map(str, self.model.exovariables)
        columns = map(lambda v: str(v(+1)), self.model.shocks)
        output2 = "\t| " +  str.join("\t\t\t",columnsNN) + "\t| " + str.join("\t\t",columns) + "\n"
        for i in range(len(rows)):
            line = rows[i] + "\t|" + str.join("\t\t",map(str,solver.NN[i,:].tolist()[0])) + "\t|" + str(1) + "\n"
            output2 = output2 + line
            
        print(output)
        print(output2) 