% Agglomerative clustering with constraints
%inputs: datain.X = k*N matrix containing k vector of co-ords of points for
%           classification
%        datain.constraints = 1*N cell of arrays containing any number of
%           constraints for each of the N points.  Each constraint in each
%           array is simply the index of the other point in X.
%        datain.link = link type,i.e. how the distance between clusters is
%           calculated currently only 'complete' is used
%        datain.thresh = threshold for minimum distance for merging, when
%           the minimum distance becomes greater than this the alg stops.
%
%outputs: model.clusters = 1*M cells of clusters
%         clusters{1,m}.X =  k*p coords of points in cluster
%         clusters{1,m}.indices = 1*p original indices of points
%         clusters{1,m}.constraints = remaining constraints with remaining
%            clusters
%
%e.g.: >> model=agglomerativeclusteringwconst(datain)
%
%datain = 
%
%              X: [2x45 double]
%           link: 'complete'
%    constraints: {1x45 cell}
%         thresh: 2.2500
%
%model = 
%
%    clusters: {[1x1 struct]  [1x1 struct]  [1x1 struct]  [1x1 struct]}
%
%>> model.clusters{1,1}
%
%ans = 
%
%              X: [2x15 double]
%        indices: [1 6 9 2 12 15 10 3 8 11 14 5 4 7 13]
%    constraints: [3 2]

%
%  Written by Mark Austin on 18/11/2004

function model = agglomerativeclusteringwconst(datain)

for i=1:size(datain.X,2)
    clusters{1,i}.X=datain.X(:,i);
    clusters{1,i}.indices=i;
    clusters{1,i}.constraints=datain.constraints{1,i};
end
%calculate differences initially
dis=NaN*ones(numel(clusters),numel(clusters));
for j=1:numel(clusters)-1
    for k=j+1:numel(clusters)               
        dis(j,k)=sum((clusters{1,j}.X(1:end)-clusters{1,k}.X(1:end)).^2);
        dis(k,j)=dis(j,k);
    end
end

%impose cannot link constraints


maxd=max(max(dis))+1;
for i=1:numel(clusters)
    for j=1:numel(datain.constraints{1,i})
        if clusters{1,i}.constraints(j)<=size(datain.X,2)
            dis(i,datain.constraints{1,i}(j))=maxd;
        end
    end
end


run=1;
i=0;
while run==1&&numel(clusters)>2
    i=i+1;
    %find x and y for merging

    [c y]=min(dis);
    [c x]=min(c);
    y=y(x);
    if y>x
    else
        temp=y;
        y=x;
        x=temp; 
    end
    [x y];
    
    if c<datain.thresh
        %merge x and y into a single cluster
        newclusters=[];
        for m=1:numel(clusters)
            if m==y           
                    newclusters{1,x}.X=[clusters{1,x}.X clusters{1,y}.X];
                    newclusters{1,x}.indices=[clusters{1,x}.indices clusters{1,y}.indices];
                    %propegate constraints
                    newclusters{1,x}.constraints=[clusters{1,x}.constraints clusters{1,y}.constraints];
                    
                    %reform differences matrix by removing y row and
                    %column 
                    arr=meshgrid(1:size(dis,2),1);
                    dis=dis(find(arr~=y),find(arr~=y));

            else
                newclusters{1,numel(newclusters)+1}=clusters{1,m};
            end
            
        end
        %correct constraints from resize
        for j=1:numel(newclusters)
            for i=1:numel(newclusters{1,j}.constraints) 
                if newclusters{1,j}.constraints(i)==y
                   newclusters{1,j}.constraints(i)=x;
                end
                if newclusters{1,j}.constraints(i)>y
                   newclusters{1,j}.constraints(i)=newclusters{1,j}.constraints(i)-1;
                end
                

            end
        end
        clusters=newclusters;
        %finish reforming dis by reforming x column and row
        dis=reformdis(dis,datain.link,clusters,x);
        


    else
        run=0;
    end
end

model.clusters=clusters;
return

%%%%%%%%%%%%%%%%%%%
function dis=reformdis(dis,link,clusters,x)

if strcmp(link,'complete')
    for j=1:numel(clusters)
            if j~=x
                dis2=NaN(size(clusters{1,j}.X,2),size(clusters{1,x}.X,2));
                for g=1:size(clusters{1,j}.X,2)
                    for h=1:size(clusters{1,x}.X,2)
                        dis2(g,h)=sum((clusters{1,j}.X(:,g)-clusters{1,x}.X(:,h)).^2);
                    end
                end            
                dis(j,x)=max(max(dis2));;
                dis(x,j)=dis(j,x);
            end
    end
    maxd=max(max(dis))+1;
    for i=1:numel(clusters)
        for j=1:numel(clusters{1,i}.constraints)
            if clusters{1,i}.constraints(j)<=size(dis,2)
                dis(i,clusters{1,i}.constraints(j))=maxd;
            end
        end
    end
end
