dynare GKsimp noclearall
dynare GKsimp_ngdp noclearall
dynare ramsey noclearall
%% Setting up the IRFs and Loss function

MP = {'it','ngdp','opt'};
labels = {'IT','NGDP','Optimal'};
codes = {'GKsimp.mat','GK_ngdp.mat','ramsey.mat'};

n = 1000; % Number of periods in the IRFs

disc_vec = zeros(n,1);
loss1_one = zeros(4,n);
loss2_one = zeros(4,n);
loss3_one = zeros(4,n);
loss4_one = zeros(4,n);

for jj=1:length(MP)
    S.(sprintf('loss1_%s',MP{jj})) = zeros(n,4);
    S.(sprintf('loss2_%s',MP{jj})) = zeros(n,4);
    S.(sprintf('loss3_%s',MP{jj})) = zeros(n,4);
    S.(sprintf('loss4_%s',MP{jj})) = zeros(n,4);
end
for k=1:n
    disc_vec(k) = 0.99^(k-1);
end

C_ss = 0.463965;
L_ss = -0.0334028;


for k=1:length(MP)
   name = sprintf(codes{k});
   load (name)
   
   % IRF
   S.(sprintf('Y_%s',MP{k}))     = [Y_e_a, Y_e_i, Y_e_ksi, Y_e_Ne];
   S.(sprintf('I_%s',MP{k}))     = [I_e_a, I_e_i, I_e_ksi, I_e_Ne];
   S.(sprintf('K_%s',MP{k}))     = [K_e_a, K_e_i, K_e_ksi, K_e_Ne];
   S.(sprintf('N_%s',MP{k}))     = [N_e_a, N_e_i, N_e_ksi, N_e_Ne];   
   S.(sprintf('prem_%s',MP{k}))  = [prem_e_a, prem_e_i, prem_e_ksi, prem_e_Ne];
   S.(sprintf('infl_%s',MP{k}))  = [infl_e_a, infl_e_i, infl_e_ksi, infl_e_Ne];  
   S.(sprintf('price_%s',MP{k})) = [cumsum(infl_e_a), cumsum(infl_e_i), cumsum(infl_e_ksi), cumsum(infl_e_Ne)];
   S.(sprintf('i_%s',MP{k}))     = [i_e_a, i_e_i, i_e_ksi, i_e_Ne];
   S.(sprintf('R_%s',MP{k}))     = [R_e_a, R_e_i, R_e_ksi, R_e_Ne];
   S.(sprintf('C_%s',MP{k}))     = [C_e_a, C_e_i, C_e_ksi, C_e_Ne];
   S.(sprintf('L_%s',MP{k}))     = [L_e_a, L_e_i, L_e_ksi, L_e_Ne];
   S.(sprintf('X_%s',MP{k}))     = [X_e_a, X_e_i, X_e_ksi, X_e_Ne];
   S.(sprintf('di_%s',MP{k}))    = [di_e_a, di_e_i, di_e_ksi, di_e_Ne];
   S.(sprintf('gap_%s',MP{k}))   = [gap_e_a, gap_e_i, gap_e_ksi, gap_e_Ne];
   %S.(sprintf('Welf_%s',MP{k}))  = [Welf_e_a, Welf_e_i, Welf_e_ksi, Welf_e_Ne];
 
   
   %Case 1. Loss function (Infl + trend output gap + interest rate smoothing)
   for l=1:4
   loss1_one(l,:) = 1/2*(1*((S.(sprintf('infl_%s',MP{k}))(:,l)).^2) + 0.1*((S.(sprintf('Y_%s',MP{k}))(:,l)).^2) + 0.5*((S.(sprintf('di_%s',MP{k}))(:,l)).^2));
   S.(sprintf('loss1_%s',MP{k}))(1,l)=sum(loss1_one(l,:)*disc_vec);
      for ii=2:n
          S.(sprintf('loss1_%s',MP{k}))(ii,l)=sum(loss1_one(l,ii:end)*disc_vec(1:end-ii+1));
      end
   end
   %Case 2. Loss function (Infl + output gap + interest rate smoothing)
   for l=1:4
   loss2_one(l,:) = 1/2*(1*((S.(sprintf('infl_%s',MP{k}))(:,l)).^2) + 0.5*((S.(sprintf('gap_%s',MP{k}))(:,l)).^2) + 0.2*((S.(sprintf('di_%s',MP{k}))(:,l)).^2));
   S.(sprintf('loss2_%s',MP{k}))(1,l)=sum(loss2_one(l,:)*disc_vec);
      for ii=2:n
          S.(sprintf('loss2_%s',MP{k}))(ii,l)=sum(loss2_one(l,ii:end)*disc_vec(1:end-ii+1));
      end
   end
   %Case 3. Loss function (Welfare) 
   for l=1:4
   loss3_one(l,:) = -(log(exp(C_ss))-chi*exp(L_ss).^(1+varphi)./(1+varphi) -  (log(exp(C_ss+S.(sprintf('C_%s',MP{k}))(:,l)))-chi*(exp(L_ss+S.(sprintf('L_%s',MP{k}))(:,l))).^(1+varphi)./(1+varphi)));
   S.(sprintf('loss3_%s',MP{k}))(1,l)=sum(loss3_one(l,:)*disc_vec);
      for ii=2:n
          S.(sprintf('loss3_%s',MP{k}))(ii,l)=sum(loss3_one(l,ii:end)*disc_vec(1:end-ii+1));
      end
   end
   
   % Experimenting
   for l=1:4
   loss4_one(l,:) = -(log(exp(S.(sprintf('C_%s',MP{k}))(:,l))/exp(C_ss))-chi*((exp(S.(sprintf('L_%s',MP{k}))(:,l))/exp(L_ss)-exp(L_ss)).^(1+varphi))./(1+varphi)).^2;
   S.(sprintf('loss4_%s',MP{k}))(1,l)=sum(loss4_one(l,:)*disc_vec);
      for ii=2:n
          S.(sprintf('loss4_%s',MP{k}))(ii,l)=sum(loss4_one(l,ii:end)*disc_vec(1:end-ii+1));
      end
   end
   
end

%% Plotting (Response to a negative TFP shock) (change the column number at the end of line 98 for other shock responses)

% Inflation, Interest rate, Loss, Real GDP (Choose which 'loss' function is
%                                           specified for ramsey, (i.e. loss1, loss2 etc)
VARS = {'K','i','loss3','Y'};
warning('off');
x = (1:1:20);

Y=zeros(length(x),length(MP),length(VARS));
formatSpec = '%s_%s';
for k=1:length(MP)
    for j=1:length(VARS)
        Y(:,k,j)= S.(sprintf(formatSpec,char(VARS(j)),MP{k}))(1:20,1);
    end
end

createfigure2x2(x,100*Y(:,:,1),400*Y(:,:,2),100*Y(:,:,3),100*Y(:,:,4));

% Legend
Legend=cell(length(MP),1);
for iter=1:length(labels)
    Legend{iter}=strcat(labels{iter});
end

[h,icons] = legend(Legend,'FontSize',10,'Orientation','horizontal');
set(icons(:),'LineWidth',2);
set(h,'Position',[0.322, 0.034, 0.376, 0.05]);
