Main Page | Modules | Namespace List | Class Hierarchy | Alphabetical List | Class List | Directories | File List | Namespace Members | Class Members | File Members | Related Pages

ml_class_stats.cpp

Go to the documentation of this file.
00001 /* ---------------------------------------------------------------------------
00002     Phission : 
00003         Realtime Vision Processing System
00004     
00005     Copyright (C) 2003-2005 Philip D.S. Thoren (pthoren@cs.uml.edu)
00006     University of Massachusetts at Lowell,
00007     Laboratory for Artificial Intelligence and Robotics
00008     
00009     This file is part of Phission.
00010 
00011  ---------------------------------------------------------------------------*/
00012 #include <ml_class_stats.h>
00013 #include <phission.h>
00014 
00015 #ifdef __cplusplus
00016 extern "C" {
00017 #endif
00018     
00019 /* ------------------------------------------------------------------------ */
00020 int ml_class_stats_new( uint32_t           num_outputs,
00021                          ml_class_stats  **pstats )
00022 {
00023     phFUNCTION("ml_class_stats_new")
00024 
00025     if (pstats == NULL) return phFAIL;
00026     
00027     *pstats = (ml_class_stats *)phCalloc(num_outputs,
00028                                     sizeof(ml_class_stats));
00029     phCHECK_PTR(*pstats,"phCalloc","phCalloc failed.");
00030 
00031     (*pstats)->num_outputs = num_outputs;
00032 
00033     (*pstats)->node_correctly_classified = 
00034             (uint32_t *)phCalloc(num_outputs,sizeof(uint32_t));
00035     phCHECK_PTR((*pstats),"phCalloc","phCalloc failed.");
00036     
00037     (*pstats)->node_incorrectly_classified =
00038             (uint32_t *)phCalloc(num_outputs,sizeof(uint32_t));
00039     phCHECK_PTR((*pstats),"phCalloc","phCalloc failed.");
00040 
00041     (*pstats)->print_output_nodes = 0;
00042     (*pstats)->print_aggregate = 1;
00043     (*pstats)->mse = 0.0F;
00044    
00045     return phSUCCESS;
00046 error:
00047     return phFAIL;
00048 }
00049 
00050 /* ------------------------------------------------------------------------ */
00051 int ml_class_stats_free( ml_class_stats **pstats )
00052 {
00053     if (pstats == NULL) return phSUCCESS;
00054 
00055     phFree((*pstats)->node_correctly_classified);
00056     phFree((*pstats)->node_incorrectly_classified);
00057     phFree(*pstats);
00058 
00059     return phSUCCESS;
00060 }
00061 /* ------------------------------------------------------------------------ */
00062 int ml_class_stats_reset( ml_class_stats *stats )
00063 {
00064     uint32_t i = 0;
00065     if (stats == NULL) return phSUCCESS;
00066 
00067     stats->output_matches_active   = 0;
00068     stats->output_matches_inactive = 0;
00069     stats->correctly_classified    = 0;
00070     stats->incorrectly_classified  = 0;
00071     stats->output_nodes_active     = 0;
00072     stats->output_nodes_inactive   = 0;
00073     stats->nclassifications        = 0;
00074 
00075     for (i = 0; i < stats->num_outputs; i++ )
00076     {
00077         stats->node_correctly_classified[i] = 0;
00078         stats->node_incorrectly_classified[i] = 0;
00079     }
00080 
00081     stats->mse = 0.0F;
00082 
00083     return phSUCCESS;
00084 }
00085  
00086 /* ------------------------------------------------------------------------ */
00087 /* This is the function that compares the actual to the expected outputs */
00088 int ml_class_stats_classify( ml_class_stats   *stats,
00089                              fann_type         *expected_output,
00090                              fann_type         *actual_output   )
00091 {
00092     uint32_t k = 0;
00093     
00094     /* TODO: make this a function */
00095     /* correctly classified */
00096     stats->output_matches_active   = 0;
00097     stats->output_matches_inactive = 0;
00098     stats->output_nodes_active     = 0;
00099     stats->output_nodes_inactive   = 0;
00100     for (k = 0; k < stats->num_outputs; k++ )
00101     {
00102         if (expected_output[k] ==  1.0F)  stats->output_nodes_active++;
00103         if (expected_output[k] == -1.0F)  stats->output_nodes_inactive++;
00104 
00105         /* if the actual output falls within the threshold
00106          * (1.0 +/- thresh) and the output node value is
00107          * also active (node is either 1 or 0), then the
00108          * node is correctly classified. */                         
00109         //if ((((actual_output[k] >= (1.0-thresh)) && (actual_output[k] <= (1.0+thresh))) && (output_node_values[k] > 0)))
00110         //if ((actual_output[k] >= 1.0) && (expected_output[k] > 0.0))
00111         if ((expected_output[k] > 0.99999F) && 
00112             (actual_output[k]   > 0.99999F))
00113         {
00114             stats->output_matches_active++;
00115             stats->node_correctly_classified[k]++;
00116         }
00117         /* if the actual output isn't within the threshold
00118          * of 1.0 and the node is inactive, then the node
00119          * is correctly classified */
00120         //else if ((actual_output[k] < 1.0) && (expected_output[k] < 1.0))
00121         else if ((expected_output[k] < -0.99999F) && 
00122                  (actual_output[k]   < -0.99999F))
00123         {
00124             stats->output_matches_inactive++;
00125             stats->node_correctly_classified[k]++;
00126         }
00127         /* otherwise the node is incorrectly classified */
00128         else
00129         {
00130             stats->node_incorrectly_classified[k]++;
00131         }
00132     }
00133     if ((stats->output_matches_active == stats->output_nodes_active) &&
00134         (stats->output_matches_inactive == stats->output_nodes_inactive))
00135     {
00136         stats->correctly_classified++;
00137     }
00138     else
00139     {
00140         stats->incorrectly_classified++;
00141     }
00142     
00143     stats->nclassifications++;
00144 
00145     return phSUCCESS;
00146 }
00147 
00148 /* ------------------------------------------------------------------------ */
00149 int ml_class_stats_set_mse( ml_class_stats *stats, float mse )
00150 {
00151     if (stats == NULL) return phFAIL;
00152     stats->mse = mse;
00153     return phSUCCESS;
00154 }
00155 
00156 /* ------------------------------------------------------------------------ */
00157 int ml_class_stats_set_options( ml_class_stats *stats,
00158                                 int print_output_nodes,
00159                                 int print_aggregate )
00160 {
00161     if (stats == NULL) return phFAIL;
00162 
00163     stats->print_output_nodes = print_output_nodes;
00164     stats->print_aggregate   = print_aggregate;
00165     return phSUCCESS;
00166 }
00167 
00168 /* ------------------------------------------------------------------------ */
00169 int ml_class_stats_headers_print( FILE             *log_fp, 
00170                                   ml_class_stats   *stats,
00171                                   char            **tags, 
00172                                   uint32_t          ntags )
00173 {
00174     uint32_t k = 0;
00175 
00176     if ((ntags == 0)    || 
00177         (log_fp == NULL)|| 
00178         (tags == NULL)  ||
00179         (stats == NULL)) 
00180     {
00181         return phFAIL;
00182     }
00183 
00184     /* Print out the row headers for the data being logged, to the logfile */
00185     fprintf(log_fp, "%s", "Epoch" );
00186     if (stats->print_aggregate)
00187     {
00188         fprintf(log_fp,
00189                 " %s %s %s %s %s %s %s",
00190                 "TrainMSE",
00191                 "TrainSetMSE",
00192                 "TrainSetSize",
00193                 "TrainSetClassified",
00194                 "TrainSetIncorrectClassified",
00195                 "TrainSetPercentClassified",
00196                 "TrainSetPercentIncorrectClassified" );
00197     }
00198     
00199     if (stats->print_output_nodes)
00200     {
00201         /* Print out a row header for each of the output node statistics */
00202         for (k = 0; k < ntags; k++ )
00203         {
00204             fprintf(log_fp," %sCorrect_train %sIncorrect_train", 
00205                     tags[k], tags[k] );
00206             fprintf(log_fp," %sPrctCorrect_train %sPrctIncorrect_train", 
00207                     tags[k], tags[k] );
00208         }
00209     }
00210     
00211     if (stats->print_aggregate)
00212     {
00213         fprintf(log_fp,
00214                 " %s %s %s %s %s %s",
00215                 "TestSetMSE",
00216                 "TestSetSize",
00217                 "TestSetClassified",
00218                 "TestSetIncorrectClassified",
00219                 "TestSetPercentClassified",
00220                 "TestSetPercentIncorrectClassified");
00221     }
00222     
00223     if (stats->print_output_nodes)
00224     {
00225         for (k = 0; k < ntags; k++ )
00226         {
00227             fprintf(log_fp," %sCorrect_test %sIncorrect_test", 
00228                     tags[k], tags[k] );
00229             fprintf(log_fp," %sPrctCorrect_test %sPrctIncorrect_test", 
00230                     tags[k], tags[k] );
00231         }
00232     }
00233 
00234     fprintf(log_fp,"\n");
00235     /* flush the headers out NOW! Don't wait for the data to be
00236      * automatically flushed by the system. */
00237     fflush(log_fp);
00238 
00239     return phSUCCESS;
00240 }
00241      
00242 /* ------------------------------------------------------------------------ */
00243 int ml_class_stats_info_print( FILE             *fp, 
00244                                ml_class_stats   *stats )
00245 {
00246     uint32_t k = 0;
00247     float ratio_correct     = 0.0;
00248     float ratio_incorrect   = 0.0;
00249     
00250     if (fp == NULL) return phSUCCESS;
00251 
00252     if (stats->print_aggregate)
00253     {
00254         ratio_correct = ((float)stats->correctly_classified / 
00255                          (float)(stats->nclassifications)),
00256         ratio_incorrect = ((float)stats->incorrectly_classified / 
00257                            (float)(stats->nclassifications));
00258     
00259         /* print the training set MSE to the log file */
00260         fprintf(fp," %0.20f %u %u %u %0.13f %0.13f",
00261                 stats->mse,
00262                 stats->nclassifications,
00263                 stats->correctly_classified,
00264                 stats->incorrectly_classified,
00265                 ratio_correct,
00266                 ratio_incorrect );
00267     }
00268 
00269     if (stats->print_output_nodes)
00270     {
00271         for (k = 0; k < stats->num_outputs; k++ )
00272         {
00273             ratio_correct = ((float)stats->node_correctly_classified[k] / 
00274                              (float)(stats->nclassifications));
00275             ratio_incorrect= ((float)stats->node_incorrectly_classified[k] / 
00276                               (float)(stats->nclassifications));
00277             fprintf(fp," %u %u %0.13f %0.13f",
00278                     stats->node_correctly_classified[k],
00279                     stats->node_incorrectly_classified[k],
00280                     ratio_correct,
00281                     ratio_incorrect );
00282         }
00283     }
00284     fflush(fp);
00285 
00286     return phSUCCESS;
00287 }
00288  
00289 /* ------------------------------------------------------------------------ */
00290 int ml_class_stats_row_print( FILE             *fp, 
00291                               int               epoch,
00292                               float             train_mse,
00293                               ml_class_stats   *train_stats,
00294                               ml_class_stats   *test_stats )
00295 {
00296     uint32_t k = 0;
00297     float ratio_correct     = 0.0;
00298     float ratio_incorrect   = 0.0;
00299     
00300     if (fp == NULL) return phSUCCESS;
00301 
00302     fprintf(fp,"%u",epoch);
00303     fprintf(fp," %0.20f", train_mse );
00304     /* print the validation set MSE to the log file */
00305     /* Print the classification statistics to file */
00306     ml_class_stats_info_print(fp,train_stats);
00307     ml_class_stats_info_print(fp,test_stats);
00308     fprintf(fp,"\n");
00309     fflush(fp);
00310 
00311     return phSUCCESS;
00312 }
00313  
00314 
00315 #ifdef __cplusplus
00316 }
00317 #endif
00318 
00319 




Copyright (C) 2002 - 2007 Philip D.S. Thoren ( pthoren@users.sourceforge.net )
University Of Massachusetts at Lowell
Robotics Lab
SourceForge.net Logo

Generated on Sat Jun 16 02:44:06 2007 for phission by  doxygen 1.4.4