#include "backprop.h" int main(int argc, char **argv) { char paraFile[256], comm[256]; char prefix[256]; MLP_NETWORK anetwork; CONTROL_INFO controlInfo; IMAGE8BIT inputImg; IMAGE8BIT resImg; double **traindata; double delta_Sum, tmp; int ntraining; int i,j,k; int k1, k2; int crow, ccol; FILE *fp, *OUT; if (argc >1) { strcpy(paraFile, argv[1]); } else { printf("Please specify the parameter file: "); scanf("%s", paraFile); } if (argc >2) { strcpy(prefix, argv[2]); } else { printf("Please specify the input image (*.pgm and prefix only) :"); scanf("%s", prefix); } inputImg.row = 0; resImg.row = 0; SetDefaultParameter(&controlInfo); LoadParameter(paraFile, &controlInfo); /* Change the input and output files */ sprintf(controlInfo.inputFile,"%s.pgm", prefix); strcpy(controlInfo.networkInFile, controlInfo.networkOutFile); sprintf(controlInfo.networkOutFile,"%s_network.dat", prefix); sprintf(controlInfo.outFile,"%s_seed.pgm", prefix); /* disable the training */ controlInfo.seedFile[0] = '\0'; DisplayParameter(&controlInfo); anetwork.INPUT_UNIT = 0; #if defined(LOCAL_FEATURE_ONLY) anetwork.INPUT_UNIT = 4; #else anetwork.INPUT_UNIT = (controlInfo.kernel_size *2 + 1) * (controlInfo.kernel_size *2 + 1); #endif anetwork.HIDDEN_UNIT = controlInfo.hidden_unit; anetwork.OUTPUT_UNIT = 1; if (Read_Image_FromP5(&inputImg, controlInfo.inputFile)) { printf("Error: cannot open input image \"%s\"\n", controlInfo.inputFile); return -1; } Allocate_Image_8Bit(&resImg, inputImg.row, inputImg.col); init_mlp_archit(controlInfo.networkInFile, &anetwork); allocate_mlp_network(&anetwork); ntraining = 0; if (controlInfo.seedFile[0] != 0) { fp = fopen(controlInfo.seedFile,"rb"); if (fp == NULL) { fprintf(stderr,"Cannot open training data file \"%s\"\n", controlInfo.seedFile); return -1; } fscanf(fp,"%d", &k); traindata = (double **)Allocate_Array(k, (anetwork.INPUT_UNIT+ anetwork.OUTPUT_UNIT+3)* sizeof(double)); Copy_Image_8Bit(&inputImg, &resImg); Change_Image_8Bit(&resImg, (unsigned char)255, (unsigned char)254); Change_Image_8Bit(&resImg, (unsigned char)0, (unsigned char)1); ntraining = 0; for (i=0; i< k; i++) { if (fscanf(fp,"%d%d", &ccol, &crow) !=2) { fprintf(stderr,"Error when reading %d row ", i+1); fprintf(stderr,"from training data file \"%s\"\n", controlInfo.seedFile); fclose(fp); Free_Array((void **)traindata, ntraining); return -1; } if (ccol < controlInfo.kernel_size || ccol >= (inputImg.row-controlInfo.kernel_size) || crow < controlInfo.kernel_size || crow >= (inputImg.row-controlInfo.kernel_size)) { for (j=0; j < anetwork.OUTPUT_UNIT; j++) { fscanf(fp,"%lf", &delta_Sum); } continue; } #if defined(LOCAL_FEATURE_ONLY) Calc_Local_MinMax(&inputImg, controlInfo.kernel_size, crow, ccol, &(traindata[ntraining][0]), &(traindata[ntraining][1])); Calc_Local_Variance(&inputImg, controlInfo.kernel_size, crow, ccol, &(traindata[ntraining][2]), &(traindata[ntraining][3])); #else for (k1= 0-controlInfo.kernel_size; k1 <= controlInfo.kernel_size; k1++) { for (k2= 0-controlInfo.kernel_size; k2 <= controlInfo.kernel_size; k2++) { j = (k1+controlInfo.kernel_size)* (2*controlInfo.kernel_size +1) + (k2+controlInfo.kernel_size); traindata[ntraining][j] = (double)(inputImg.image[k1+crow][k2+ccol])/256.; } } #endif traindata[ntraining][anetwork.INPUT_UNIT] = -1.; for (j=0; j < anetwork.OUTPUT_UNIT; j++) { if (fscanf(fp,"%lf", &(traindata[ntraining][j+anetwork.INPUT_UNIT+1])) != 1) { fprintf(stderr,"Error when reading %d row ", i+1); fprintf(stderr,"from training data file \"%s\"\n", controlInfo.seedFile); fclose(fp); Free_Array((void **)traindata, k); return -1; } } for (k1= 0-controlInfo.kernel_size; k1 <= controlInfo.kernel_size; k1++) { if (traindata[ntraining][anetwork.INPUT_UNIT+1] < 0.5) { resImg.image[k1+crow][ccol] = 0; resImg.image[crow][ccol+k1] = 0; } else { resImg.image[k1+crow][ccol] = 255; resImg.image[crow][ccol+k1] = 255; } } traindata[ntraining][j+anetwork.INPUT_UNIT+1] = ccol; traindata[ntraining][j+anetwork.INPUT_UNIT+2] = crow; ntraining++; } Write_Image_8Bit_ToP5(&resImg, "init_seed_place.pgm"); printf("There are %d training samples loaded ",ntraining); printf("from file \"%s\"\n", controlInfo.seedFile); /*for (i=0; i< ntraining; i++) { for (j=0; j < (anetwork.OUTPUT_UNIT+anetwork.INPUT_UNIT+1); j++) { printf("%6.4f ", traindata[i][j]); } printf("\n"); } printf("\n"); */ } init_mlp_network(&anetwork, &controlInfo); if (controlInfo.networkInFile[0] != 0) { load_mlp_network(controlInfo.networkInFile, &anetwork); } else { if (ntraining == 0) { printf("Warning: The network is random.\n"); printf("\tThe network is not trained nor loaded.\n"); } } if (ntraining > 0) { save_mlp_network("init_network.dat", &anetwork); k1 = 0; for (i=0; i< controlInfo.max_iteration; i++) { delta_Sum = 0.0; for (j=0; j < ntraining; j++) { for (k=0; k < (anetwork.INPUT_UNIT+1); k++) { anetwork.v1[k] = traindata[j][k]; } for (k=0; k < anetwork.OUTPUT_UNIT; k++) { anetwork.v_out[k] = traindata[j][anetwork.INPUT_UNIT+1+k]; } tmp = backpropagation(&anetwork, &controlInfo); delta_Sum += tmp; } if ( ((i+1)%1000) ==0 ) { printf("Error at iteration %d: %8.6f\n", i+1, delta_Sum); /*scanf("%d", &k);*/ } if (delta_Sum < (controlInfo.tol*ntraining)) { k1++; if (k1 > 10) break; } else { k1 =0; } } printf("Error at iteration %d: %8.6f\n", i+1, delta_Sum); } save_mlp_network(controlInfo.networkOutFile, &anetwork); for (i=0; i< ntraining; i++) { printf("%2d (%d, %d) : ", i, (int)traindata[i][anetwork.OUTPUT_UNIT+anetwork.INPUT_UNIT+1], (int)traindata[i][anetwork.OUTPUT_UNIT+anetwork.INPUT_UNIT+2]); for (j=0; j < (anetwork.OUTPUT_UNIT+anetwork.INPUT_UNIT+1); j++) { if (j == anetwork.INPUT_UNIT) continue; if (anetwork.INPUT_UNIT > 8) { if (j<3 || (j>(anetwork.INPUT_UNIT-3))) { printf("%6.4f ", traindata[i][j]); } else { if (j == (anetwork.INPUT_UNIT-4)) { printf(" ... "); } } } else { printf("%6.4f ", traindata[i][j]); } } for (j=0; j < (anetwork.INPUT_UNIT+1); j++) { anetwork.v1[j] = traindata[i][j]; } forward_prop_only(&anetwork, &controlInfo); printf("(%6.4f)", anetwork.v3[0]); printf("\n"); } printf("\n"); /*Init_Image_8Bit(&resImg, (unsigned char)0);*/ Copy_Image_8Bit(&inputImg, &resImg); Change_Image_8Bit(&resImg, (unsigned char)255, (unsigned char)254); printf("Feedforward classification: "); for (i=controlInfo.kernel_size; i < (inputImg.row-controlInfo.kernel_size); i++) { if (i%10 ==0) printf("%d ", i); fflush(stdout); for (j=controlInfo.kernel_size; j < (inputImg.col-controlInfo.kernel_size); j++) { #if defined(LOCAL_FEATURE_ONLY) Calc_Local_MinMax(&inputImg, controlInfo.kernel_size, i,j, &(anetwork.v1[0]), &(anetwork.v1[1])); Calc_Local_Variance(&inputImg, controlInfo.kernel_size, i,j, &(anetwork.v1[2]), &(anetwork.v1[3])); #else for (k1= 0-controlInfo.kernel_size; k1 <= controlInfo.kernel_size; k1++) { for (k2= 0-controlInfo.kernel_size; k2 <= controlInfo.kernel_size; k2++) { k = (k1+controlInfo.kernel_size)* (2*controlInfo.kernel_size + 1) + (k2+controlInfo.kernel_size); anetwork.v1[k] = (double)(inputImg.image[k1+i][k2+j])/256.; } } #endif anetwork.v1[anetwork.INPUT_UNIT] = -1.; forward_prop_only(&anetwork, &controlInfo); crow = (int)(anetwork.v3[0]*255); if (crow < 0) crow = 0; if (crow > 255) crow = 255; if (crow > controlInfo.grayThreshold) { resImg.image[i][j] = 255; } } } printf("\n"); Write_Image_8Bit_ToP5(&resImg, controlInfo.outFile); /*sprintf(comm,"xv %s &",controlInfo.outFile); system(comm); */ Free_Image_8Bit(&inputImg); Free_Image_8Bit(&resImg); Free_Array((void **)traindata, ntraining); free_mlp_network(&anetwork); return 0; }