from xml.etree import ElementTree as ET
from symbolic import *
import copy
import inspect
import misc.calculus
import re 

class Model:
    
    fname = None # Model basename
    variables = []
    exovariables = []
    shocks = []
    covariances = []
    parameters = []
    equations = []
    init_values = dict()
    variables_order = dict()
    commands = None # keep some instructions to treat the model, not sure how to do it

    def __init__(self, fname, lookup=False):
        self.fname = fname
        if lookup:
            frame = inspect.currentframe().f_back
            try:
                self.variables = frame.f_globals["variables"]
                self.exovariables = frame.f_globals["exovariables"]
                self.shocks = frame.f_globals["shocks"]
                self.parameters = frame.f_globals["parameters"]
                self.equations = frame.f_globals["equations"]
            except:
                print("Variables, exovariables, shocks or parameters were not defined correctly")
            finally:
                del frame

        return(None)
    
    def copy(self):
        c = Model(self.fname)
        c.variables = copy.copy(self.variables)
        c.exovariables = copy.copy(self.exovariables)
        c.covariances = copy.copy(self.covariances)
        c.shocks = copy.copy(self.shocks)
        c.parameters = copy.copy(self.parameters)
        c.equations = copy.copy(self.equations)
        c.init_values = copy.copy(self.init_values)
        c.commands = copy.copy(self.commands)
        return(c)
    
    def check_all(self,print_info=False,print_eq_info=False,model_type ="uhlig"):
        for i in range(len(self.equations)):
            self.equations[i].n = i
            self.set_info(self.equations[i])
        if print_eq_info:
            for eq in self.equations:
                print("Eq (" + str(eq.n) +") : " + str(eq.info) )
        info = {
                "n_variables" : len(self.variables),
                "n_exovariables" : len(self.exovariables),
                "n_variables" : len(self.variables),
                "n_shocks" : len(self.shocks),
                "n_equations" : len(self.equations)
        }
        self.info = info
        if print_info:
            print("Model check : " + self.fname)
            for k in info:
                print("\t"+k+"\t\t"+str(info[k]))
                
    def check_type(self,model_type="uhlig"):
        '''
        Runs a series of assertion to verify model compliance to uhlig/dynare conventions.
        '''
        if model_type == "uhlig":
            self.check_type_uhlig()
        elif model_type == "dynare":
            self.check_type_dynare()   
            
    def check_type_uhlig(self):
        try:
            assert(self.info['n_equations'] == self.info['n_variables'] + self.info['n_exovariables'])
        except:
            raise "Number of equations must equal number of variables and exovariables"
        for eq in self.equations:
            if eq.info['exogenous']:
                all_shocks = eq.info['all_shocks']
                shocks_lags = map(lambda x: x.lag,all_shocks)
                try:
                    assert( min(shocks_lags) ==1 and max(shocks_lags) == 1)
                except:
                    raise "In exogenous equations, shocks must only appear with lag 1"
                all_exo = eq.info['all_exo']
                exo_lags = map(lambda x: x.lag,all_exo)
                try:
                    assert( min(exo_lags) >=0 and max(exo_lags) <= 1)
                except:
                    raise "In exogenous equations, exogenous variables must only appear with lag 0 and 1"
            else:
                all_exo = eq.info['all_exo']
                exo_lags = map(lambda x: x.lag,all_exo)
                try:
                    assert( min(exo_lags) >=0 and max(exo_lags) <= 1)
                except:
                    raise "In endogenous equations, exogenous variables must only appear with lag 0 and 1"
                
                all_endo = eq.info['all_endo']
                endo_lags = map(lambda x: x.lag,all_endo)
                try:
                    assert( min(endo_lags) >=-1 and max(endo_lags) <= 1)
                except:
                    raise "In endogenous equations, endogenous variables must only appear with lag  between -1 and 1"         
        
    def check_type_dynare(self):
        try: assert(self.info['n_equations'] == self.info['n_variables'] + self.info['n_exovariables'])
        except: raise "Number of equations must equal number of variables and exovariables"
        
        for eq in self.equations:
            all_shocks = eq.info['all_shocks']
            shocks_lags = map(lambda x: x.lag,all_shocks)
            if len(shocks_lags)>0:
                try: assert(max(shocks_lags)<=0)
                except: raise('Exogenous variable can not appear with lead strictly greater than 0')
                        
        return(None)

    def set_info(self,eq):
        '''
        Computes all informations concerning one equation (leads , lags, ...)
        '''
        info = {}
        all_vars = [] # will contain all variables
        all_endo = []
        all_exo = []
        all_shocks = []
        for a in eq.atoms():
            if isinstance(a,Variable):
                all_vars.append(a)
                if a(-a.lag) in self.variables:
                    all_endo.append(a)
                elif a(-a.lag) in self.exovariables:
                    all_exo.append(a)
                elif a(-a.lag) in self.shocks:
                    all_shocks.append(a)
                else:
                    raise("something impossible happened")
                    
        all_vars_c = map(lambda v: v(-v.lag),all_vars)
        lags = map(lambda v: v.lag, all_vars)
        # These information don't depend on the system of equations
        info['max_lag'] = max(lags)
        info['min_lag'] = min(lags)
        info['expected'] = (max(lags) > 0)
        # These information depend on the model
        info['exogenous'] =set(all_vars_c).issubset(set(self.exovariables).union(self.shocks)) # & max(lags)<=0
        info['all_endo'] = all_endo
        info['all_exo'] = all_exo
        info['all_shocks'] = all_shocks
        eq.info = info

    def map_function_to_expression(self,f,expr):
        if len(expr._args)==0 :
            return( f(expr) )
        else:
            l = list(expr._args)
            args = []
            for a in l:
                args.append(self.map_function_to_expression(f,a))
            return( expr.__class__(* args) )
        
    def current_equations(self):
        c_eqs = []
        def f(x):
            if x in self.shocks:
                return(0)
            elif x.__class__ == Variable:
                return(x(-x.lag))
            else:
                return(x)
        for eq in self.equations:
            n_eq = self.map_function_to_expression(f,eq)
            #n_eq.is_endogenous = eq.is_endogenous
            c_eqs.append(n_eq)
        return(c_eqs)
    
    def future_variables(self):
        # returns [f_vars, f_eq, f_eq_n]
        # f_vars : list of variables with lag > 1
        # f_eq : list of equations containing future variables
        # f_eq_n : list of indices of equations containing future variables
        
        f_eq_n = [] # indices of future equations
        f_eq = [] # future equations (same order)
        f_vars = set([]) # future variables
        for i in range(len(self.equations)):
            eq = self.equations[i]
            all_atoms = eq.atoms()
            f_var = []
            for a in all_atoms:
                if (a.__class__ == Variable) and (a(-a.lag) in self.variables):
                    if a.lag > 0:
                        f_var.append(a)
            if len(f_var)>0:
                f_eq_n.append(i)
                f_eq.append(eq)
                f_vars = f_vars.union(f_var)
        f_vars = list(f_vars)
        return([f_vars,f_eq,f_eq_n])

    def get_jacobian(self):
        '''
        returns the jacobian of the equations with respect to
        all_variables_f,all_variables,all_variables_b
        where all_variables == variables + exovariables
        '''
        #print(self.variables)
        all_variables = self.variables + self.exovariables
        vec = map(lambda v: v(1), all_variables) + all_variables + map(lambda v: v(-1), all_variables)
        vec = Matrix([vec])
        #print vec
        f = Matrix( map(lambda eq: eq.gap(),self.equations) )
        return f.T.jacobian(vec)

    def toxml(self):
        xmlmodel  = ET.Element("model")
        xmlvariables = ET.SubElement(xmlmodel, "variables")
        xmlexovariables = ET.SubElement(xmlmodel, "exovariables")
        xmlshocks = ET.SubElement(xmlmodel, "shocks")
        xmlparameters = ET.SubElement(xmlmodel, "parameters")
        xmlequations = ET.SubElement(xmlmodel, "equations")
        for v in self.variables:
            xmlvariables.append(v.toxml())
        for v in self.exovariables:
            xmlexovariables.append(v.toxml())
        for v in self.shocks:
            xmlshocks.append(v.toxml())
        for p in self.parameters:
            xmlparameters.append(p.toxml())
        for eq in self.equations:
            xmlequations.append(eq.toxml())
            
        return xmlmodel

    def to_uhlig_model(self):
        umodel = self.copy()
        umodel.fname = umodel.fname + "_uhlig"
        if umodel.exovariables:
            raise("The model already has exogenous variables defined")
        endo_shocks =  {}    # list of shocks appearing in endogenous equations with corresponding exogenous variable
        for i in range(len(umodel.equations)):
            eq = umodel.equations[i]
            if (not eq.info['exogenous']):
                all_shocks = eq.info['all_shocks']
                for s in all_shocks:
                    v = Variable('z_' + s.basename)   # the exogenous variable such that : z(+1) = epsilon
                    if not s(-s.lag - 1) in endo_shocks:
                        endo_shocks[s(-s.lag - 1)] = v
                    if not s in endo_shocks:
                        endo_shocks[s] = v(s.lag + 1)
                umodel.equations[i] = eq.subs_dict(endo_shocks)
        all_shocks_c = set(map(lambda v: v(-v.lag), endo_shocks.keys()))
        #all_z_c = map(lambda v: endo_shocks[v],all_shocks_c)
        #umodel.exovariables = umodel.exovariables + all_z
        for s in all_shocks_c:
            umodel.exovariables.append(endo_shocks[s].p())
            umodel.equations.append(Equation(endo_shocks[s],s))
        return umodel   

    def process_portfolio_model(self):
        
        pf_model = self.copy()    # I should work on a copy on the model not on itself !
        pf_model.fname = pf_model.fname + "_pf"

        [f_vars,f_eq,f_eq_n] = pf_model.future_variables()
        
        f_vars_c = map(lambda v: v(-v.lag), f_vars) # this returns current dates of future variables

        nf_eq_n = list(set(range(len(self.equations))).difference(f_eq_n)) # indices of non forward equations
        nf_eq = map(lambda (n): self.equations[n], nf_eq_n) # non forward equations (differences)
        
        nf_eq = map(lambda eq: eq.gap(), nf_eq) # we remove equal signs before we can calculate jacobian
        f_eq = map(lambda eq: eq.gap(), f_eq)

        A = Matrix(len(nf_eq),len(self.exovariables), lambda i,j : Variable("A_"+str(i+1)+str(j+1)))
        for a in A:
            pf_model.variables.append(a)
        
        # Thanks to non forward equations, forward variables variations are
        # approximated by A * shocks (first order approximation)
        f = Matrix(nf_eq)        
        f_X = f.jacobian(f_vars_c)
        for v in pf_model.exovariables:
            f_X = f_X.subs(v,0)
        f_E = f.jacobian(Matrix(self.exovariables))
        for v in pf_model.exovariables:
            f_E = f_E.subs(v,0)
        f_ = f_X * A + f_E
        
        # Then future variables are substituted in forward equations
        # which are taylorized twice with respect to shocks
        # Cross products of shocks are replaced by covariances
        s_f_eq = [] # will contain future equations with substitutions
        for i in range(len(f_eq)):
            eq = f_eq[i]
            for k in range(len(f_vars)):
                v = f_vars[k]
                dv = (A[k,:] * Matrix(self.exovariables))[0,0]
                eq = eq.subs(v,v+dv)
            seq = Calculus.second_order_taylorisation(eq, self.exovariables, self.covariances)
            s_f_eq.append(Equality(0,seq))

        for eq in f_:
            pf_model.equations.append(Equation(0,eq,"Another moment condition") )

        for i in range(len(f_eq)):
            pf_model.equations[f_eq_n[i]] = s_f_eq[i]
            
        return(pf_model)