Xmipp  v3.23.11-Nereus
classify_kmeans_2d.cpp
Go to the documentation of this file.
1 /***************************************************************************
2  *
3  * Authors: Tomas Majtner tmajtner@cnb.csic.es (2017)
4  *
5  * Unidad de Bioinformatica of Centro Nacional de Biotecnologia , CSIC
6  *
7  * This program is free software; you can redistribute it and/or modify
8  * it under the terms of the GNU General Public License as published by
9  * the Free Software Foundation; either version 2 of the License, or
10  * (at your option) any later version.
11  *
12  * This program is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15  * GNU General Public License for more details.
16  *
17  * You should have received a copy of the GNU General Public License
18  * along with this program; if not, write to the Free Software
19  * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
20  * 02111-1307 USA
21  *
22  * All comments concerning this program package may be sent to the
23  * e-mail address 'xmipp@cnb.csic.es'
24  ***************************************************************************/
25 
26 #include <algorithm>
27 #include <fstream>
28 #include <random>
29 #include "data/filters.h"
30 #include "core/metadata_vec.h"
31 #include "core/xmipp_image.h"
32 #include "classify_kmeans_2d.h"
34 
35 // Read arguments ==========================================================
37 {
38  fnSel = getParam("-i");
39  fnOut = getParam("-o");
40  K = getIntParam("-k");
41  fnClusters = getParam("-c");
42  fnPoints = getParam("-p");
43  maxObjects = getIntParam("-m");
44 }
45 
46 // Show ====================================================================
48 {
49  if (verbose==0)
50  return;
51  std::cerr
52  << "Input selfile: " << fnSel << std::endl
53  << "Output selfile: " << fnOut << std::endl
54  << "Number of clusters: " << K << std::endl
55  << "Filename with clusters: " << fnClusters << std::endl
56  << "Filename with points: " << fnPoints << std::endl
57  << "Threshold for number of particles: " << maxObjects << std::endl
58  ;
59 }
60 
61 // Usage ===================================================================
63 {
64  addUsageLine("Clusters a set of images");
65  addParamsLine(" -i <selfile> : Selfile containing images to be clustered");
66  addParamsLine(" [-o <image=\"output.xmd\">] : Output selfile");
67  addParamsLine(" -k <int> : Number of clusters");
68  addParamsLine(" [-c <image=\"clusters.xmd\">] : Filename with clusters");
69  addParamsLine(" [-p <image=\"points.xmd\">] : Filename with points");
70  addParamsLine(" [-m <int=\"-1\">] : Threshold for number of particles after which the position of clusters will be fixed");
71 }
72 
73 
74 class KPoint
75 {
76 private:
77  int id_point;
78  int id_cluster;
79  std::vector<double> values;
80  int total_values;
81 
82 public:
83  KPoint(int id_point, std::vector<double>& values)
84  {
85  this->id_point = id_point;
86  total_values = values.size();
87 
88  for (int i = 0; i < total_values; i++)
89  this->values.push_back(values[i]);
90 
91  id_cluster = -1;
92  }
93 
94  int getID()
95  {
96  return id_point;
97  }
98 
99  void setCluster(int id_cluster)
100  {
101  this->id_cluster = id_cluster;
102  }
103 
105  {
106  return id_cluster;
107  }
108 
109  double getValue(int index)
110  {
111  return values[index];
112  }
113 
115  {
116  return total_values;
117  }
118 
119  void addValue(double value)
120  {
121  values.push_back(value);
122  }
123 };
124 
125 
126 class Cluster
127 {
128 private:
129  int id_cluster;
130  std::vector<double> central_values;
131  std::vector<KPoint> points;
132 
133 public:
134  Cluster(int id_cluster, KPoint point)
135  {
136  this->id_cluster = id_cluster;
137 
138  int total_values = point.getTotalValues();
139 
140  for (int i = 0; i < total_values; i++)
141  central_values.push_back(point.getValue(i));
142 
143  points.push_back(point);
144  }
145 
147  {
148  points.push_back(point);
149  }
150 
151  bool removePoint(int id_point)
152  {
153  int total_points = points.size();
154 
155  for (int i = 0; i < total_points; i++)
156  {
157  if(points[i].getID() == id_point)
158  {
159  points.erase(points.begin() + i);
160  return true;
161  }
162  }
163  return false;
164  }
165 
166  double getCentralValue(int index)
167  {
168  return central_values[index];
169  }
170 
171  void setCentralValue(int index, double value)
172  {
173  central_values[index] = value;
174  }
175 
177  {
178  return points[index];
179  }
180 
182  {
183  return points.size();
184  }
185 
186  int getID()
187  {
188  return id_cluster;
189  }
190 };
191 
192 
193 class KMeans
194 {
195 private:
196  int K; // number of clusters
197  int total_values, total_points, maxIterations, maxObjects;
198  std::vector<Cluster> clusters;
199 
200  // return ID of nearest center (uses euclidean distance)
201  int getIDNearestCenter(KPoint point)
202  {
203  double sum = 0.0, min_dist;
204  int id_cluster_center = 0;
205 
206  for (int i = 0; i < total_values; i++)
207  sum += ((clusters[0].getCentralValue(i) - point.getValue(i)) *
208  (clusters[0].getCentralValue(i) - point.getValue(i)));
209 
210  min_dist = sqrt(sum);
211 
212  for (int i = 1; i < K; i++)
213  {
214  double dist;
215  sum = 0.0;
216 
217  for (int j = 0; j < total_values; j++)
218  sum += ((clusters[i].getCentralValue(j) - point.getValue(j)) *
219  (clusters[i].getCentralValue(j) - point.getValue(j)));
220 
221  dist = sqrt(sum);
222 
223  if (dist < min_dist)
224  {
225  min_dist = dist;
226  id_cluster_center = i;
227  }
228  }
229 
230  return id_cluster_center;
231  }
232 
233 public:
234  KMeans(int K, int total_points, int total_values, int maxIterations,
235  int maxObjects)
236  {
237  this->K = K;
238  this->total_points = total_points;
239  this->total_values = total_values;
240  this->maxIterations = maxIterations;
241  this->maxObjects = maxObjects;
242  }
243 
244  std::vector<Cluster> run(std::vector<KPoint> & points,
246  {
247  // create clusters and choose K points as their centers centers
248  std::vector<int> prohibited_indexes;
249  std::random_device rd; // Will be used to obtain a seed for the random number engine
250  std::mt19937 gen(rd()); // Standard mersenne_twister_engine seeded with rd()
251  std::uniform_int_distribution<> udistr(0,total_points-1);
252 
253  if (total_points == 0)
254  {
255  std::ostringstream msg;
256  msg << "Division by zero: total_points == 0";
257  throw std::runtime_error(msg.str());
258  }
259 
260  for (int i = 0; i < K; i++)
261  {
262  while (true)
263  {
264  int index_point = udistr(gen);
265 
266  if (find(prohibited_indexes.begin(), prohibited_indexes.end(),
267  index_point) == prohibited_indexes.end())
268  {
269  prohibited_indexes.push_back(index_point);
270  points[index_point].setCluster(i);
271  Cluster cluster(i, points[index_point]);
272  clusters.push_back(cluster);
273  break;
274  }
275  }
276  }
277 
278  // if clusters already exists, load their computed centers
279  std::fstream savedClusters(fnClusters.c_str());
280  if (savedClusters.good())
281  {
282  std::string line;
283  for (int i = 0; i < K; i++)
284  {
285  std::getline(savedClusters, line);
286  std::stringstream ss(line);
287  double point_value;
288  for (int j = 0; j < total_values; j++)
289  {
290  ss >> point_value;
291  clusters[i].setCentralValue(j, point_value);
292  }
293  }
294  }
295 
296  int iter = 1;
297  while(true)
298  {
299  bool done = true;
300 
301  // when total_points is larger than threshold value maxObjects
302  // we only assign labels to point based on nearest center
303  if ((maxObjects != -1) && (total_points > maxObjects))
304  {
305  for (int i = 0; i < total_points; i++)
306  {
307  int nearest_center = getIDNearestCenter(points[i]);
308  points[i].setCluster(nearest_center);
309  clusters[nearest_center].addPoint(points[i]);
310  }
311  }
312  else // perform k-means clustering
313  {
314  // associates each point to the nearest center
315  for (int i = 0; i < total_points; i++)
316  {
317  int old_cluster = points[i].getCluster();
318  int nearest_center = getIDNearestCenter(points[i]);
319 
320  if (old_cluster != nearest_center)
321  {
322  if (old_cluster != -1)
323  clusters[old_cluster].removePoint(points[i].getID());
324 
325  points[i].setCluster(nearest_center);
326  clusters[nearest_center].addPoint(points[i]);
327  done = false;
328  }
329  }
330 
331  // recalculating the center of each cluster
332  for (int i = 0; i < K; i++)
333  {
334  for (int j = 0; j < total_values; j++)
335  {
336  int total_points = clusters[i].getTotalPoints();
337  double sum = 0.0;
338 
339  if (total_points > 0)
340  {
341  for (int p = 0; p < total_points; p++)
342  sum += clusters[i].getPoint(p).getValue(j);
343 
344  clusters[i].setCentralValue(j, sum / total_points);
345  }
346  }
347  }
348  }
349  if (done == true || iter >= maxIterations) break;
350  iter++;
351  }
352 
353  // This code is removing outliers, whose distance from centroid is
354  // 1.5*stddev, its efficiency depends strongly on feature extraction
355  /*
356  double dist, sum, stddev;
357  std::vector<double> cluster_point_dist;
358  for (int i = 0; i < K; i++)
359  {
360  dist = 0.0;
361  int points_orig_total = clusters[i].getTotalPoints();
362 
363  for (int p = 0; p < points_orig_total; p++)
364  {
365  sum = 0.0;
366  for (int j = 0; j < total_values; j++)
367  sum += pow(clusters[i].getCentralValue(j) -
368  clusters[i].getPoint(p).getValue(j), 2.0);
369 
370  cluster_point_dist.push_back(sqrt(sum));
371  dist += sqrt(sum) / points_orig_total;
372  }
373 
374  for (int p = 0; p < points_orig_total; p++)
375  stddev += pow(cluster_point_dist[p] - dist, 2.0);
376 
377  stddev = sqrt(stddev / points_orig_total);
378 
379  int pp = 0;
380  for (int p = 0; p < points_orig_total; p++)
381  {
382  // Swich this condition for taking only eliminated particles
383  if ((cluster_point_dist[p] > (dist + 1.5*stddev)) ||
384  (cluster_point_dist[p] < (dist - 1.5*stddev)))
385  clusters[i].removePoint(clusters[i].getPoint(pp).getID());
386  else pp++;
387  }
388  }
389  */
390 
391  std::ofstream saveData;
392 
393  // saving clusters
394  saveData.open(fnClusters.c_str());
395  for (int i = 0; i < K; i++)
396  {
397  for (int j = 0; j < total_values; j++)
398  saveData << clusters[i].getCentralValue(j) << " ";
399  saveData << std::endl;
400  }
401  saveData.close();
402 
403  // saving points
404  saveData.open(fnPoints.c_str());
405  for (int i = 0; i < total_points; i++)
406  {
407  for (int j = 0; j < total_values; j++)
408  saveData << points[i].getValue(j) << " ";
409  saveData << std::endl;
410  }
411  saveData.close();
412 
413  return clusters;
414  }
415 };
416 
418 {
419  MetaDataVec SF, MDsummary, MDclass, MDallDone;
420  FileName fnImg, fnClass, fnallDone;
421  Image<double> I, Imasked;
423  CorrelationAux aux;
424  std::vector<std::vector<double> > fvs;
425  std::vector<double> fv, fv_temp;
426  std::vector<KPoint> points;
427  std::vector<Cluster> clusters;
429  srand (time(nullptr));
430 
431  // reading new images from input file
432  SF.read(fnSel);
433  for (size_t objId : SF.ids())
434  {
435  SF.getValue(MDL_IMAGE, fnImg, objId);
436  I.read(fnImg);
437  I().setXmippOrigin();
438  centerImageTranslationally(I(), aux);
439  fv.clear();
440  ef.extractEntropy(I(), Imasked(), fv);
441  //ef.extractZernike(I(), fv);
442  //ef.extractLBP(I(), fv);
443  //ef.extractVariance(I(), fv);
444  //ef.extractGranulo(I(), fv);
445  //ef.extractRamp(I(), fv);
446  //ef.extractHistDist(I(), fv);
447  fvs.push_back(fv);
448  }
449 
450  double min_item, max_item;
451  for(int i = 0; i < fv.size(); i++)
452  {
453  fv_temp.clear();
454  for (int j = 0; j < fvs.size(); j++)
455  fv_temp.push_back(fvs[j][i]);
456 
457  max_item = *std::max_element(fv_temp.begin(), fv_temp.end());
458  min_item = *std::min_element(fv_temp.begin(), fv_temp.end());
459  for (int j = 0; j < fvs.size(); j++)
460  fvs[j][i] = (fvs[j][i] - min_item) / (max_item - min_item);
461  }
462 
463  int allItems = 0;
464  for (size_t objId : SF.ids())
465  {
466  allItems++;
467  KPoint p(allItems, fvs.front());
468  points.push_back(p);
469  fvs.erase(fvs.begin());
470  }
471 
472  // preparing all the paths to external files
473  std::size_t extraPath = fnSel.find_last_of("/");
474  fnOut = fnSel.substr(0, extraPath+1) + fnOut.c_str();
475  fnClusters = fnSel.substr(0, extraPath+1) + fnClusters.c_str();
476  fnPoints = fnSel.substr(0, extraPath+1) + fnPoints.c_str();
477  fnallDone = fnSel.substr(0, extraPath+1) + "allDone.xmd";
478 
479  // loading all the stored points from file (their count is unknown here)
480  std::vector<double> fv_load;
481  std::fstream savedPoints(fnPoints.c_str());
482  std::string line;
483  while (savedPoints.good())
484  {
485  std::getline(savedPoints, line);
486  if (line.size() < 2) break;
487  allItems++;
488  std::stringstream ss(line);
489  fv_load.clear();
490  double point_value;
491  for (int j = 0; j < fv.size(); j++)
492  {
493  ss >> point_value;
494  fv_load.push_back(point_value);
495  }
496  KPoint p(allItems, fv_load);
497  points.push_back(p);
498  }
499 
500  // performing k-means clustering
501  KMeans kmeans(K, allItems, fv.size(), allItems, maxObjects);
502  clusters = kmeans.run(points, fnClusters, fnPoints);
503 
504  // for cycle writing output file
505  for (int i = 0; i < clusters.size(); i++)
506  {
507  size_t total_points_cluster = clusters[i].getTotalPoints();
508 
509  size_t ii = MDsummary.addObject();
510  MDsummary.setValue(MDL_REF, i+1, ii);
511  MDsummary.setValue(MDL_CLASS_COUNT, total_points_cluster, ii);
512 
513  std::ostringstream clusterValues;
514  clusterValues << "[";
515  for (int j = 0; j < fv.size()-1; j++)
516  clusterValues << clusters[i].getCentralValue(j) << ", ";
517  clusterValues << clusters[i].getCentralValue(fv.size()-1) << "]";
518 
519  MDsummary.setValue(MDL_KMEANS2D_CENTROID, clusterValues.str(), ii);
520  MDsummary.write(formatString("classes@%s", fnOut.c_str()), MD_APPEND);
521  MDclass.clear();
522 
523  std::ifstream f(fnallDone.c_str());
524  if (f.good()) MDallDone.read(fnallDone);
525 
526  for (int j = 0; j < total_points_cluster; j++)
527  {
528  MDRowVec row;
529  MDallDone.getRow(row, clusters[i].getPoint(j).getID());
530  size_t recId = MDclass.addRow(row);
531  MDclass.setValue(MDL_REF, i+1, recId);
532  }
533  MDclass.write(formatString("class%06d_images@%s", i+1,
534  fnOut.c_str()), MD_APPEND);
535  }
536 }
void extractEntropy(const MultidimArray< double > &I, MultidimArray< double > &Imasked, std::vector< double > &fv)
Extracting entropy features.
double getCentralValue(int index)
void centerImageTranslationally(MultidimArray< double > &I, CorrelationAux &aux)
Definition: filters.cpp:3212
void read(const FileName &inFile, const std::vector< MDLabel > *desiredLabels=nullptr, bool decomposeStack=true) override
void sqrt(Image< double > &op)
std::vector< SelLine >::iterator find(std::vector< SelLine > &text, const std::string &img_name)
Definition: selfile.cpp:553
KMeans(int K, int total_points, int total_values, int maxIterations, int maxObjects)
void readParams()
Read argument.
void run()
Main routine.
void write(const FileName &outFile, WriteModeMetaData mode=MD_OVERWRITE) const
glob_prnt iter
virtual IdIteratorProxy< false > ids()
std::unique_ptr< MDRow > getRow(size_t id) override
#define i
size_t addRow(const MDRow &row) override
KPoint(int id_point, std::vector< double > &values)
void clear() override
Centroid of a cluster for the KMEANS2D classification.
KPoint getPoint(int index)
void addPoint(KPoint point)
const char * getParam(const char *param, int arg=0)
viol index
void setCentralValue(int index, double value)
bool setValue(const MDObject &mdValueIn, size_t id)
size_t addObject() override
double * f
bool removePoint(int id_point)
double getValue(int index)
int verbose
Verbosity level.
#define j
bool getValue(MDObject &mdValueOut, size_t id) const override
Class to which the image belongs (int)
void defineParams()
Define parameters.
Number of images assigned to the same class as this image.
void addValue(double value)
String formatString(const char *format,...)
int read(const FileName &name, DataMode datamode=DATA, size_t select_img=ALL_IMAGES, bool mapData=false, int mode=WRITE_READONLY)
Cluster(int id_cluster, KPoint point)
void addUsageLine(const char *line, bool verbatim=false)
std::vector< Cluster > run(std::vector< KPoint > &points, FileName fnClusters, FileName fnPoints)
int getIntParam(const char *param, int arg=0)
int getTotalValues()
Name of an image (std::string)
void setCluster(int id_cluster)
void addParamsLine(const String &line)