/*
 * Copyright (C) 2008-2010 Sébastien Villemot <sebastien.villemot@ens.fr>
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

#include <iostream>
#include <cstdlib>
#include <cmath>
#include <cassert>
#include <dlfcn.h>

#include "SmolSolution.hh"

int SmolSolution::instances = 0;

SmolSolution::SmolSolution(const ModelSpec &modspec_arg) throw (UnsupportedSpec)
  : ModelSolution(modspec_arg)
{
  const int s = modspec.spec_no();
  const int n = modspec.n;

  switch(s)
    {
    case 1:
    case 5:
      if (n != 2 && n != 4 && n != 6 && n != 8 && n != 10)
        throw UnsupportedSpec();
      break;
    case 2:
    case 6:
      if (n != 2 && n != 4 && n != 6 && n != 8)
        throw UnsupportedSpec();
      break;
    case 3:
    case 4:
    case 7:
    case 8:
      if (n != 2 && n != 4 && n != 6)
        throw UnsupportedSpec();
      break;
    default:
      throw UnsupportedSpec();
    }

  // We must increment instances after having thrown UnsupportedSpec
  assert(!instances);
  instances++;

  char soname[50];
  snprintf(soname, 50, "smol/TestA%d/kkm-A%d-N%d.so", s, s, n);

  // Load the code for this specification
  handle = dlopen(soname, RTLD_LAZY);
  if (!handle)
    {
      std::cerr << dlerror() << std::endl;
      exit(EXIT_FAILURE);
    }

  char *error;

  *(void**) (&paramupdate) = dlsym(handle, "paramupdate_");
  if ((error = dlerror()) != NULL)
    {
      std::cerr << error << std::endl;
      exit(EXIT_FAILURE);
    }

  *(void**) (&test) = dlsym(handle, "test_");
  if ((error = dlerror()) != NULL)
    {
      std::cerr << error << std::endl;
      exit(EXIT_FAILURE);
    }

  *(void**) (&incoeffs) = dlsym(handle, "incoeffs_");
  if ((error = dlerror()) != NULL)
    {
      std::cerr << error << std::endl;
      exit(EXIT_FAILURE);
    }

  (*paramupdate)();

  // We set 30 here because it is the constraint imposed in the fortran function InCoeffs (in Init.f90)
  char csvname[30];
  snprintf(csvname, 30, "smol/TestA%d/coeff%d.csv", s, n == 10 ? 9 : n);
  (*incoeffs)(csvname);

  // Allocate temporary variables
  state = new double[2*n];
  if (modspec.labor())
    control = new double[4*n];
  else
    control = new double[3*n];
}

SmolSolution::~SmolSolution()
{
  if (dlclose(handle))
    {
      std::cerr << dlerror() << std::endl;
      exit(EXIT_FAILURE);
    }

  instances--;

  delete[] state;
  delete[] control;
}

void
SmolSolution::policy_func_internal(const gsl_vector *y_prev, const gsl_vector *shocks, gsl_vector *y_curr)
{
  const int &n = modspec.n;

  for(int j = 0; j < n; j++)
    {
      state[j] = gsl_vector_get(y_prev, 3*n+j);
      state[n+j] = modspec.rho*log(gsl_vector_get(y_prev, 4*n+j))
        + modspec.sigma*(gsl_vector_get(shocks, j) + gsl_vector_get(shocks, n));
    }

  (*test)(state, control);

  for(int j = 0; j < n; j++)
    {
      gsl_vector_set(y_curr, j, control[n+j]);
      if (modspec.labor())
        gsl_vector_set(y_curr, n+j, control[2*n+j]);
      else
        gsl_vector_set(y_curr, n+j, NAN);
      gsl_vector_set(y_curr, 2*n+j, control[j] - (1-modspec.delta)*state[j]);
      gsl_vector_set(y_curr, 3*n+j, control[j]);
      gsl_vector_set(y_curr, 4*n+j, exp(state[n+j]));
    }

  // Compute marginal utility of country 0 for filling lambda
  double uc_curr[n], ul_curr[n];
  assert(y_curr->stride == 1);
  modspec.marginal_utilities(y_curr->data, uc_curr, ul_curr);
  gsl_vector_set(y_curr, 5*n, uc_curr[0] * modspec.taus[0]);
}

std::string
SmolSolution::name() const
{
  return "SMOL";
}
