/*
 * Copyright (C) 2009-2010 Sébastien Villemot <sebastien.villemot@ens.fr>
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

#include <cmath>
#include <cassert>
#include <iostream>

#include <gsl/gsl_blas.h>

#include "mat.h"

#include "MRGalSolution.hh"

MRGalSolution::MRGalSolution(const ModelSpec &modspec_arg) throw (UnsupportedSpec)
  : ModelSolution(modspec_arg), nr(1+2*modspec.n+(2*modspec.n-1)*modspec.n+2*modspec.n)
{
  const int s = modspec.spec_no();
  const int n = modspec.n;

  switch(s)
    {
    case 1:
    case 5:
      if (n != 2 && n != 4 && n != 6 && n != 8 && n != 10)
        throw UnsupportedSpec();
      break;
    case 2:
    case 6:
      if (n != 2 && n != 4 && n != 6 && n != 8)
        throw UnsupportedSpec();
      break;
    case 3:
    case 4:
    case 7:
    case 8:
      if (n != 2 && n != 4 && n != 6)
        throw UnsupportedSpec();
      break;
    default:
      throw UnsupportedSpec();
    }

  // Retrieve MATLAB arrays
  char filename[50];
  snprintf(filename, 50, "mrgal/PI_A%dN%d.mat", s, n);

  MATFile *mfp = matOpen(filename, "r");
  if (mfp == NULL)
    {
      std::cerr << "MRGalSolution::MRGalSolution: can't open " << filename << std::endl;
      exit(EXIT_FAILURE);
    }

  mxArray *amax_mx = matGetVariable(mfp, "amax");
  mxArray *amin_mx = matGetVariable(mfp, "amin");
  mxArray *kmax_mx = matGetVariable(mfp, "kmax");
  mxArray *kmin_mx = matGetVariable(mfp, "kmin");
  mxArray *regvec_mx; // Stores either regvec or regvec_lablr
  mxArray *alt_max_mx; // Stores either C_max or mumax
  mxArray *alt_min_mx; // Stores either C_min or mumin
  if (s == 1 || s == 2 || s == 5 || s == 6)
    {
      alt_max_mx = matGetVariable(mfp, "C_max");
      alt_min_mx = matGetVariable(mfp, "C_min");
      regvec_mx = matGetVariable(mfp, "regvec");
    }
  else
    {
      alt_max_mx = matGetVariable(mfp, "mumax");
      alt_min_mx = matGetVariable(mfp, "mumin");
      regvec_mx = matGetVariable(mfp, "regvec_labor");
    }

  char dr_coeff_name[15];
  snprintf(dr_coeff_name, 15, "dr_coeff_A%dN%d", s, n);
  mxArray *dr_coeff_mx = matGetVariable(mfp, dr_coeff_name);
  assert(alt_max_mx != NULL && alt_min_mx != NULL && amax_mx != NULL && amin_mx != NULL && kmax_mx != NULL && kmin_mx != NULL
         && regvec_mx != NULL && dr_coeff_mx != NULL);
  matClose(mfp);

  // Test array types and dimensions
  assert(mxGetClassID(amax_mx) == mxDOUBLE_CLASS && mxGetNumberOfDimensions(amax_mx) == 2
         && mxGetDimensions(amax_mx)[0] == 1 && mxGetDimensions(amax_mx)[1] == 1);
  assert(mxGetClassID(amin_mx) == mxDOUBLE_CLASS && mxGetNumberOfDimensions(amin_mx) == 2
         && mxGetDimensions(amin_mx)[0] == 1 && mxGetDimensions(amin_mx)[1] == 1);
  assert(mxGetClassID(kmax_mx) == mxDOUBLE_CLASS && mxGetNumberOfDimensions(kmax_mx) == 2
         && mxGetDimensions(kmax_mx)[0] == 1 && mxGetDimensions(kmax_mx)[1] == 1);
  assert(mxGetClassID(kmin_mx) == mxDOUBLE_CLASS && mxGetNumberOfDimensions(kmin_mx) == 2
         && mxGetDimensions(kmin_mx)[0] == 1 && mxGetDimensions(kmin_mx)[1] == 1);
  assert(mxGetClassID(alt_max_mx) == mxDOUBLE_CLASS && mxGetNumberOfDimensions(alt_max_mx) == 2
         && mxGetDimensions(alt_max_mx)[0] == 1 && mxGetDimensions(alt_max_mx)[1] == 1);
  assert(mxGetClassID(alt_min_mx) == mxDOUBLE_CLASS && mxGetNumberOfDimensions(alt_min_mx) == 2
         && mxGetDimensions(alt_min_mx)[0] == 1 && mxGetDimensions(alt_min_mx)[1] == 1);

  assert(mxGetClassID(regvec_mx) == mxDOUBLE_CLASS && mxGetNumberOfDimensions(regvec_mx) == 2);
  if (s == 1 || s == 2 || s == 5 || s == 6)
    assert(mxGetDimensions(regvec_mx)[0] == 9 && mxGetDimensions(regvec_mx)[1] == 1);
  else
    assert(mxGetDimensions(regvec_mx)[0] == 165 && mxGetDimensions(regvec_mx)[1] == n);

  assert(mxGetClassID(dr_coeff_mx) == mxDOUBLE_CLASS && mxGetNumberOfDimensions(dr_coeff_mx) == 2
         && mxGetDimensions(kmin_mx)[1] == 1);
  if (s == 1 || s == 5)
    assert(mxGetDimensions(dr_coeff_mx)[0] == n*nr);
  else
    assert(mxGetDimensions(dr_coeff_mx)[0] == (n+1)*nr);

  // Convert to doubles and GSL vectors
  amax = mxGetPr(amax_mx)[0];
  amin = mxGetPr(amin_mx)[0];
  kmax = mxGetPr(kmax_mx)[0];
  kmin = mxGetPr(kmin_mx)[0];

  if (s == 1 || s == 2 || s == 5 || s == 6)
    {
      C_max = mxGetPr(alt_max_mx)[0];
      C_min = mxGetPr(alt_min_mx)[0];
      regvec = gsl_vector_alloc(9);
      gsl_vector_memcpy(regvec, &gsl_vector_view_array(mxGetPr(regvec_mx), 9).vector);
    }
  else
    {
      mumax = mxGetPr(alt_max_mx)[0];
      mumin = mxGetPr(alt_min_mx)[0];
      const gsl_matrix *regvec_labor_prime = &gsl_matrix_view_array(mxGetPr(regvec_mx), n, 165).matrix;
      regvec_labor = gsl_matrix_alloc(165, n);
      gsl_matrix_transpose_memcpy(regvec_labor, regvec_labor_prime);
    }

  if (s == 1 || s == 5)
    {
      dr_coeff = gsl_vector_alloc(n*nr);
      gsl_vector_memcpy(dr_coeff, &gsl_vector_view_array(mxGetPr(dr_coeff_mx), n*nr).vector);
    }
  else
    {
      dr_coeff = gsl_vector_alloc((n+1)*nr);
      gsl_vector_memcpy(dr_coeff, &gsl_vector_view_array(mxGetPr(dr_coeff_mx), (n+1)*nr).vector);
    }

  // Free MX arrays
  mxDestroyArray(amax_mx);
  mxDestroyArray(amin_mx);
  mxDestroyArray(kmax_mx);
  mxDestroyArray(kmin_mx);
  mxDestroyArray(alt_max_mx);
  mxDestroyArray(alt_min_mx);
  mxDestroyArray(regvec_mx);
  mxDestroyArray(dr_coeff_mx);

  // Allocate temporary variables
  sd = gsl_vector_alloc(2*modspec.n);
  REG = gsl_vector_alloc(nr);
  if (s == 1 || s == 2 || s == 5 || s == 6)
    REGC = gsl_vector_alloc(9);
  else
    {
      REG_A = gsl_matrix_alloc(n, 9);
      REG_K = gsl_matrix_alloc(n, 9);
      REG_MU = gsl_vector_alloc(9);
      REGMAT = gsl_matrix_alloc(n, 165);
    }
}

MRGalSolution::~MRGalSolution()
{
  const int s = modspec.spec_no();
  if (s == 1 || s == 2 || s == 5 || s == 6)
    gsl_vector_free(regvec);
  else
    gsl_matrix_free(regvec_labor);

  gsl_vector_free(dr_coeff);

  gsl_vector_free(sd);
  gsl_vector_free(REG);
  if (s == 1 || s == 2 || s == 5 || s == 6)
    gsl_vector_free(REGC);
  else
    {
      gsl_matrix_free(REG_A);
      gsl_matrix_free(REG_K);
      gsl_vector_free(REG_MU);
      gsl_matrix_free(REGMAT);
    }
}

void
MRGalSolution::policy_func_internal(const gsl_vector *y_prev, const gsl_vector *shocks, gsl_vector *y_curr)
{
  const int n = modspec.n;
  const int s = modspec.spec_no();

  assert(y_prev->stride == 1 && y_curr->stride == 1);

  double *c_curr = gsl_vector_ptr(y_curr, 0);
  double *l_curr = gsl_vector_ptr(y_curr, n);
  double *i_curr = gsl_vector_ptr(y_curr, 2*n);
  double *k_curr = gsl_vector_ptr(y_curr, 3*n);
  double *a_curr = gsl_vector_ptr(y_curr, 4*n);

  const double *k_prev = gsl_vector_const_ptr(y_prev, 3*n);
  const double *a_prev = gsl_vector_const_ptr(y_prev, 4*n);

  // Convert y_prev and shocks to Pichler's rescaled representation of state space, and set tech shock in y_curr
  for(int j = 0; j < n; j++)
    {
      gsl_vector_set(sd, j, 2.0*(k_prev[j]-kmin) / (kmax-kmin) - 1.0);
      a_curr[j] = exp(modspec.rho*log(a_prev[j]) + modspec.sigma*(gsl_vector_get(shocks, j) + gsl_vector_get(shocks, n)));
      gsl_vector_set(sd, n+j, 2.0*(a_curr[j]-amin)/(amax-amin) - 1.0);
    }

  // Construct REG matrix (Chebychev polynomials of state vars)
  int jjj = 0;
  gsl_vector_set(REG, jjj++, 1.0);
  for(int i = 0; i < 2*n; i++)
    gsl_vector_set(REG, jjj++, gsl_vector_get(sd, i));
  for(int j = 0; j < 2*n-1; j++)
    for(int i = j+1; i < 2*n; i++)
      gsl_vector_set(REG, jjj++, gsl_vector_get(sd, i) * gsl_vector_get(sd, j));
  for(int i = 0; i < 2*n; i++)
    {
      double x = gsl_vector_get(sd, i);
      gsl_vector_set(REG, jjj++, 2.0*x*x-1.0);
    }

  // Compute capital level
  for(int j = 0; j < n; j++)
    gsl_blas_ddot(REG, &gsl_vector_const_subvector(dr_coeff, j*nr, nr).vector,
                  k_curr + j);

  // Compute investment
  for(int j = 0; j < n; j++)
    i_curr[j] = k_curr[j] - (1.0-modspec.delta)*k_prev[j];

  if (s == 1 || s == 2 || s == 5 || s == 6)
    {
      // Compute aggregate consumption
      double C;
      if (s == 1 || s == 5)
        {
          double f[n], fk[n], fl[n];
          modspec.production(y_prev->data, y_curr->data, f, fk, fl);
          C = 0;
          for(int j = 0; j < n; j++)
            C += a_curr[j] * f[j] - k_curr[j] + k_prev[j] - modspec.phi/2.0*k_prev[j]*pow(k_curr[j]/k_prev[j]-1.0, 2.0);
        }
      else
        gsl_blas_ddot(REG, &gsl_vector_const_subvector(dr_coeff, n*nr, nr).vector, &C);

      // Compute consumption for country 1
      double sd_C = 2.0*(C-C_min)/(C_max-C_min)-1.0;
      gsl_vector_set(REGC, 0, 1.0);
      gsl_vector_set(REGC, 1, sd_C);
      for(int i = 2; i < 9; i++)
        gsl_vector_set(REGC, i, 2.0*sd_C*gsl_vector_get(REGC, i-1) - gsl_vector_get(REGC, i-2));
      gsl_blas_ddot(REGC, regvec, c_curr);

      // Compute consumption for other countries and labor for all
      if (s == 1 || s == 5)
        {
          const ModelSpecA1A5 &ms = dynamic_cast<const ModelSpecA1A5 &>(modspec);
          for(int j = 1; j < n; j++)
            c_curr[j] = pow(modspec.taus[0] / modspec.taus[j] * pow(c_curr[0], -1.0/ms.gammas[0]), -ms.gammas[j]);
          for(int j = 0; j < n; j++)
            l_curr[j] = NAN;
        }
      else
        {
          const ModelSpecA2A6 &ms = dynamic_cast<const ModelSpecA2A6 &>(modspec);
          for(int j = 1; j < n; j++)
            c_curr[j] = pow(modspec.taus[0] / modspec.taus[j] * pow(c_curr[0], -1.0/ms.gammas[0]), -ms.gammas[j]);
          for(int j = 0; j < n; j++)
            l_curr[j] = pow(1.0/ms.b[j]*pow(c_curr[j], -1.0/ms.gammas[j]) * a_curr[j]
                            * (1.0-modspec.alpha) * modspec.A * pow(k_prev[j], modspec.alpha),
                            1.0/(modspec.alpha+1.0/ms.etas[j]));
        }

      // Set lambda
      double uc_curr[n], ul_curr[n];
      modspec.marginal_utilities(y_curr->data, uc_curr, ul_curr);
      gsl_vector_set(y_curr, 5*n, uc_curr[0] * modspec.taus[0]);
    }
  else // specs A3,A4,A7,A8
    {
      // Compute lambda
      gsl_blas_ddot(REG, &gsl_vector_const_subvector(dr_coeff, n*nr, nr).vector, gsl_vector_ptr(y_curr, 5*n));
      gsl_vector_set(y_curr, 5*n, exp(gsl_vector_get(y_curr, 5*n)));
      double sd_mu = 2.0*(gsl_vector_get(y_curr, 5*n) - mumin)/(mumax-mumin) - 1.0;

      // Compute degree 8 chebychev polynomials of state variables and lambda
      gsl_vector_set_all(&gsl_matrix_column(REG_A, 0).vector, 1.0);
      gsl_vector_set_all(&gsl_matrix_column(REG_K, 0).vector, 1.0);
      gsl_vector_set(REG_MU, 0, 1.0);

      gsl_vector_memcpy(&gsl_matrix_column(REG_K, 1).vector, &gsl_vector_const_subvector(sd, 0, n).vector);
      gsl_vector_memcpy(&gsl_matrix_column(REG_A, 1).vector, &gsl_vector_const_subvector(sd, n, n).vector);
      gsl_vector_set(REG_MU, 1, sd_mu);

      for(int i = 2; i < 9; i++)
        {
          for(int j = 0; j < n; j++)
            {
              gsl_matrix_set(REG_K, j, i, 2.0*gsl_vector_get(sd, j)*gsl_matrix_get(REG_K, j, i-1) - gsl_matrix_get(REG_K, j, i-2));
              gsl_matrix_set(REG_A, j, i, 2.0*gsl_vector_get(sd, n+j)*gsl_matrix_get(REG_A, j, i-1) - gsl_matrix_get(REG_A, j, i-2));
            }
          gsl_vector_set(REG_MU, i, 2.0*sd_mu*gsl_vector_get(REG_MU, i-1) - gsl_vector_get(REG_MU, i-2));
        }

      // Compute degree 8 chebychev multivariate polynomials of state variables and lambda
      int counti = 0;
      for(int i1 = 0; i1 <= 8; i1++)
        for(int i2 = 0; i2 <= 8-i1; i2++)
          for(int i3 = 0; i3 <= 8-i1-i2; i3++)
            {
              gsl_vector *tmp = &gsl_matrix_column(REGMAT, counti).vector;
              gsl_vector_set_all(tmp, gsl_vector_get(REG_MU, i1));
              gsl_vector_mul(tmp, &gsl_matrix_const_column(REG_A, i2).vector);
              gsl_vector_mul(tmp, &gsl_matrix_const_column(REG_K, i3).vector);
              counti++;
            }

      // Compute labor
      for(int j = 0; j < n; j++)
        gsl_blas_ddot(&gsl_matrix_const_row(REGMAT, j).vector, &gsl_matrix_const_column(regvec_labor, j).vector,
                      l_curr+j);

      if (s == 3 || s == 7)
        {
          const ModelSpecA3A7 &ms = dynamic_cast<const ModelSpecA3A7 &>(modspec);
          for(int j = 0; j < n; j++)
            c_curr[j] = (ms.psi/(1.0-ms.psi))*(ms.time_endowment-l_curr[j])
              * a_curr[j] * (1.0-modspec.alpha) * modspec.A
              * pow(k_prev[j], modspec.alpha)
              * pow(l_curr[j], -modspec.alpha);
        }
      else
        {
          const ModelSpecA4A8 &ms = dynamic_cast<const ModelSpecA4A8 &>(modspec);
          for(int j = 0; j < n; j++)
            c_curr[j] = pow((ms.b[j] * pow(ms.time_endowment-l_curr[j], -1.0/ms.xis[j]))
                            / (a_curr[j] * (1.0-modspec.alpha) * modspec.A
                               * pow(l_curr[j], ms.mus[j]-1.0)
                               * pow(modspec.alpha * pow(k_prev[j], ms.mus[j])
                                     + (1.0-modspec.alpha) * pow(l_curr[j], ms.mus[j]),
                                     1.0/ms.mus[j]-1.0)
                               )
                            , -ms.xis[j]);
        }
    }
}

std::string
MRGalSolution::name() const
{
  return "MRGAL";
}
