%---------------------------------------------------------
% This is one of the main files to perform GIST.
% All GIST functions and scripts written and maintained by Y. She
%---------------------------------------------------------
% TISP for spectral analysis
clear all; clc;

DoMedPlot = 1;
DoIdenPlot = 1;
DoSave = 1;
pathinit = '.\res\'
%---------------------------------------------------------
% Sparse signal setup
initN = 100;  % number of samples for training
initD = 500;  % number of frequecy bins x2
sigmasqGrid = [1, 4, 8] % noise variance

trueFreqs = [0.248 0.25 0.252  0.398 0.4];
trueAmps = [2 4 3 3.5 3];
truePhis = [pi/4 pi/6 pi/3 pi/5 pi/2];

id = trueFreqs*initD;
id = [id,initD/2+id];


fs = 1; % Hz, sampling frequency
family = 'gaussian';
valN = 2000; tstN = 2000; %sample size for validation and testing

% method:
thresholdingWay = 'hybrid' % 'soft' 'hard' etc.
grouped = true  %false %  %%%grouped form or not

% Parameters for Tuning
tuningMethod = 'selCV' %  'validation' %
    % For hybrid TISP only:
    hybridTuningWay = 'simple'; % '3+1'; 
        % Note: In this case (Hybrid-TISP), there are two regularization
        %   parameters; nicely, the ridge parameter is usually not a
        %   sensitive one. For example, we can use a small value for eta 
        %   ('simple') which is fast and effective.
        %   However, if time allows, recommend a better tuning (such as '3+1')
        %   which results in much better prediction performance!!
    
    % The following params are only used for selective cross-validation
    scvbiascorrt0 = 'bic'  %'aic' 
    scvbiascorrt1 = 'bic'  
    matchLocEst0 = 'naive';
    matchLocEst1 = 'df'; 
    unpenCorr = false; 


% Parameters for Screening
nzExpected = 0.25* initN; %  % Take advantage of the sparsity of the problem for dimension reduction and efficient computation
screening = 'tisp'; %'none'; %


output = false; %true
times = 50; % number of replicates in performing the experiment

outputfilename = sprintf('Res_%s.txt', datestr(clock,30))

diary(fullfile(pathinit,outputfilename));
diary off
for sigmasq = sigmasqGrid    

    allStartTime = clock;
    global totalTispTime;
    totalTispTime=0;
    t_total = 0;
    %-------------------------------------------------------
    % Spectral Estimation Experiment
    disp('..................................................');
    disp('         TISP for spectral analysis               ');
    disp('..................................................');
    disp(['thresholdingWay = ' thresholdingWay]);
    disp(['sigmasq = ' num2str(sigmasq)]);

    randn('state', 0); rand('state', 0);
    optBetas = [];  % record the beta etimate of each run
    optBetas_rawscale = []; % record the beta etimate (raw scale, before X standardization) of each run
    optAmps = []; % record the amplitude estimate of each frequency for all runs
    % predScreened_Mat = [];
    detectN = 0; % record correctly detected atoms via screening
    modelErrs2= -ones(times,1); % prediction error (plain-mse - sigmasq)

    for timeInd = 1:times
        disp(['timeInd = ', num2str(timeInd)]);
        centerX = 1; scaleX = 1; centerY = 1; % X and Y are both centered and scaled
        % This is proper for a Gaussian model and can make the intercept (zero frequency response) vanish

        % Generate Data
        [alphaTrue,betaTrue,X,valX,tstX,trnCenters,trnScales,y,valY,tstY,trnCenterY,d,D,n,fd,newF,grps,betaTrue_rawscale] = ...
            generateSpecData(initN, valN, tstN, initD, trueFreqs, trueAmps, truePhis, sigmasq, fs, family,...
            centerX, scaleX, centerY);

        if ~grouped, grps = []; end

        t1 = tic;

        % The following split is to be used for selective cross-validation tuning ('selCV')
        nCV = 5;
        minTh0 = 0;
        [nCV, dataIndsCV, dataIndsCVStarts, dataIndsCVEnds] = CVSplit(size(X, 1), nCV);

        if ~strcmp(thresholdingWay, 'hybrid') & ~strcmp(thresholdingWay, 'hybrid-prop')
            Nu = 1e-8; % 0;  % usually better than 0 (helpful in decorrelation)
        else
            Nu = [];
        end
        % Parameters for TISP
        intercptchoice = 'none'; % no intercept thanks to the centering step
        nPoints = 25;  % number of different lambdas
        maxIT = 2e+2;  % max iterations this is Omega in paper, default 1e3
        errBnd = 1e-5; % error tolerance
        nzUBnd = n/2; % upper bound of the number of nonzero predictors % If the current estimate has more nonzero components, stop descreasing lambda
        updating = 'synchronous'; % Use the synchronous form
        relaxWay = 1;  % Use the relaxation form

        % Solve TISP
        if strcmp(thresholdingWay, 'hybrid') | strcmp(thresholdingWay, 'hybrid-prop')
            % if hybird hard-ridge
            run hybridTISP_Tuned_singledataset;
        else
            % if soft or others
            run TISPPath_GLM; % compute the solution path
            scvbiascorrt = scvbiascorrt1; % which can be 'aic', 'bic', or []
            matchLocEst = matchLocEst1;
            run TISPTuning_GLM; % find the opt param
        end
        t_each = toc(t1);
        t_total = t_total + t_each;

        % Record result
        optBetas = [optBetas, betaOpt];
        optBetas_rawscale = [optBetas_rawscale, betaOpt./trnScales'];
    %     predScreened_Mat = [predScreened_Mat predScreened(:)];
    %     detectN = detectN + sum(ismember(id,predScreened));

        % Compute the prediction error (MSE*) on the test dataset
        modelErr2 = 2*(GLM_negLogLik(tstX,tstY,betaOpt,family,alphaOpt*ones(size(tstX,1),1))/size(tstY,1))' - sigmasq;
        modelErrs2(timeInd) = modelErr2;
    end

    %---------------------------------------------------
    % Summarize the results
    % optBetas_sc = optBetas./repmat(trnScales',1,times); % scale the beta back
    optBetas_sc = optBetas_rawscale;
    tmp1 = (optBetas_sc(1:initD/2-1,:)).^2 + (optBetas_sc(initD/2+1:end,:)).^2;
    optAmps = [tmp1.^.5;optBetas_sc(initD/2,:)]; % amplitudes of freqs for all runs
    % betaTrue_sc = betaTrue./trnScales'; % scale the true beta back
    betaTrue_sc = betaTrue_rawscale;
    tmp2 = (betaTrue_sc(1:initD/2-1)).^2 + (betaTrue_sc(initD/2+1:end)).^2;
    ampTrue = [tmp2.^.5;betaTrue_sc(initD/2)]; % true amplitudes of freqs
    % [pSpa, spaErr, pNz, sucIden] = SparStat(optBetas, betaTrue, eye(size(optBetas,1)));
    [pSpa, spaErr, pNz, sucIden] = SparStat(optAmps, ampTrue, eye(size(optAmps,1)));
    summres2 = mean([pSpa', spaErr', pNz'], 1)';
    res = [median(modelErrs2), 100-summres2(3,:)'*100, 100-summres2(1,:)'*100, sucIden*100];

    
    totalTime=etime(clock,allStartTime);
    t_mean = t_total/times;
    diary on;
    fprintf('N:%d,D:%d,sigmasq:%d\n',initN,initD,sigmasq);
    disp(['Test error and Sparsity recovery', ' (method: ', thresholdingWay, ', tuning: ', tuningMethod, ', grouped=', num2str(grouped), ', sigma^2=', num2str(sigmasq), ')']);
    disp('pred-ERR, Miss, FA, JD');
    disp(num2str(res));
    
    disp(['Total time: ', num2str(t_total)]);
    disp(['Average time: ' num2str(t_mean)]);
    disp(['Tisp time: ' num2str(totalTispTime)]);
    disp(['Average Tisp time: ' num2str(totalTispTime/times)]);
    disp(['Percent: ' num2str(totalTispTime/t_total)]);    
    diary off;
    
    % Summarize screening
    % missProb = 100*(1-detectN/length(id)/times);
    % fprintf('N:%d,D:%d,sigmasq:%d,nzBnd:%d\n',initN,initD,sigmasq,nzExpected);
    % disp(['Screening', ' (method: ', thresholdingWay, ', tuning: ', tuningMethod, ', grouped=', num2str(grouped), ', sigma^2=', num2str(sigmasq), ')'])
    % disp('Miss')
    % disp(num2str(missProb))

    %-------------------------------------------------------------------
    % Plot the PSD
    % TISP result

    ampMd = median(optAmps,2);
    ampProb = 100*sum(optAmps>1e-5,2)/times;
    % betaSummarized = median(optBetas, 2); % We can plot the median estimate from all the runs
    % betaSummarized = betaSummarized./trnScales';
    % SP_est = betaSummarized.^2;
    % SP_est = SP_est(1:D/2)+[SP_est(D/2+1:end); 0];
    % Amp_est = sqrt(SP_est);
    % % True PSD
    % SP_true = (trueAmps.^2); % why do we have a factor of 2 there?


    % Plot
    if DoMedPlot
        %%%%%%%%%%%%%%%%%%%% median plot
        H=figure; hold on;
        stem(fd, ampMd,'b-','Marker','none','LineWidth',1.5);
        %stem(trueFreqs, trueAmps,'r-.*')
        v = axis; axis([0.23,0.421,v(3),v(4)]);
        stem(trueFreqs, (v(4) + 0) .* (trueAmps~=0),'r:', 'marker', 'none','LineWidth',1)
        stem(fd, ampMd,'b-','Marker','none','LineWidth',1.5);
        xlabel('Freqence (Hz)','fonts',12);
        ylabel('Amplitude','fonts',12);
        if abs(sigmasq-10^.5) >1e-2
            title(sprintf('GIST, N=%d, 2D=%d, $\\sigma^{2}$=%d, average runTime:%0.2fs',initN,initD,sigmasq,t_mean),'interpreter','latex','fonts',14);
        else
            title(sprintf('GIST, N=%d, 2D=%d, $\\sigma^{2}=\\sqrt{10}$, average runTime:%0.2fs',initN,initD,t_mean),'interpreter','latex','fonts',14);
        end
        hold off;

        if DoSave
            % save median plot
            savePath = sprintf('GIST_N_%d_2D_%d_var_%d',initN,initD,floor(sigmasq));
            savePath = fullfile(pathinit,savePath);
            saveas(H,savePath,'fig');
            saveas(H,savePath,'png');
            saveas(H,savePath,'epsc');
    %         % save identification plot
    %         savePath2 = sprintf('GIST_N_%d_2D_%d_var_%d_iden',initN,initD,floor(sigmasq));
    %         savePath2 = fullfile(pathinit,savePath2);
    %         saveas(H2,savePath2,'fig');
    %         saveas(H2,savePath2,'png');
    %         saveas(H2,savePath2,'epsc');
            close all;
        end

    end
    if DoIdenPlot
        %%%%%%%%%%%%%%%%%%%% identification plot
        H2=figure; hold on;
        %v = axis; %axis([0.2,0.45,v(3),v(4)]);
        axis([0.23,0.421,0,100]);
    %     stem(trueFreqs, (v(4) + 0) .* (trueAmps~=0),'r-.', 'marker', 'none','LineWidth',1)
        stem(trueFreqs, (100 + 0) .* (trueAmps~=0),'r:', 'marker', '*','LineWidth',1)
        stem(fd, ampProb,'b-','Marker','none','LineWidth',1.5);
        
        xlabel('Freqence (Hz)','fonts',12);
        ylabel('Identification rate (percentages)','fonts',12);
        if abs(sigmasq-10^.5) >1e-2
            title(sprintf('GIST, N=%d, 2D=%d, $\\sigma^{2}$=%d, average runTime:%0.2fs',initN,initD,sigmasq,t_mean),'interpreter','latex','fonts',14);
        else
            title(sprintf('GIST, N=%d, 2D=%d, $\\sigma^{2}=\\sqrt{10}$, average runTime:%0.2fs',initN,initD,t_mean),'interpreter','latex','fonts',14);
        end
        hold off;
        if DoSave
    %         % save median plot
    %         savePath = sprintf('GIST_N_%d_2D_%d_var_%d',initN,initD,floor(sigmasq));
    %         savePath = fullfile(pathinit,savePath);
    %         saveas(H,savePath,'fig');
    %         saveas(H,savePath,'png');
    %         saveas(H,savePath,'epsc');
            % save identification plot
            savePath2 = sprintf('GIST_N_%d_2D_%d_var_%d_iden',initN,initD,floor(sigmasq));
            savePath2 = fullfile(pathinit,savePath2);
            saveas(H2,savePath2,'fig');
            saveas(H2,savePath2,'png');
            saveas(H2,savePath2,'epsc');
            close all;
        end
    end
end