fix(plugins): improve custom passes merge ordering

This commit is contained in:
Skylot
2023-03-30 17:14:10 +01:00
parent ee3a653c1b
commit a992c93198
24 changed files with 437 additions and 61 deletions
@@ -3,6 +3,7 @@ package jadx.api.impl.passes;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import jadx.api.plugins.pass.JadxPass;
import jadx.api.plugins.pass.types.JadxDecompilePass;
import jadx.core.dex.nodes.ClassNode;
import jadx.core.dex.nodes.MethodNode;
@@ -10,7 +11,7 @@ import jadx.core.dex.nodes.RootNode;
import jadx.core.dex.visitors.AbstractVisitor;
import jadx.core.utils.exceptions.JadxException;
public class DecompilePassWrapper extends AbstractVisitor {
public class DecompilePassWrapper extends AbstractVisitor implements IPassWrapperVisitor {
private static final Logger LOG = LoggerFactory.getLogger(DecompilePassWrapper.class);
private final JadxDecompilePass decompilePass;
@@ -19,6 +20,11 @@ public class DecompilePassWrapper extends AbstractVisitor {
this.decompilePass = decompilePass;
}
@Override
public JadxPass getPass() {
return decompilePass;
}
@Override
public void init(RootNode root) throws JadxException {
try {
@@ -48,7 +54,7 @@ public class DecompilePassWrapper extends AbstractVisitor {
}
@Override
public String toString() {
public String getName() {
return decompilePass.getInfo().getName();
}
}
@@ -0,0 +1,9 @@
package jadx.api.impl.passes;
import jadx.api.plugins.pass.JadxPass;
import jadx.core.dex.visitors.IDexTreeVisitor;
public interface IPassWrapperVisitor extends IDexTreeVisitor {
JadxPass getPass();
}
@@ -3,12 +3,13 @@ package jadx.api.impl.passes;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import jadx.api.plugins.pass.JadxPass;
import jadx.api.plugins.pass.types.JadxPreparePass;
import jadx.core.dex.nodes.RootNode;
import jadx.core.dex.visitors.AbstractVisitor;
import jadx.core.utils.exceptions.JadxException;
public class PreparePassWrapper extends AbstractVisitor {
public class PreparePassWrapper extends AbstractVisitor implements IPassWrapperVisitor {
private static final Logger LOG = LoggerFactory.getLogger(PreparePassWrapper.class);
private final JadxPreparePass preparePass;
@@ -17,6 +18,11 @@ public class PreparePassWrapper extends AbstractVisitor {
this.preparePass = preparePass;
}
@Override
public JadxPass getPass() {
return preparePass;
}
@Override
public void init(RootNode root) throws JadxException {
try {
@@ -27,7 +33,7 @@ public class PreparePassWrapper extends AbstractVisitor {
}
@Override
public String toString() {
public String getName() {
return preparePass.getInfo().getName();
}
}
@@ -55,7 +55,7 @@ public class PluginsContext implements JadxPluginContext {
return codeInputs;
}
public void setCurrentPlugin(JadxPlugin currentPlugin) {
public void setCurrentPlugin(@Nullable JadxPlugin currentPlugin) {
this.currentPlugin = currentPlugin;
}
@@ -1,8 +1,26 @@
package jadx.api.plugins;
import jadx.api.plugins.pass.types.JadxAfterLoadPass;
import jadx.api.plugins.pass.types.JadxPreparePass;
/**
* Base interface for all jadx plugins
* <br>
* To create new plugin implement this interface and add to resources
* a {@code META-INF/services/jadx.api.plugins.JadxPlugin} file with a full name of your class.
*/
public interface JadxPlugin {
/**
* Method for provide plugin information, like name and description.
* Can be invoked several times.
*/
JadxPluginInfo getPluginInfo();
/**
* Init plugin.
* Use {@link JadxPluginContext} to register passes, code inputs and options.
* For long operation, prefer {@link JadxPreparePass} or {@link JadxAfterLoadPass} instead.
*/
void init(JadxPluginContext context);
}
@@ -4,11 +4,35 @@ import java.util.List;
public interface JadxPassInfo {
/**
* Add this to 'run after' list to place pass before others
*/
String START = "start";
/**
* Add this to 'run before' list to place pass at end
*/
String END = "end";
/**
* Pass short id, should be unique.
*/
String getName();
/**
* Pass description
*/
String getDescription();
/**
* This pass will be executed after these passes.
* Passes names list.
*/
List<String> runAfter();
/**
* This pass will be executed before these passes.
* Passes names list.
*/
List<String> runBefore();
}
@@ -52,4 +52,9 @@ public class OrderedJadxPassInfo implements JadxPassInfo {
public List<String> runBefore() {
return runBefore;
}
@Override
public String toString() {
return "PassInfo{'" + name + '\'' + ", desc='" + desc + '\'' + ", runAfter=" + runAfter + ", runBefore=" + runBefore + '}';
}
}
+2 -2
View File
@@ -201,7 +201,7 @@ public class Jadx {
if (args.isRawCFGOutput()) {
passes.add(DotGraphVisitor.dumpRaw());
}
passes.add(new MethodVisitor(mth -> mth.add(AFlag.DISABLE_BLOCKS_LOCK)));
passes.add(new MethodVisitor("DisableBlockLock", mth -> mth.add(AFlag.DISABLE_BLOCKS_LOCK)));
passes.add(new BlockProcessor());
passes.add(new SSATransform());
passes.add(new MoveInlineVisitor());
@@ -220,7 +220,7 @@ public class Jadx {
passes.add(new ReSugarCode());
passes.add(new CodeShrinkVisitor());
passes.add(new SimplifyVisitor());
passes.add(new MethodVisitor(mth -> mth.remove(AFlag.DONT_GENERATE)));
passes.add(new MethodVisitor("ForceGenerateAll", mth -> mth.remove(AFlag.DONT_GENERATE)));
if (args.isCfgOutput()) {
passes.add(DotGraphVisitor.dump());
}
@@ -73,7 +73,7 @@ public class DeobfuscatorVisitor extends AbstractVisitor {
}
@Override
public String toString() {
public String getName() {
return "DeobfuscatorVisitor";
}
}
@@ -47,7 +47,7 @@ public class SaveDeobfMapping extends AbstractVisitor {
}
@Override
public String toString() {
public String getName() {
return "SaveDeobfMapping";
}
}
@@ -309,12 +309,10 @@ public class RootNode {
}
public void mergePasses(Map<JadxPassType, List<JadxPass>> customPasses) {
PassMerge.run(preDecompilePasses,
customPasses.get(JadxPreparePass.TYPE),
p -> new PreparePassWrapper((JadxPreparePass) p));
PassMerge.run(processClasses.getPasses(),
customPasses.get(JadxDecompilePass.TYPE),
p -> new DecompilePassWrapper((JadxDecompilePass) p));
new PassMerge(preDecompilePasses)
.merge(customPasses.get(JadxPreparePass.TYPE), p -> new PreparePassWrapper((JadxPreparePass) p));
new PassMerge(processClasses.getPasses())
.merge(customPasses.get(JadxDecompilePass.TYPE), p -> new DecompilePassWrapper((JadxDecompilePass) p));
}
public void runPreDecompileStage() {
@@ -615,6 +613,10 @@ public class RootNode {
return processClasses.getPasses();
}
public List<IDexTreeVisitor> getPreDecompilePasses() {
return preDecompilePasses;
}
public void initPasses() {
processClasses.initPasses(this);
}
@@ -24,7 +24,12 @@ public abstract class AbstractVisitor implements IDexTreeVisitor {
}
@Override
public String toString() {
public String getName() {
return this.getClass().getSimpleName();
}
@Override
public String toString() {
return getName();
}
}
@@ -10,6 +10,11 @@ import jadx.core.utils.exceptions.JadxException;
*/
public interface IDexTreeVisitor {
/**
* Visitor short id
*/
String getName();
/**
* Called after loading dex tree, but before visitor traversal.
*/
@@ -2,16 +2,16 @@ package jadx.core.dex.visitors;
import java.util.function.Consumer;
import jadx.core.dex.nodes.ClassNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.nodes.RootNode;
import jadx.core.utils.exceptions.JadxException;
public class MethodVisitor implements IDexTreeVisitor {
public class MethodVisitor extends AbstractVisitor {
private final String name;
private final Consumer<MethodNode> visitor;
public MethodVisitor(Consumer<MethodNode> visitor) {
public MethodVisitor(String name, Consumer<MethodNode> visitor) {
this.name = name;
this.visitor = visitor;
}
@@ -21,11 +21,7 @@ public class MethodVisitor implements IDexTreeVisitor {
}
@Override
public void init(RootNode root) throws JadxException {
}
@Override
public boolean visit(ClassNode cls) throws JadxException {
return true;
public String getName() {
return name;
}
}
@@ -462,7 +462,7 @@ public class OverrideMethodVisitor extends AbstractVisitor {
}
@Override
public String toString() {
public String getName() {
return "OverrideMethodVisitor";
}
}
@@ -362,7 +362,7 @@ public class ProcessAnonymous extends AbstractVisitor {
}
@Override
public String toString() {
public String getName() {
return "ProcessAnonymous";
}
}
@@ -73,7 +73,7 @@ public class ProcessMethodsForInline extends AbstractVisitor {
}
@Override
public String toString() {
public String getName() {
return "ProcessMethodsForInline";
}
}
@@ -278,7 +278,7 @@ public class SignatureProcessor extends AbstractVisitor {
}
@Override
public String toString() {
public String getName() {
return "SignatureProcessor";
}
}
@@ -246,7 +246,7 @@ public class RenameVisitor extends AbstractVisitor {
}
@Override
public String toString() {
public String getName() {
return "RenameVisitor";
}
}
@@ -177,7 +177,7 @@ public class UsageInfoVisitor extends AbstractVisitor {
}
@Override
public String toString() {
public String getName() {
return "UsageInfoVisitor";
}
}
@@ -1,9 +1,14 @@
package jadx.core.utils;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import jadx.api.plugins.pass.JadxPass;
import jadx.api.plugins.pass.JadxPassInfo;
@@ -12,71 +17,216 @@ import jadx.core.utils.exceptions.JadxRuntimeException;
public class PassMerge {
public static void run(List<IDexTreeVisitor> passes, List<JadxPass> customPasses, Function<JadxPass, IDexTreeVisitor> wrap) {
private final List<IDexTreeVisitor> visitors;
private Set<String> mergePassesNames;
private Map<IDexTreeVisitor, String> namesMap;
public PassMerge(List<IDexTreeVisitor> visitors) {
this.visitors = visitors;
}
public void merge(List<JadxPass> customPasses, Function<JadxPass, IDexTreeVisitor> wrap) {
if (Utils.isEmpty(customPasses)) {
return;
}
for (JadxPass customPass : customPasses) {
IDexTreeVisitor pass = wrap.apply(customPass);
int pos = searchInsertPos(passes, customPass.getInfo());
List<MergePass> mergePasses = ListUtils.map(customPasses, p -> new MergePass(p, wrap.apply(p), p.getInfo()));
linkDeps(mergePasses);
mergePasses.sort(new ExtDepsComparator(visitors).thenComparing(InvertedDepsComparator.INSTANCE));
namesMap = new IdentityHashMap<>();
visitors.forEach(p -> namesMap.put(p, p.getName()));
mergePasses.forEach(p -> namesMap.put(p.getVisitor(), p.getName()));
mergePassesNames = mergePasses.stream().map(MergePass::getName).collect(Collectors.toSet());
for (MergePass mergePass : mergePasses) {
int pos = searchInsertPos(mergePass);
if (pos == -1) {
passes.add(pass);
visitors.add(mergePass.getVisitor());
} else {
passes.add(pos, pass);
visitors.add(pos, mergePass.getVisitor());
}
}
}
private static int searchInsertPos(List<IDexTreeVisitor> passes, JadxPassInfo info) {
List<String> runAfter = info.runAfter();
List<String> runBefore = info.runBefore();
private int searchInsertPos(MergePass pass) {
List<String> runAfter = pass.after();
List<String> runBefore = pass.before();
if (runAfter.isEmpty() && runBefore.isEmpty()) {
return -1; // last
}
if (ListUtils.isSingleElement(runAfter, "start")) {
if (ListUtils.isSingleElement(runAfter, JadxPassInfo.START)) {
return 0;
}
if (ListUtils.isSingleElement(runBefore, "end")) {
if (ListUtils.isSingleElement(runBefore, JadxPassInfo.END)) {
return -1;
}
Map<String, Integer> namesMap = buildNamesMap(passes);
int after = 0;
int visitorsCount = visitors.size();
Map<String, Integer> namePosMap = new HashMap<>(visitorsCount);
for (int i = 0; i < visitorsCount; i++) {
namePosMap.put(namesMap.get(visitors.get(i)), i);
}
int after = -1;
for (String name : runAfter) {
Integer pos = namesMap.get(name);
Integer pos = namePosMap.get(name);
if (pos != null) {
after = Math.max(after, pos);
} else {
if (mergePassesNames.contains(name)) {
// ignore known passes
continue;
}
throw new JadxRuntimeException("Ordering pass not found: " + name
+ ", listed in 'runAfter' of pass: " + pass
+ "\n all passes: " + ListUtils.map(visitors, namesMap::get));
}
}
int before = Integer.MAX_VALUE;
for (String name : runBefore) {
Integer pos = namesMap.get(name);
Integer pos = namePosMap.get(name);
if (pos != null) {
before = Math.min(before, pos);
} else {
if (mergePassesNames.contains(name)) {
// ignore known passes
continue;
}
throw new JadxRuntimeException("Ordering pass not found: " + name
+ ", listed in 'runBefore' of pass: " + pass
+ "\n all passes: " + ListUtils.map(visitors, namesMap::get));
}
}
if (before <= after) {
throw new JadxRuntimeException("Conflict pass order requirements: " + info.getName()
throw new JadxRuntimeException("Conflict order requirements for pass: " + pass
+ "\n run after: " + runAfter
+ "\n run before: " + runBefore
+ "\n passes: " + ListUtils.map(passes, PassMerge::getPassName));
+ "\n passes: " + ListUtils.map(visitors, namesMap::get));
}
if (after == 0) {
if (after == -1) {
if (before == Integer.MAX_VALUE) {
// not ordered, put at last
return -1;
}
return before;
}
int pos = after + 1;
return pos >= passes.size() ? -1 : pos;
return pos >= visitorsCount ? -1 : pos;
}
private static Map<String, Integer> buildNamesMap(List<IDexTreeVisitor> passes) {
int size = passes.size();
Map<String, Integer> namesMap = new HashMap<>(size);
for (int i = 0; i < size; i++) {
namesMap.put(getPassName(passes.get(i)), i);
private static final class MergePass {
private final JadxPass pass;
private final IDexTreeVisitor visitor;
private final JadxPassInfo info;
// copy dep lists for future modifications
private final List<String> before;
private final List<String> after;
private MergePass(JadxPass pass, IDexTreeVisitor visitor, JadxPassInfo info) {
this.pass = pass;
this.visitor = visitor;
this.info = info;
this.before = new ArrayList<>(info.runBefore());
this.after = new ArrayList<>(info.runAfter());
}
public JadxPass getPass() {
return pass;
}
public IDexTreeVisitor getVisitor() {
return visitor;
}
public String getName() {
return info.getName();
}
public JadxPassInfo getInfo() {
return info;
}
public List<String> before() {
return before;
}
public List<String> after() {
return after;
}
@Override
public String toString() {
return info.getName();
}
return namesMap;
}
private static String getPassName(IDexTreeVisitor pass) {
return pass.getClass().getSimpleName();
/**
* Make deps double linked
*/
private static void linkDeps(List<MergePass> mergePasses) {
Map<String, MergePass> map = mergePasses.stream().collect(Collectors.toMap(MergePass::getName, p -> p));
for (MergePass pass : mergePasses) {
for (String after : pass.getInfo().runAfter()) {
MergePass beforePass = map.get(after);
if (beforePass != null) {
beforePass.before().add(pass.getName());
}
}
for (String before : pass.getInfo().runBefore()) {
MergePass afterPass = map.get(before);
if (afterPass != null) {
afterPass.after().add(pass.getName());
}
}
}
}
/**
* Place passes with visitors dependencies before others.
*/
private static class ExtDepsComparator implements Comparator<MergePass> {
private final Set<String> names;
public ExtDepsComparator(List<IDexTreeVisitor> visitors) {
this.names = visitors.stream()
.map(IDexTreeVisitor::getName)
.collect(Collectors.toSet());
}
@Override
public int compare(MergePass first, MergePass second) {
boolean isFirst = containsVisitor(first.before()) || containsVisitor(first.after());
boolean isSecond = containsVisitor(second.before()) || containsVisitor(second.after());
return -Boolean.compare(isFirst, isSecond);
}
private boolean containsVisitor(List<String> deps) {
for (String dep : deps) {
if (names.contains(dep)) {
return true;
}
}
return false;
}
}
/**
* Sort to get inverted dependencies i.e. if pass depends on another place it before.
*/
private static class InvertedDepsComparator implements Comparator<MergePass> {
public static final InvertedDepsComparator INSTANCE = new InvertedDepsComparator();
@Override
public int compare(MergePass first, MergePass second) {
if (first.before().contains(second.getName())
|| first.after().contains(second.getName())) {
return 1;
}
if (second.before().contains(first.getName())
|| second.after().contains(first.getName())) {
return -1;
}
return 0;
}
}
}
@@ -0,0 +1,143 @@
package jadx.core.utils;
import java.util.List;
import org.junit.jupiter.api.Test;
import jadx.api.impl.passes.DecompilePassWrapper;
import jadx.api.plugins.pass.JadxPass;
import jadx.api.plugins.pass.JadxPassInfo;
import jadx.api.plugins.pass.impl.OrderedJadxPassInfo;
import jadx.api.plugins.pass.impl.SimpleJadxPassInfo;
import jadx.api.plugins.pass.types.JadxDecompilePass;
import jadx.core.dex.nodes.ClassNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.nodes.RootNode;
import jadx.core.dex.visitors.AbstractVisitor;
import jadx.core.dex.visitors.IDexTreeVisitor;
import jadx.core.utils.exceptions.JadxRuntimeException;
import static java.util.Arrays.asList;
import static java.util.Collections.emptyList;
import static java.util.Collections.singletonList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.catchThrowable;
class PassMergeTest {
@Test
public void testSimple() {
List<String> base = asList("a", "b", "c");
check(base, mockPass("x"), asList("a", "b", "c", "x"));
check(base, mockPass(mockInfo("x").after(JadxPassInfo.START)), asList("x", "a", "b", "c"));
check(base, mockPass(mockInfo("x").before(JadxPassInfo.END)), asList("a", "b", "c", "x"));
}
@Test
public void testSingle() {
List<String> base = asList("a", "b", "c");
check(base, mockPass(mockInfo("x").after("a")), asList("a", "x", "b", "c"));
check(base, mockPass(mockInfo("x").before("c")), asList("a", "b", "x", "c"));
check(base, mockPass(mockInfo("x").before("a")), asList("x", "a", "b", "c"));
check(base, mockPass(mockInfo("x").after("c")), asList("a", "b", "c", "x"));
}
@Test
public void testMulti() {
List<String> base = asList("a", "b", "c");
JadxPass x = mockPass(mockInfo("x").after("a"));
JadxPass y = mockPass(mockInfo("y").after("a"));
JadxPass z = mockPass(mockInfo("z").before("b"));
check(base, asList(x, y, z), asList("a", "y", "x", "z", "b", "c"));
}
@Test
public void testMultiWithDeps() {
List<String> base = asList("a", "b", "c");
JadxPass x = mockPass(mockInfo("x").after("a"));
JadxPass y = mockPass(mockInfo("y").after("x"));
JadxPass z = mockPass(mockInfo("z").before("b").after("y"));
check(base, asList(x, y, z), asList("a", "x", "y", "z", "b", "c"));
}
@Test
public void testMultiWithDeps2() {
List<String> base = asList("a", "b", "c");
JadxPass x = mockPass(mockInfo("x").before("y"));
JadxPass y = mockPass(mockInfo("y").before("b"));
JadxPass z = mockPass(mockInfo("z").after("y"));
check(base, asList(x, y, z), asList("a", "x", "y", "z", "b", "c"));
}
@Test
public void testMultiWithDeps3() {
List<String> base = asList("a", "b", "c");
JadxPass x = mockPass(mockInfo("x"));
JadxPass y = mockPass(mockInfo("y").after("x").before("b"));
check(base, asList(x, y), asList("a", "x", "y", "b", "c"));
}
@Test
public void testLoop() {
List<String> base = asList("a", "b", "c");
JadxPass x = mockPass(mockInfo("x").before("y"));
JadxPass y = mockPass(mockInfo("y").before("x"));
Throwable thrown = catchThrowable(() -> check(base, asList(x, y), emptyList()));
assertThat(thrown).isInstanceOf(JadxRuntimeException.class);
}
private void check(List<String> visitorNames, JadxPass pass, List<String> result) {
check(visitorNames, singletonList(pass), result);
}
private void check(List<String> visitorNames, List<JadxPass> passes, List<String> result) {
List<IDexTreeVisitor> visitors = ListUtils.map(visitorNames, PassMergeTest::mockVisitor);
new PassMerge(visitors).merge(passes, p -> new DecompilePassWrapper((JadxDecompilePass) p));
List<String> resultVisitors = ListUtils.map(visitors, IDexTreeVisitor::getName);
assertThat(resultVisitors).isEqualTo(result);
}
private static IDexTreeVisitor mockVisitor(String name) {
return new AbstractVisitor() {
@Override
public String getName() {
return name;
}
};
}
private JadxPass mockPass(String name) {
return mockPass(new SimpleJadxPassInfo(name));
}
private OrderedJadxPassInfo mockInfo(String name) {
return new OrderedJadxPassInfo(name, name);
}
private JadxPass mockPass(JadxPassInfo info) {
return new JadxDecompilePass() {
@Override
public void init(RootNode root) {
}
@Override
public boolean visit(ClassNode cls) {
return false;
}
@Override
public void visit(MethodNode mth) {
}
@Override
public JadxPassInfo getInfo() {
return info;
}
@Override
public String toString() {
return info.getName();
}
};
}
}
@@ -15,4 +15,12 @@ class Debug(private val jadx: JadxScriptInstance) {
fun saveCFG(mth: MethodNode, file: File = File("dump-mth-raw")) {
DotGraphVisitor.dumpRaw().save(file, mth)
}
fun printPreparePasses() {
jadx.internalDecompiler.root.preDecompilePasses.forEach { jadx.log.info { it.name } }
}
fun printPasses() {
jadx.internalDecompiler.root.passes.forEach { jadx.log.info { it.name } }
}
}
@@ -13,8 +13,7 @@ class Decompile(private val jadx: JadxScriptInstance) {
fun allThreaded(threadsCount: Int = JadxArgs.DEFAULT_THREADS_COUNT) {
val executor = Executors.newFixedThreadPool(threadsCount)
val dec = jadx.internalDecompiler
val batches = dec.decompileScheduler.buildBatches(jadx.classes)
val batches = jadx.internalDecompiler.decompileScheduler.buildBatches(jadx.classes)
for (batch in batches) {
executor.submit {
batch.forEach(JavaClass::decompile)