From 82a113b0898da1f20fee82a67c691b9eaa9621e7 Mon Sep 17 00:00:00 2001 From: Jerry Gamache Date: Fri, 20 Dec 2024 12:42:45 -0500 Subject: [PATCH] Fix crash when used with a real world shader library --- .../render/MaterialXGenOgsXml/LobePruner.cpp | 105 +++++++++++------- 1 file changed, 63 insertions(+), 42 deletions(-) diff --git a/lib/mayaUsd/render/MaterialXGenOgsXml/LobePruner.cpp b/lib/mayaUsd/render/MaterialXGenOgsXml/LobePruner.cpp index 927d704a7..98984bdb5 100644 --- a/lib/mayaUsd/render/MaterialXGenOgsXml/LobePruner.cpp +++ b/lib/mayaUsd/render/MaterialXGenOgsXml/LobePruner.cpp @@ -57,13 +57,13 @@ class LobePrunerImpl // nodeGraphName, // map< attributeName, // AttributeMap // map< attributeValue, // OptimizableValueMap - // NodeVector + // NodeSet // > // > // } // > - using NodeVector = std::vector; - using OptimizableValueMap = std::map; + using NodeSet = PXR_NS::TfToken::HashSet; + using OptimizableValueMap = std::map; // We want attributes alphabetically sorted: using AttributeMap = std::map; struct NodeDefData @@ -73,7 +73,7 @@ class LobePrunerImpl }; // Also helps if we have a reverse connection map from source node to dest node: - using Destinations = std::vector; + using Destinations = std::set; using ReverseCnxMap = std::map; public: @@ -133,6 +133,7 @@ class LobePrunerImpl mx::NodePtr& node, const std::string& darkNodeName, const std::string& darkNodeDefName) const; + void updateReverseMap(ReverseCnxMap& reverseMap, const std::string& nameToRemove) const; std::unordered_map _prunerData; mx::DocumentPtr _library; @@ -325,10 +326,10 @@ void LobePrunerImpl::addOptimizableValue( auto& valueMap = attrMap._attributeData.find(interfaceName)->second; if (!valueMap.count(value)) { - valueMap.emplace(value, NodeVector {}); + valueMap.emplace(value, NodeSet {}); } - valueMap.find(value)->second.push_back(PXR_NS::TfToken(input->getParent()->getName())); + valueMap.find(value)->second.insert(PXR_NS::TfToken(input->getParent()->getName())); } mx::NodeDefPtr LobePrunerImpl::getOptimizedNodeDef(const mx::Node& node) @@ -502,7 +503,7 @@ mx::NodeDefPtr LobePrunerImpl::getOrAddOptimizedNodeDef( if (!reverseMap.count(sourceNodeName)) { reverseMap.emplace(sourceNodeName, Destinations {}); } - reverseMap.find(sourceNodeName)->second.push_back(node->getName()); + reverseMap.find(sourceNodeName)->second.insert(node->getName()); } } } @@ -587,38 +588,38 @@ void LobePrunerImpl::optimizeMixNode( if (!bgInput) { return; } - for (const auto& destNodeName : reverseMap.find(mixNode->getName())->second) { - auto destNode = optimizedNodeGraph->getNode(destNodeName); - if (!destNode) { - return; - } - for (auto input : destNode->getInputs()) { - if (input->getNodeName() == mixNode->getName()) { - input->removeAttribute(mx::PortElement::NODE_NAME_ATTRIBUTE); - if (bgInput->hasNodeName()) { - input->setNodeName(bgInput->getNodeName()); - auto& nodeVector = reverseMap.find(bgInput->getNodeName())->second; - nodeVector.push_back(destNodeName); - nodeVector.erase( - std::remove_if( - nodeVector.begin(), - nodeVector.end(), - [mixNode](const std::string& s) { return s == mixNode->getName(); }), - nodeVector.end()); - } - if (bgInput->hasInterfaceName()) { - input->setInterfaceName(bgInput->getInterfaceName()); - } - if (bgInput->hasOutputString()) { - input->setOutputString(bgInput->getOutputString()); - } - if (bgInput->hasValueString()) { - input->setValueString(bgInput->getValueString()); + const auto nodesToUpdateIt = reverseMap.find(mixNode->getName()); + if (nodesToUpdateIt != reverseMap.end()) { + for (const auto& destNodeName : nodesToUpdateIt->second) { + auto destNode = optimizedNodeGraph->getNode(destNodeName); + if (!destNode) { + return; + } + for (auto input : destNode->getInputs()) { + if (input->getNodeName() == mixNode->getName()) { + input->removeAttribute(mx::PortElement::NODE_NAME_ATTRIBUTE); + if (bgInput->hasNodeName()) { + input->setNodeName(bgInput->getNodeName()); + const auto bgInputsToUpdateIt = reverseMap.find(bgInput->getNodeName()); + if (bgInputsToUpdateIt != reverseMap.end()) { + bgInputsToUpdateIt->second.insert(destNodeName); + } + } + if (bgInput->hasInterfaceName()) { + input->setInterfaceName(bgInput->getInterfaceName()); + } + if (bgInput->hasOutputString()) { + input->setOutputString(bgInput->getOutputString()); + } + if (bgInput->hasValueString()) { + input->setValueString(bgInput->getValueString()); + } } } } } optimizedNodeGraph->removeNode(mixNode->getName()); + updateReverseMap(reverseMap, mixNode->getName()); } void LobePrunerImpl::optimizeMultiplyNode( @@ -627,19 +628,23 @@ void LobePrunerImpl::optimizeMultiplyNode( ReverseCnxMap& reverseMap) const { // Result will be a zero value of the type it requests: - for (const auto& destNodeName : reverseMap.find(node->getName())->second) { - auto destNode = optimizedNodeGraph->getNode(destNodeName); - for (auto input : destNode->getInputs()) { - if (input->getNodeName() == node->getName()) { - input->removeAttribute(mx::PortElement::NODE_NAME_ATTRIBUTE); - const auto defaultValueIt = kZeroMultiplyValueMap.find(input->getType()); - if (defaultValueIt != kZeroMultiplyValueMap.end()) { - input->setValueString(defaultValueIt->second); + const auto nodesToUpdateIt = reverseMap.find(node->getName()); + if (nodesToUpdateIt != reverseMap.end()) { + for (const auto& destNodeName : nodesToUpdateIt->second) { + auto destNode = optimizedNodeGraph->getNode(destNodeName); + for (auto input : destNode->getInputs()) { + if (input->getNodeName() == node->getName()) { + input->removeAttribute(mx::PortElement::NODE_NAME_ATTRIBUTE); + const auto defaultValueIt = kZeroMultiplyValueMap.find(input->getType()); + if (defaultValueIt != kZeroMultiplyValueMap.end()) { + input->setValueString(defaultValueIt->second); + } } } } } optimizedNodeGraph->removeNode(node->getName()); + updateReverseMap(reverseMap, node->getName()); } void LobePrunerImpl::optimizePbrNode( @@ -658,6 +663,22 @@ void LobePrunerImpl::optimizePbrNode( } } +void LobePrunerImpl::updateReverseMap(ReverseCnxMap& reverseMap, const std::string& nameToRemove) + const +{ + // Need to remove anything leading to that deleted node: + std::vector emptyEntries; + for (auto& mapEntry : reverseMap) { + mapEntry.second.erase(nameToRemove); + if (mapEntry.second.empty()) { + emptyEntries.push_back(mapEntry.first); + } + } + for (const auto& emptyEntry : emptyEntries) { + reverseMap.erase(emptyEntry); + } +} + LobePruner::Ptr LobePruner::create() { return std::make_shared(); } LobePruner::~LobePruner() = default;