Xmipp  v3.23.11-Nereus
Classes | Public Member Functions | Protected Types | Protected Member Functions | Protected Attributes | List of all members
Solver Class Reference
Inheritance diagram for Solver:
Inheritance graph
[legend]
Collaboration diagram for Solver:
Collaboration graph
[legend]

Classes

struct  SolutionInfo
 

Public Member Functions

 Solver ()
 
virtual ~Solver ()
 
void Solve (int l, const QMatrix &Q, const double *p_, const schar *y_, double *alpha_, double Cp, double Cn, double eps, SolutionInfo *si, int shrinking)
 

Protected Types

enum  { LOWER_BOUND, UPPER_BOUND, FREE }
 

Protected Member Functions

double get_C (int i)
 
void update_alpha_status (int i)
 
bool is_upper_bound (int i)
 
bool is_lower_bound (int i)
 
bool is_free (int i)
 
void swap_index (int i, int j)
 
void reconstruct_gradient ()
 
virtual int select_working_set (int &i, int &j)
 
virtual double calculate_rho ()
 
virtual void do_shrinking ()
 

Protected Attributes

int active_size
 
schary
 
double * G
 
char * alpha_status
 
double * alpha
 
const QMatrixQ
 
const double * QD
 
double eps
 
double Cp
 
double Cn
 
double * p
 
int * active_set
 
double * G_bar
 
int l
 
bool unshrink
 

Detailed Description

Definition at line 399 of file svm.cpp.

Member Enumeration Documentation

◆ anonymous enum

anonymous enum
protected
Enumerator
LOWER_BOUND 
UPPER_BOUND 
FREE 

Definition at line 419 of file svm.cpp.

Constructor & Destructor Documentation

◆ Solver()

Solver::Solver ( )
inline

Definition at line 401 of file svm.cpp.

401 {};

◆ ~Solver()

virtual Solver::~Solver ( )
inlinevirtual

Definition at line 402 of file svm.cpp.

402 {};

Member Function Documentation

◆ calculate_rho()

double Solver::calculate_rho ( )
protectedvirtual

Definition at line 972 of file svm.cpp.

973 {
974  double r;
975  int nr_free = 0;
976  double ub = INF, lb = -INF, sum_free = 0;
977  for(int i=0;i<active_size;i++)
978  {
979  double yG = y[i]*G[i];
980 
981  if(is_upper_bound(i))
982  {
983  if(y[i]==-1)
984  ub = min(ub,yG);
985  else
986  lb = max(lb,yG);
987  }
988  else if(is_lower_bound(i))
989  {
990  if(y[i]==+1)
991  ub = min(ub,yG);
992  else
993  lb = max(lb,yG);
994  }
995  else
996  {
997  ++nr_free;
998  sum_free += yG;
999  }
1000  }
1001 
1002  if(nr_free>0)
1003  r = sum_free/nr_free;
1004  else
1005  r = (ub+lb)/2;
1006 
1007  return r;
1008 }
int active_size
Definition: svm.cpp:416
void min(Image< double > &op1, const Image< double > &op2)
double * G
Definition: svm.cpp:418
#define i
bool is_upper_bound(int i)
Definition: svm.cpp:444
void max(Image< double > &op1, const Image< double > &op2)
schar * y
Definition: svm.cpp:417
bool is_lower_bound(int i)
Definition: svm.cpp:445
#define INF
Definition: svm.cpp:43

◆ do_shrinking()

void Solver::do_shrinking ( )
protectedvirtual

Definition at line 911 of file svm.cpp.

912 {
913  int i;
914  double Gmax1 = -INF; // max { -y_i * grad(f)_i | i in I_up(\alpha) }
915  double Gmax2 = -INF; // max { y_i * grad(f)_i | i in I_low(\alpha) }
916 
917  // find maximal violating pair first
918  for(i=0;i<active_size;i++)
919  {
920  if(y[i]==+1)
921  {
922  if(!is_upper_bound(i))
923  {
924  if(-G[i] >= Gmax1)
925  Gmax1 = -G[i];
926  }
927  if(!is_lower_bound(i))
928  {
929  if(G[i] >= Gmax2)
930  Gmax2 = G[i];
931  }
932  }
933  else
934  {
935  if(!is_upper_bound(i))
936  {
937  if(-G[i] >= Gmax2)
938  Gmax2 = -G[i];
939  }
940  if(!is_lower_bound(i))
941  {
942  if(G[i] >= Gmax1)
943  Gmax1 = G[i];
944  }
945  }
946  }
947 
948  if(unshrink == false && Gmax1 + Gmax2 <= eps*10)
949  {
950  unshrink = true;
952  active_size = l;
953  info("*");
954  }
955 
956  for(i=0;i<active_size;i++)
957  if (be_shrunk(i, Gmax1, Gmax2))
958  {
959  active_size--;
960  while (active_size > i)
961  {
962  if (!be_shrunk(active_size, Gmax1, Gmax2))
963  {
964  swap_index(i,active_size);
965  break;
966  }
967  active_size--;
968  }
969  }
970 }
int active_size
Definition: svm.cpp:416
int l
Definition: svm.cpp:429
void swap_index(int i, int j)
Definition: svm.cpp:456
double * G
Definition: svm.cpp:418
#define i
bool is_upper_bound(int i)
Definition: svm.cpp:444
schar * y
Definition: svm.cpp:417
bool is_lower_bound(int i)
Definition: svm.cpp:445
#define INF
Definition: svm.cpp:43
double eps
Definition: svm.cpp:424
bool unshrink
Definition: svm.cpp:430
void reconstruct_gradient()
Definition: svm.cpp:468

◆ get_C()

double Solver::get_C ( int  i)
inlineprotected

Definition at line 432 of file svm.cpp.

433  {
434  return (y[i] > 0)? Cp : Cn;
435  }
double Cp
Definition: svm.cpp:425
#define i
double Cn
Definition: svm.cpp:425
schar * y
Definition: svm.cpp:417

◆ is_free()

bool Solver::is_free ( int  i)
inlineprotected

Definition at line 446 of file svm.cpp.

446 { return alpha_status[i] == FREE; }
#define i
char * alpha_status
Definition: svm.cpp:420

◆ is_lower_bound()

bool Solver::is_lower_bound ( int  i)
inlineprotected

Definition at line 445 of file svm.cpp.

445 { return alpha_status[i] == LOWER_BOUND; }
#define i
char * alpha_status
Definition: svm.cpp:420

◆ is_upper_bound()

bool Solver::is_upper_bound ( int  i)
inlineprotected

Definition at line 444 of file svm.cpp.

444 { return alpha_status[i] == UPPER_BOUND; }
#define i
char * alpha_status
Definition: svm.cpp:420

◆ reconstruct_gradient()

void Solver::reconstruct_gradient ( )
protected

Definition at line 468 of file svm.cpp.

469 {
470  // reconstruct inactive elements of G from G_bar and free variables
471 
472  if(active_size == l) return;
473 
474  int i,j;
475  int nr_free = 0;
476 
477  for(j=active_size;j<l;j++)
478  G[j] = G_bar[j] + p[j];
479 
480  for(j=0;j<active_size;j++)
481  if(is_free(j))
482  nr_free++;
483 
484  if(2*nr_free < active_size)
485  info("\nWARNING: using -h 0 may be faster\n");
486 
487  if (nr_free*l > 2*active_size*(l-active_size))
488  {
489  for(i=active_size;i<l;i++)
490  {
491  const Qfloat *Q_i = Q->get_Q(i,active_size);
492  for(j=0;j<active_size;j++)
493  if(is_free(j))
494  G[i] += alpha[j] * Q_i[j];
495  }
496  }
497  else
498  {
499  for(i=0;i<active_size;i++)
500  if(is_free(i))
501  {
502  const Qfloat *Q_i = Q->get_Q(i,l);
503  double alpha_i = alpha[i];
504  for(j=active_size;j<l;j++)
505  G[j] += alpha_i * Q_i[j];
506  }
507  }
508 }
int active_size
Definition: svm.cpp:416
int l
Definition: svm.cpp:429
double * G
Definition: svm.cpp:418
virtual Qfloat * get_Q(int column, int len) const =0
double * G_bar
Definition: svm.cpp:428
double * alpha
Definition: svm.cpp:421
#define i
double * p
Definition: svm.cpp:426
float Qfloat
Definition: svm.cpp:18
#define j
const QMatrix * Q
Definition: svm.cpp:422
bool is_free(int i)
Definition: svm.cpp:446

◆ select_working_set()

int Solver::select_working_set ( int &  i,
int &  j 
)
protectedvirtual

Definition at line 792 of file svm.cpp.

793 {
794  // return i,j such that
795  // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)
796  // j: minimizes the decrease of obj value
797  // (if quadratic coefficeint <= 0, replace it with tau)
798  // -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha)
799 
800  double Gmax = -INF;
801  double Gmax2 = -INF;
802  int Gmax_idx = -1;
803  int Gmin_idx = -1;
804  double obj_diff_min = INF;
805 
806  for(int t=0;t<active_size;t++)
807  if(y[t]==+1)
808  {
809  if(!is_upper_bound(t))
810  if(-G[t] >= Gmax)
811  {
812  Gmax = -G[t];
813  Gmax_idx = t;
814  }
815  }
816  else
817  {
818  if(!is_lower_bound(t))
819  if(G[t] >= Gmax)
820  {
821  Gmax = G[t];
822  Gmax_idx = t;
823  }
824  }
825 
826  int i = Gmax_idx;
827  const Qfloat *Q_i = NULL;
828  if(i != -1) // NULL Q_i not accessed: Gmax=-INF if i=-1
829  Q_i = Q->get_Q(i,active_size);
830 
831  for(int j=0;j<active_size;j++)
832  {
833  if(y[j]==+1)
834  {
835  if (!is_lower_bound(j))
836  {
837  double grad_diff=Gmax+G[j];
838  if (G[j] >= Gmax2)
839  Gmax2 = G[j];
840  if (grad_diff > 0)
841  {
842  double obj_diff;
843  double quad_coef = QD[i]+QD[j]-2.0*y[i]*Q_i[j];
844  if (quad_coef > 0)
845  obj_diff = -(grad_diff*grad_diff)/quad_coef;
846  else
847  obj_diff = -(grad_diff*grad_diff)/TAU;
848 
849  if (obj_diff <= obj_diff_min)
850  {
851  Gmin_idx=j;
852  obj_diff_min = obj_diff;
853  }
854  }
855  }
856  }
857  else
858  {
859  if (!is_upper_bound(j))
860  {
861  double grad_diff= Gmax-G[j];
862  if (-G[j] >= Gmax2)
863  Gmax2 = -G[j];
864  if (grad_diff > 0)
865  {
866  double obj_diff;
867  double quad_coef = QD[i]+QD[j]+2.0*y[i]*Q_i[j];
868  if (quad_coef > 0)
869  obj_diff = -(grad_diff*grad_diff)/quad_coef;
870  else
871  obj_diff = -(grad_diff*grad_diff)/TAU;
872 
873  if (obj_diff <= obj_diff_min)
874  {
875  Gmin_idx=j;
876  obj_diff_min = obj_diff;
877  }
878  }
879  }
880  }
881  }
882 
883  if(Gmax+Gmax2 < eps || Gmin_idx == -1)
884  return 1;
885 
886  out_i = Gmax_idx;
887  out_j = Gmin_idx;
888  return 0;
889 }
int active_size
Definition: svm.cpp:416
double * G
Definition: svm.cpp:418
virtual Qfloat * get_Q(int column, int len) const =0
#define i
bool is_upper_bound(int i)
Definition: svm.cpp:444
const double * QD
Definition: svm.cpp:423
#define TAU
Definition: svm.cpp:44
float Qfloat
Definition: svm.cpp:18
#define j
const QMatrix * Q
Definition: svm.cpp:422
schar * y
Definition: svm.cpp:417
bool is_lower_bound(int i)
Definition: svm.cpp:445
#define INF
Definition: svm.cpp:43
double eps
Definition: svm.cpp:424

◆ Solve()

void Solver::Solve ( int  l,
const QMatrix Q,
const double *  p_,
const schar y_,
double *  alpha_,
double  Cp,
double  Cn,
double  eps,
SolutionInfo si,
int  shrinking 
)

Definition at line 510 of file svm.cpp.

513 {
514  this->l = l;
515  this->Q = &Q;
516  QD=Q.get_QD();
517  clone(p, p_,l);
518  clone(y, y_,l);
519  clone(alpha,alpha_,l);
520  this->Cp = Cp;
521  this->Cn = Cn;
522  this->eps = eps;
523  unshrink = false;
524 
525  // initialize alpha_status
526  {
527  alpha_status = new char[l];
528  for(int i=0;i<l;i++)
530  }
531 
532  // initialize active set (for shrinking)
533  {
534  active_set = new int[l];
535  for(int i=0;i<l;i++)
536  active_set[i] = i;
537  active_size = l;
538  }
539 
540  // initialize gradient
541  {
542  G = new double[l];
543  G_bar = new double[l];
544  int i;
545  for(i=0;i<l;i++)
546  {
547  G[i] = p[i];
548  G_bar[i] = 0;
549  }
550  for(i=0;i<l;i++)
551  if(!is_lower_bound(i))
552  {
553  const Qfloat *Q_i = Q.get_Q(i,l);
554  double alpha_i = alpha[i];
555  int j;
556  for(j=0;j<l;j++)
557  G[j] += alpha_i*Q_i[j];
558  if(is_upper_bound(i))
559  for(j=0;j<l;j++)
560  G_bar[j] += get_C(i) * Q_i[j];
561  }
562  }
563 
564  // optimization step
565 
566  int iter = 0;
567  int max_iter = max(10000000, l>INT_MAX/100 ? INT_MAX : 100*l);
568  int counter = min(l,1000)+1;
569 
570  while(iter < max_iter)
571  {
572  // show progress and do shrinking
573 
574  if(--counter == 0)
575  {
576  counter = min(l,1000);
577  if(shrinking) do_shrinking();
578  info(".");
579  }
580 
581  int i,j;
582  if(select_working_set(i,j)!=0)
583  {
584  // reconstruct the whole gradient
586  // reset active set size and check
587  active_size = l;
588  info("*");
589  if(select_working_set(i,j)!=0)
590  break;
591  else
592  counter = 1; // do shrinking next iteration
593  }
594 
595  ++iter;
596 
597  // update alpha[i] and alpha[j], handle bounds carefully
598 
599  const Qfloat *Q_i = Q.get_Q(i,active_size);
600  const Qfloat *Q_j = Q.get_Q(j,active_size);
601 
602  double C_i = get_C(i);
603  double C_j = get_C(j);
604 
605  double old_alpha_i = alpha[i];
606  double old_alpha_j = alpha[j];
607 
608  if(y[i]!=y[j])
609  {
610  double quad_coef = QD[i]+QD[j]+2*Q_i[j];
611  if (quad_coef <= 0)
612  quad_coef = TAU;
613  double delta = (-G[i]-G[j])/quad_coef;
614  double diff = alpha[i] - alpha[j];
615  alpha[i] += delta;
616  alpha[j] += delta;
617 
618  if(diff > 0)
619  {
620  if(alpha[j] < 0)
621  {
622  alpha[j] = 0;
623  alpha[i] = diff;
624  }
625  }
626  else
627  {
628  if(alpha[i] < 0)
629  {
630  alpha[i] = 0;
631  alpha[j] = -diff;
632  }
633  }
634  if(diff > C_i - C_j)
635  {
636  if(alpha[i] > C_i)
637  {
638  alpha[i] = C_i;
639  alpha[j] = C_i - diff;
640  }
641  }
642  else
643  {
644  if(alpha[j] > C_j)
645  {
646  alpha[j] = C_j;
647  alpha[i] = C_j + diff;
648  }
649  }
650  }
651  else
652  {
653  double quad_coef = QD[i]+QD[j]-2*Q_i[j];
654  if (quad_coef <= 0)
655  quad_coef = TAU;
656  double delta = (G[i]-G[j])/quad_coef;
657  double sum = alpha[i] + alpha[j];
658  alpha[i] -= delta;
659  alpha[j] += delta;
660 
661  if(sum > C_i)
662  {
663  if(alpha[i] > C_i)
664  {
665  alpha[i] = C_i;
666  alpha[j] = sum - C_i;
667  }
668  }
669  else
670  {
671  if(alpha[j] < 0)
672  {
673  alpha[j] = 0;
674  alpha[i] = sum;
675  }
676  }
677  if(sum > C_j)
678  {
679  if(alpha[j] > C_j)
680  {
681  alpha[j] = C_j;
682  alpha[i] = sum - C_j;
683  }
684  }
685  else
686  {
687  if(alpha[i] < 0)
688  {
689  alpha[i] = 0;
690  alpha[j] = sum;
691  }
692  }
693  }
694 
695  // update G
696 
697  double delta_alpha_i = alpha[i] - old_alpha_i;
698  double delta_alpha_j = alpha[j] - old_alpha_j;
699 
700  for(int k=0;k<active_size;k++)
701  {
702  G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;
703  }
704 
705  // update alpha_status and G_bar
706 
707  {
708  bool ui = is_upper_bound(i);
709  bool uj = is_upper_bound(j);
712  int k;
713  if(ui != is_upper_bound(i))
714  {
715  Q_i = Q.get_Q(i,l);
716  if(ui)
717  for(k=0;k<l;k++)
718  G_bar[k] -= C_i * Q_i[k];
719  else
720  for(k=0;k<l;k++)
721  G_bar[k] += C_i * Q_i[k];
722  }
723 
724  if(uj != is_upper_bound(j))
725  {
726  Q_j = Q.get_Q(j,l);
727  if(uj)
728  for(k=0;k<l;k++)
729  G_bar[k] -= C_j * Q_j[k];
730  else
731  for(k=0;k<l;k++)
732  G_bar[k] += C_j * Q_j[k];
733  }
734  }
735  }
736 
737  if(iter >= max_iter)
738  {
739  if(active_size < l)
740  {
741  // reconstruct the whole gradient to calculate objective value
743  active_size = l;
744  info("*");
745  }
746  fprintf(stderr,"\nWARNING: reaching max number of iterations\n");
747  }
748 
749  // calculate rho
750 
751  si->rho = calculate_rho();
752 
753  // calculate objective value
754  {
755  double v = 0;
756  int i;
757  for(i=0;i<l;i++)
758  v += alpha[i] * (G[i] + p[i]);
759 
760  si->obj = v/2;
761  }
762 
763  // put back the solution
764  {
765  for(int i=0;i<l;i++)
766  alpha_[active_set[i]] = alpha[i];
767  }
768 
769  // juggle everything back
770  /*{
771  for(int i=0;i<l;i++)
772  while(active_set[i] != i)
773  swap_index(i,active_set[i]);
774  // or Q.swap_index(i,active_set[i]);
775  }*/
776 
777  si->upper_bound_p = Cp;
778  si->upper_bound_n = Cn;
779 
780  info("\noptimization finished, #iter = %d\n",iter);
781 
782  delete[] p;
783  delete[] y;
784  delete[] alpha;
785  delete[] alpha_status;
786  delete[] active_set;
787  delete[] G;
788  delete[] G_bar;
789 }
int active_size
Definition: svm.cpp:416
int l
Definition: svm.cpp:429
void min(Image< double > &op1, const Image< double > &op2)
double * G
Definition: svm.cpp:418
virtual Qfloat * get_Q(int column, int len) const =0
double * G_bar
Definition: svm.cpp:428
double * alpha
Definition: svm.cpp:421
glob_prnt iter
double Cp
Definition: svm.cpp:425
#define i
ql0001_ & k(htemp+1),(cvec+1),(atemp+1),(bj+1),(bl+1),(bu+1),(x+1),(clamda+1), &iout, infoqp, &zero,(w+1), &lenw,(iw+1), &leniw, &glob_grd.epsmac
virtual int select_working_set(int &i, int &j)
Definition: svm.cpp:792
double * p
Definition: svm.cpp:426
bool is_upper_bound(int i)
Definition: svm.cpp:444
const double * QD
Definition: svm.cpp:423
#define TAU
Definition: svm.cpp:44
void max(Image< double > &op1, const Image< double > &op2)
float Qfloat
Definition: svm.cpp:18
void update_alpha_status(int i)
Definition: svm.cpp:436
int * active_set
Definition: svm.cpp:427
#define j
double Cn
Definition: svm.cpp:425
const QMatrix * Q
Definition: svm.cpp:422
schar * y
Definition: svm.cpp:417
virtual double * get_QD() const =0
virtual void do_shrinking()
Definition: svm.cpp:911
bool is_lower_bound(int i)
Definition: svm.cpp:445
fprintf(glob_prnt.io, "\)
double eps
Definition: svm.cpp:424
bool unshrink
Definition: svm.cpp:430
virtual double calculate_rho()
Definition: svm.cpp:972
double get_C(int i)
Definition: svm.cpp:432
void reconstruct_gradient()
Definition: svm.cpp:468
char * alpha_status
Definition: svm.cpp:420
double * delta

◆ swap_index()

void Solver::swap_index ( int  i,
int  j 
)
protected

Definition at line 456 of file svm.cpp.

457 {
458  Q->swap_index(i,j);
459  swap(y[i],y[j]);
460  swap(G[i],G[j]);
461  swap(alpha_status[i],alpha_status[j]);
462  swap(alpha[i],alpha[j]);
463  swap(p[i],p[j]);
464  swap(active_set[i],active_set[j]);
465  swap(G_bar[i],G_bar[j]);
466 }
double * G
Definition: svm.cpp:418
double * G_bar
Definition: svm.cpp:428
double * alpha
Definition: svm.cpp:421
#define i
double * p
Definition: svm.cpp:426
int * active_set
Definition: svm.cpp:427
#define j
const QMatrix * Q
Definition: svm.cpp:422
schar * y
Definition: svm.cpp:417
virtual void swap_index(int i, int j) const =0
char * alpha_status
Definition: svm.cpp:420

◆ update_alpha_status()

void Solver::update_alpha_status ( int  i)
inlineprotected

Definition at line 436 of file svm.cpp.

437  {
438  if(alpha[i] >= get_C(i))
440  else if(alpha[i] <= 0)
442  else alpha_status[i] = FREE;
443  }
double * alpha
Definition: svm.cpp:421
#define i
double get_C(int i)
Definition: svm.cpp:432
char * alpha_status
Definition: svm.cpp:420

Member Data Documentation

◆ active_set

int* Solver::active_set
protected

Definition at line 427 of file svm.cpp.

◆ active_size

int Solver::active_size
protected

Definition at line 416 of file svm.cpp.

◆ alpha

double* Solver::alpha
protected

Definition at line 421 of file svm.cpp.

◆ alpha_status

char* Solver::alpha_status
protected

Definition at line 420 of file svm.cpp.

◆ Cn

double Solver::Cn
protected

Definition at line 425 of file svm.cpp.

◆ Cp

double Solver::Cp
protected

Definition at line 425 of file svm.cpp.

◆ eps

double Solver::eps
protected

Definition at line 424 of file svm.cpp.

◆ G

double* Solver::G
protected

Definition at line 418 of file svm.cpp.

◆ G_bar

double* Solver::G_bar
protected

Definition at line 428 of file svm.cpp.

◆ l

int Solver::l
protected

Definition at line 429 of file svm.cpp.

◆ p

double* Solver::p
protected

Definition at line 426 of file svm.cpp.

◆ Q

const QMatrix* Solver::Q
protected

Definition at line 422 of file svm.cpp.

◆ QD

const double* Solver::QD
protected

Definition at line 423 of file svm.cpp.

◆ unshrink

bool Solver::unshrink
protected

Definition at line 430 of file svm.cpp.

◆ y

schar* Solver::y
protected

Definition at line 417 of file svm.cpp.


The documentation for this class was generated from the following file: