/*
 * Copyright (C) 2008-2010 Sébastien Villemot <sebastien.villemot@ens.fr>
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

#include <iostream>
#include <cassert>
#include <cmath>

#include <gsl/gsl_blas.h>

#include "mat.h"

#include "PerSolution.hh"

PerSolution::PerSolution(const ModelSpec &modspec_arg, bool first_order_arg) throw (UnsupportedSpec)
  : ModelSolution(modspec_arg), first_order(first_order_arg), nw(4*modspec.n+1), nz(2*modspec.n)
{
  const int s = modspec.spec_no();
  const int n = modspec.n;

  switch(s)
    {
    case 1:
    case 5:
      if (n != 2 && n != 4 && n != 6 && n != 8 && n != 10)
        throw UnsupportedSpec();
      break;
    case 2:
    case 6:
      if (n != 2 && n != 4 && n != 6 && n != 8)
        throw UnsupportedSpec();
      break;
    case 3:
    case 4:
    case 7:
    case 8:
      if (n != 2 && n != 4 && n != 6)
        throw UnsupportedSpec();
      break;
    default:
      throw UnsupportedSpec();
    }

  // Retrieve MATLAB arrays
  char filename[50];
  snprintf(filename, 50, "per/KKKLOG_m%dN%d.mat", s, n);

  MATFile *mfp = matOpen(filename, "r");
  if (mfp == NULL)
    {
      std::cerr << "PerSolution::PerSolution: can't open " << filename << std::endl;
      exit(EXIT_FAILURE);
    }

  mxArray *HH0mx = matGetVariable(mfp, "HH0");
  mxArray *HH1mx = matGetVariable(mfp, "HH1");
  mxArray *HH2mx = matGetVariable(mfp, "HH2");
  mxArray *wSSmx = matGetVariable(mfp, "wSS");
  mxArray *zSSmx = matGetVariable(mfp, "zSS");
  assert(HH0mx != NULL && HH1mx != NULL && HH2mx != NULL && wSSmx != NULL && zSSmx != NULL);
  matClose(mfp);

  // Test array types and dimensions
  assert(mxGetClassID(HH0mx) == mxDOUBLE_CLASS && mxGetNumberOfDimensions(HH0mx) == 2
         && mxGetDimensions(HH0mx)[0] == nw && mxGetDimensions(HH0mx)[1] == 1);
  assert(mxGetClassID(wSSmx) == mxDOUBLE_CLASS && mxGetNumberOfDimensions(wSSmx) == 2
         && mxGetDimensions(wSSmx)[0] == nw && mxGetDimensions(wSSmx)[1] == 1);
  assert(mxGetClassID(zSSmx) == mxDOUBLE_CLASS && mxGetNumberOfDimensions(zSSmx) == 2
         && mxGetDimensions(zSSmx)[0] == nz && mxGetDimensions(zSSmx)[1] == 1);
  assert(mxGetClassID(HH1mx) == mxDOUBLE_CLASS && mxGetNumberOfDimensions(HH1mx) == 2
         && mxGetDimensions(HH1mx)[0] == nw && mxGetDimensions(HH1mx)[1] == nz);
  assert(mxGetClassID(HH2mx) == mxDOUBLE_CLASS && mxGetNumberOfDimensions(HH2mx) == 3
         && mxGetDimensions(HH2mx)[0] == nw && mxGetDimensions(HH2mx)[1] == nz
         && mxGetDimensions(HH2mx)[2] == nz);

  // Convert to GSL vectors and matrices
  HH0 = gsl_vector_alloc(nw);
  gsl_vector_memcpy(HH0, &gsl_vector_view_array(mxGetPr(HH0mx), nw).vector);
  wSS = gsl_vector_alloc(nw);
  gsl_vector_memcpy(wSS, &gsl_vector_view_array(mxGetPr(wSSmx), nw).vector);
  zSS = gsl_vector_alloc(nz);
  gsl_vector_memcpy(zSS, &gsl_vector_view_array(mxGetPr(zSSmx), nz).vector);

  const gsl_matrix *HH1_prime = &gsl_matrix_view_array(mxGetPr(HH1mx), nz, nw).matrix;
  HH1 = gsl_matrix_alloc(nw, nz);
  gsl_matrix_transpose_memcpy(HH1, HH1_prime);

  HH2 = (gsl_matrix **) malloc(sizeof(gsl_matrix *) * nw);
  for(int i = 0; i < nw; i++)
    {
      HH2[i] = gsl_matrix_alloc(nz, nz);
      for(int j = 0; j < nz; j++)
        for(int k = 0; k < nz; k++)
          gsl_matrix_set(HH2[i], j, k, mxGetPr(HH2mx)[i + j*nw + k*nw*nz]);
    }

  // Free MX arrays
  mxDestroyArray(HH0mx);
  mxDestroyArray(HH1mx);
  mxDestroyArray(HH2mx);
  mxDestroyArray(wSSmx);
  mxDestroyArray(zSSmx);

  // Allocate temporary variables
  state = gsl_vector_alloc(nz);
  control = gsl_vector_alloc(nw);
  tmp = gsl_vector_alloc(nz);
}

PerSolution::~PerSolution()
{
  gsl_vector_free(wSS);
  gsl_vector_free(zSS);
  gsl_vector_free(HH0);
  gsl_matrix_free(HH1);
  for(int i = 0; i < nw; i++)
    gsl_matrix_free(HH2[i]);
  free(HH2);
  gsl_vector_free(tmp);
  gsl_vector_free(state);
  gsl_vector_free(control);
}

void
PerSolution::policy_func_internal(const gsl_vector *y_prev, const gsl_vector *shocks, gsl_vector *y_curr)
{
  const int &n = modspec.n;

  // Convert y_prev and shocks to KKK's representation of state space
  for(int j = 0; j < n; j++)
    {
      gsl_vector_set(state, j, log(gsl_vector_get(y_prev, 3*n+j)));
      gsl_vector_set(state, n+j, modspec.rho*log(gsl_vector_get(y_prev, 4*n+j))
                     + modspec.sigma*(gsl_vector_get(shocks, j) + gsl_vector_get(shocks, n)));
    }
  gsl_vector_sub(state, zSS);

  // Compute KKK's approximation of control variables, begginning by 2nd order terms
  if (first_order)
    gsl_vector_set_zero(control);
  else
    for(int j = 0; j < nw; j++)
      {
        gsl_blas_dgemv(CblasNoTrans, 0.5, HH2[j], state, 0, tmp);
        gsl_blas_ddot(state, tmp, gsl_vector_ptr(control, j));
      }
  gsl_blas_dgemv(CblasNoTrans, 1, HH1, state, 1, control);
  gsl_vector_add(control, HH0);
  gsl_vector_add(control, wSS);

  // Convert to our representation of y_curr
  for(int j = 0; j < n; j++)
    {
      gsl_vector_set(y_curr, j, exp(gsl_vector_get(control, j)));
      if (modspec.labor())
        gsl_vector_set(y_curr, n+j, exp(gsl_vector_get(control, n+j)));
      else
        gsl_vector_set(y_curr, n+j, NAN);
      gsl_vector_set(y_curr, 2*n+j, exp(gsl_vector_get(control, 3*n+1+j)));
      gsl_vector_set(y_curr, 3*n+j, exp(gsl_vector_get(control, 2*n+1+j)));
      gsl_vector_set(y_curr, 4*n+j, exp(gsl_vector_get(state, n+j)));
    }
  gsl_vector_set(y_curr, 5*n, exp(gsl_vector_get(control, 2*n)));
}

std::string
PerSolution::name() const
{
  if (first_order)
    return "PER1";
  else
    return "PER2";
}
