Skip to content

Commit

Permalink
Fix crash when used with a real world shader library
Browse files Browse the repository at this point in the history
  • Loading branch information
JGamache-autodesk committed Dec 20, 2024
1 parent ccf4260 commit 82a113b
Showing 1 changed file with 63 additions and 42 deletions.
105 changes: 63 additions & 42 deletions lib/mayaUsd/render/MaterialXGenOgsXml/LobePruner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,13 @@ class LobePrunerImpl
// nodeGraphName,
// map< attributeName, // AttributeMap
// map< attributeValue, // OptimizableValueMap
// NodeVector
// NodeSet
// >
// >
// }
// >
using NodeVector = std::vector<PXR_NS::TfToken>;
using OptimizableValueMap = std::map<float, NodeVector>;
using NodeSet = PXR_NS::TfToken::HashSet;
using OptimizableValueMap = std::map<float, NodeSet>;
// We want attributes alphabetically sorted:
using AttributeMap = std::map<PXR_NS::TfToken, OptimizableValueMap>;
struct NodeDefData
Expand All @@ -73,7 +73,7 @@ class LobePrunerImpl
};

// Also helps if we have a reverse connection map from source node to dest node:
using Destinations = std::vector<std::string>;
using Destinations = std::set<std::string>;
using ReverseCnxMap = std::map<std::string, Destinations>;

public:
Expand Down Expand Up @@ -133,6 +133,7 @@ class LobePrunerImpl
mx::NodePtr& node,
const std::string& darkNodeName,
const std::string& darkNodeDefName) const;
void updateReverseMap(ReverseCnxMap& reverseMap, const std::string& nameToRemove) const;

std::unordered_map<PXR_NS::TfToken, NodeDefData, PXR_NS::TfToken::HashFunctor> _prunerData;
mx::DocumentPtr _library;
Expand Down Expand Up @@ -325,10 +326,10 @@ void LobePrunerImpl::addOptimizableValue(
auto& valueMap = attrMap._attributeData.find(interfaceName)->second;

if (!valueMap.count(value)) {
valueMap.emplace(value, NodeVector {});
valueMap.emplace(value, NodeSet {});
}

valueMap.find(value)->second.push_back(PXR_NS::TfToken(input->getParent()->getName()));
valueMap.find(value)->second.insert(PXR_NS::TfToken(input->getParent()->getName()));
}

mx::NodeDefPtr LobePrunerImpl::getOptimizedNodeDef(const mx::Node& node)
Expand Down Expand Up @@ -502,7 +503,7 @@ mx::NodeDefPtr LobePrunerImpl::getOrAddOptimizedNodeDef(
if (!reverseMap.count(sourceNodeName)) {
reverseMap.emplace(sourceNodeName, Destinations {});
}
reverseMap.find(sourceNodeName)->second.push_back(node->getName());
reverseMap.find(sourceNodeName)->second.insert(node->getName());
}
}
}
Expand Down Expand Up @@ -587,38 +588,38 @@ void LobePrunerImpl::optimizeMixNode(
if (!bgInput) {
return;
}
for (const auto& destNodeName : reverseMap.find(mixNode->getName())->second) {
auto destNode = optimizedNodeGraph->getNode(destNodeName);
if (!destNode) {
return;
}
for (auto input : destNode->getInputs()) {
if (input->getNodeName() == mixNode->getName()) {
input->removeAttribute(mx::PortElement::NODE_NAME_ATTRIBUTE);
if (bgInput->hasNodeName()) {
input->setNodeName(bgInput->getNodeName());
auto& nodeVector = reverseMap.find(bgInput->getNodeName())->second;
nodeVector.push_back(destNodeName);
nodeVector.erase(
std::remove_if(
nodeVector.begin(),
nodeVector.end(),
[mixNode](const std::string& s) { return s == mixNode->getName(); }),
nodeVector.end());
}
if (bgInput->hasInterfaceName()) {
input->setInterfaceName(bgInput->getInterfaceName());
}
if (bgInput->hasOutputString()) {
input->setOutputString(bgInput->getOutputString());
}
if (bgInput->hasValueString()) {
input->setValueString(bgInput->getValueString());
const auto nodesToUpdateIt = reverseMap.find(mixNode->getName());
if (nodesToUpdateIt != reverseMap.end()) {
for (const auto& destNodeName : nodesToUpdateIt->second) {
auto destNode = optimizedNodeGraph->getNode(destNodeName);
if (!destNode) {
return;
}
for (auto input : destNode->getInputs()) {
if (input->getNodeName() == mixNode->getName()) {
input->removeAttribute(mx::PortElement::NODE_NAME_ATTRIBUTE);
if (bgInput->hasNodeName()) {
input->setNodeName(bgInput->getNodeName());
const auto bgInputsToUpdateIt = reverseMap.find(bgInput->getNodeName());
if (bgInputsToUpdateIt != reverseMap.end()) {
bgInputsToUpdateIt->second.insert(destNodeName);
}
}
if (bgInput->hasInterfaceName()) {
input->setInterfaceName(bgInput->getInterfaceName());
}
if (bgInput->hasOutputString()) {
input->setOutputString(bgInput->getOutputString());
}
if (bgInput->hasValueString()) {
input->setValueString(bgInput->getValueString());
}
}
}
}
}
optimizedNodeGraph->removeNode(mixNode->getName());
updateReverseMap(reverseMap, mixNode->getName());
}

void LobePrunerImpl::optimizeMultiplyNode(
Expand All @@ -627,19 +628,23 @@ void LobePrunerImpl::optimizeMultiplyNode(
ReverseCnxMap& reverseMap) const
{
// Result will be a zero value of the type it requests:
for (const auto& destNodeName : reverseMap.find(node->getName())->second) {
auto destNode = optimizedNodeGraph->getNode(destNodeName);
for (auto input : destNode->getInputs()) {
if (input->getNodeName() == node->getName()) {
input->removeAttribute(mx::PortElement::NODE_NAME_ATTRIBUTE);
const auto defaultValueIt = kZeroMultiplyValueMap.find(input->getType());
if (defaultValueIt != kZeroMultiplyValueMap.end()) {
input->setValueString(defaultValueIt->second);
const auto nodesToUpdateIt = reverseMap.find(node->getName());
if (nodesToUpdateIt != reverseMap.end()) {
for (const auto& destNodeName : nodesToUpdateIt->second) {
auto destNode = optimizedNodeGraph->getNode(destNodeName);
for (auto input : destNode->getInputs()) {
if (input->getNodeName() == node->getName()) {
input->removeAttribute(mx::PortElement::NODE_NAME_ATTRIBUTE);
const auto defaultValueIt = kZeroMultiplyValueMap.find(input->getType());
if (defaultValueIt != kZeroMultiplyValueMap.end()) {
input->setValueString(defaultValueIt->second);
}
}
}
}
}
optimizedNodeGraph->removeNode(node->getName());
updateReverseMap(reverseMap, node->getName());
}

void LobePrunerImpl::optimizePbrNode(
Expand All @@ -658,6 +663,22 @@ void LobePrunerImpl::optimizePbrNode(
}
}

void LobePrunerImpl::updateReverseMap(ReverseCnxMap& reverseMap, const std::string& nameToRemove)
const
{
// Need to remove anything leading to that deleted node:
std::vector<std::string> emptyEntries;
for (auto& mapEntry : reverseMap) {
mapEntry.second.erase(nameToRemove);
if (mapEntry.second.empty()) {
emptyEntries.push_back(mapEntry.first);
}
}
for (const auto& emptyEntry : emptyEntries) {
reverseMap.erase(emptyEntry);
}
}

LobePruner::Ptr LobePruner::create() { return std::make_shared<LobePruner>(); }

LobePruner::~LobePruner() = default;
Expand Down

0 comments on commit 82a113b

Please sign in to comment.