Main Page | Modules | Namespace List | Class Hierarchy | Alphabetical List | Class List | Directories | File List | Namespace Members | Class Members | File Members | Related Pages

MLImageNeuralNetwork.cpp

Go to the documentation of this file.
00001 /* ---------------------------------------------------------------------------
00002     Phission : 
00003         Realtime Vision Processing System
00004     
00005     Copyright (C) 2003-2005 Philip D.S. Thoren (pthoren@cs.uml.edu)
00006     University of Massachusetts at Lowell,
00007     Laboratory for Artificial Intelligence and Robotics
00008     
00009     This file is part of Phission.
00010 
00011  ---------------------------------------------------------------------------*/
00012 #include <image_collection.h>
00013 #include <MLImageNeuralNetwork.h>
00014 #include <phANNSystem.h>
00015 
00016 /* ------------------------------------------------------------------------ */
00017 /* Todo:
00018     Adjustable Learning rate & sigmoid activation
00019     Adjustable learning method
00020     Vary hidden nodes between low and high
00021     
00022     Add segmentation/classification code; 
00023         Run cropped and resized images through ANN
00024         Test segmentation code with meadian/mean/gaussian 
00025  
00026    Way in the future todo:
00027    
00028     Add code to collect training data by thresholding the image and
00029         outputing the .dat file.
00030         that will be input to idc_crop_resize and the rest of the functions
00031  */
00032 /* ------------------------------------------------------------------------ */
00033 
00034 /* ------------------------------------------------------------------------ */
00035 int usage()
00036 {
00037     printf("\n\nUsage:\n");
00038     printf("\t--help\t\tdisplay usage\n");
00039     printf("\n");
00040     printf("\t--data <file>\tfile is the file that contains the list of files and resize info\n");
00041     printf("\n");
00042     printf("\t--resize\tresize and crop the input data to the proper size\n");
00043     printf("\n");
00044     printf("\t--dotrain\ttrain on the data\n");
00045     printf("\n");
00046     printf("\t--rate <float>\tthe learning rate when the ANN uses incremental learning\n");
00047     printf("\n");
00048     printf("\t--connections <float>\t0.0-1.0 the connection ratio of the ANN\n");
00049     printf("\n");
00050     printf("\t--alg <rprop, qprop, inc, batch>\t Choose one of the algorithms to train the ANN (default:2)");
00051     printf("\n");
00052     printf("\t--error <float>\tThe desired Mean Square Error of the validation set testing.\n"); 
00053     printf("\t\t\tThis determines when the program thinks it's finished.\n");
00054     printf("\n");
00055     printf("\t--maxepochs <epochs>\tLimit the number of training epochs.\n"); 
00056     printf("\n");
00057     printf("\t--classify\tclassifies an input source(image or video)\n");
00058     printf("\t\t\tMust define annfile, input label, input type, and image count\n");
00059     printf("\n");
00060     printf("\t--in <label>\t\"dev/video0\", \"0\", or \"movie/imgs\"\n");
00061     printf("\n");
00062     printf("\t--type <0:1:2>\t0/1 for ppm/jpg images w/<label> as the prefix\n");
00063     printf("\t\t\t2 for V4LCapture or VFWSource w/<label> as the device\n");
00064     printf("\n");
00065     printf("\t--nimgs <count>\ttotal number of images when using ppm as input\n");
00066     printf("\n");
00067     printf("\t--ann <file>\tANN saved file output from FANN\n");
00068     printf("\t\t\twhen classifying, used as the ANN; when training\n");
00069     printf("\t\t\tused as a starting state for the network.\n");
00070     printf("\n");
00071     printf("\t--epoch <epoch>\tUsed to set the starting epoch when loading from an ANN file.\n"); 
00072     printf("\n");
00073     printf("\tTo resize collected training image data to the size of");
00074     printf(" the desired network size\n");
00075     printf("\t\t./program --resize --data <orig_img_datafile>\n\n");
00076     printf("\n");
00077     printf("\tTo train on a set of images:\n");
00078     printf("\t\t./program --dotrain --data <resize_img_datafile>\n\n");
00079     printf("\n");
00080     printf("\tTo classify input from a learned ANN:\n");
00081     printf("\t\t./program --classify --ann ann.net --in movie_00003/image ");
00082     printf(" --type ppm --nimgs 100\n\n");
00083     printf("\n");
00084     printf("\t--norgb | --nohsv | --nogrey | --nocanny | --nosobel\n");
00085     printf("\t--nohistrgb | --nohisthsv | --noratio\n");
00086     printf("\n\n");
00087     exit(1);
00088 }
00089 
00090 /* ------------------------------------------------------------------------ */
00091 int glbl_nodisplays = 0;
00092 int glbl_disable_displays = 0;
00093                        
00094 /* ------------------------------------------------------------------------ */
00095 /* This program takes in a data file to crop and resize the files for
00096  * input into a training program */
00097 int main(int argc, char *argv[] )
00098 {
00099     phFUNCTION("main")
00100 
00101     int i = 0;
00102 
00103     char                    *datafile = NULL;
00104     image_data_collection   data_collection = NULL;
00105 
00106     phANNSystem  *ann_system = NULL;
00107     
00108     /* ------------------------------------------------------------------ */
00109     /* command line options */
00110     int resize_crop = 0;
00111     int print       = 0;
00112     int train       = 0;
00113     int gen         = 0;
00114     int classify    = 0;
00115 
00116     int input_flags = phML_ALL;
00117     int norgb       = 0;
00118     int nohsv       = 0;
00119     int nocanny     = 0;
00120     int nosobel     = 0;
00121     int nogrey      = 0;
00122     int noratio     = 0;
00123     int nohistrgb   = 0;
00124     int nohisthsv   = 0;
00125 
00126     int use_random  = 0;
00127 
00128     /* ------------------------------------------------------------------ */
00129     /* Training parameters */
00130     uint32_t    train_set_size      = 50;
00131     uint32_t    test_set_size       = 25;
00132     uint32_t    max_epochs          = 50000;
00133     uint32_t    start_epoch         = 0;
00134     float       learning_rate       = 0.8;
00135     float       desired_error       = 0.00000001; /* very unlikely */
00136     float       connection_rate     = 1.0;
00137     uint32_t    num_layers          = 3;
00138     char        *algorithm          = NULL;
00139     uint32_t    training_algorithm  = FANN_TRAIN_RPROP;
00140     uint32_t    nhidden             = 10;
00141 
00142     /* required for classification */
00143     char        *annfile            = NULL;
00144     char        *input_label        = NULL;
00145     int         input_type          = 0;
00146     int         image_count         = 0;
00147 
00148     /* ------------------------------------------------------------------ */
00149     phArgTable      *arg_parser = new phArgTable();
00150 
00151     /* Setup and parse all the arguments */
00152     rc = arg_parser->add("--data",&datafile,phARG_CHAR);
00153     phCHECK_RC(rc,NULL,"arg_parser->add");
00154 
00155     rc = arg_parser->add("--nodisplays",&glbl_nodisplays,phARG_BOOL);
00156     phCHECK_RC(rc,NULL,"arg_parser->add");
00157     
00158     rc = arg_parser->add("--resize",&resize_crop,phARG_BOOL);
00159     phCHECK_RC(rc,NULL,"arg_parser->add");
00160     rc = arg_parser->add("--dotrain",&train,phARG_BOOL);
00161     phCHECK_RC(rc,NULL,"arg_parser->add");
00162     rc = arg_parser->add("--classify",&classify,phARG_BOOL);
00163     phCHECK_RC(rc,NULL,"arg_parser->add");
00164     
00165     rc = arg_parser->add("--print",&print,phARG_BOOL);
00166     phCHECK_RC(rc,NULL,"arg_parser->add");
00167     rc = arg_parser->add("--help",(void *)&usage,phARG_FUNC);
00168     phCHECK_RC(rc,NULL,"arg_parser->add");
00169 
00170     /* required for classification */
00171     rc = arg_parser->add("--ann",&annfile,phARG_CHAR);
00172     phCHECK_RC(rc,NULL,"arg_parser->add");
00173     rc = arg_parser->add("--in",&input_label,phARG_CHAR);
00174     phCHECK_RC(rc,NULL,"arg_parser->add");
00175     rc = arg_parser->add("--type",&input_type,phARG_INT);
00176     phCHECK_RC(rc,NULL,"arg_parser->add");
00177     rc = arg_parser->add("--nimgs",&image_count,phARG_INT);
00178     phCHECK_RC(rc,NULL,"arg_parser->add");
00179 
00180     /* fann variables */
00181     rc = arg_parser->add("--trainsize",&train_set_size,phARG_UINT32);
00182     phCHECK_RC(rc,NULL,"arg_parser->add");
00183     rc = arg_parser->add("--testsize",&test_set_size,phARG_UINT32);
00184     phCHECK_RC(rc,NULL,"arg_parser->add");
00185     rc = arg_parser->add("--maxepochs",&max_epochs,phARG_UINT32);
00186     phCHECK_RC(rc,NULL,"arg_parser->add");
00187     rc = arg_parser->add("--epoch",&start_epoch,phARG_UINT32);
00188     phCHECK_RC(rc,NULL,"arg_parser->add");
00189     rc = arg_parser->add("--layers",&num_layers,phARG_UINT32);
00190     phCHECK_RC(rc,NULL,"arg_parser->add");
00191     rc = arg_parser->add("--error",&desired_error,phARG_FLOAT);
00192     phCHECK_RC(rc,NULL,"arg_parser->add");
00193     rc = arg_parser->add("--rate",&learning_rate,phARG_FLOAT);
00194     phCHECK_RC(rc,NULL,"arg_parser->add");
00195     rc = arg_parser->add("--connections",&connection_rate,phARG_FLOAT);
00196     phCHECK_RC(rc,NULL,"arg_parser->add");
00197     rc = arg_parser->add("--nhidden",&nhidden,phARG_UINT32);
00198     phCHECK_RC(rc,NULL,"arg_parser->add");
00199     rc = arg_parser->add("--alg",&algorithm,phARG_CHAR);
00200     phCHECK_RC(rc,NULL,"arg_parser->add");
00201     
00202     rc = arg_parser->add("--random",&use_random,phARG_INT);
00203     phCHECK_RC(rc,NULL,"arg_parser->add");
00204     
00205     /* variables to disable certain inputs to the ANN */
00206     /* RGB, HSV, CANNY, SOBEL, GREY, RATIO, HISTRGB, HISTHSV */
00207     rc = arg_parser->add("--norgb",&norgb,phARG_BOOL);
00208     phCHECK_RC(rc,NULL,"arg_parser->add");
00209     rc = arg_parser->add("--nohsv",&nohsv,phARG_BOOL);
00210     phCHECK_RC(rc,NULL,"arg_parser->add");
00211     rc = arg_parser->add("--nocanny",&nocanny,phARG_BOOL);
00212     phCHECK_RC(rc,NULL,"arg_parser->add");
00213     rc = arg_parser->add("--nosobel",&nosobel,phARG_BOOL);
00214     phCHECK_RC(rc,NULL,"arg_parser->add");
00215     rc = arg_parser->add("--nogrey",&nogrey,phARG_BOOL);
00216     phCHECK_RC(rc,NULL,"arg_parser->add");
00217     rc = arg_parser->add("--noratio",&noratio,phARG_BOOL);
00218     phCHECK_RC(rc,NULL,"arg_parser->add");
00219     rc = arg_parser->add("--nohistrgb",&nohistrgb,phARG_BOOL);
00220     phCHECK_RC(rc,NULL,"arg_parser->add");
00221     rc = arg_parser->add("--nohisthsv",&nohisthsv,phARG_BOOL);
00222     phCHECK_RC(rc,NULL,"arg_parser->add");
00223 
00224     /* parse the command line */
00225     rc = arg_parser->parse(argc,argv);
00226     phCHECK_RC(rc,NULL,"arg_parser->parse");
00227     
00228 
00229     /* process the input flags */
00230     if (norgb)      input_flags &= ~(phML_RGB);
00231     if (nohsv)      input_flags &= ~(phML_HSV);
00232     if (nocanny)    input_flags &= ~(phML_CANNY);
00233     if (nosobel)    input_flags &= ~(phML_SOBEL);
00234     if (nogrey)     input_flags &= ~(phML_GREY);
00235     if (noratio)    input_flags &= ~(phML_RATIO);
00236     if (nohistrgb)  input_flags &= ~(phML_HISTRGB);
00237     if (nohisthsv)  input_flags &= ~(phML_HISTHSV);
00238     /* figure out which training algorithm was specified */
00239     if (algorithm != NULL)
00240     {
00241         if (strcmp(algorithm,"rprop")==0)
00242         {
00243             training_algorithm = FANN_TRAIN_RPROP;
00244         }
00245         else if (strcmp(algorithm,"inc")==0)
00246         {
00247             training_algorithm = FANN_TRAIN_INCREMENTAL;
00248         }
00249         else if (strcmp(algorithm,"batch")==0)
00250         {
00251             training_algorithm = FANN_TRAIN_BATCH;
00252         }
00253         else if (strcmp(algorithm,"qprop")==0)
00254         {
00255             training_algorithm = FANN_TRAIN_QUICKPROP;
00256         }
00257     }
00258 
00259     if (datafile != NULL) 
00260     {
00261 #if 0
00262         if ((datafile != NULL) && (collect_data))
00263         {
00264             rc = ml_data_collect( datafile );
00265             phPRINT_RC(rc,NULL,"ml_data_collect");
00266         }
00267 #endif
00268         /* Allocate a new piece of data */
00269         rc = idc_new( &data_collection );
00270         phCHECK_RC(rc,NULL,"idc_new");
00271 
00272         /* Read the data set information from the file */
00273         rc = idc_readfile( datafile, data_collection);
00274         phCHECK_RC(rc,NULL,"idc_readfile");
00275 
00276         /* Print the data read from the file ? */
00277         if (print)
00278         {
00279             rc = idc_print( data_collection );
00280             phCHECK_RC(rc,NULL,"print_datafile");
00281         }
00282 
00283         /* Resize the images and output them so they are formated for learning */
00284         if ((datafile != NULL) && (resize_crop))
00285         {
00286             rc = idc_crop_resize( data_collection );
00287             phCHECK_RC(rc,NULL,"idc_crop_resize");
00288         }
00289 
00290         if ((train) || (classify))
00291         {
00292             ann_system = new phANNSystem(data_collection,input_flags);
00293             ann_system->setBrightnessStep(10);
00294             
00295             if (glbl_nodisplays)
00296             {
00297                 ann_system->useDisplays(0);
00298             }
00299             
00300             /* always call this, some parameters can't be loaded from the ANN .net file
00301              * so they need to be given on the command line. When ann_system->setNetwork
00302              * is called with the file, it will load from the file but any settings 
00303              * that couldn't be retreived will not be erased from this call */
00304             rc = ann_system->setNetwork( connection_rate, 
00305                                          learning_rate, 
00306                                          training_algorithm,
00307                                          num_layers,
00308                                          nhidden );
00309             phPRINT_RC(rc,NULL,"ann_system->setNetwork(x,x,x,x,x)");
00310 
00311             /* If a FANN file was specified, then load the ANN from the file */
00312             if (annfile != NULL)
00313             {
00314                 rc = ann_system->setNetwork( annfile );
00315                 phPRINT_RC(rc,NULL,"ann_system->setNetwork( annfile )");
00316             }
00317         }
00318         
00319         /* if we're trainings, then do it! */
00320         if (train)
00321         {
00322             rc = ann_system->train( 1 /* preload images */,
00323                                     use_random,
00324                                     train_set_size,
00325                                     test_set_size,
00326                                     max_epochs,
00327                                     desired_error,
00328                                     start_epoch );
00329             phPRINT_RC(rc,NULL,"ann_system->train");
00330         }
00331     
00332         /* if we're classifying, then we're running live with the learned 
00333          * neural network to classify objects */
00334         if (classify && 
00335             (annfile != NULL) && 
00336             (input_label != NULL) &&
00337             ((input_type >= 0) && (input_type <= 2)) && 
00338             (image_count >= 0))
00339         {
00340             rc = ann_system->classify( input_label,
00341                                        input_type,
00342                                        image_count );
00343             phPRINT_RC(rc,NULL,"ml_classify");
00344         }
00345     }
00346     else
00347     {
00348         usage();
00349     }
00350     
00351 error:
00352     if (data_collection != NULL)
00353     {
00354         rc = idc_free( &data_collection );
00355         phCHECK_RC(rc,NULL,"idc_free");
00356     }
00357 
00358     phFree(datafile);
00359     phFree(input_label);
00360     phFree(annfile);
00361     phDelete(ann_system);
00362     
00363     phDelete(arg_parser);
00364 
00365     return phSUCCESS;
00366 }
00367 




Copyright (C) 2002 - 2007 Philip D.S. Thoren ( pthoren@users.sourceforge.net )
University Of Massachusetts at Lowell
Robotics Lab
SourceForge.net Logo

Generated on Sat Jun 16 02:44:06 2007 for phission by  doxygen 1.4.4