Xmipp  v3.23.11-Nereus
Public Member Functions | List of all members
KMeans Class Reference

Public Member Functions

 KMeans (int K, int total_points, int total_values, int maxIterations, int maxObjects)
 
std::vector< Clusterrun (std::vector< KPoint > &points, FileName fnClusters, FileName fnPoints)
 

Detailed Description

Definition at line 193 of file classify_kmeans_2d.cpp.

Constructor & Destructor Documentation

◆ KMeans()

KMeans::KMeans ( int  K,
int  total_points,
int  total_values,
int  maxIterations,
int  maxObjects 
)
inline

Definition at line 234 of file classify_kmeans_2d.cpp.

236  {
237  this->K = K;
238  this->total_points = total_points;
239  this->total_values = total_values;
240  this->maxIterations = maxIterations;
241  this->maxObjects = maxObjects;
242  }
constexpr int K

Member Function Documentation

◆ run()

std::vector<Cluster> KMeans::run ( std::vector< KPoint > &  points,
FileName  fnClusters,
FileName  fnPoints 
)
inline

Definition at line 244 of file classify_kmeans_2d.cpp.

246  {
247  // create clusters and choose K points as their centers centers
248  std::vector<int> prohibited_indexes;
249  std::random_device rd; // Will be used to obtain a seed for the random number engine
250  std::mt19937 gen(rd()); // Standard mersenne_twister_engine seeded with rd()
251  std::uniform_int_distribution<> udistr(0,total_points-1);
252 
253  if (total_points == 0)
254  {
255  std::ostringstream msg;
256  msg << "Division by zero: total_points == 0";
257  throw std::runtime_error(msg.str());
258  }
259 
260  for (int i = 0; i < K; i++)
261  {
262  while (true)
263  {
264  int index_point = udistr(gen);
265 
266  if (find(prohibited_indexes.begin(), prohibited_indexes.end(),
267  index_point) == prohibited_indexes.end())
268  {
269  prohibited_indexes.push_back(index_point);
270  points[index_point].setCluster(i);
271  Cluster cluster(i, points[index_point]);
272  clusters.push_back(cluster);
273  break;
274  }
275  }
276  }
277 
278  // if clusters already exists, load their computed centers
279  std::fstream savedClusters(fnClusters.c_str());
280  if (savedClusters.good())
281  {
282  std::string line;
283  for (int i = 0; i < K; i++)
284  {
285  std::getline(savedClusters, line);
286  std::stringstream ss(line);
287  double point_value;
288  for (int j = 0; j < total_values; j++)
289  {
290  ss >> point_value;
291  clusters[i].setCentralValue(j, point_value);
292  }
293  }
294  }
295 
296  int iter = 1;
297  while(true)
298  {
299  bool done = true;
300 
301  // when total_points is larger than threshold value maxObjects
302  // we only assign labels to point based on nearest center
303  if ((maxObjects != -1) && (total_points > maxObjects))
304  {
305  for (int i = 0; i < total_points; i++)
306  {
307  int nearest_center = getIDNearestCenter(points[i]);
308  points[i].setCluster(nearest_center);
309  clusters[nearest_center].addPoint(points[i]);
310  }
311  }
312  else // perform k-means clustering
313  {
314  // associates each point to the nearest center
315  for (int i = 0; i < total_points; i++)
316  {
317  int old_cluster = points[i].getCluster();
318  int nearest_center = getIDNearestCenter(points[i]);
319 
320  if (old_cluster != nearest_center)
321  {
322  if (old_cluster != -1)
323  clusters[old_cluster].removePoint(points[i].getID());
324 
325  points[i].setCluster(nearest_center);
326  clusters[nearest_center].addPoint(points[i]);
327  done = false;
328  }
329  }
330 
331  // recalculating the center of each cluster
332  for (int i = 0; i < K; i++)
333  {
334  for (int j = 0; j < total_values; j++)
335  {
336  int total_points = clusters[i].getTotalPoints();
337  double sum = 0.0;
338 
339  if (total_points > 0)
340  {
341  for (int p = 0; p < total_points; p++)
342  sum += clusters[i].getPoint(p).getValue(j);
343 
344  clusters[i].setCentralValue(j, sum / total_points);
345  }
346  }
347  }
348  }
349  if (done == true || iter >= maxIterations) break;
350  iter++;
351  }
352 
353  // This code is removing outliers, whose distance from centroid is
354  // 1.5*stddev, its efficiency depends strongly on feature extraction
355  /*
356  double dist, sum, stddev;
357  std::vector<double> cluster_point_dist;
358  for (int i = 0; i < K; i++)
359  {
360  dist = 0.0;
361  int points_orig_total = clusters[i].getTotalPoints();
362 
363  for (int p = 0; p < points_orig_total; p++)
364  {
365  sum = 0.0;
366  for (int j = 0; j < total_values; j++)
367  sum += pow(clusters[i].getCentralValue(j) -
368  clusters[i].getPoint(p).getValue(j), 2.0);
369 
370  cluster_point_dist.push_back(sqrt(sum));
371  dist += sqrt(sum) / points_orig_total;
372  }
373 
374  for (int p = 0; p < points_orig_total; p++)
375  stddev += pow(cluster_point_dist[p] - dist, 2.0);
376 
377  stddev = sqrt(stddev / points_orig_total);
378 
379  int pp = 0;
380  for (int p = 0; p < points_orig_total; p++)
381  {
382  // Swich this condition for taking only eliminated particles
383  if ((cluster_point_dist[p] > (dist + 1.5*stddev)) ||
384  (cluster_point_dist[p] < (dist - 1.5*stddev)))
385  clusters[i].removePoint(clusters[i].getPoint(pp).getID());
386  else pp++;
387  }
388  }
389  */
390 
391  std::ofstream saveData;
392 
393  // saving clusters
394  saveData.open(fnClusters.c_str());
395  for (int i = 0; i < K; i++)
396  {
397  for (int j = 0; j < total_values; j++)
398  saveData << clusters[i].getCentralValue(j) << " ";
399  saveData << std::endl;
400  }
401  saveData.close();
402 
403  // saving points
404  saveData.open(fnPoints.c_str());
405  for (int i = 0; i < total_points; i++)
406  {
407  for (int j = 0; j < total_values; j++)
408  saveData << points[i].getValue(j) << " ";
409  saveData << std::endl;
410  }
411  saveData.close();
412 
413  return clusters;
414  }
std::vector< SelLine >::iterator find(std::vector< SelLine > &text, const std::string &img_name)
Definition: selfile.cpp:553
glob_prnt iter
#define i
#define j
constexpr int K

The documentation for this class was generated from the following file: