pktools  2.6.7
Processing Kernel for geospatial data
ConfusionMatrix.h
1 /**********************************************************************
2 ConfusionMatrix.h: class for (classification accuracy) confusion matrix
3 Copyright (C) 2008-2012 Pieter Kempeneers
4 
5 This file is part of pktools
6 
7 pktools is free software: you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation, either version 3 of the License, or
10 (at your option) any later version.
11 
12 pktools is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16 
17 You should have received a copy of the GNU General Public License
18 along with pktools. If not, see <http://www.gnu.org/licenses/>.
19 ***********************************************************************/
20 #ifndef _CONFUSIONMATRIX_H_
21 #define _CONFUSIONMATRIX_H_
22 
23 #include <sstream>
24 #include <vector>
25 #include "base/Vector2d.h"
26 #include "base/Optionpk.h"
27 
28 namespace confusionmatrix
29 {
30  enum CM_FORMAT { ASCII = 0, LATEX = 1, HTML = 2 };
31 
33 public:
35  ConfusionMatrix(short nclass);
36  ConfusionMatrix(const std::vector<std::string>& classNames);
38  ConfusionMatrix& operator=(const ConfusionMatrix& cm);
39  short size() const {return m_results.size();};
40  void resize(short nclass);
41  void setClassNames(const std::vector<std::string>& classNames, bool doSort=false);
42  void pushBackClassName(const std::string& className, bool doSort=false);
43  void setResults(const Vector2d<double>& theResults);
44  void setResult(const std::string& theRef, const std::string& theClass, double theResult);
45  void incrementResult(const std::string& theRef, const std::string& theClass, double theIncrement);
46  void clearResults();
47  double nReference(const std::string& theRef) const;
48  double nReference() const;
49  double nClassified(const std::string& theRef) const;
50  int nClasses() const {return m_classes.size();};
51  std::string getClass(int iclass) const {assert(iclass>=0);assert(iclass<m_classes.size());return m_classes[iclass];};
52  int getClassIndex(std::string className) const {
53  int index=0;
54  for(index=0;index<m_classes.size();++index){
55  if(m_classes[index]==className)
56  break;
57  }
58  if(index>=m_classes.size())
59  index=-1;
60  return index;
61  // int index=distance(m_classes.begin(),find(m_classes.begin(),m_classes.end(),className));
62  // assert(index>=0);
63  // if(index<m_results.size())
64  // return(index);
65  // else
66  // return(-1);
67  }
68  std::vector<std::string> getClassNames() const {return m_classes;};
69  ~ConfusionMatrix();
70  double pa(const std::string& theClass, double* se95=NULL) const;
71  double ua(const std::string& theClass, double* se95=NULL) const;
72  double oa(double* se95=NULL) const;
73  int pa_pct(const std::string& theClass, double* se95=NULL) const;
74  int ua_pct(const std::string& theClass, double* se95=NULL) const;
75  int oa_pct(double* se95=NULL) const;
76  double kappa() const;
77  ConfusionMatrix& operator*=(double weight);
78  ConfusionMatrix operator*(double weight);
79  ConfusionMatrix& operator+=(const ConfusionMatrix &cm);
80  ConfusionMatrix operator+(const ConfusionMatrix &cm){
81  return ConfusionMatrix(*this)+=cm;
82  }
83  void sortClassNames();
84 
85  void reportSE95(bool doReport) {m_se95=doReport;};
86  void setFormat(const CM_FORMAT& theFormat) {m_format=theFormat;};
87  void setFormat(const std::string theFormat) {m_format=getFormat(theFormat);};
88  CM_FORMAT getFormat() const {return m_format;};
89 
90  static const CM_FORMAT getFormat(const std::string theFormat){
91  if(theFormat=="ascii") return(ASCII);
92  else if(theFormat=="latex") return(LATEX);
93  else{
94  std::string errorString="Format not supported: ";
95  errorString+=theFormat;
96  errorString+=" use ascii or latex";
97  throw(errorString);
98  }
99  };
100 
101  friend std::ostream& operator<<(std::ostream& os, const ConfusionMatrix &cm){
102  std::ostringstream streamLine;
103  /* streamosclass << iclass; */
104  /* m_classes[iclass]=osclass.str(); */
105 
106  std::string fieldSeparator=" ";
107  std::string lineSeparator="";
108  std::string mathMode="";
109  switch(cm.getFormat()){
110  case(LATEX):
111  fieldSeparator=" & ";
112  lineSeparator="\\\\";
113  mathMode="$";
114  break;
115  case(ASCII):
116  default:
117  fieldSeparator="\t";
118  lineSeparator="";
119  mathMode="";
120  break;
121  }
122 
123  double se95_ua=0;
124  double se95_pa=0;
125  double se95_oa=0;
126  double dua=0;
127  double dpa=0;
128  double doa=0;
129 
130  doa = cm.oa(&se95_oa);
131 
132  if(cm.getFormat()==LATEX){
133  os << "\\documentclass{article}" << std::endl;
134  os << "\\begin{document}" << std::endl;
135  }
136  os << "Kappa = " << mathMode << cm.kappa() << mathMode ;
137  os << ", Overall Acc. = " << mathMode << 100.0*cm.oa() << mathMode ;
138  if(cm.m_se95)
139  os << " (" << mathMode << se95_oa << mathMode << ")";
140  os << std::endl;
141  os << std::endl;
142  if(cm.getFormat()==LATEX){
143  os << "\\begin{tabular}{@{}l";
144  for(int iclass=0;iclass<cm.nClasses();++iclass)
145  os << "l";
146  os << "}" << std::endl;
147  os << "\\hline" << std::endl;
148  }
149 
150  os << "Class";
151  for(int iclass=0;iclass<cm.nClasses();++iclass)
152  os << fieldSeparator << cm.m_classes[iclass];
153  os << lineSeparator << std::endl;
154  if(cm.getFormat()==LATEX)
155  os << "\\hline" << std::endl;
156  assert(cm.m_classes.size()==cm.m_results.size());
157  for(int irow=0;irow<cm.m_results.size();++irow){
158  os << cm.m_classes[irow];
159  for(int icol=0;icol<cm.m_results[irow].size();++icol)
160  os << fieldSeparator << cm.m_results[irow][icol];
161  os << lineSeparator<< std::endl;
162  }
163  if(cm.getFormat()==LATEX){
164  os << "\\hline" << std::endl;
165  }
166  else
167  os << std::endl;
168 
169  os << "User' Acc.";
170  for(int iclass=0;iclass<cm.nClasses();++iclass){
171  dua=cm.ua_pct(cm.m_classes[iclass],&se95_ua);
172  os << fieldSeparator << dua;
173  if(cm.m_se95)
174  os << " (" << se95_ua << ")";
175  }
176  os << lineSeparator<< std::endl;
177  os << "Prod. Acc.";
178  for(int iclass=0;iclass<cm.nClasses();++iclass){
179  dpa=cm.pa_pct(cm.m_classes[iclass],&se95_ua);
180  os << fieldSeparator << dpa;
181  if(cm.m_se95)
182  os << " (" << se95_pa << ")";
183  }
184  os << lineSeparator<< std::endl;
185  if(cm.getFormat()==LATEX){
186  os << "\\end{tabular}" << std::endl;
187  os << "\\end{document}" << std::endl;
188  }
189  return os;
190  };
191 private:
192  std::vector<std::string> m_classes;
193  Vector2d<double> m_results;
194  CM_FORMAT m_format;
195  bool m_se95;
196 };
197 }
198 #endif /* _CONFUSIONMATRIX_H_ */