This repository has been archived on 2021-11-25. You can view files and clone it, but cannot push or open issues or pull requests.
K-Means/K-MeansPlus.cpp
2017-03-06 18:22:41 +08:00

180 lines
3.6 KiB
C++

#include <iostream>
#include <vector>
#include <algorithm>
#include <string>
using namespace std;
struct Point
{
vector<double> val;
int id;
string tag;
};
int N,Attr,K,MaxIteration;
vector<Point> vec;
vector<Point> center;
vector<vector<int>> Clusters;
double getDistance(Point& a,Point& b)
{
if(a.val.size()!=b.val.size()) return -1;
double dis=0;
size_t sz=a.val.size();
for(size_t i=0;i<sz;i++)
{
dis+=pow(a.val.at(i)-b.val.at(i),2);
}
dis=sqrt(dis);
return dis;
}
bool ClusterCmp(vector<int>& a,vector<int>& b)
{
double sumA=0,sumB=0;
for(auto idx:a) ///类群中的各个点
{
for(auto idPos:vec.at(idx).val) /// 点的每个维度
{
sumA+=idPos;
}
}
for(auto idx:b)
{
for(auto idPos:vec.at(idx).val)
{
sumB+=idPos;
}
}
return sumA<sumB;
}
void KMeans()
{
for(int IterTime=0;IterTime<MaxIteration;IterTime++)
{
vector<vector<int>> newCluster(center.size());
for(size_t i=0;i<vec.size();i++)
{
int minID=-1;
double minDistance=INT_MAX;
for(size_t j=0;j<center.size();j++)
{
double dis=getDistance(vec.at(i),center.at(j));
if(minDistance>dis)
{
minID=j;
minDistance=dis;
}
}
newCluster.at(minID).push_back(i);
}
sort(newCluster.begin(),newCluster.end(),ClusterCmp);
#ifdef _OUTPUT_ITER
cout<<"===============Iteration "<<IterTime<<"================="<<endl;
for(size_t i=0; i<newCluster.size(); i++)
{
cout<<"Cluster "<<i<<":"<<endl;
for(size_t j=0; j<newCluster.at(i).size(); j++)
{
cout<<"Point "<<vec.at(newCluster.at(i).at(j)).id<<" Tag="<<vec.at(newCluster.at(i).at(j)).tag<<endl;
}
}
#endif
if(newCluster==Clusters)
{
break;
}
else
{
Clusters=newCluster;
for(size_t i=0;i<Clusters.size();i++)/// i 类群
{
for(int j=0;j<Attr;j++) /// j 点的维度
{
double sum=0;
for(size_t k=0;k<Clusters.at(i).size();k++) /// k 类群中的点
{
sum+=vec.at(Clusters.at(i).at(k)).val.at(j);
}
center.at(i).val.at(j)=sum/Clusters.at(i).size();
}
}
}
}
cout<<"End"<<endl;
for(size_t i=0;i<Clusters.size();i++)
{
cout<<"Cluster "<<i<<":"<<endl;
for(size_t j=0;j<Clusters.at(i).size();j++)
{
cout<<"Point "<<vec.at(Clusters.at(i).at(j)).id<<" Tag="<<vec.at(Clusters.at(i).at(j)).tag<<endl;
}
}
}
int main()
{
cout<<"Please Input N,Attr,K,MaxIteration"<<endl;
cin>>N>>Attr>>K>>MaxIteration;
for(int i=0;i<N;i++)
{
Point p;
for(int j=0;j<Attr;j++)
{
double tmp;
cin>>tmp;
p.val.push_back(tmp);
}
cin>>p.tag;
p.id=i;
vec.push_back(p);
}
cout<<"Please Select Seed From Input..."<<endl;
for(int i=0;i<K;i++)
{
int temp;
cin>>temp;
center.push_back(vec.at(temp));
}
KMeans();
return 0;
}
/**
15 3 3 100
1 1 0.5 中国
0.3 0 0.19 日本
0 0.15 0.13 韩国
0.24 0.76 0.25 伊朗
0.3 0.76 0.06 沙特
1 1 0 伊拉克
1 0.76 0.5 卡塔尔
1 0.76 0.5 阿联酋
0.7 0.76 0.25 乌兹别克斯坦
1 1 0.5 泰国
1 1 0.25 越南
1 1 0.5 阿曼
0.7 0.76 0.5 巴林
0.7 0.68 1 朝鲜
1 1 0.5 印尼
1 9 12
*/