Last active
June 13, 2023 23:37
-
-
Save agucova/cae477b3e7d913487b9849e0c9560075 to your computer and use it in GitHub Desktop.
Automated libCST refactor for squigglepy
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import libcst as cst | |
from libcst.codemod import CodemodTest, VisitorBasedCodemodCommand | |
from libcst.codemod.visitors import AddImportsVisitor | |
DISTS: dict[str, str] = { | |
"base": "BaseDistribution", | |
"complex": "ComplexDistribution", | |
"const": "ConstantDistribution", | |
"uniform": "UniformDistribution", | |
"norm": "NormalDistribution", | |
"lognorm": "LognormalDistribution", | |
"binomial": "BinomialDistribution", | |
"beta": "BetaDistribution", | |
"bernoulli": "BernoulliDistribution", | |
"discrete": "DiscreteDistribution", | |
"tdist": "TDistribution", | |
"log_tdist": "LogTDistribution", | |
"triangular": "TriangularDistribution", | |
"poisson": "PoissonDistribution", | |
"chisquare": "ChiSquareDistribution", | |
"exponential": "ExponentialDistribution", | |
"gamma": "GammaDistribution", | |
"pareto": "ParetoDistribution", | |
"mixture": "MixtureDistribution", | |
} | |
class DistributionTypeRefactorCommand(VisitorBasedCodemodCommand): | |
DESCRIPTION = """ | |
Refactor expressions of the form 'prior.type == "beta"' to 'isinstance(prior, BetaDistribution)' | |
""" | |
def leave_Comparison( | |
self, original_node: cst.Comparison, updated_node: cst.Comparison | |
) -> cst.Comparison: | |
if len(updated_node.comparisons) != 1: | |
# We're only interested in comparisons with one comparator. | |
return updated_node | |
comparison_base = updated_node.left | |
comparison_target = updated_node.comparisons[0] | |
# Check we're comparing against an attribute | |
if not isinstance(comparison_base, cst.Attribute): | |
return updated_node | |
# Check the operator is ==, !=, or is | |
if not isinstance(comparison_target.operator, (cst.Is, cst.Equal, cst.NotEqual)): | |
return updated_node | |
attribute = comparison_base | |
# Check the attribute is a type attribute | |
# Explore nested attribute for a "type" value | |
explored_attribute = attribute | |
print(explored_attribute) | |
# while isinstance(explored_attribute.value, cst.Attribute): | |
# explored_attribute = attribute.value | |
if not isinstance(explored_attribute.attr, cst.Name) or explored_attribute.attr.value != "type": | |
return updated_node | |
# Check we're comparing against a string | |
if not isinstance(comparison_target.comparator, cst.SimpleString): | |
return updated_node | |
# Check if it's a supported distribution | |
if comparison_target.comparator.evaluated_value not in DISTS: | |
print(f"Unsupported distribution: {comparison_target.comparator.evaluated_value}") | |
return updated_node | |
# If we've gotten this far, we have a match! Replace the comparison with a call to isinstance. | |
dist = DISTS[comparison_target.comparator.evaluated_value] | |
new_left = attribute.value | |
new_node = cst.Call( | |
func=cst.Name("isinstance"), | |
args=[ | |
cst.Arg(new_left), | |
cst.Arg(cst.Name(dist)), | |
], | |
) | |
# Ensure import is added | |
AddImportsVisitor.add_needed_import(self.context, ".distributions", dist) | |
return new_node | |
class TestDistributionTypeRefactorCommand(CodemodTest): | |
TRANSFORM = DistributionTypeRefactorCommand | |
def test_noop(self) -> None: | |
before = """ | |
from squigglepy import BaseDistribution | |
prior = BaseDistribution() | |
""" | |
after = """ | |
from squigglepy import BaseDistribution | |
prior = BaseDistribution() | |
""" | |
self.assertCodemod(before, after) | |
def test_simple(self) -> None: | |
before = """ | |
from .distributions import BaseDistribution | |
prior = BaseDistribution() | |
prior.type == "base" | |
""" | |
after = """ | |
from .distributions import BaseDistribution | |
from .distributions import BaseDistribution | |
prior = BaseDistribution() | |
isinstance(prior, BaseDistribution) | |
""" | |
self.assertCodemod(before, after) | |
def test_if(self) -> None: | |
before = """ | |
from .distributions import BaseDistribution | |
prior = BaseDistribution() | |
if prior.type == "beta": | |
print("hello") | |
""" | |
after = """ | |
from .distributions import BaseDistribution | |
from .distributions import BetaDistribution | |
prior = BaseDistribution() | |
if isinstance(prior, BetaDistribution): | |
print("hello") | |
""" | |
self.assertCodemod(before, after) | |
def test_auto_import(self) -> None: | |
before = """ | |
import squigglepy as sq | |
posterior = sq.norm(0, 1) | |
if posterior.type == "norm": | |
print("hello") | |
""" | |
after = """ | |
import squigglepy as sq | |
from .distributions import NormalDistribution | |
posterior = sq.norm(0, 1) | |
if isinstance(posterior, NormalDistribution): | |
print("hello") | |
""" | |
self.assertCodemod(before, after) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment