% Agglomerative clustering with constraints demo
% the demo asks for manual input of 45 points via ginput and then graphs
% each step of the algorithm broken down by key presses
%inputs: datain.X = 2*N matrix containing x,y co-ords of points for
%           classification
%        datain.constraints = 1*N cell of arrays containing any number of
%           constraints for each of the N points.  Each constraint in each
%           array is simply the index of the other point in X.
%        datain.link = link type,i.e. how the distance between clusters is
%           calculated currently only 'complete' is used
%        datain.thresh = threshold for minimum distance for merging, when
%           the minimum distance becomes greater than this the alg stops.
%
%outputs: model.clusters = 1*M cells of clusters
%         clusters{1,m}.X =  2*p coords of points in cluster
%         clusters{1,m}.indices = 1*p original indices of points
%         clusters{1,m}.constraints = remaining constraints with remaining
%            clusters
%
%e.g.: >> agglomerativeclusteringwconstdemo
%
%datain = 
%
%              X: [2x45 double]
%           link: 'complete'
%    constraints: {1x45 cell}
%         thresh: 2.2500
%
%model = 
%
%    clusters: {[1x1 struct]  [1x1 struct]  [1x1 struct]  [1x1 struct]}
%
%>> model.clusters{1,1}
%
%ans = 
%
%              X: [2x15 double]
%        indices: [1 6 9 2 12 15 10 3 8 11 14 5 4 7 13]
%    constraints: [3 2]

%
%  Written by Mark Austin on 18/11/2004

function agglomerativeclusteringwconstdemo

datain.X=ginput(45)';%[repmat([1,2],15,1)+0.3*randn(15,2);repmat([2,1],15,1)+0.3*randn(15,2);repmat([1,1],15,1)+0.3*randn(15,2)]';
datain.link='complete';
for i=1:size(datain.X,2)
datain.constraints{1,i}=[];
end
datain.constraints{1,15}=[16];
datain.constraints{1,16}=[15];
datain.constraints{1,1}=[31];
datain.constraints{1,31}=[1];
datain.constraints{1,45}=[30];
datain.constraints{1,30}=[45];
datain.thresh=1.5^2;

for i=1:size(datain.X,2)
    clusters{1,i}.X=datain.X(:,i);
    clusters{1,i}.indices=i;
    clusters{1,i}.constraints=datain.constraints{1,i};
end
%calculate differences initially
dis=NaN*ones(numel(clusters),numel(clusters));
for j=1:numel(clusters)-1
    for k=j+1:numel(clusters)               
        dis(j,k)=sum((clusters{1,j}.X(1:end)-clusters{1,k}.X(1:end)).^2);
        dis(k,j)=dis(j,k);
    end
end

%impose cannot link constraints


maxd=max(max(dis))+.01;
for i=1:numel(clusters)
    for j=1:numel(datain.constraints{1,i})
        if clusters{1,i}.constraints(j)<=size(datain.X,2)
            dis(i,datain.constraints{1,i}(j))=maxd;
        end
    end
end

figure
run=1;
i=0;
while run==1&&numel(clusters)>2
    i=i+1;
    %find x and y for merging

    [c y]=min(dis);
    [c x]=min(c);
    y=y(x);
    if y>x
    else
        temp=y;
        y=x;
        x=temp; 
    end
    [x y];
    if c<datain.thresh
        %merge x and y into a single cluster
        newclusters=[];
        for m=1:numel(clusters)
            if m==y           
                    newclusters{1,x}.X=[clusters{1,x}.X clusters{1,y}.X];
                    newclusters{1,x}.indices=[newclusters{1,x}.indices clusters{1,y}.indices];
                    %propegate constraints
                    newclusters{1,x}.constraints=[clusters{1,x}.constraints clusters{1,y}.constraints];
                    
                    %reform differences matrix by removing y row and
                    %column 
                    arr=meshgrid(1:size(dis,2),1);
                    dis=dis(find(arr~=y),find(arr~=y));

            else
                newclusters{1,numel(newclusters)+1}=clusters{1,m};
            end
            
        end
        %correct constraints from resize
        for j=1:numel(newclusters)
            for i=1:numel(newclusters{1,j}.constraints) 
                if newclusters{1,j}.constraints(i)==y
                   newclusters{1,j}.constraints(i)=x;
                end
                if newclusters{1,j}.constraints(i)>y
                   newclusters{1,j}.constraints(i)=newclusters{1,j}.constraints(i)-1;
                end
                

            end
        end
        clusters=newclusters;
        %finish reforming dis by reforming x column and row
        dis=reformdis(dis,datain.link,clusters,x);
        


    else
        run=0;
    end
    for d=1:numel(clusters)
        if d==x
            plot(datain.X(1,clusters{1,d}.indices),datain.X(2,clusters{1,d}.indices),'o');
        else
            plot(datain.X(1,clusters{1,d}.indices),datain.X(2,clusters{1,d}.indices),'x');
        end
        hold all
    end
    for w=1:numel(datain.constraints)
            for z=1:numel(datain.constraints{1,w})
                x1=datain.X(1,datain.constraints{1,w}(z));
                y1=datain.X(2,datain.constraints{1,w}(z));
                x2=datain.X(1,w);
                y2=datain.X(2,w);
                m = (y2-y1)/(x2-x1);
                b = y1 - m*x1;
                if x2>x1
                    t = x1:0.01:x2;
                else
                    t = x2:0.01:x1;
                end
                y = m*t+b;
                plot(t,y,'b:')
            end
    end
    hold off

    
    axis([0 3 0 3]);
    pause
end

model.clusters=clusters;
return

%%%%%%%%%%%%%%%%%%%
function dis=reformdis(dis,link,clusters,x)

if strcmp(link,'complete')
    for j=1:numel(clusters)
            if j~=x
                dis2=NaN(size(clusters{1,j}.X,2),size(clusters{1,x}.X,2));
                for g=1:size(clusters{1,j}.X,2)
                    for h=1:size(clusters{1,x}.X,2)
                        dis2(g,h)=sum((clusters{1,j}.X(:,g)-clusters{1,x}.X(:,h)).^2);
                    end
                end            
                dis(j,x)=max(max(dis2));
                dis(x,j)=dis(j,x);
            end
            
    end
    maxd=max(max(dis))+.01;
    for i=1:numel(clusters)
        for j=1:numel(clusters{1,i}.constraints)
            if clusters{1,i}.constraints(j)<=size(dis,2)
                dis(i,clusters{1,i}.constraints(j))=maxd;
            end
        end
    end
end
