Last active
August 3, 2017 17:16
-
-
Save aadnk/906c809671db5297daf309c292bea236 to your computer and use it in GitHub Desktop.
Test of the pattern matching pseudo code in "Pattern Matching with Brian Goetz @briangoetz"
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.comphenix.test; | |
import java.util.Objects; | |
public class PatternMatchingTest { | |
public static void main(String[] args) throws Exception { | |
AddNode zeroPlusOne = new AddNode(IntNode.ZERO, IntNode.ONE); | |
AddNode onePlusOne = new AddNode(IntNode.ONE, IntNode.ONE); | |
System.out.println("Testing our proposed fix: "); | |
System.out.println("simplify(" + zeroPlusOne + ") => " + simplifyB(zeroPlusOne)); | |
System.out.println("simplify(" + onePlusOne + ") => " + simplifyB(onePlusOne)); | |
System.out.println("\nTesting original code: "); | |
System.out.println("simplify(" + zeroPlusOne + ") => " + simplifyA(zeroPlusOne)); | |
System.out.println("simplify(" + onePlusOne + ") => " + simplifyA(onePlusOne)); | |
} | |
// Directly adapted from pseudo code in the following Brian Goetz Java talk: | |
// https://www.youtube.com/watch?v=n3_8YcYKScw&feature=youtu.be | |
private static Node simplifyA(Node n) { | |
if (n instanceof IntNode) { | |
return n; | |
} else if (n instanceof NegNode) { | |
NegNode negNode = (NegNode) n; | |
if (negNode.inner instanceof NegNode) { | |
return simplifyA(negNode.inner); | |
} else { | |
return simplifyA(new NegNode(simplifyB(negNode.inner))); | |
} | |
} else if (n instanceof AddNode) { | |
AddNode addNode = (AddNode) n; | |
if (IntNode.ZERO.equals(addNode.left)) { | |
return simplifyA(addNode.right); | |
} else if (IntNode.ZERO.equals(addNode.right)) { | |
return simplifyA(addNode.left); | |
} else { | |
return simplifyA(new AddNode(simplifyB(addNode.left), simplifyB(addNode.right))); | |
} | |
} else if (n instanceof MulNode) { | |
MulNode mulNode = (MulNode) n; | |
if (IntNode.ONE.equals(mulNode.left)) { | |
return simplifyA(mulNode.right); | |
} else if (IntNode.ONE.equals(mulNode.right)) { | |
return simplifyA(mulNode.left); | |
} else if (IntNode.ZERO.equals(mulNode.left)) { | |
return IntNode.ZERO; | |
} else if (IntNode.ZERO.equals(mulNode.right)) { | |
return IntNode.ZERO; | |
} else { | |
return simplifyA(new MulNode(simplifyB(mulNode.left), simplifyB(mulNode.right))); | |
} | |
} | |
// We can't handle this node yet | |
throw new IllegalArgumentException("Unknown node " + n); | |
} | |
private static Node simplifyB(Node n) { | |
return simplifyB(n, null); | |
} | |
// A proposed fix - do not attempt to simplify a node that did not change when we simplified its children. | |
private static Node simplifyB(Node n, Node source) { | |
// Prevent infinite loops | |
if (Objects.equals(n, source)) { | |
return n; | |
} | |
if (n instanceof IntNode) { | |
return n; | |
} else if (n instanceof NegNode) { | |
NegNode negNode = (NegNode) n; | |
if (negNode.inner instanceof NegNode) { | |
return simplifyB(negNode.inner, n); | |
} else { | |
return simplifyB(new NegNode(simplifyB(negNode.inner)), n); | |
} | |
} else if (n instanceof AddNode) { | |
AddNode addNode = (AddNode) n; | |
if (IntNode.ZERO.equals(addNode.left)) { | |
return simplifyB(addNode.right, n); | |
} else if (IntNode.ZERO.equals(addNode.right)) { | |
return simplifyB(addNode.left, n); | |
} else { | |
return simplifyB(new AddNode(simplifyB(addNode.left), simplifyB(addNode.right)), n); | |
} | |
} else if (n instanceof MulNode) { | |
MulNode mulNode = (MulNode) n; | |
if (IntNode.ONE.equals(mulNode.left)) { | |
return simplifyB(mulNode.right, n); | |
} else if (IntNode.ONE.equals(mulNode.right)) { | |
return simplifyB(mulNode.left, n); | |
} else if (IntNode.ZERO.equals(mulNode.left)) { | |
return IntNode.ZERO; | |
} else if (IntNode.ZERO.equals(mulNode.right)) { | |
return IntNode.ZERO; | |
} else { | |
return simplifyB(new MulNode(simplifyB(mulNode.left), simplifyB(mulNode.right)), n); | |
} | |
} | |
// We can't handle this node yet | |
throw new IllegalArgumentException("Unknown node " + n); | |
} | |
// *** Dependencies *** | |
static abstract class Node { | |
} | |
static class IntNode extends Node { | |
static final IntNode ZERO = new IntNode(0); | |
static final IntNode ONE = new IntNode(1); | |
final int value; | |
IntNode(int value) { | |
this.value = value; | |
} | |
@Override | |
public boolean equals(Object o) { | |
return this == o || !(o == null || getClass() != o.getClass()) && | |
value == ((IntNode) o).value; | |
} | |
@Override | |
public int hashCode() { | |
return 31 * value; | |
} | |
@Override | |
public String toString() { | |
return String.valueOf(value); | |
} | |
} | |
static class NegNode extends Node { | |
final Node inner; | |
NegNode(Node inner) { | |
this.inner = inner; | |
} | |
@Override | |
public boolean equals(Object o) { | |
return this == o || !(o == null || getClass() != o.getClass()) && | |
Objects.equals(inner, ((NegNode) o).inner); | |
} | |
@Override | |
public int hashCode() { | |
return inner != null ? inner.hashCode() : 0; | |
} | |
@Override | |
public String toString() { | |
return "(-" + inner + ")"; | |
} | |
} | |
static class AddNode extends Node { | |
final Node left; | |
final Node right; | |
AddNode(Node left, Node right) { | |
this.left = left; | |
this.right = right; | |
} | |
@Override | |
public boolean equals(Object o) { | |
return this == o || !(o == null || getClass() != o.getClass()) && | |
Objects.equals(left, ((AddNode) o).left) && | |
Objects.equals(right, ((AddNode) o).right); | |
} | |
@Override | |
public int hashCode() { | |
return Objects.hash(left, right); | |
} | |
@Override | |
public String toString() { | |
return "(" + left + " + " + right + ")"; | |
} | |
} | |
static class MulNode extends Node { | |
final Node left; | |
final Node right; | |
MulNode(Node left, Node right) { | |
this.left = left; | |
this.right = right; | |
} | |
@Override | |
public boolean equals(Object o) { | |
return this == o || !(o == null || getClass() != o.getClass()) && | |
Objects.equals(left, ((MulNode) o).left) && | |
Objects.equals(right, ((MulNode) o).right); | |
} | |
@Override | |
public int hashCode() { | |
return Objects.hash(left, right); | |
} | |
@Override | |
public String toString() { | |
return "(" + left + " * " + right + ")"; | |
} | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Original pseudo code | |
Node simplify(Node n) { | |
return switch(n) { | |
case IntNode -> n; | |
case NegNode(NegNode(var n)) -> simplify(n); | |
case NegNode(var n) -> simplify(new NegNode(simplify(n))); | |
case AddNode(IntNode(0), var right) -> simplify(right); | |
case AddNode(var left, IntNode(0)) -> simplify(left); | |
case AddNode(var left, var right) | |
-> simplify(new AddNode(simplify(left), simplify(right))); | |
case MulNode(IntNode(1), var right) -> simplify(right); | |
case MulNode(var left, IntNode(1)) -> simplify(left); | |
case MulNode(IntNode(0), var right) -> new IntNode(0); | |
case MulNode(var left, IntNode(0)) -> new IntNode(0); | |
case MulNode(var left, var right) | |
-> simplify(new MulNode(simplify(left), simplify(right))); | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Fixed pseudo code | |
Node simplify(Node n, Node s) { | |
if (Objects.equals(n, s)) return n; | |
return switch(n) { | |
case IntNode -> n; | |
case NegNode(NegNode(var x)) -> simplify(x, n); | |
case NegNode(var x) -> simplify(new NegNode(simplify(x, n)), n); | |
case AddNode(IntNode(0), var right) -> simplify(right, n); | |
case AddNode(var left, IntNode(0)) -> simplify(left, n); | |
case AddNode(var left, var right) | |
-> simplify(new AddNode(simplify(left, n), simplify(right, n)), n); | |
case MulNode(IntNode(1), var right) -> simplify(right, n); | |
case MulNode(var left, IntNode(1)) -> simplify(left, n); | |
case MulNode(IntNode(0), var right) -> new IntNode(0); | |
case MulNode(var left, IntNode(0)) -> new IntNode(0); | |
case MulNode(var left, var right) | |
-> simplify(new MulNode(simplify(left, n), simplify(right, n)), n); | |
default -> throw new IllegalArgumentException("Not recognized: " + n); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment