20 #include "ConfusionMatrix.h" 26 bool compareClass(
const std::string& string1,
const std::string& string2){
27 int int1=string2type<int>(string1);
28 int int2=string2type<int>(string2);
32 ConfusionMatrix::ConfusionMatrix()
33 : m_classes(),m_results(),m_se95(true),m_format(ASCII)
37 ConfusionMatrix::~ConfusionMatrix()
42 ConfusionMatrix::ConfusionMatrix(
short nclass){
46 ConfusionMatrix::ConfusionMatrix(
const std::vector<std::string>& classNames){
47 setClassNames(classNames);
52 setClassNames(cm.m_classes);
53 setResults(cm.m_results);
62 setClassNames(cm.m_classes);
63 setResults(cm.m_results);
70 if(cm.m_classes.size()!=this->m_classes.size()){
71 std::cerr <<
"error0: "<< cm.m_classes.size() <<
"!=" << this->m_classes.size() << std::endl;
74 if(cm.m_results.size()!=this->m_results.size()){
75 std::cerr <<
"error1: "<< cm.m_results.size() <<
"!=" << this->m_results.size() << std::endl;
78 for(
int irow=0;irow<m_results.size();++irow){
79 if(cm.m_results[irow].size()!=this->m_results[irow].size()){
80 std::cerr <<
"error2: " << cm.m_results[irow].size() <<
"!=" << this->m_results[irow].size() << std::endl;
83 for(
int icol=0;icol<m_results[irow].size();++icol)
84 this->m_results[irow][icol]+=cm.m_results[irow][icol];
91 for(
int irow=0;irow<m_results.size();++irow){
92 for(
int icol=0;icol<m_results[irow].size();++icol)
93 m_results[irow][icol]*=weight;
98 void ConfusionMatrix::sortClassNames(){
99 sort(m_classes.begin(),m_classes.end(),compareClass);
109 void ConfusionMatrix::resize(
short nclass){
110 m_classes.resize(nclass);
111 for(
short iclass=0;iclass<nclass;++iclass){
112 std::ostringstream osclass;
114 m_classes[iclass]=osclass.str();
116 m_results.resize(nclass,nclass);
119 void ConfusionMatrix::setClassNames(
const std::vector<std::string>& classNames,
bool doSort){
120 m_classes=classNames;
123 if(m_results.size()!=m_classes.size())
124 m_results.resize(m_classes.size(),m_classes.size());
127 void ConfusionMatrix::pushBackClassName(
const std::string& className,
bool doSort){
128 m_classes.push_back(className);
131 if(m_results.size()!=m_classes.size())
132 m_results.resize(m_classes.size(),m_classes.size());
137 m_results=theResults;
140 void ConfusionMatrix::clearResults(){
142 m_results.resize(m_classes.size(),m_classes.size());
145 void ConfusionMatrix::setResult(
const std::string& theRef,
const std::string& theClass,
double theResult){
152 int ir=getClassIndex(theRef);
153 int ic=getClassIndex(theClass);
154 m_results[ir][ic]=theResult;
157 void ConfusionMatrix::incrementResult(
const std::string& theRef,
const std::string& theClass,
double theIncrement){
160 int ir=getClassIndex(theRef);
161 int ic=getClassIndex(theClass);
163 if(ir>=m_results.size())
164 std::cerr <<
"Error: " << theRef <<
" not found in class ConfusionMatrix when incrementing for class " << theClass << std::endl;
165 assert(ir<m_results.size());
167 assert(ic<m_results[ir].size());
168 m_results[ir][ic]+=theIncrement;
171 double ConfusionMatrix::nReference(
const std::string& theRef)
const{
173 int ir=getClassIndex(theRef);
174 return accumulate(m_results[ir].begin(),m_results[ir].end(),0);
177 double ConfusionMatrix::nReference()
const{
179 for(
int ir=0;ir<m_classes.size();++ir)
180 nref+=accumulate(m_results[ir].begin(),m_results[ir].end(),0);
184 double ConfusionMatrix::nClassified(
const std::string& theClass)
const{
186 int ic=getClassIndex(theClass);
187 double nclassified=0;
188 for(
int iref=0;iref<m_results.size();++iref){
189 assert(ic<m_results[iref].size());
190 nclassified+=m_results[iref][ic];
195 double ConfusionMatrix::pa(
const std::string& theClass,
double* se95)
const{
196 assert(m_results.size());
197 assert(m_results.size()==m_classes.size());
200 int ir=getClassIndex(theClass);
202 assert(ir<m_results.size());
203 assert(!theClass.compare(m_classes[ir]));
204 for(
int iclass=0;iclass<m_results.size();++iclass){
205 assert(iclass<m_results[ir].size());
206 producer+=m_results[ir][iclass];
208 double dpa=(producer>0)? static_cast<double>(m_results[ir][ir])/producer : 0;
211 *se95=(dpa<1&&dpa>0)? sqrt(dpa*dqa/(producer-1)) : 0;
215 int ConfusionMatrix::pa_pct(
const std::string& theClass,
double* se95)
const{
216 double dpa=pa(theClass,se95);
218 *se95=
static_cast<double>(
static_cast<int>(0.5+1000*(*se95)))/10.0;
219 return static_cast<int>(0.5+100.0*dpa);
223 double ConfusionMatrix::ua(
const std::string& theClass,
double* se95)
const{
224 assert(m_results.size());
225 assert(m_results.size()==m_classes.size());
228 int ic=getClassIndex(theClass);
230 assert(ic<m_results.size());
231 assert(!theClass.compare(m_classes[ic]));
232 for(
int iref=0;iref<m_results.size();++iref){
233 assert(ic<m_results[iref].size());
234 user+=m_results[iref][ic];
236 double dua=(user>0)? static_cast<double>(m_results[ic][ic])/user : 0;
239 *se95=(dua<1&&dua>0)? sqrt(dua*dva/(user-1)) : 0;
243 int ConfusionMatrix::ua_pct(
const std::string& theClass,
double* se95)
const{
244 double dua=ua(theClass,se95);
246 *se95=
static_cast<double>(
static_cast<int>(0.5+1000*(*se95)))/10.0;
247 return static_cast<int>(0.5+100.0*dua);
250 double ConfusionMatrix::oa(
double* se95)
const{
251 double ntotal=m_results.sum();
254 for(
int iclass=0;iclass<m_classes.size();++iclass)
255 pCorrect+=static_cast<double>(m_results[iclass][iclass])/ntotal;
256 double qCorrect=1-pCorrect;
258 *se95=(pCorrect<1&&pCorrect>0)? sqrt(pCorrect*qCorrect/(ntotal-1)) : 0;
265 int ConfusionMatrix::oa_pct(
double* se95)
const{
268 *se95=
static_cast<double>(
static_cast<int>(0.5+1000*(*se95)))/10.0;
269 return static_cast<int>(0.5+100.0*doa);
272 double ConfusionMatrix::kappa()
const{
273 double ntotal=m_results.sum();
276 for(
int iclass=0;iclass<m_classes.size();++iclass){
277 pChance+=nClassified(m_classes[iclass])*nReference(m_classes[iclass])/ntotal/ntotal;
278 pCorrect+=
static_cast<double>(m_results[iclass][iclass])/ntotal;
281 return((pCorrect-pChance)/(1-pChance));