From 2107da2e1a1041a0b508b30fa5e3a41bec22b2de Mon Sep 17 00:00:00 2001 From: Skylot Date: Tue, 21 Jan 2020 18:11:34 +0300 Subject: [PATCH] fix: improve 'out' block detection in switch (#826) --- .../java/jadx/core/codegen/RegionGen.java | 46 ++-- .../core/dex/instructions/InsnDecoder.java | 2 +- .../core/dex/instructions/SwitchNode.java | 22 +- .../jadx/core/dex/regions/SwitchRegion.java | 30 +-- .../core/dex/visitors/DotGraphVisitor.java | 16 ++ .../dex/visitors/regions/RegionMaker.java | 241 ++++++++++-------- .../main/java/jadx/core/utils/BlockUtils.java | 124 ++++++++- .../main/java/jadx/core/utils/DebugUtils.java | 8 + .../java/jadx/core/utils/ImmutableList.java | 7 +- .../api/utils/assertj/JadxCodeAssertions.java | 4 + .../integration/switches/TestSwitch2.java | 2 +- .../switches/TestSwitchFallThrough.java | 8 +- .../switches/TestSwitchReturnFromCase.java | 2 + .../TestSwitchWithFallThroughCase.java | 1 + 14 files changed, 345 insertions(+), 168 deletions(-) diff --git a/jadx-core/src/main/java/jadx/core/codegen/RegionGen.java b/jadx-core/src/main/java/jadx/core/codegen/RegionGen.java index 4596d8c48..47237deaf 100644 --- a/jadx-core/src/main/java/jadx/core/codegen/RegionGen.java +++ b/jadx-core/src/main/java/jadx/core/codegen/RegionGen.java @@ -275,37 +275,41 @@ public class RegionGen extends InsnGen { List keys = caseInfo.getKeys(); IContainer c = caseInfo.getContainer(); for (Object k : keys) { - code.startLine("case "); - if (k instanceof FieldNode) { - FieldNode fn = (FieldNode) k; - if (fn.getParentClass().isEnum()) { - code.add(fn.getAlias()); - } else { - staticField(code, fn.getFieldInfo()); - // print original value, sometimes replace with incorrect field - FieldInitAttr valueAttr = fn.get(AType.FIELD_INIT); - if (valueAttr != null && valueAttr.getValue() != null) { - code.add(" /*").add(valueAttr.getValue().toString()).add("*/"); - } - } - } else if (k instanceof Integer) { - code.add(TypeGen.literalToString((Integer) k, arg.getType(), mth, fallback)); + if (k == SwitchRegion.DEFAULT_CASE_KEY) { + code.startLine("default:"); } else { - throw new JadxRuntimeException("Unexpected key in switch: " + (k != null ? k.getClass() : null)); + code.startLine("case "); + addCaseKey(code, arg, k); + code.add(':'); } - code.add(':'); } makeRegionIndent(code, c); } - if (sw.getDefaultCase() != null) { - code.startLine("default:"); - makeRegionIndent(code, sw.getDefaultCase()); - } code.decIndent(); code.startLine('}'); return code; } + private void addCaseKey(CodeWriter code, InsnArg arg, Object k) { + if (k instanceof FieldNode) { + FieldNode fn = (FieldNode) k; + if (fn.getParentClass().isEnum()) { + code.add(fn.getAlias()); + } else { + staticField(code, fn.getFieldInfo()); + // print original value, sometimes replace with incorrect field + FieldInitAttr valueAttr = fn.get(AType.FIELD_INIT); + if (valueAttr != null && valueAttr.getValue() != null) { + code.add(" /*").add(valueAttr.getValue().toString()).add("*/"); + } + } + } else if (k instanceof Integer) { + code.add(TypeGen.literalToString((Integer) k, arg.getType(), mth, fallback)); + } else { + throw new JadxRuntimeException("Unexpected key in switch: " + (k != null ? k.getClass() : null)); + } + } + private void makeTryCatch(TryCatchRegion region, CodeWriter code) throws CodegenException { code.startLine("try {"); makeRegionIndent(code, region.getTryRegion()); diff --git a/jadx-core/src/main/java/jadx/core/dex/instructions/InsnDecoder.java b/jadx-core/src/main/java/jadx/core/dex/instructions/InsnDecoder.java index ace669f08..67b318a6a 100644 --- a/jadx-core/src/main/java/jadx/core/dex/instructions/InsnDecoder.java +++ b/jadx-core/src/main/java/jadx/core/dex/instructions/InsnDecoder.java @@ -623,7 +623,7 @@ public class InsnDecoder { targets[i] = targets[i] - payloadOffset + offset; } int nextOffset = getNextInsnOffset(insnArr, offset); - return new SwitchNode(InsnArg.reg(insn, 0, ArgType.NARROW), keys, targets, nextOffset); + return new SwitchNode(InsnArg.reg(insn, 0, ArgType.NARROW), keys, targets, nextOffset, packed); } private InsnNode fillArray(DecodedInstruction insn) { diff --git a/jadx-core/src/main/java/jadx/core/dex/instructions/SwitchNode.java b/jadx-core/src/main/java/jadx/core/dex/instructions/SwitchNode.java index 641b45ac0..ed4f79076 100644 --- a/jadx-core/src/main/java/jadx/core/dex/instructions/SwitchNode.java +++ b/jadx-core/src/main/java/jadx/core/dex/instructions/SwitchNode.java @@ -16,20 +16,22 @@ public class SwitchNode extends TargetInsnNode { private final Object[] keys; private final int[] targets; private final int def; // next instruction + private final boolean packed; // type of switch insn, if true can contain filler keys private BlockNode[] targetBlocks; private BlockNode defTargetBlock; - public SwitchNode(InsnArg arg, Object[] keys, int[] targets, int def) { - this(keys, targets, def); + public SwitchNode(InsnArg arg, Object[] keys, int[] targets, int def, boolean packed) { + this(keys, targets, def, packed); addArg(arg); } - private SwitchNode(Object[] keys, int[] targets, int def) { + private SwitchNode(Object[] keys, int[] targets, int def, boolean packed) { super(InsnType.SWITCH, 1); this.keys = keys; this.targets = targets; this.def = def; + this.packed = packed; } public int getCasesCount() { @@ -48,6 +50,10 @@ public class SwitchNode extends TargetInsnNode { return def; } + public boolean isPacked() { + return packed; + } + public BlockNode[] getTargetBlocks() { return targetBlocks; } @@ -103,7 +109,7 @@ public class SwitchNode extends TargetInsnNode { @Override public InsnNode copy() { - SwitchNode copy = new SwitchNode(keys, targets, def); + SwitchNode copy = new SwitchNode(keys, targets, def, packed); copy.targetBlocks = targetBlocks; copy.defTargetBlock = defTargetBlock; return copyCommonParams(copy); @@ -114,9 +120,13 @@ public class SwitchNode extends TargetInsnNode { StringBuilder sb = new StringBuilder(); sb.append(super.toString()); for (int i = 0; i < targets.length; i++) { - sb.append(" case ").append(keys[i]) - .append(": goto ").append(InsnUtils.formatOffset(targets[i])); sb.append(CodeWriter.NL); + sb.append(" case ").append(keys[i]); + sb.append(": goto ").append(InsnUtils.formatOffset(targets[i])); + } + if (def != -1) { + sb.append(CodeWriter.NL); + sb.append(" default: goto ").append(InsnUtils.formatOffset(def)); } return sb.toString(); } diff --git a/jadx-core/src/main/java/jadx/core/dex/regions/SwitchRegion.java b/jadx-core/src/main/java/jadx/core/dex/regions/SwitchRegion.java index bbdfbcb91..482f40495 100644 --- a/jadx-core/src/main/java/jadx/core/dex/regions/SwitchRegion.java +++ b/jadx-core/src/main/java/jadx/core/dex/regions/SwitchRegion.java @@ -13,10 +13,11 @@ import jadx.core.utils.Utils; public final class SwitchRegion extends AbstractRegion implements IBranchRegion { + public static final Object DEFAULT_CASE_KEY = new Object(); + private final BlockNode header; private final List cases; - private IContainer defCase; public SwitchRegion(IRegion parent, BlockNode header) { super(parent); @@ -50,14 +51,6 @@ public final class SwitchRegion extends AbstractRegion implements IBranchRegion cases.add(new CaseInfo(keysList, c)); } - public void setDefaultCase(IContainer block) { - defCase = block; - } - - public IContainer getDefaultCase() { - return defCase; - } - public List getCases() { return cases; } @@ -68,23 +61,15 @@ public final class SwitchRegion extends AbstractRegion implements IBranchRegion @Override public List getSubBlocks() { - List all = new ArrayList<>(cases.size() + 2); + List all = new ArrayList<>(cases.size() + 1); all.add(header); all.addAll(getCaseContainers()); - if (defCase != null) { - all.add(defCase); - } return Collections.unmodifiableList(all); } @Override public List getBranches() { - List branches = new ArrayList<>(cases.size() + 1); - branches.addAll(getCaseContainers()); - if (defCase != null) { - branches.add(defCase); - } - return Collections.unmodifiableList(branches); + return Collections.unmodifiableList(getCaseContainers()); } @Override @@ -97,13 +82,12 @@ public final class SwitchRegion extends AbstractRegion implements IBranchRegion StringBuilder sb = new StringBuilder(); sb.append("Switch: ").append(cases.size()); for (CaseInfo caseInfo : cases) { + List keyStrings = Utils.collectionMap(caseInfo.getKeys(), + k -> k == DEFAULT_CASE_KEY ? "default" : k.toString()); sb.append(CodeWriter.NL).append(" case ") - .append(Utils.listToString(caseInfo.getKeys())) + .append(Utils.listToString(keyStrings)) .append(" -> ").append(caseInfo.getContainer()); } - if (defCase != null) { - sb.append(CodeWriter.NL).append(" default -> ").append(defCase); - } return sb.toString(); } } diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/DotGraphVisitor.java b/jadx-core/src/main/java/jadx/core/dex/visitors/DotGraphVisitor.java index e5a944e74..29b084930 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/DotGraphVisitor.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/DotGraphVisitor.java @@ -27,6 +27,7 @@ public class DotGraphVisitor extends AbstractVisitor { private static final String NL = "\\l"; private static final boolean PRINT_DOMINATORS = false; + private static final boolean PRINT_DOMINATORS_INFO = false; private final boolean useRegions; private final boolean rawInsn; @@ -182,6 +183,14 @@ public class DotGraphVisitor extends AbstractVisitor { if (!attrs.isEmpty()) { dot.add('|').add(attrs); } + if (PRINT_DOMINATORS_INFO) { + dot.add('|'); + dot.startLine("doms: ").add(escape(block.getDoms())); + dot.startLine("\\lidom: ").add(escape(block.getIDom())); + dot.startLine("\\ldom-f: ").add(escape(block.getDomFrontier())); + dot.startLine("\\ldoms-on: ").add(escape(Utils.listToString(block.getDominatesOn()))); + dot.startLine("\\l"); + } String insns = insertInsns(mth, block); if (!insns.isEmpty()) { dot.add('|').add(insns); @@ -272,6 +281,13 @@ public class DotGraphVisitor extends AbstractVisitor { } } + private String escape(Object obj) { + if (obj == null) { + return "null"; + } + return escape(obj.toString()); + } + private String escape(String string) { return string .replace("\\", "") // TODO replace \" diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionMaker.java b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionMaker.java index 4632fb1d5..780b1f8c7 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionMaker.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionMaker.java @@ -2,6 +2,7 @@ package jadx.core.dex.visitors.regions; import java.util.ArrayList; import java.util.BitSet; +import java.util.Collections; import java.util.HashSet; import java.util.LinkedHashMap; import java.util.List; @@ -10,6 +11,7 @@ import java.util.Map.Entry; import java.util.Optional; import java.util.Set; +import org.jetbrains.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -737,135 +739,71 @@ public class RegionMaker { } private BlockNode processSwitch(IRegion currentRegion, BlockNode block, SwitchNode insn, RegionStack stack) { - SwitchRegion sw = new SwitchRegion(currentRegion, block); - currentRegion.getSubBlocks().add(sw); - + // map case blocks to keys int len = insn.getTargets().length; - // sort by target Map> blocksMap = new LinkedHashMap<>(len); + Object[] keysArr = insn.getKeys(); + BlockNode[] targetBlocksArr = insn.getTargetBlocks(); for (int i = 0; i < len; i++) { - Object key = insn.getKeys()[i]; - BlockNode targ = insn.getTargetBlocks()[i]; - List keys = blocksMap.computeIfAbsent(targ, k -> new ArrayList<>(2)); - keys.add(key); + List keys = blocksMap.computeIfAbsent(targetBlocksArr[i], k -> new ArrayList<>(2)); + keys.add(keysArr[i]); } BlockNode defCase = insn.getDefTargetBlock(); if (defCase != null) { - blocksMap.remove(defCase); + List keys = blocksMap.computeIfAbsent(defCase, k -> new ArrayList<>(1)); + keys.add(SwitchRegion.DEFAULT_CASE_KEY); } + + // search 'out' block - 'next' block after whole switch statement + BlockNode out; LoopInfo loop = mth.getLoopForBlock(block); - - Map fallThroughCases = new LinkedHashMap<>(); - - List basicBlocks = mth.getBasicBlocks(); - BitSet outs = new BitSet(basicBlocks.size()); - outs.or(block.getDomFrontier()); - for (BlockNode s : block.getCleanSuccessors()) { - BitSet df = s.getDomFrontier(); - // fall through case block - if (df.cardinality() > 1) { - if (df.cardinality() > 2) { - LOG.debug("Unexpected case pattern, block: {}, mth: {}", s, mth); - } else { - BlockNode first = basicBlocks.get(df.nextSetBit(0)); - BlockNode second = basicBlocks.get(df.nextSetBit(first.getId() + 1)); - if (second.getDomFrontier().get(first.getId())) { - fallThroughCases.put(s, second); - df = new BitSet(df.size()); - df.set(first.getId()); - } else if (first.getDomFrontier().get(second.getId())) { - fallThroughCases.put(s, first); - df = new BitSet(df.size()); - df.set(second.getId()); - } - } + if (loop == null) { + out = calcPostDomOut(mth, block, mth.getExitBlocks()); + } else { + BlockNode loopEnd = loop.getEnd(); + // treat 'continue' as exit + out = calcPostDomOut(mth, block, loopEnd.getPredecessors()); + if (out != null) { + insertContinueInSwitch(block, out, loopEnd); + } else { + // no 'continue' + out = calcPostDomOut(mth, block, Collections.singletonList(loopEnd)); } - outs.or(df); - } - outs.clear(block.getId()); - if (loop != null) { - outs.clear(loop.getStart().getId()); } + SwitchRegion sw = new SwitchRegion(currentRegion, block); + currentRegion.getSubBlocks().add(sw); stack.push(sw); - stack.addExits(BlockUtils.bitSetToBlocks(mth, outs)); + stack.addExit(out); - // check cases order if fall through case exists - if (!fallThroughCases.isEmpty() - && isBadCasesOrder(blocksMap, fallThroughCases)) { - LOG.debug("Fixing incorrect switch cases order, method: {}", mth); - blocksMap = reOrderSwitchCases(blocksMap, fallThroughCases); - if (isBadCasesOrder(blocksMap, fallThroughCases)) { - mth.addWarn("Can't fix incorrect switch cases order"); + // detect fallthrough cases + Map fallThroughCases = new LinkedHashMap<>(); + if (out != null) { + BitSet caseBlocks = BlockUtils.blocksToBitSet(mth, blocksMap.keySet()); + caseBlocks.clear(out.getId()); + for (BlockNode successor : block.getCleanSuccessors()) { + BlockNode fallThroughBlock = searchFallThroughCase(successor, out, caseBlocks); + if (fallThroughBlock != null) { + fallThroughCases.put(successor, fallThroughBlock); + } } - } - - // filter 'out' block - if (outs.cardinality() > 1) { - // remove exception handlers - BlockUtils.cleanBitSet(mth, outs); - } - if (outs.cardinality() > 1) { - // filter loop start and successors of other blocks - for (int i = outs.nextSetBit(0); i >= 0; i = outs.nextSetBit(i + 1)) { - BlockNode b = basicBlocks.get(i); - outs.andNot(b.getDomFrontier()); - if (b.contains(AFlag.LOOP_START)) { - outs.clear(b.getId()); + // check fallthrough cases order + if (!fallThroughCases.isEmpty() && isBadCasesOrder(blocksMap, fallThroughCases)) { + Map> newBlocksMap = reOrderSwitchCases(blocksMap, fallThroughCases); + if (isBadCasesOrder(newBlocksMap, fallThroughCases)) { + mth.addComment("JADX INFO: Can't fix incorrect switch cases order, some code will duplicate"); + fallThroughCases.clear(); } else { - for (BlockNode s : b.getCleanSuccessors()) { - outs.clear(s.getId()); - } + blocksMap = newBlocksMap; } } } - if (loop != null && outs.cardinality() > 1) { - outs.clear(loop.getEnd().getId()); - } - if (outs.cardinality() == 0) { - // one or several case blocks are empty, - // run expensive algorithm for find 'out' block - for (BlockNode maybeOut : block.getSuccessors()) { - boolean allReached = true; - for (BlockNode s : block.getSuccessors()) { - if (!isPathExists(s, maybeOut)) { - allReached = false; - break; - } - } - if (allReached) { - outs.set(maybeOut.getId()); - break; - } - } - } - BlockNode out = null; - if (outs.cardinality() == 1) { - out = basicBlocks.get(outs.nextSetBit(0)); - stack.addExit(out); - } else if (loop == null && outs.cardinality() > 1) { - LOG.warn("Can't detect out node for switch block: {} in {}", block, mth); - } - if (loop != null) { - // check if 'continue' must be inserted - BlockNode end = loop.getEnd(); - if (out != end && out != null) { - insertContinueInSwitch(block, out, end); - } - } - - if (!stack.containsExit(defCase)) { - Region defRegion = makeRegion(defCase, stack); - if (RegionUtils.notEmpty(defRegion)) { - sw.setDefaultCase(defRegion); - } - } for (Entry> entry : blocksMap.entrySet()) { + List keysList = entry.getValue(); BlockNode caseBlock = entry.getKey(); if (stack.containsExit(caseBlock)) { - // empty case block - sw.addCase(entry.getValue(), new Region(stack.peekRegion())); + sw.addCase(keysList, new Region(stack.peekRegion())); } else { BlockNode next = fallThroughCases.get(caseBlock); stack.addExit(next); @@ -875,17 +813,100 @@ public class RegionMaker { next.add(AFlag.FALL_THROUGH); caseRegion.add(AFlag.FALL_THROUGH); } - sw.addCase(entry.getValue(), caseRegion); + sw.addCase(keysList, caseRegion); // 'break' instruction will be inserted in RegionMakerVisitor.PostRegionVisitor } } + removeEmptyCases(insn, sw, defCase); + stack.pop(); return out; } - private boolean isBadCasesOrder(Map> blocksMap, - Map fallThroughCases) { + @Nullable + private BlockNode searchFallThroughCase(BlockNode successor, BlockNode out, BitSet caseBlocks) { + BitSet df = successor.getDomFrontier(); + if (df.intersects(caseBlocks)) { + return getOneIntersectionBlock(out, caseBlocks, df); + } + Set allPathsBlocks = BlockUtils.getAllPathsBlocks(successor, out); + Map bitSetMap = BlockUtils.calcPartialPostDominance(mth, allPathsBlocks, out); + BitSet pdoms = bitSetMap.get(successor); + if (pdoms != null && pdoms.intersects(caseBlocks)) { + return getOneIntersectionBlock(out, caseBlocks, pdoms); + } + return null; + } + + @Nullable + private BlockNode getOneIntersectionBlock(BlockNode out, BitSet caseBlocks, BitSet fallThroughSet) { + BitSet caseExits = BlockUtils.copyBlocksBitSet(mth, fallThroughSet); + caseExits.clear(out.getId()); + caseExits.and(caseBlocks); + return BlockUtils.bitSetToOneBlock(mth, caseExits); + } + + @Nullable + private static BlockNode calcPostDomOut(MethodNode mth, BlockNode block, List exits) { + if (exits.size() == 1 && mth.getExitBlocks().equals(exits)) { + // simple case: for only one exit which is equal to method exit block + return BlockUtils.calcImmediatePostDominator(mth, block); + } + // fast search: union of blocks dominance frontier + // work if no fallthrough cases and no returns inside switch + BitSet outs = BlockUtils.copyBlocksBitSet(mth, block.getDomFrontier()); + for (BlockNode s : block.getCleanSuccessors()) { + outs.or(s.getDomFrontier()); + } + outs.clear(block.getId()); + + if (outs.cardinality() != 1) { + // slow search: calculate partial post-dominance for every exit node + BitSet ipdoms = BlockUtils.newBlocksBitSet(mth); + for (BlockNode exitBlock : exits) { + Set pathBlocks = BlockUtils.getAllPathsBlocks(block, exitBlock); + BlockNode ipdom = BlockUtils.calcPartialImmediatePostDominator(mth, block, pathBlocks, exitBlock); + if (ipdom != null) { + ipdoms.set(ipdom.getId()); + } + } + outs.and(ipdoms); + } + return BlockUtils.bitSetToOneBlock(mth, outs); + } + + /** + * Remove empty case blocks: + * 1. single 'default' case + * 2. filler cases if switch is 'packed' and 'default' case is empty + */ + private void removeEmptyCases(SwitchNode insn, SwitchRegion sw, BlockNode defCase) { + boolean defaultCaseIsEmpty; + if (defCase == null) { + defaultCaseIsEmpty = true; + } else { + defaultCaseIsEmpty = sw.getCases().stream() + .anyMatch(c -> c.getKeys().contains(SwitchRegion.DEFAULT_CASE_KEY) + && RegionUtils.isEmpty(c.getContainer())); + } + if (defaultCaseIsEmpty) { + sw.getCases().removeIf(caseInfo -> { + if (RegionUtils.isEmpty(caseInfo.getContainer())) { + List keys = caseInfo.getKeys(); + if (keys.contains(SwitchRegion.DEFAULT_CASE_KEY)) { + return true; + } + if (insn.isPacked()) { + return true; + } + } + return false; + }); + } + } + + private boolean isBadCasesOrder(Map> blocksMap, Map fallThroughCases) { BlockNode nextCaseBlock = null; for (BlockNode caseBlock : blocksMap.keySet()) { if (nextCaseBlock != null && !caseBlock.equals(nextCaseBlock)) { diff --git a/jadx-core/src/main/java/jadx/core/utils/BlockUtils.java b/jadx-core/src/main/java/jadx/core/utils/BlockUtils.java index 256a594fd..d076e03a7 100644 --- a/jadx-core/src/main/java/jadx/core/utils/BlockUtils.java +++ b/jadx-core/src/main/java/jadx/core/utils/BlockUtils.java @@ -4,9 +4,11 @@ import java.util.ArrayList; import java.util.BitSet; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; import java.util.LinkedList; import java.util.List; +import java.util.Map; import java.util.Set; import org.jetbrains.annotations.Nullable; @@ -285,7 +287,19 @@ public class BlockUtils { return null; } - public static BitSet blocksToBitSet(MethodNode mth, List blocks) { + public static BitSet newBlocksBitSet(MethodNode mth) { + return new BitSet(mth.getBasicBlocks().size()); + } + + public static BitSet copyBlocksBitSet(MethodNode mth, BitSet bitSet) { + BitSet copy = new BitSet(mth.getBasicBlocks().size()); + if (!bitSet.isEmpty()) { + copy.or(bitSet); + } + return copy; + } + + public static BitSet blocksToBitSet(MethodNode mth, Collection blocks) { BitSet bs = new BitSet(mth.getBasicBlocks().size()); for (BlockNode block : blocks) { bs.set(block.getId()); @@ -293,8 +307,16 @@ public class BlockUtils { return bs; } + @Nullable + public static BlockNode bitSetToOneBlock(MethodNode mth, BitSet bs) { + if (bs == null || bs.cardinality() != 1) { + return null; + } + return mth.getBasicBlocks().get(bs.nextSetBit(0)); + } + public static List bitSetToBlocks(MethodNode mth, BitSet bs) { - if (bs == null) { + if (bs == null || bs == EmptyBitSet.EMPTY) { return Collections.emptyList(); } int size = bs.cardinality(); @@ -649,4 +671,102 @@ public class BlockUtils { } return false; } + + public static Map calcPostDominance(MethodNode mth) { + return calcPartialPostDominance(mth, mth.getBasicBlocks(), mth.getExitBlocks().get(0)); + } + + public static Map calcPartialPostDominance(MethodNode mth, Collection blockNodes, BlockNode exitBlock) { + int blocksCount = mth.getBasicBlocks().size(); + Map map = new HashMap<>(blocksCount); + + BitSet initSet = new BitSet(blocksCount); + for (BlockNode block : blockNodes) { + initSet.set(block.getId()); + } + + for (BlockNode block : blockNodes) { + BitSet postDoms = new BitSet(blocksCount); + postDoms.or(initSet); + map.put(block, postDoms); + } + BitSet exitBitSet = map.get(exitBlock); + exitBitSet.clear(); + exitBitSet.set(exitBlock.getId()); + + BitSet domSet = new BitSet(blocksCount); + boolean changed; + do { + changed = false; + for (BlockNode block : blockNodes) { + if (block == exitBlock) { + continue; + } + BitSet d = map.get(block); + if (!changed) { + domSet.clear(); + domSet.or(d); + } + for (BlockNode scc : block.getSuccessors()) { + BitSet scPDoms = map.get(scc); + if (scPDoms != null) { + d.and(scPDoms); + } + } + d.set(block.getId()); + if (!changed && !d.equals(domSet)) { + changed = true; + map.put(block, d); + } + } + } while (changed); + + blockNodes.forEach(block -> { + BitSet postDoms = map.get(block); + postDoms.clear(block.getId()); + if (postDoms.isEmpty()) { + map.put(block, EmptyBitSet.EMPTY); + } + }); + return map; + } + + @Nullable + public static BlockNode calcImmediatePostDominator(MethodNode mth, BlockNode block) { + BlockNode oneSuccessor = Utils.getOne(block.getSuccessors()); + if (oneSuccessor != null) { + return oneSuccessor; + } + return calcImmediatePostDominator(mth, block, calcPostDominance(mth)); + } + + @Nullable + public static BlockNode calcPartialImmediatePostDominator(MethodNode mth, BlockNode block, + Collection blockNodes, BlockNode exitBlock) { + BlockNode oneSuccessor = Utils.getOne(block.getSuccessors()); + if (oneSuccessor != null) { + return oneSuccessor; + } + Map pDomsMap = calcPartialPostDominance(mth, blockNodes, exitBlock); + return calcImmediatePostDominator(mth, block, pDomsMap); + } + + @Nullable + public static BlockNode calcImmediatePostDominator(MethodNode mth, BlockNode block, Map postDomsMap) { + BlockNode oneSuccessor = Utils.getOne(block.getSuccessors()); + if (oneSuccessor != null) { + return oneSuccessor; + } + List basicBlocks = mth.getBasicBlocks(); + BitSet postDoms = postDomsMap.get(block); + BitSet bs = copyBlocksBitSet(mth, postDoms); + for (int i = bs.nextSetBit(0); i >= 0; i = bs.nextSetBit(i + 1)) { + BlockNode pdomBlock = basicBlocks.get(i); + BitSet pdoms = postDomsMap.get(pdomBlock); + if (pdoms != null) { + bs.andNot(pdoms); + } + } + return bitSetToOneBlock(mth, bs); + } } diff --git a/jadx-core/src/main/java/jadx/core/utils/DebugUtils.java b/jadx-core/src/main/java/jadx/core/utils/DebugUtils.java index 16d8dbcb0..8947f731e 100644 --- a/jadx-core/src/main/java/jadx/core/utils/DebugUtils.java +++ b/jadx-core/src/main/java/jadx/core/utils/DebugUtils.java @@ -5,6 +5,7 @@ import java.util.Arrays; import java.util.Iterator; import java.util.LinkedHashSet; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.stream.Collectors; @@ -159,4 +160,11 @@ public class DebugUtils { cw.startLine(indent).add("|+ ").add(it.next()); } } + + public static void printMap(Map map, String desc) { + LOG.debug("Map {} (size = {}):", desc, map.size()); + for (Map.Entry entry : map.entrySet()) { + LOG.debug(" {}: {}", entry.getKey(), entry.getValue()); + } + } } diff --git a/jadx-core/src/main/java/jadx/core/utils/ImmutableList.java b/jadx-core/src/main/java/jadx/core/utils/ImmutableList.java index c399e5d16..70f226241 100644 --- a/jadx-core/src/main/java/jadx/core/utils/ImmutableList.java +++ b/jadx-core/src/main/java/jadx/core/utils/ImmutableList.java @@ -76,7 +76,12 @@ public final class ImmutableList implements List, RandomAccess { @Override public boolean containsAll(@NotNull Collection c) { - throw new UnsupportedOperationException(); + for (Object obj : c) { + if (!contains(obj)) { + return false; + } + } + return true; } @NotNull diff --git a/jadx-core/src/test/java/jadx/tests/api/utils/assertj/JadxCodeAssertions.java b/jadx-core/src/test/java/jadx/tests/api/utils/assertj/JadxCodeAssertions.java index 1a24ac94c..adbe9d7d7 100644 --- a/jadx-core/src/test/java/jadx/tests/api/utils/assertj/JadxCodeAssertions.java +++ b/jadx-core/src/test/java/jadx/tests/api/utils/assertj/JadxCodeAssertions.java @@ -10,6 +10,10 @@ public class JadxCodeAssertions extends AbstractStringAssert super(code, JadxCodeAssertions.class); } + public JadxCodeAssertions containsOne(String substring) { + return countString(1, substring); + } + public JadxCodeAssertions countString(int count, String substring) { isNotNull(); int actualCount = TestUtils.count(actual, substring); diff --git a/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitch2.java b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitch2.java index 15f471d7e..38e33eb33 100644 --- a/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitch2.java +++ b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitch2.java @@ -17,7 +17,7 @@ public class TestSwitch2 extends IntegrationTest { boolean isScrolling; float multiTouchZoomOldDist; - void test(int action) { + public void test(int action) { switch (action & 255) { case 0: this.isLongtouchable = true; diff --git a/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchFallThrough.java b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchFallThrough.java index 7e2a0fe9d..4f56cac47 100644 --- a/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchFallThrough.java +++ b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchFallThrough.java @@ -2,7 +2,6 @@ package jadx.tests.integration.switches; import org.junit.jupiter.api.Test; -import jadx.NotYetImplemented; import jadx.tests.api.IntegrationTest; import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat; @@ -12,6 +11,7 @@ public class TestSwitchFallThrough extends IntegrationTest { public static class TestCls { public int r; + @SuppressWarnings("fallthrough") public void test(int a) { int i = 10; switch (a) { @@ -43,12 +43,14 @@ public class TestSwitchFallThrough extends IntegrationTest { } } - @NotYetImplemented("switch fallthrough") @Test public void test() { assertThat(getClassNode(TestCls.class)) .code() - .containsOnlyOnce("switch"); + .containsOne("switch (a) {") + .containsOne("r = i;") + .containsOne("r = -1;") + .countString(2, "break;"); // code correctness checks done in 'check' method } } diff --git a/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchReturnFromCase.java b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchReturnFromCase.java index c77012de0..2fec12709 100644 --- a/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchReturnFromCase.java +++ b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchReturnFromCase.java @@ -47,6 +47,8 @@ public class TestSwitchReturnFromCase extends IntegrationTest { String code = cls.getCode().toString(); assertThat(code, containsString("switch (a % 10) {")); + + // case 5: removed assertEquals(5, count(code, "case ")); assertEquals(3, count(code, "break;")); diff --git a/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchWithFallThroughCase.java b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchWithFallThroughCase.java index 3c6bb05ae..e3ccad501 100644 --- a/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchWithFallThroughCase.java +++ b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchWithFallThroughCase.java @@ -26,6 +26,7 @@ public class TestSwitchWithFallThroughCase extends IntegrationTest { } break; } + // fallthrough case 2: if (b) { str += "2";