/* Sara Moss Intro. to Artificial Intelligence HW4 due Oct. 4, 2000 The starting X1 and X2 are 0.5. The successive X1 and X2 are determined as: X[0] - learnRate*g[0]; X[1] - learnRate*g[1]; where: X[0] is the last X1 X[1] is the last X2 g[0] is the partial derivative of F with respect to X1 evaluated at the last X1 g[1] is the partial derivative of F with respect to X2 evaluated at the last X2 the learnRate is a factor which changes for each function The algorithm terminates when both F is less than 0.000005. */ import java.awt.*; import java.awt.event.*; import java.util.*; import java.text.*; class CloseableFrame extends Frame { public CloseableFrame() { addWindowListener(new WindowAdapter() { public void windowClosing(WindowEvent e) { System.exit(0);} }); } } //draws graph of population class PopulationCanvas extends Canvas { private int width; //width of canvas private int height; //height of canvas private Vector points; //vector holds points of graph of f private int lowerBound; private int upperBound; private int offset; public PopulationCanvas(int lb, int ub, int w, int h) { width=w; height=h; lowerBound=lb; upperBound=ub; offset=20; points = new Vector(); } public void AddPoint(double x, double y) { int posX=offset+(int)(width*((x-lowerBound)/(upperBound-lowerBound))); int posY=offset+(int)(height-(height*((y-lowerBound)/(upperBound-lowerBound)))); points.addElement(new Point(posX, posY)); repaint(); } public void Reset() { points.removeAllElements(); repaint(); } public void paint(Graphics g) { //graph of function is drawn on canvas NumberFormat nf=NumberFormat.getNumberInstance(); nf.setMaximumFractionDigits(1); nf.setMinimumFractionDigits(1); double midPoint=lowerBound+(upperBound-lowerBound)/2.0; g.drawString(String.valueOf(lowerBound), offset, offset-3); g.drawString(nf.format(midPoint), offset+(width/2),offset-3); g.drawString(String.valueOf(upperBound),offset+width, offset-3); g.drawString(String.valueOf(upperBound), 3, offset+10); g.drawString(nf.format(midPoint), 3, offset+(height/2)); g.drawString(String.valueOf(lowerBound), 3, offset+height); g.drawRect(0, 0, width + 2*offset, height + 2*offset); g.drawRect(offset,offset,width, height); for (int i = 0; i < points.size(); i++) { Point p = (Point)points.elementAt(i); g.drawRect(p.x, p.y, 1, 1); } } } public class HW4 extends CloseableFrame implements ActionListener, Runnable { private Button b1, b2, b3; private TextArea ta; private PopulationCanvas popCanvas; //visual graph of population private String function; //represents function chosen by user private Thread thread; private double[] x; private double[] g; public static void main(String[] args) { Frame f=new HW4(); f.setSize(508, 400); f.setVisible(true); } public HW4() { setTitle("Steepest Descent"); setLayout(new BorderLayout()); Panel p=new Panel(); p.setLayout(new FlowLayout()); p.add(b1=new Button("X1^2 + X2^2")); p.add(b2=new Button("X1^2 + 25*X2^2")); p.add(b3=new Button("100(X1^2 - X2)^2 + (1 - X1)^2")); add("North", p); add("Center", ta=new TextArea("", 8, 90, TextArea.SCROLLBARS_VERTICAL_ONLY)); add("Center", popCanvas=new PopulationCanvas(-1, 1, 300, 300)); add("East", ta=new TextArea("", 7, 20, TextArea.SCROLLBARS_VERTICAL_ONLY)); b1.addActionListener(this); b2.addActionListener(this); b3.addActionListener(this); x = new double[2]; g = new double[2]; } public void run() { SteepestDescent(); } public void actionPerformed(ActionEvent e) { String arg=e.getActionCommand(); String s1=ta.getText(); String s2=""; ta.replaceRange(s2, 0, s1.length()); popCanvas.Reset(); if (arg.equals("X1^2 + X2^2")) { function="X1^2 + X2^2"; } else if (arg.equals("X1^2 + 25*X2^2")) { function="X1^2 + 25*X2^2"; } else if (arg.equals("100(X1^2 - X2)^2 + (1 - X1)^2")) { function="100(X1^2 - X2)^2 + (1 - X1)^2"; } thread=new Thread(this); thread.start(); } private void SteepestDescent() { int count=0; double learnRate=0.02; x[0]=0.5; x[1]=0.5; if (function=="X1^2 + X2^2") learnRate=0.02; else if (function=="X1^2 + 25*X2^2") learnRate=0.02; else if (function=="100(X1^2 - X2)^2 + (1 - X1)^2") learnRate=0.001; NumberFormat nf=NumberFormat.getNumberInstance(); nf.setMaximumFractionDigits(6); nf.setMinimumFractionDigits(6); double F=0.0; while(true) { count++; if (function=="X1^2 + X2^2") { g[0]=2*x[0]; g[1]=2*x[1]; } else if (function=="X1^2 + 25*X2^2") { g[0]=2*x[0]; g[1]=50*x[1]; } else if (function=="100(X1^2 - X2)^2 + (1 - X1)^2") { g[0]=(400*Math.pow(x[0],3))-(400*x[0]*x[1])+(2*x[0])-2; g[1]=-(200*Math.pow(x[0],2))+(200*x[1]); } double x0 = x[0]; double x1 = x[1]; x[0]=x[0]-learnRate*g[0]; //determine new x0 x[1]=x[1]-learnRate*g[1]; //determine new x1 //evaluate function at new X values if (function=="X1^2 + X2^2") F=Math.pow(x[0],2)+Math.pow(x[1],2); else if (function=="X1^2 + 25*X2^2") F=Math.pow(x[0],2) + 25*Math.pow(x[1],2); else if (function=="100(X1^2 - X2)^2 + (1 - X1)^2") F=100*Math.pow(Math.pow(x[0],2)-x[1],2)+ Math.pow(1-x[0],2); if ((function.equals("100(X1^2 - X2)^2 + (1 - X1)^2") && count%100==0) || !function.equals("100(X1^2 - X2)^2 + (1 - X1)^2")) { if (thread!=null) { ta.append("\n\nGen = "+count); ta.append("\nF = "+nf.format(F)); popCanvas.AddPoint(x[0], x[1]); try {Thread.sleep(200);} catch (InterruptedException e){} } } if (F<0.000005) //terminate when F is sufficiently small break; } ta.append("\n\nDONE:"+ "\nGen = "+count+ "\nF = "+nf.format(F)+ "\nX1 = "+nf.format(x[0])+ "\nX2 = "+nf.format(x[1])); } }