// // load for early releases //import com.sun.java.swing.*; import javax.swing.*; import javax.swing.event.*; import java.awt.*; import java.awt.event.*; import java.util.*; public class SwingApplet extends JApplet implements ActionListener,Runnable{ static final int BW=300, BH=300, BX=8, BY=8, NUM_WALLS=20, SAMP_W = 100, SAMP_H = 100; static final int DEF_EPOCHS = 50000; static final long DELAY=500; static int MAXX=400, MAXY=400; CatAndMouseGame game; CatAndMouseWorld trainWorld, playWorld; // seperate world from playing world RLController rlc; RLearner rl; JTabbedPane tabbedPane; Container instructions, playPanel, trainPanel, worldPanel; // world setting components JTextField rows, cols, obst; sampleWorlds samples; boolean[][] selectedWalls; ButtonGroup worldSelGroup; boolean sampleWorld=true, designWorld=false; // instructions components JLabel instructLabel, usageLabel; final String INSTRUCT_MESSAGE = "

This applet demonstrates how reinforcement

learning can be used to train an agent to play

a simple game. In this case the game is Cat and

Mouse- the mouse tries to get to the cheese

and back to it's hole, the cat tries to catch the mouse.", USAGE_MESSAGE = "

You can train the agent by selecting the Train tab. At

any time you can select the Play tab to see how

well the agent is performing! Of course, the more

training, the better the chance the mouse

has of surviving :)"; // train panel components public static final String START="S", CONT_CHECK="C"; final String SETTINGS_TEXT = "These settings adjust some of the internal workings of the reinforcement learning algorithm.", SETTINGS_TEXT2 = "Please see the web pages for more details on what the parameters do."; JTextField alpha, gamma, epsilon, epochs, penalty, reward; JButton startTraining, stopTraining; JRadioButton softmax, greedy, sarsa, qlearn; JProgressBar progress; JLabel learnEpochsDone; // play panel components JButton startbutt, stopbutt, pausebutt; boardPanel bp; public int mousescore=0, catscore =0; JLabel catscorelabel, mousescorelabel; final String MS_TEXT = "Mouse Score:", CS_TEXT = "Cat Score:"; JSlider speed, smoothSlider; Image catImg, mouseImg; chartPanel graphPanel; JLabel winPerc; boardObject cat, mouse, cheese, back, hole, wall; public SwingApplet() { getRootPane().putClientProperty("defeatSystemEventQueueCheck",Boolean.TRUE); } public void init() { // load images catImg = getImage(getCodeBase(), "cat.gif"); mouseImg = getImage(getCodeBase(), "mouse.gif"); Image wallImg = getImage(getCodeBase(), "wall.gif"); Image cheeseImg = getImage(getCodeBase(), "cheese.gif"); Image floorImg = getImage(getCodeBase(), "floor.gif"); /* Image catImg = getImage(ClassLoader.getSystemResource("cat.gif")); Image mouseImg = getImage(ClassLoader.getSystemResource("mouse.gif")); Image wallImg = getImage(ClassLoader.getSystemResource("wall.gif")); Image cheeseImg = getImage(ClassLoader.getSystemResource("cheese.gif"));*/ // set up board objects cat = new boardObject(catImg); mouse = new boardObject(mouseImg); cheese = new boardObject(cheeseImg); back = new boardObject(floorImg); hole = new boardObject(Color.orange); wall = new boardObject(wallImg); // setup content panes tabbedPane = new JTabbedPane(); //instructions = makeInstructions(); worldPanel = makeWorldPanel(); playPanel = makePlayPanel(); trainPanel = makeTrainPanel(); tabbedPane.addTab("World", worldPanel); tabbedPane.addTab("Play", playPanel); //tabbedPane.addTab("Instructions", instructions); tabbedPane.addTab("Train", trainPanel); tabbedPane.setSelectedIndex(0); // disable panes until world created tabbedPane.setEnabledAt(1,false); tabbedPane.setEnabledAt(2,false); // set up controls //setContentPane(new JPanel()); //getContentPane().add(tabbedPane); getContentPane().add(tabbedPane); } public void worldInit(int xdim, int ydim, int numwalls) { trainWorld = new CatAndMouseWorld(xdim, ydim,numwalls); gameInit(xdim,ydim); } public void worldInit(boolean[][] givenWalls) { int xdim = givenWalls.length, ydim = givenWalls[0].length; trainWorld = new CatAndMouseWorld(xdim, ydim,givenWalls); gameInit(xdim,ydim); } private void gameInit(int xdim, int ydim) { // disable this pane tabbedPane.setEnabledAt(0,false); playWorld = new CatAndMouseWorld(xdim, ydim,trainWorld.walls); bp.setDimensions(xdim, ydim); rlc = new RLController(this, trainWorld, DELAY); rl = rlc.learner; rlc.start(); game = new CatAndMouseGame(this, DELAY, playWorld, rl.getPolicy()); game.start(); // set text fields on panels penalty.setText(Integer.toString(trainWorld.deathPenalty)); reward.setText(Integer.toString(trainWorld.cheeseReward)); alpha.setText(Double.toString(rl.getAlpha())); gamma.setText(Double.toString(rl.getGamma())); epsilon.setText(Double.toString(rl.getEpsilon())); // enable other panes tabbedPane.setEnabledAt(1,true); tabbedPane.setEnabledAt(2,true); // switch active pane tabbedPane.setSelectedIndex(1); // set first position on board updateBoard(); } // this method is triggered by SwingUtilities.invokeLater in other threads public void run() { updateBoard(); } /************ general functions ****************/ public void updateBoard() { // update score panels mousescorelabel.setText(MS_TEXT+" "+Integer.toString(mousescore)); catscorelabel.setText(CS_TEXT+" "+Integer.toString(catscore)); if (game.newInfo) { updateScore(); game.newInfo = false; } // update progress info progress.setValue(rlc.epochsdone); learnEpochsDone.setText(Integer.toString(rlc.totaldone)); if (rlc.newInfo) endTraining(); // update game board bp.clearBoard(); // draw walls boolean[][] w = game.getWalls(); for (int i=0; i= 0) && (newval <= 1)) previous = newval; else System.err.println("Invalid new value: "+newval); int yval = (int) (newval * getHeight()); int xval = x-(history.size() - getWidth()); //System.out.println("index="+x+" thisval="+thisval+"newval="+newval+" xval="+xval+" yval="+yval+" previous="+previous); g.drawOval(xval,yval,POINTW,POINTH); } } public Dimension getPreferredSize() { return new Dimension(PREFX, PREFY); } public void setSmoothing(double s) { smoothing = s; } void addScore(double s) { if (!((s >= 0) && (s<=1))) { System.err.println("Graph: rejecting value"+s); return; } history.addElement(new Double(s)); // prune if list too big if (history.size() >= MAXSIZE) history.remove(0); //System.out.println("Size:"+history.size()+" maxsize:"+MAXSIZE); /* System.out.println("History being pruned."+Thread.currentThread().getName()); Vector nVec = new Vector(); for (int i=(MAXSIZE/3); i