Xmipp  v3.23.11-Nereus
svm_classifier.cpp
Go to the documentation of this file.
1 /***************************************************************************
2  *
3  * Authors: Vahid Abrishami (vabrishami@cnb.csic.es)
4  *
5  * Unidad de Bioinformatica of Centro Nacional de Biotecnologia , CSIC
6  *
7  * This program is free software; you can redistribute it and/or modify
8  * it under the terms of the GNU General Public License as published by
9  * the Free Software Foundation; either version 2 of the License, or
10  * (at your option) any later version.
11  *
12  * This program is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15  * GNU General Public License for more details.
16  *
17  * You should have received a copy of the GNU General Public License
18  * along with this program; if not, write to the Free Software
19  * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
20  * 02111-1307 USA
21  *
22  * All comments concerning this program package may be sent to the
23  * e-mail address 'xmipp@cnb.csic.es'
24  ***************************************************************************/
25 #include "svm_classifier.h"
26 #include "core/multidim_array.h"
27 #include "core/xmipp_filename.h"
28 
29 #ifdef UNUSED // detected as unused 29.6.2018
30 bool findElementIn1DArray(MultidimArray<double> &inputArray,double element)
31 {
33  if (DIRECT_A1D_ELEM(inputArray,i)==element)
34  return true;
35  return false;
36 }
37 #endif
38 
40 {
43  param.degree = 2;
44  param.gamma = gamma;
45  param.coef0 = 0;
46  param.nu = 0.1;
47  param.cache_size = 1000;
48  param.C = c;
49  param.eps = 0.001;
50  param.p = 0.1;
51  param.shrinking = 1;
52  param.probability = 1;
53  param.nr_weight = 0;
54  param.weight_label = NULL;
55  param.weight = NULL;
56  model=NULL;
57  prob.y=NULL;
58  prob.x=NULL;
59 }
60 
62 {
63  param=other.param;
64  prob=other.prob;
65  delete model;
66  model=new svm_model(*other.model);
67  return *this;
68 }
69 
71 {
74  if (prob.y!=NULL)
75  delete [] prob.y;
76  if (prob.x!=NULL)
77  {
78  for(int i=0;i<prob.l;i++)
79  delete [] prob.x[i];
80  delete [] prob.x;
81  }
82 }
83 
85 {
86 
87  prob.l = YSIZE(trainSet);
88  prob.y = new double[prob.l];
89  prob.x = new svm_node *[prob.l+1];
90  const char *error_msg;
91  for (size_t i=0;i<YSIZE(trainSet);i++)
92  {
93  prob.x[i]=new svm_node[XSIZE(trainSet)+1];
94  int cnt = 0;
95  for (size_t j=0;j<XSIZE(trainSet);j++)
96  {
97  if (trainSet(i,j)==0)
98  continue;
99  else
100  {
101  prob.x[i][cnt].value=DIRECT_A2D_ELEM(trainSet,i,j);
102  prob.x[i][cnt].index=j+1;
103  cnt++;
104  }
105  }
106  prob.x[i][cnt].index=-1;
107  prob.x[i][cnt].value=2;
108  prob.y[i] = DIRECT_A1D_ELEM(label,i);
109  }
110  error_msg = svm_check_parameter(&prob,&param);
111  if(error_msg)
112  {
113  fprintf(stderr,"ERROR: %s\n",error_msg);
114  exit(1);
115  }
117 }
118 double SVMClassifier::predict(MultidimArray<double> &featVec,double &score)
119 {
120  svm_node *x_space;
121  int cnt=0;
122  int nr_class=svm_get_nr_class(model);
123  double *prob_estimates=new double[nr_class];
124  x_space=new svm_node[XSIZE(featVec)+1];
125 
126  for (size_t i=0;i<XSIZE(featVec);i++)
127  {
128  if (DIRECT_A1D_ELEM(featVec,i)==0)
129  continue;
130  else
131  {
132  x_space[cnt].value=DIRECT_A1D_ELEM(featVec,i);
133  x_space[cnt].index=i+1;
134  cnt++;
135  }
136  }
137  x_space[cnt].index=-1;
138  double label=svm_predict_probability(model,x_space,prob_estimates);
139  // Extracting the probability of the selected class
140  score=prob_estimates[0];
141  for (int i=1;i<nr_class;++i)
142  if (prob_estimates[i]>score)
143  score=prob_estimates[i];
144  delete [] prob_estimates;
145  delete [] x_space;
146  return label;
147 }
148 void SVMClassifier::SaveModel(const FileName &fnModel)
149 {
150  if (model->l!=0)
151  svm_save_model(fnModel.c_str(),model);
152 }
153 void SVMClassifier::LoadModel(const FileName &fnModel)
154 {
155  model=svm_load_model(fnModel.c_str());
156 }
157 
158 #ifdef UNUSED // detected as unused 29.6.2018
159 int SVMClassifier::getNumClasses()
160 {
161  return svm_get_nr_class(model);
162 }
163 #endif
164 
#define YSIZE(v)
SVMClassifier & operator=(const SVMClassifier &other)
int svm_get_nr_class(const svm_model *model)
Definition: svm.cpp:2471
void LoadModel(const FileName &fnModel)
doublereal * c
double svm_predict_probability(const svm_model *model, const svm_node *x, double *prob_estimates)
Definition: svm.cpp:2598
svm_model * svm_train(const svm_problem *prob, const svm_parameter *param)
Definition: svm.cpp:2098
svm_parameter param
void SaveModel(const FileName &fnModel)
#define DIRECT_A2D_ELEM(v, i, j)
double value
Definition: svm.h:19
int l
Definition: svm.h:62
int nr_weight
Definition: svm.h:46
int * weight_label
Definition: svm.h:47
#define FOR_ALL_DIRECT_ELEMENTS_IN_ARRAY1D(v)
void svm_free_and_destroy_model(svm_model **model_ptr_ptr)
Definition: svm.cpp:3036
double * gamma
#define i
svm_problem prob
Definition: svm.h:58
double p
Definition: svm.h:50
#define DIRECT_A1D_ELEM(v, i)
double cache_size
Definition: svm.h:43
void setParameters(double c, double gamma)
void svm_destroy_param(svm_parameter *param)
Definition: svm.cpp:3046
double eps
Definition: svm.h:44
#define XSIZE(v)
int shrinking
Definition: svm.h:51
const char * svm_check_parameter(const svm_problem *prob, const svm_parameter *param)
Definition: svm.cpp:3052
struct svm_node ** x
Definition: svm.h:26
#define j
int index
Definition: svm.h:18
double predict(MultidimArray< double > &featVec, double &score)
int svm_save_model(const char *model_file_name, const svm_model *model)
Definition: svm.cpp:2653
svm_model * model
int probability
Definition: svm.h:52
int degree
Definition: svm.h:38
Definition: svm.h:16
double * y
Definition: svm.h:25
fprintf(glob_prnt.io, "\)
double gamma
Definition: svm.h:39
int l
Definition: svm.h:24
double * weight
Definition: svm.h:48
double C
Definition: svm.h:45
int svm_type
Definition: svm.h:36
svm_model * svm_load_model(const char *model_file_name)
Definition: svm.cpp:2893
double nu
Definition: svm.h:49
void SVMTrain(MultidimArray< double > &trainSet, MultidimArray< double > &lable)
double coef0
Definition: svm.h:40
int kernel_type
Definition: svm.h:37