Xmipp  v3.23.11-Nereus
mean_shift.cpp
Go to the documentation of this file.
1 /***************************************************************************
2  *
3 * Authors: J.R. Bilbao-Castro (jrbcast@ace.ual.es)
4 * Updated by: J.M. de la Rosa Trevin (jmdelarosa@cnb.csic.es)
5  *
6  * Unidad de Bioinformatica of Centro Nacional de Biotecnologia , CSIC
7  *
8  * This program is free software; you can redistribute it and/or modify
9  * it under the terms of the GNU General Public License as published by
10  * the Free Software Foundation; either version 2 of the License, or
11  * (at your option) any later version.
12  *
13  * This program is distributed in the hope that it will be useful,
14  * but WITHOUT ANY WARRANTY; without even the implied warranty of
15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16  * GNU General Public License for more details.
17  *
18  * You should have received a copy of the GNU General Public License
19  * along with this program; if not, write to the Free Software
20  * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
21  * 02111-1307 USA
22  *
23  * All comments concerning this program package may be sent to the
24  * e-mail address 'xmipp@cnb.csic.es'
25  ***************************************************************************/
26 
27 #include "mean_shift.h"
28 #include "core/xmipp_program.h"
29 
31 {
32  unsigned int myID;
33  unsigned int numThreads;
34  double sigma_s;
35  double sigma_r;
38  bool fast;
39 };
40 
41 void * thread_process_plane( void * args )
42 {
43  auto * thrParams = (ThreadProcessPlaneArgs *) args;
44 
45  unsigned int myID = thrParams->myID;
46  unsigned int numThreads = thrParams->numThreads;
47  double sigma_s = thrParams->sigma_s;
48  double sigma_r = thrParams->sigma_r;
49  MultidimArray<double> &input = *thrParams->input;
50  MultidimArray<double> &output = *thrParams->output;
51  bool fast = thrParams->fast;
52 
53  if( !fast )
54  {
55  sigma_s /= 3.0;
56  sigma_s = CEIL( sigma_s );
57  sigma_r /= 3.0;
58  }
59  else
60  {
61  sigma_s = CEIL( sigma_s );
62  }
63 
64  double curr_I;
65  double inv_2_sigma_s_2 = 1.0/( 2.0* sigma_s * sigma_s);
66  double inv_2_sigma_r_2 = 1.0/( 2.0* sigma_r * sigma_r);
67  double sigma_r_2 = sigma_r * sigma_r;
68  auto _3_sigma_s = (int)(3 * sigma_s);
69  double _3_sigma_r = 3.0 * sigma_r;
70  int y_min = input.startingY();
71  int y_max = input.finishingY();
72  int x_min = input.startingX();
73  int x_max = input.finishingX();
74  int z_min = input.startingZ();
75  int z_max = input.finishingZ();
76  int curr_x, curr_y, curr_z;
77  int prev_x, prev_y, prev_z;
78  double error;
79 
80  int myFirstY, myLastY;
81 
82  int numElems = y_max-y_min+1;
83  numElems /= numThreads;
84 
85  myFirstY = myID * numElems + y_min;
86 
87  if( myID == numThreads -1 )
88  myLastY = y_max;
89  else
90  myLastY = myFirstY + numElems - 1;
91 
92  if( myID == 0 )
93  {
94  std::cerr << "Progress (over a total of " << z_max - z_min + 1 << " slices): ";
95  }
96 
97  if( fast )
98  {
99  for (int k=z_min; k<=z_max; k++)
100  {
101  if( myID == 0 )
102  {
103  if( z_min < 0 )
104  std::cerr << k - z_min << " ";
105  else
106  std::cerr << k + z_min << " ";
107  }
108 
109  for (int i=myFirstY; i<=myLastY; i++)
110  {
111  for (int j=x_min; j<=x_max; j++)
112  {
113  // x == j
114  // y == i
115  // z == k
116 
117  int xc = j;
118  int yc = i;
119  int zc = k;
120 
121  int xcOld, ycOld, zcOld;
122  double YcOld=0;
123  double Yc = A3D_ELEM( input, k, i, j);
124  int iters =0;
125  double shift;
126  int num;
127  double mY;
128 
129  do
130  {
131  xcOld = xc;
132  ycOld = yc;
133  zcOld = zc;
134 
135  double mx=0;
136  double my=0;
137  double mz=0;
138  mY=0;
139  num=0;
140 
141  for( int z_j = -(int)sigma_s ; z_j <=(int)sigma_s ; z_j++ )
142  {
143  int z2 = zc + z_j;
144  if( z2 >= z_min && z2 <= z_max)
145  {
146  int z_j_2 = z_j * z_j; // Speed-up
147 
148  for( int y_j = -(int)sigma_s ; y_j <= (int)sigma_s ; y_j++ )
149  {
150  int y2 = yc + y_j;
151  if( y2 >= y_min && y2 <= y_max)
152  {
153  int y_j_2 = y_j * y_j; // Speed-up
154 
155  for( int x_j = -(int)sigma_s ; x_j <= (int)sigma_s ; x_j++ )
156  {
157  int x2 = xc + x_j;
158  if( x2 >= x_min && x2 <= x_max)
159  {
160 
161  if( x_j*x_j+y_j_2+z_j_2 <= sigma_s*sigma_s )
162  {
163  double Y2 = A3D_ELEM( input, z2, y2, x2);
164  double dY = Yc - Y2;
165 
166  if( dY*dY <= sigma_r_2 )
167  {
168  mx += x2;
169  my += y2;
170  mz += z2;
171  mY += Y2;
172  num++;
173  }
174  }
175  }
176  }
177  }
178  }
179  }
180  }
181 
182  double num_ = 1.0/num;
183  Yc = mY*num_;
184  xc = (int) (mx*num_+0.5);
185  yc = (int) (my*num_+0.5);
186  zc = (int) (mz*num_+0.5);
187  int dx = xc-xcOld;
188  int dy = yc-ycOld;
189  int dz = zc-zcOld;
190 
191  double dY = Yc-YcOld;
192 
193  shift = dx*dx+dy*dy+dz*dz*dY*dY;
194  iters++;
195  }
196  while(shift>3 && iters<100);
197 
198  A3D_ELEM( output, k,i,j ) = Yc;
199  }
200  }
201  }
202  }
203  else
204  {
205  for (int k=z_min; k<=z_max; k++)
206  {
207  if( myID == 0 )
208  std::cerr << k << " ";
209  for (int i=myFirstY; i<=myLastY; i++)
210  {
211  for (int j=x_min; j<=x_max; j++)
212  {
213  // x == j
214  // y == i
215  // z == k
216 
217  curr_x = j;
218  curr_y = i;
219  curr_z = k;
220  curr_I = A3D_ELEM( input, k, i, j);
221 
222  double sum_denom;
223  double I_sum;
224  do
225  {
226  // Check neighbourhood
227  double x_sum = 0, y_sum = 0, z_sum = 0;
228  sum_denom = 0;
229  I_sum = 0;
230 
231  for( int z_j = curr_z - _3_sigma_s ; z_j <= curr_z + _3_sigma_s ; z_j++ )
232  {
233  if( z_j >= z_min && z_j <= z_max)
234  {
235  for( int y_j = curr_y - _3_sigma_s ; y_j <= curr_y + _3_sigma_s ; y_j++ )
236  {
237  if( y_j >= y_min && y_j <= y_max)
238  {
239  for( int x_j = curr_x - _3_sigma_s ; x_j <= curr_x + _3_sigma_s ; x_j++ )
240  {
241  if( x_j >= x_min && x_j <= x_max)
242  {
243  double I_j = A3D_ELEM( input, z_j, y_j, x_j);
244  double I_dist=fabs(I_j - curr_I);
245 
246  if( I_dist <= _3_sigma_r )
247  {
248  // Take this point into account
249  double eucl_dist = (curr_x - x_j)*(curr_x - x_j)+(curr_y - y_j)*(curr_y - y_j)+(curr_z - z_j)*(curr_z - z_j);
250  double aux = exp(-(eucl_dist*inv_2_sigma_s_2) - (I_dist*I_dist*inv_2_sigma_r_2));
251  x_sum += x_j*aux;
252  y_sum += y_j*aux;
253  z_sum += z_j*aux;
254  I_sum += I_j*aux;
255 
256  sum_denom += aux;
257  }
258  }
259  }
260  }
261  }
262  }
263  }
264 
265  prev_x = (int)round(curr_x);
266  prev_y = (int)round(curr_y);
267  prev_z = (int)round(curr_z);
268 
269  double isum_denom=1.0/sum_denom;
270  curr_x = (int)round(x_sum*isum_denom);
271  curr_y = (int)round(y_sum*isum_denom);
272  curr_z = (int)round(z_sum*isum_denom);
273  curr_I = I_sum*isum_denom;
274 
275  error = fabs(prev_x - curr_x) + fabs(prev_y - curr_y) + fabs(prev_z - curr_z);
276  }
277  while(error>0);
278 
279  A3D_ELEM( output, k,i,j ) = curr_I;
280  }
281  }
282  }
283  }
284  return nullptr;
285 }
286 
287 /* Define params ------------------------------------------------------------------- */
289 {
290  program->addParamsLine("== Mean shift ==");
291  program->addParamsLine("[--mean_shift <hr> <hs> <iter=1>] : Filter based on the mean-shift algorithm");
292  program->addParamsLine(" :+ hs: Sigma for the range domain");
293  program->addParamsLine(" :+ hr: Sigma for the spatial domain");
294  program->addParamsLine(" :+ iter: Number of iterations to be used");
295  program->addParamsLine(" alias -t;");
296  program->addParamsLine("[--thr <n=1>] : Number of processing threads");
297  program->addParamsLine("[--fast] : Use faster processing (avoid gaussian calculations)");
298  program->addParamsLine("[--save_iters] : Save result image/volume for each iteration");
299 
300 }
301 
303 {
304  sigma_r = program->getDoubleParam("--mean_shift", 0);//hr
305  sigma_s = program->getDoubleParam("--mean_shift", 1);//hs
306  iters = program->getIntParam("--mean_shift", 2);//iters
307  fast = program->checkParam("--fast");
308  numThreads = program->getIntParam("--thr");
309  save_iters = program->checkParam("--save_iters");
310  verbose = program->verbose;
311 }
312 
314 {
316 
317  auto * th_ids = new pthread_t[numThreads];
318  auto * th_args = new ThreadProcessPlaneArgs[numThreads];
319 
320  for( int iter = 0 ; iter < iters ; ++iter )
321  {
322  if (verbose)
323  std::cout << formatString("Running iteration %d/%d", iter+1, iters) << std::endl;
324 
325  for( int nt = 0; nt < numThreads ; nt++ )
326  {
327  th_args[nt].myID = nt;
328  th_args[nt].numThreads = numThreads;
329  th_args[nt].sigma_s = sigma_s;
330  th_args[nt].sigma_r = sigma_r;
331  th_args[nt].input = &input;
332  th_args[nt].output = &output;
333  th_args[nt].fast = fast;
334 
335  pthread_create( (th_ids+nt), nullptr, thread_process_plane, (void *)(th_args+nt));
336  }
337 
338  for( int nt = 0; nt < numThreads ; nt++ )
339  {
340  pthread_join( *(th_ids+nt), nullptr);
341  }
342 
343  if( save_iters )
344  {
345  FileName fn_aux = "xxx"; //fixme
346  fn_aux = fn_aux.insertBeforeExtension( std::string("_iter_") + integerToString(iter) );
347 
348  if (verbose)
349  std::cout << "Saving intermidiate file: " << fn_aux << std::endl;
350 
351  output.write( fn_aux );
352  }
353 
354  //update input for next iteration
355  input = output;
356  }
357  delete[] th_ids;
358  delete[] th_args;
359 }
#define yc
double getDoubleParam(const char *param, int arg=0)
FileName insertBeforeExtension(const String &str) const
void * thread_process_plane(void *args)
Definition: mean_shift.cpp:41
static void defineParams(XmippProgram *program)
Definition: mean_shift.cpp:288
String integerToString(int I, int _width, char fill_with)
MultidimArray< double > * input
Definition: mean_shift.cpp:36
glob_prnt iter
#define i
ql0001_ & k(htemp+1),(cvec+1),(atemp+1),(bj+1),(bl+1),(bu+1),(x+1),(clamda+1), &iout, infoqp, &zero,(w+1), &lenw,(iw+1), &leniw, &glob_grd.epsmac
#define A3D_ELEM(V, k, i, j)
unsigned int numThreads
Definition: mean_shift.cpp:33
void apply(MultidimArray< double > &img)
Definition: mean_shift.cpp:313
#define CEIL(x)
Definition: xmipp_macros.h:225
void write(const FileName &fn) const
double dx
int verbose
Verbosity level.
void readParams(XmippProgram *program)
Definition: mean_shift.cpp:302
#define j
void error(char *s)
Definition: tools.cpp:107
int round(double x)
Definition: ap.cpp:7245
String formatString(const char *format,...)
bool checkParam(const char *param)
int getIntParam(const char *param, int arg=0)
MultidimArray< double > * output
Definition: mean_shift.cpp:37
void addParamsLine(const String &line)
#define xc