JAVATM NEURAL NETWORK TOOLKIT |
|
Wilfred Gander - Lorenzo Patocchi
|
//////////////////////////////////////////////////////////////////
//
// TC problem.
//
// Learn a T and a C drawn in a 3 pixel rect and roteated
// by 90 degrees between each pattern
//
// T's
//
// ### # # #
// # ### # ###
// # # ### #
//
// C's
//
// ### ### ### # #
// # # # # # #
// ### # # ### ###
//
//////////////////////////////////////////////////////////////////
import jaNet.backprop.*;
public class myTCtrainer{
public static final double FIRE = 1.0;
public static final double DOWN = -1.0;
public static void main(String args[]){
BPN myBPN = null;
BPN mySecondBPN = null;
int layers[] = {9,5,1};
String activations[] = {"jaNet.backprop.Sigmoidc","jaNet.backprop.Sigmoidc"};
try{
myBPN = new BPN(layers, activations);
}catch(BPNException bpne){
System.out.println(bpne);
}
try{
System.out.println("____________________________________________________\n\n");
System.out.println(" Learn T and C letters");
System.out.println("\n____________________________________________________\n");
//
// Input table, 8 nine input patterns
//
double in[][] = {{FIRE, FIRE, FIRE, DOWN, FIRE, DOWN, DOWN, FIRE, DOWN}, // T
{DOWN, DOWN, FIRE, FIRE, FIRE, FIRE, DOWN, DOWN, FIRE}, // T
{DOWN, FIRE, DOWN, DOWN, FIRE, DOWN, FIRE, FIRE, FIRE}, // T
{FIRE, DOWN, DOWN, FIRE, FIRE, FIRE, FIRE, DOWN, DOWN}, // T
{FIRE, FIRE, FIRE, FIRE, DOWN, DOWN, FIRE, FIRE, FIRE}, // C
{FIRE, FIRE, FIRE, FIRE, DOWN, FIRE, FIRE, DOWN, FIRE}, // C
{FIRE, FIRE, FIRE, DOWN, DOWN, FIRE, FIRE, FIRE, FIRE}, // C
{FIRE, DOWN, FIRE, FIRE, DOWN, FIRE, FIRE, FIRE, FIRE} // C
};
//
// Target table, 8 one output patterns
//
double out[][] = {{DOWN},{DOWN},{DOWN},{DOWN},{FIRE},{FIRE},{FIRE},{FIRE}};
//
// test what the network give without learning
//
System.out.println("state at begin... "+myBPN);
for(int i=0; i<in.length; i++){
myBPN.propagate(in[i]);
System.out.println(""+bool(in[i][0])+" "+bool(in[i][1])+" "+bool(in[i][2])+" -> XOR ->"+bool(myBPN.getOutputVector()[0]));
}
//
// init learning variables
//
double globalError = 10.0;
double lowererror = 100.0;
double thiserror;
int counts = 0;
//
// setup learning parameters
//
myBPN.setLearningRate(0.5);
myBPN.setMomentum(0.3);
//
// learn input table until the error is lower than 0.09 or we reach 10000 steps
//
System.out.println("Perform until global error is lower than 0.09 or 10000 setps ...");
while(globalError > 0.09 && counts < 10000){
counts ++;
// chose at random an input pattern
int p = (int)(Math.random()*(double)in.length);
// make learn the input pattern on order to reproduce
// corresponding target
myBPN.learn(in[p],out[p]);
// compute global error making the sum of errors of all patterns
thiserror = myBPN.getError();
globalError = 0.0;
for(int i=0; i<in.length; i++){
myBPN.propagate(in[i]);
globalError += 0.5 * (out[i][0] - myBPN.getOutputVector()[0])*(out[i][0] - myBPN.getOutputVector()[0]);
}
// if we found a lower global error then print it
if(globalError < lowererror){
System.out.println("state at "+counts+" single error is "+thiserror+" global error is "+globalError);
lowererror = globalError;
}
}
//
// finally print last error
//
System.out.println("state at the end, error is "+myBPN.getError());
//
// test each input pattern, it should give output pattern table correctly
//
for(int i=0; i<in.length; i++){
//
// Take 'i'th pattern from input table and make it
// propagate through the network
//
myBPN.propagate(in[i]);
//
// print this input pattern
//
double input[] = in[i];
System.out.print("input = {");
for(int h=0; h<input.length; h++)
System.out.print(" "+bool(input[h]));
System.out.print(" } ");
//
// Fetch output result from the network
//
double output[] = myBPN.getOutputVector();
System.out.print("output = {");
for(int h=0; h<output.length; h++)
System.out.print(" "+bool(output[h]));
//
// print output table (right result)
//
System.out.println(" } right is "+bool(out[i][0]));
}
//
// Save the resulting netwok in a file
//
myBPN.saveNeuralNetwork(".","tc.bpn");
System.out.println("load ./tc.bpn and check if it works still good ...");
//
// Load the network structure from the file tc.bpn
//
mySecondBPN = new BPN(".","tc.bpn");
for(int i=0; i<in.length; i++){
mySecondBPN.propagate(in[i]);
double input[] = in[i];
System.out.print("input = {");
for(int h=0; h<input.length; h++)
System.out.print(" "+bool(input[h]));
System.out.print(" } ");
double output[] = mySecondBPN.getOutputVector();
System.out.print("output = {");
for(int h=0; h<output.length; h++)
System.out.print(" "+bool(output[h]));
System.out.println(" } right is "+bool(out[i][0]));
}
}catch(BPNException bpne){
System.out.println(bpne);
}
}
public static String bool(double x){
return (x>0.5)? "true ":(x< -0.5)? "false": "undef";
}
}
© 1996 ISBiel by W.Gander & L.Patocchi