/*
 * Copyright (C) 2008-2010 Sébastien Villemot <sebastien.villemot@ens.fr>
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

#include <cmath>
#include <cstring>
#include <unistd.h>
#include <iomanip>
#include <iostream>
#include <vector>

#include "ModelSpec.hh"
#include "SolutionTester.hh"

#include "smol/SmolSolution.hh"
#include "per/PerSolution.hh"
#include "mrgal/MRGalSolution.hh"
#include "mmj/CGASolution.hh"
#include "mmj/SSA1Solution.hh"

int
main(int argc, char** argv)
{
  arc_resid_t arc_resid = MichelARC;
  bool dump = false;
  int opt;
  int begin_spec = 1, end_spec = 0, test_no = 0;
  int N1 = 1000;  // default nbr of points for Test 1
  int N2 = 10000; // default nbr of points for Test 2
  int N3 = 10000; // default nbr of points for Test 3
  char *selected_solution = NULL;
  unsigned long rng_seed = 0;
  SolutionTester::integration_method int_meth = SolutionTester::GaussHermiteAndMonomialDegree5;

  while((opt = getopt(argc, argv, "aAb:B:dgmqr:s:t:u:v:w:")) != -1)
    {
      switch(opt)
        {
        case 'a':
          arc_resid = BenARC;
          break;
        case 'A':
          arc_resid = PaulARC;
          break;
        case 'b':
          begin_spec = atoi(optarg);
          break;
        case 'B':
          begin_spec = atoi(optarg);
          end_spec = begin_spec;
          break;
        case 'd':
          dump = true;
          break;
        case 'g':
          int_meth = SolutionTester::GaussHermiteOnly;
          break;
        case 'm':
          int_meth = SolutionTester::MonomialDegree5Only;
          break;
        case 'q':
          int_meth = SolutionTester::QuasiMonteCarloOnly;
          break;
        case 'r':
          rng_seed = atoi(optarg);
          break;
        case 's':
          selected_solution = optarg;
          break;
        case 't':
          test_no = atoi(optarg);
          break;
        case 'u':
          N1 = atoi(optarg);
          break;
        case 'v':
          N2 = atoi(optarg);
          break;
        case 'w':
          N3 = atoi(optarg);
          break;
        default:
          std::cerr << "Usage: " << argv[0] << " [-a] [-A] [-b INTEGER] [-B INTEGER] [-d] [-g] [-m] [-q] [-r INTEGER] [-s NAME] [-t INTEGER] [-u INTEGER] [-v INTEGER] [-w INTEGER]" << std::endl;
          exit(EXIT_FAILURE);
        }
    }

  double alpha = 0.36, beta = 0.99, delta = 0.025, sigma = 0.01,
    rho = 0.95, phi = 0.5, time_endowment = 2.5;

  std::vector<ModelSpec *> models;

  // Construct models
  for(int n = 2; n <= 10; n += 2)
    models.push_back(new ModelSpecA1(n, alpha, beta, delta, rho, sigma, phi, 1, arc_resid));

  for(int n = 2; n <= 8; n += 2)
    models.push_back(new ModelSpecA2(n, alpha, beta, delta, rho, sigma, phi, 0.25, 0.1, arc_resid));

  for(int n = 2; n <= 6; n += 2)
    models.push_back(new ModelSpecA3(n, alpha, beta, delta, rho, sigma, phi, time_endowment, 0.25, arc_resid));

  for(int n = 2; n <= 6; n += 2)
    models.push_back(new ModelSpecA4(n, alpha, beta, delta, rho, sigma, phi, time_endowment, 0.25, -0.2, 0.83, arc_resid));

  for(int n = 2; n <= 10; n += 2)
    models.push_back(new ModelSpecA5(n, alpha, beta, delta, rho, sigma, phi, 0.25, 1, arc_resid));

  for(int n = 2; n <= 8; n += 2)
    models.push_back(new ModelSpecA6(n, alpha, beta, delta, rho, sigma, phi, 0.25, 1, 0.1, 1, arc_resid));

  for(int n = 2; n <= 6; n += 2)
    models.push_back(new ModelSpecA7(n, alpha, beta, delta, rho, sigma, phi, time_endowment, 0.25, 1, arc_resid));

  for(int n = 2; n <= 6; n += 2)
    models.push_back(new ModelSpecA8(n, alpha, beta, delta, rho, sigma, phi, time_endowment, 0.2, 0.4, -0.3, 0.3, 0.75, 0.9, arc_resid));

  if (begin_spec < 1 || begin_spec > models.size())
    {
      std::cerr << "Invalid specification number " << begin_spec << std::endl;
      exit(EXIT_FAILURE);
    }

  if (end_spec == 0)
    {
      end_spec = models.size();
    }

  if (end_spec < begin_spec || end_spec > models.size())
    {
      std::cerr << "Invalid specification number " << begin_spec << std::endl;
      exit(EXIT_FAILURE);
    }

  std::cout << std::setprecision(10);

#ifdef DEBUG
  for(std::vector<ModelSpec *>::const_iterator it = models.begin();
      it != models.end(); ++it)
    {
      // Display model specification
      (*it)->write_output(std::cout);

      // Fetch steady state and display it
      const int n = (*it)->n;
      gsl_vector *y_ss = gsl_vector_alloc(5*n+1);
      (*it)->steady_state(y_ss);
      std::cout << "Steady state:" << std::endl;
      (*it)->write_state(std::cout, y_ss);

      // Compute max theoretical error at steady state
      gsl_vector *shocks = gsl_vector_alloc(n+1);
      gsl_vector_set_zero(shocks);
      gsl_vector *fwd = gsl_vector_alloc(n);
      (*it)->forward_part(y_ss, y_ss, y_ss, shocks, fwd);
      double max_err = (*it)->max_error(y_ss, y_ss, fwd, shocks);
      std::cout << "Max theoretical error at steady state: " << max_err << std::endl;

      // Compute policy functions errors at steady state with zero shock
      gsl_vector *y_curr = gsl_vector_alloc(5*n+1);
      gsl_vector *y_next = gsl_vector_alloc(5*n+1);

      // Determine the solution methods to be tested
      std::vector<ModelSolution *> solutions;
      solutions.push_back(new SmolSolution(**it));
      solutions.push_back(new PerSolution(**it));
      solutions.push_back(new MRGalSolution(**it));
      solutions.push_back(new CGASolution(**it));

      for(std::vector<ModelSolution *>::iterator it2 = solutions.begin();
          it2 != solutions.end(); ++it2)
        {
          (*it2)->policy_func(y_ss, shocks, y_curr);
          (*it2)->policy_func(y_curr, shocks, y_next);

          // Display max error at steady state
          (*it)->forward_part(y_ss, y_curr, y_next, shocks, fwd);
          double max_err = (*it)->max_error(y_ss, y_curr, fwd, shocks);
          std::cout << "Max error for " << (*it2)->name() << " at steady state with zero shock: " << max_err << std::endl;

          delete *it2;
        }

      std::cout << std::endl;

      gsl_vector_free(shocks);
      gsl_vector_free(fwd);
      gsl_vector_free(y_curr);
      gsl_vector_free(y_next);
      gsl_vector_free(y_ss);
    }
#endif

  gsl_vector *thresholds = gsl_vector_alloc(3);
  gsl_vector_set(thresholds, 0, 0.05);
  gsl_vector_set(thresholds, 1, 0.5);
  gsl_vector_set(thresholds, 2, 0.95);

  // Perform tests
  for(std::vector<ModelSpec *>::const_iterator it = models.begin() + begin_spec - 1;
      it != models.begin() + end_spec; ++it)
    {
      const int n = (*it)->n;

      // Display model specification
      std::cout << "*** ";
      (*it)->write_output(std::cout);
      std::cout << std::endl;

      // Determine the solution methods to be tested
      std::vector<ModelSolution *> solutions;
      if (selected_solution == NULL)
        {
          solutions.push_back(new PerSolution(**it, true));
          solutions.push_back(new SSA1Solution(**it));
          solutions.push_back(new PerSolution(**it, false));
          solutions.push_back(new MRGalSolution(**it));
          solutions.push_back(new CGASolution(**it));
          solutions.push_back(new SmolSolution(**it));
        }
      else
        {
          if (!strcmp(selected_solution, "per2"))
            solutions.push_back(new PerSolution(**it, false));
          else if (!strcmp(selected_solution, "per1"))
            solutions.push_back(new PerSolution(**it, true));
          else if (!strcmp(selected_solution, "smol"))
            solutions.push_back(new SmolSolution(**it));
          else if (!strcmp(selected_solution, "mrgal"))
            solutions.push_back(new MRGalSolution(**it));
          else if (!strcmp(selected_solution, "cga"))
            solutions.push_back(new CGASolution(**it));
          else if (!strcmp(selected_solution, "ssa1"))
            solutions.push_back(new SSA1Solution(**it));
          else
            {
              std::cerr << "Unknown solution method " << selected_solution << std::endl;
              exit(EXIT_FAILURE);
            }
        }

      // For DHM stat: test the all the Euler equations together
      SolutionTester::equation_blocks_type eb;
      std::vector<int> all_euler_eqs;
      for(int j = 0; j < n; j++)
        all_euler_eqs.push_back(2*n+j);
      eb.push_back(all_euler_eqs);

      for(std::vector<ModelSolution *>::iterator it2 = solutions.begin();
          it2 != solutions.end(); ++it2)
        {
          SolutionTester st(**it2, int_meth, dump);

          double max_err;
          if (!test_no || test_no == 1)
            {
              std::cout << "Accuracy test 1 on " << (*it2)->name() << ":" << std::endl;
              std::cout << " Nbr of points " << N1 << std::endl;
              std::cout << " Radius 0.01: ";
              st.accuracy_test1(N1, 0.01, max_err);
              std::cout << max_err << std::endl;
              std::cout << " Radius 0.10: ";
              st.accuracy_test1(N1, 0.1, max_err);
              std::cout << max_err << std::endl;
              std::cout << " Radius 0.30: ";
              st.accuracy_test1(N1, 0.3, max_err);
              std::cout << max_err << std::endl;
              std::cout << std::endl;
            }

          double mean_err;
          if (!test_no || test_no == 2)
            {
              std::cout << "Accuracy test 2 on " << (*it2)->name() << ":" << std::endl;
              std::cout << " Nbr of points " << N2 << std::endl;
              st.accuracy_test2(N2, 200, max_err, mean_err, NULL, rng_seed);
              std::cout << " Max:  " << max_err << std::endl
                        << " Mean: " << mean_err << std::endl;
              std::cout << std::endl;
            }

          if (test_no == 3)
            {
              gsl_matrix *freqs = gsl_matrix_alloc(3, eb.size());
              std::cout << "Accuracy test 3 on " << (*it2)->name() << ":" << std::endl;
              std::cout << " Nbr of points " << N3 << std::endl;
              st.accuracy_test3(200, N3, 200, eb, true, true, false, thresholds, freqs, rng_seed);
              std::cout << " P <= 0.05: " << gsl_matrix_get(freqs, 0, 0) << std::endl;
              std::cout << " P <= 0.50: " << gsl_matrix_get(freqs, 1, 0) << std::endl;
              std::cout << " P <= 0.95: " << gsl_matrix_get(freqs, 2, 0) << std::endl;
              gsl_matrix_free(freqs);
              std::cout << std::endl;
            }

          delete *it2;
        }

      std::cout << std::endl;

    }

  // Cleanup
  for(std::vector<ModelSpec *>::iterator it = models.begin();
      it != models.end(); ++it)
    delete (*it);
  gsl_vector_free(thresholds);

#ifdef DEBUG
  SolutionTester::hypersphere_draw_test(16, 0.2, 0.3, 100000);
#endif

  return EXIT_SUCCESS;
}
