ChainResolver.kt
package com.depanalyzer.core.graph
import com.depanalyzer.report.Vulnerability
object ChainResolver {
private const val MAX_DEPTH = 50
fun resolveAllChains(
graph: DependencyGraph,
vulnerabilities: Map<String, List<Vulnerability>>
): List<VulnerabilityChain> {
if (vulnerabilities.isEmpty()) {
return emptyList()
}
val vulnerableNodes = graph.getAllVulnerableNodes()
val allChains = mutableListOf<VulnerabilityChain>()
vulnerableNodes.forEach { vulnNode ->
val nodeVulns = vulnerabilities[vulnNode.id] ?: return@forEach
if (nodeVulns.isEmpty()) return@forEach
val allPaths = findAllPaths(vulnNode, graph, mutableSetOf())
allPaths.forEach { path ->
val chain = VulnerabilityChain(
chain = path,
vulnerabilities = nodeVulns,
isShortestPath = false,
classification = classifyVulnerability(path)
)
allChains.add(chain)
}
}
markShortestPaths(allChains)
return deduplicateChains(allChains)
}
private fun findAllPaths(
vulnerableNode: DependencyNode,
graph: DependencyGraph,
visited: MutableSet<String>
): List<List<DependencyNode>> {
val paths = mutableListOf<List<DependencyNode>>()
if (vulnerableNode.isDirectDependency()) {
paths.add(listOf(vulnerableNode))
return paths
}
dfsBackwardsToRoots(
vulnerableNode,
graph,
mutableListOf(vulnerableNode),
visited.toMutableSet(),
paths
)
return paths
}
private fun dfsBackwardsToRoots(
current: DependencyNode,
graph: DependencyGraph,
currentPath: MutableList<DependencyNode>,
visited: MutableSet<String>,
results: MutableList<List<DependencyNode>>
) {
if (currentPath.size > MAX_DEPTH) {
return
}
if (current.isDirectDependency()) {
results.add(currentPath.reversed())
return
}
if (visited.contains(current.id)) {
return
}
visited.add(current.id)
if (current.parent != null) {
val parent = current.parent
currentPath.add(parent)
dfsBackwardsToRoots(parent, graph, currentPath, visited.toMutableSet(), results)
currentPath.removeAt(currentPath.size - 1)
} else {
results.add(currentPath.reversed())
}
}
private fun markShortestPaths(chains: MutableList<VulnerabilityChain>) {
val groupedByKey = chains.groupBy { chain ->
Triple(
chain.directDependency.id,
chain.vulnerableNode.id,
chain.cveIds.joinToString(",")
)
}
chains.replaceAll { original ->
val groupKey = Triple(
original.directDependency.id,
original.vulnerableNode.id,
original.cveIds.joinToString(",")
)
val group = groupedByKey[groupKey] ?: return@replaceAll original
val shortest = group.minByOrNull { it.depth }
if (original.depth == shortest?.depth) original.copy(isShortestPath = true) else original
}
}
private fun classifyVulnerability(path: List<DependencyNode>): VulnerabilityClassification {
val directDep = path.first()
val vulnerableNode = path.last()
return when {
directDep.id == vulnerableNode.id ->
VulnerabilityClassification.DIRECTLY_VULNERABLE
directDep.isDirectDependency() && directDep.id != vulnerableNode.id ->
VulnerabilityClassification.INDIRECTLY_VULNERABLE
else ->
VulnerabilityClassification.TRANSITIVE_VULNERABLE
}
}
private fun deduplicateChains(chains: List<VulnerabilityChain>): List<VulnerabilityChain> {
val grouped = chains.groupBy { chain ->
Triple(
chain.directDependency.id,
chain.vulnerableNode.id,
chain.cveIds.toSet()
)
}
return grouped.values.map { group ->
group.minByOrNull { it.depth } ?: group.first()
}
}
}