JAVATM NEURAL NETWORK TOOLKIT Wilfred Gander - Lorenzo Patocchi


//////////////////////////////////////////////////////////////////
//
//    TC problem.
//
//    Learn a T and a C drawn in a 3 pixel rect and roteated
//    by 90 degrees between each pattern
//
//  T's
//
//  ###  #     #     #
//   #   ###   #   ###
//   #   #    ###    #
//
//  C's
//
//  ###  ###  ###  # #
//  #    # #    #  # #
//  ###  # #  ###  ###
//
//////////////////////////////////////////////////////////////////
import jaNet.backprop.*;

public class myTCtrainer{

    public static final double FIRE =  1.0;
    public static final double DOWN = -1.0;
    

    public static void main(String args[]){
        BPN myBPN = null;
        BPN mySecondBPN = null;
        int layers[] = {9,5,1};
        String activations[] = {"jaNet.backprop.Sigmoidc","jaNet.backprop.Sigmoidc"};
                
        try{
            myBPN = new BPN(layers, activations);
        }catch(BPNException bpne){
            System.out.println(bpne);
        }
        try{
            System.out.println("____________________________________________________\n\n");
            System.out.println("     Learn T and C letters");
            System.out.println("\n____________________________________________________\n");
            //
            // Input table, 8 nine input patterns
            //
            double in[][] = {{FIRE,  FIRE,  FIRE,  DOWN,  FIRE,  DOWN,  DOWN,  FIRE,  DOWN}, // T
                             {DOWN,  DOWN,  FIRE,  FIRE,  FIRE,  FIRE,  DOWN,  DOWN,  FIRE}, // T
                             {DOWN,  FIRE,  DOWN,  DOWN,  FIRE,  DOWN,  FIRE,  FIRE,  FIRE}, // T
                             {FIRE,  DOWN,  DOWN,  FIRE,  FIRE,  FIRE,  FIRE,  DOWN,  DOWN}, // T
                             {FIRE,  FIRE,  FIRE,  FIRE,  DOWN,  DOWN,  FIRE,  FIRE,  FIRE}, // C
                             {FIRE,  FIRE,  FIRE,  FIRE,  DOWN,  FIRE,  FIRE,  DOWN,  FIRE}, // C
                             {FIRE,  FIRE,  FIRE,  DOWN,  DOWN,  FIRE,  FIRE,  FIRE,  FIRE}, // C
                             {FIRE,  DOWN,  FIRE,  FIRE,  DOWN,  FIRE,  FIRE,  FIRE,  FIRE}  // C
                            };
            //
            // Target table, 8 one output patterns
            //
            double out[][] = {{DOWN},{DOWN},{DOWN},{DOWN},{FIRE},{FIRE},{FIRE},{FIRE}};

            //
            // test what the network give without learning
            //
            System.out.println("state at begin... "+myBPN);
            for(int i=0; i<in.length; i++){
                myBPN.propagate(in[i]);
                System.out.println(""+bool(in[i][0])+" "+bool(in[i][1])+" "+bool(in[i][2])+" -> XOR ->"+bool(myBPN.getOutputVector()[0]));
            }
            
            //
            // init learning variables
            //
            double globalError = 10.0;
            double lowererror = 100.0;
            double thiserror;
            int counts = 0;
            
            //
            // setup learning parameters
            //
            myBPN.setLearningRate(0.5);
            myBPN.setMomentum(0.3);
            
            //
            // learn input table until the error is lower than 0.09 or we reach 10000 steps
            //
            System.out.println("Perform until global error is lower than 0.09 or 10000 setps ...");
            while(globalError > 0.09 && counts < 10000){
                counts ++;
                
                // chose at random an input pattern
                int p = (int)(Math.random()*(double)in.length);
                
                // make learn the input pattern on order to reproduce
                // corresponding target
                myBPN.learn(in[p],out[p]);
                
                // compute global error making the sum of errors of all patterns
                thiserror = myBPN.getError();
                globalError = 0.0;
                for(int i=0; i<in.length; i++){
                    myBPN.propagate(in[i]);
                    globalError += 0.5 * (out[i][0] - myBPN.getOutputVector()[0])*(out[i][0] - myBPN.getOutputVector()[0]);
                }
                
                // if we found a lower global error then print it
                if(globalError < lowererror){
                    System.out.println("state at "+counts+" single error is "+thiserror+" global error is "+globalError);
                    lowererror = globalError;
                    
                }
            }
            //
            // finally print last error
            //
            System.out.println("state at the end, error is "+myBPN.getError());
            
            //
            // test each input pattern, it should give output pattern table correctly
            //            
            for(int i=0; i<in.length; i++){
                //
                // Take 'i'th pattern from input table and make it
                // propagate through the network
                //
                myBPN.propagate(in[i]);

                //
                // print this input pattern
                //
                double input[] = in[i];
                System.out.print("input = {");
                for(int h=0; h<input.length; h++)
                    System.out.print(" "+bool(input[h]));
                System.out.print(" }  ");

                //
                // Fetch output result from the network
                //
                double output[] = myBPN.getOutputVector();
                System.out.print("output = {");
                for(int h=0; h<output.length; h++)
                    System.out.print(" "+bool(output[h]));
                
                //
                // print output table (right result)
                //
                System.out.println(" }    right is "+bool(out[i][0]));
            }
            //
            // Save the resulting netwok in a file
            //
             myBPN.saveNeuralNetwork(".","tc.bpn");
            
            System.out.println("load ./tc.bpn and check if it works still good ...");
            
            //
            // Load the network structure from the file tc.bpn
            //
            mySecondBPN = new BPN(".","tc.bpn");
            for(int i=0; i<in.length; i++){
                mySecondBPN.propagate(in[i]);

                double input[] = in[i];
                System.out.print("input = {");
                for(int h=0; h<input.length; h++)
                    System.out.print(" "+bool(input[h]));
                System.out.print(" }  ");

                double output[] = mySecondBPN.getOutputVector();
                System.out.print("output = {");
                for(int h=0; h<output.length; h++)
                    System.out.print(" "+bool(output[h]));
                System.out.println(" }    right is "+bool(out[i][0]));
            }            

        }catch(BPNException bpne){
            System.out.println(bpne);
        }
    }
    
    public static String bool(double x){
        return (x>0.5)? "true ":(x< -0.5)? "false": "undef";
    }
}

© 1996 ISBiel by W.Gander & L.Patocchi