%%%%%%%%%%%%%%%%%%%% GLM Optimization fuction %%%%%%%%%%%%%%%%%%%%%
function [beta, rounds]=GLM_opt(X, y, lambda, family, alpha, tol) 
% This function implements Newton's algorithm for solving the l2-penalized GLM
% If lambda=0, it yields the MLE.
warning off
if nargin == 4
    alpha = zeros(size(X, 1), 1);
    tol=1e-5;
end
if nargin == 5
  %the required tolerance
  tol=1e-5;
end

[n, d] = size(X);
y = reshape(y, n, 1);
if isempty(alpha)
    alpha = 0 * ones(n, 1);
elseif numel(alpha) == 1
    alpha = alpha * ones(n, 1);
else
    alpha = reshape(alpha, n, 1);
end

if numel(lambda) == 1
    Lambda = lambda * ones(d, 1);
else
    Lambda = reshape(lambda, d, 1); % could be componentwise
end


%the loglikelihood function
pLL = @(beta)( GLM_negLogLik(X, y, beta, family, alpha) + 1/2 * sum(Lambda.*(beta.^2)));
betaCur = zeros(d, 1);
%the error at each step
Err = 1;
mu = [];
rounds = 0;
curLL = pLL(betaCur);

C2 = diag(Lambda);
while rounds < 5000 && Err > tol && curLL < 1e+9 
    %sometimes some beta reaches inf, thus making neg LL blown upe; stop the iter         
    
	mu = GLM_mean(X, betaCur, family, alpha); % 1./(1+exp(-X*betaCur)); 
    
    switch lower(family)
       case {'binomial'}
          A = diag(mu.*(1-mu));
       case 'poisson'
          A = diag(mu);
       case 'gaussian'
          A = eye(size(mu, 1));
       otherwise
          error('not implemented yet')
    end

    C1 = X' * A * X;
    B = C1 + C2;

    wGrad = B \ (X'*(y-mu)-Lambda.*betaCur);
    if any(isnan(wGrad)) || any(isinf(wGrad))
        break;
    end
    betaNew = betaCur + wGrad;
% 
%     z = X * betaCur + C1\(y-mu);
%     betaNew = B \ (X'*A*z);
%     
%     betaNew = B \ (C1 * betaCur + X'*(y-mu));
    
    newLL = pLL(betaNew);
    Err = abs(curLL-newLL)/(tol+abs(curLL)); %  max(abs(betaCur-betaNew))/(tol+max(abs(betaCur)));  %
                            % Remakr: curLL might be very close to 0 in perfect distrimination

    betaCur = betaNew;
    curLL = newLL;
    rounds = rounds + 1;
end
% sprintf(['rounds: ', num2str(rounds), '; rel err: ', num2str(Err), ', neg loglik: ', num2str(curLL)])
beta = betaCur;

warning on
