fix: improve deobfuscation performance for overridden methods (#1133)

This commit is contained in:
Skylot
2021-03-20 15:48:22 +00:00
parent a1247f4d96
commit 19572a674e
5 changed files with 81 additions and 30 deletions
@@ -163,17 +163,33 @@ public class DeobfPresets {
} }
public String getForCls(ClassInfo cls) { public String getForCls(ClassInfo cls) {
if (clsPresetMap.isEmpty()) {
return null;
}
return clsPresetMap.get(cls.makeRawFullName()); return clsPresetMap.get(cls.makeRawFullName());
} }
public String getForFld(FieldInfo fld) { public String getForFld(FieldInfo fld) {
if (fldPresetMap.isEmpty()) {
return null;
}
return fldPresetMap.get(fld.getRawFullId()); return fldPresetMap.get(fld.getRawFullId());
} }
public String getForMth(MethodInfo mth) { public String getForMth(MethodInfo mth) {
if (mthPresetMap.isEmpty()) {
return null;
}
return mthPresetMap.get(mth.getRawFullId()); return mthPresetMap.get(mth.getRawFullId());
} }
public Set<String> getForVars(MethodInfo mth) {
if (varPresetMap.isEmpty()) {
return null;
}
return varPresetMap.get(mth.getRawFullId());
}
public void clear() { public void clear() {
clsPresetMap.clear(); clsPresetMap.clear();
fldPresetMap.clear(); fldPresetMap.clear();
@@ -43,6 +43,8 @@ public class Deobfuscator {
private final Set<String> pkgSet = new TreeSet<>(); private final Set<String> pkgSet = new TreeSet<>();
private final Set<String> reservedClsNames = new HashSet<>(); private final Set<String> reservedClsNames = new HashSet<>();
private final NavigableSet<MethodNode> mthProcessQueue = new TreeSet<>();
private final int maxLength; private final int maxLength;
private final int minLength; private final int minLength;
private final boolean useSourceNameAsAlias; private final boolean useSourceNameAsAlias;
@@ -155,6 +157,13 @@ public class Deobfuscator {
for (ClassNode cls : root.getClasses()) { for (ClassNode cls : root.getClasses()) {
processClass(cls); processClass(cls);
} }
while (true) {
MethodNode next = mthProcessQueue.pollLast();
if (next == null) {
break;
}
renameMethod(next);
}
} }
private void processClass(ClassNode cls) { private void processClass(ClassNode cls) {
@@ -182,9 +191,8 @@ public class Deobfuscator {
} }
renameField(field); renameField(field);
} }
for (MethodNode mth : cls.getMethods()) { mthProcessQueue.addAll(cls.getMethods());
renameMethod(mth);
}
for (ClassNode innerCls : cls.getInnerClasses()) { for (ClassNode innerCls : cls.getInnerClasses()) {
processClass(innerCls); processClass(innerCls);
} }
@@ -203,9 +211,10 @@ public class Deobfuscator {
} }
private void renameMethod(MethodNode mth) { private void renameMethod(MethodNode mth) {
Set<String> names = deobfPresets.getVarPresetMap().get(mth.getMethodInfo().getRawFullId()); MethodInfo mthInfo = mth.getMethodInfo();
Set<String> names = deobfPresets.getForVars(mthInfo);
if (names != null) { if (names != null) {
mth.getMethodInfo().setVarNameMap(names); mthInfo.setVarNameMap(names);
} }
String alias = getMethodAlias(mth); String alias = getMethodAlias(mth);
if (alias != null) { if (alias != null) {
@@ -219,28 +228,25 @@ public class Deobfuscator {
} }
private void applyMethodAlias(MethodNode mth, String alias) { private void applyMethodAlias(MethodNode mth, String alias) {
MethodInfo methodInfo = mth.getMethodInfo(); setSingleMethodAlias(mth, alias);
methodInfo.setAlias(alias);
String prev = mthMap.put(methodInfo, alias);
if (prev == null) {
resolveOverriding(mth, alias);
}
}
private void resolveOverriding(MethodNode mth, String alias) {
MethodOverrideAttr overrideAttr = mth.get(AType.METHOD_OVERRIDE); MethodOverrideAttr overrideAttr = mth.get(AType.METHOD_OVERRIDE);
if (overrideAttr != null) { if (overrideAttr != null) {
for (MethodNode ovrdMth : overrideAttr.getRelatedMthNodes()) { for (MethodNode ovrdMth : overrideAttr.getRelatedMthNodes()) {
if (ovrdMth == mth) { if (ovrdMth != mth) {
continue; setSingleMethodAlias(ovrdMth, alias);
} }
MethodInfo methodInfo = ovrdMth.getMethodInfo();
methodInfo.setAlias(alias);
mthMap.put(methodInfo, alias);
} }
} }
} }
private void setSingleMethodAlias(MethodNode mth, String alias) {
MethodInfo mthInfo = mth.getMethodInfo();
mthInfo.setAlias(alias);
mthMap.put(mthInfo, alias);
mthProcessQueue.remove(mth);
}
public void addPackagePreset(String origPkgName, String pkgAlias) { public void addPackagePreset(String origPkgName, String pkgAlias) {
PackageNode pkg = getPackageNode(origPkgName, true); PackageNode pkg = getPackageNode(origPkgName, true);
pkg.setAlias(pkgAlias); pkg.setAlias(pkgAlias);
@@ -497,15 +503,6 @@ public class Deobfuscator {
if (alias != null) { if (alias != null) {
return alias; return alias;
} }
MethodOverrideAttr overrideAttr = mth.get(AType.METHOD_OVERRIDE);
if (overrideAttr != null) {
for (MethodNode relatedMthNode : overrideAttr.getRelatedMthNodes()) {
String assignedAlias = getAssignedAlias(relatedMthNode.getMethodInfo());
if (assignedAlias != null) {
return assignedAlias;
}
}
}
if (shouldRename(mth.getName())) { if (shouldRename(mth.getName())) {
return makeMethodAlias(mth); return makeMethodAlias(mth);
} }
@@ -19,6 +19,9 @@ public final class MethodInfo implements Comparable<MethodInfo> {
private final List<ArgType> argTypes; private final List<ArgType> argTypes;
private final ClassInfo declClass; private final ClassInfo declClass;
private final String shortId; private final String shortId;
private final String rawFullId;
private final int hash;
private String alias; private String alias;
private Map<String, String> varNameMap; private Map<String, String> varNameMap;
@@ -29,6 +32,8 @@ public final class MethodInfo implements Comparable<MethodInfo> {
this.argTypes = args; this.argTypes = args;
this.retType = retType; this.retType = retType;
this.shortId = makeShortId(name, argTypes, retType); this.shortId = makeShortId(name, argTypes, retType);
this.rawFullId = declClass.makeRawFullName() + '.' + shortId;
this.hash = calcHashCode();
} }
public static MethodInfo fromRef(RootNode root, IMethodRef methodRef) { public static MethodInfo fromRef(RootNode root, IMethodRef methodRef) {
@@ -103,7 +108,7 @@ public final class MethodInfo implements Comparable<MethodInfo> {
} }
public String getRawFullId() { public String getRawFullId() {
return declClass.makeRawFullName() + '.' + shortId; return rawFullId;
} }
/** /**
@@ -172,9 +177,13 @@ public final class MethodInfo implements Comparable<MethodInfo> {
return varNameMap != null && varNameMap.size() > 0; return varNameMap != null && varNameMap.size() > 0;
} }
public int calcHashCode() {
return shortId.hashCode() + 31 * declClass.hashCode();
}
@Override @Override
public int hashCode() { public int hashCode() {
return shortId.hashCode() + 31 * declClass.hashCode(); return hash;
} }
@Override @Override
@@ -111,7 +111,11 @@ public class RootNode {
// sort classes by name, expect top classes before inner // sort classes by name, expect top classes before inner
classes.sort(Comparator.comparing(ClassNode::getFullName)); classes.sort(Comparator.comparing(ClassNode::getFullName));
initInnerClasses(); initInnerClasses();
LOG.info("Classes loaded: {}", classes.size());
// print stats for loaded classes
int mthCount = classes.stream().mapToInt(c -> c.getMethods().size()).sum();
int insnsCount = classes.stream().flatMap(c -> c.getMethods().stream()).mapToInt(MethodNode::getInsnsCount).sum();
LOG.info("Loaded classes: {}, methods: {}, instructions: {}", classes.size(), mthCount, insnsCount);
} }
private void addDummyClass(IClassData classData, Exception exc) { private void addDummyClass(IClassData classData, Exception exc) {
@@ -1,11 +1,14 @@
package jadx.core.utils; package jadx.core.utils;
import java.io.File; import java.io.File;
import java.util.Comparator;
import java.util.Iterator; import java.util.Iterator;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import java.util.function.Predicate; import java.util.function.Predicate;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
@@ -18,13 +21,16 @@ import jadx.api.ICodeWriter;
import jadx.api.impl.SimpleCodeWriter; import jadx.api.impl.SimpleCodeWriter;
import jadx.core.codegen.InsnGen; import jadx.core.codegen.InsnGen;
import jadx.core.codegen.MethodGen; import jadx.core.codegen.MethodGen;
import jadx.core.dex.attributes.AType;
import jadx.core.dex.attributes.IAttributeNode; import jadx.core.dex.attributes.IAttributeNode;
import jadx.core.dex.attributes.nodes.MethodOverrideAttr;
import jadx.core.dex.nodes.BlockNode; import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.IBlock; import jadx.core.dex.nodes.IBlock;
import jadx.core.dex.nodes.IContainer; import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion; import jadx.core.dex.nodes.IRegion;
import jadx.core.dex.nodes.InsnNode; import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode; import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.nodes.RootNode;
import jadx.core.dex.visitors.AbstractVisitor; import jadx.core.dex.visitors.AbstractVisitor;
import jadx.core.dex.visitors.DotGraphVisitor; import jadx.core.dex.visitors.DotGraphVisitor;
import jadx.core.dex.visitors.IDexTreeVisitor; import jadx.core.dex.visitors.IDexTreeVisitor;
@@ -184,4 +190,23 @@ public class DebugUtils {
public static void printStackTrace(String label) { public static void printStackTrace(String label) {
LOG.debug("StackTrace: {}\n{}", label, Utils.getStackTrace(new Exception())); LOG.debug("StackTrace: {}\n{}", label, Utils.getStackTrace(new Exception()));
} }
public static void printMethodOverrideTop(RootNode root) {
LOG.debug("Methods override top 10:");
root.getClasses().stream()
.flatMap(c -> c.getMethods().stream())
.filter(m -> m.contains(AType.METHOD_OVERRIDE))
.map(m -> m.get(AType.METHOD_OVERRIDE))
.filter(o -> !o.getOverrideList().isEmpty())
.filter(distinctByKey(methodOverrideAttr -> methodOverrideAttr.getRelatedMthNodes().size()))
.filter(distinctByKey(MethodOverrideAttr::getRelatedMthNodes))
.sorted(Comparator.comparingInt(o -> -o.getRelatedMthNodes().size()))
.limit(10)
.forEach(o -> LOG.debug(" {} : {}", o.getRelatedMthNodes().size(), Utils.last(o.getOverrideList())));
}
private static <T> Predicate<T> distinctByKey(Function<? super T, ?> keyExtractor) {
Set<Object> seen = ConcurrentHashMap.newKeySet();
return t -> seen.add(keyExtractor.apply(t));
}
} }