Skip to content

Commit

Permalink
Reduce stack usage when compiling large switch statements (#107)
Browse files Browse the repository at this point in the history
This is the port for the fix for
Beamdog/nwn-issues#561. The root cause is that
the compiler does a fair bit of recursion and when the parse tree is
sufficiently deep, we run out of stack. This is usually not an issue on
modern platforms, and is mostly seen on Borland/Toolset, but can
technically happen anywhere. You'll most likely encounter it on Windows
first, and on debug builds sooner than release.

This changeset fixes two issues:
1) `ConstantFoldNode()` recursively walks the tree always; there is no
reason to do this if the node is not something we can fold, so just
check and bail early. This is a regression as stack usage would have
gone up significantly with #68 .
2) `TraverseTreeForSwitchLabels()` recursively visits the tree when
parsing case labels. This is preexisting, but seems like stack usage has
increased from other changes, so overall number of cases you can have
before it goes boom is lower. Switch it over to iterative traversal.

## Testing

- Modified repro mod from
Beamdog/nwn-issues#561 works in toolset again
- Counting recursion depth for various function gives sane results now
- @Daztek 's module compiles and runs all scripts fine
- @tinygiant98 confirmed scripts with nested switches generate
equivalent NCS
- nwn_script_comp test suite passes
- Added a nested switch test case in this PR.

## Changelog

### Fixed
- Fixed potential stack overflow issue when compiling particularly
complex switch cases.

## Licence

- [x] I am licencing my change under the project's MIT licence,
including all changes to GPL-3.0 licenced parts of the codebase.
  • Loading branch information
mtijanic authored Apr 3, 2024
1 parent 9d108d9 commit 9bd1d63
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 119 deletions.
34 changes: 34 additions & 0 deletions neverwinter/nwscript/native/scriptcompcore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1527,6 +1527,40 @@ BOOL CScriptCompiler::ConstantFoldNode(CScriptParseTreeNode *pNode, BOOL bForce)
if (!pNode)
return FALSE;

// Only proceed if the operation is something we can actually const-fold.
// Sometimes the parse tree is *really* deep and we recursively call this
// function thousands of times until we run out of stack space, but that
// only tends to happen on giant switches and/or nested structures, which
// we don't care about here.
switch (pNode->nOperation)
{
case CSCRIPTCOMPILER_OPERATION_LOGICAL_OR:
case CSCRIPTCOMPILER_OPERATION_LOGICAL_AND:
case CSCRIPTCOMPILER_OPERATION_INCLUSIVE_OR:
case CSCRIPTCOMPILER_OPERATION_EXCLUSIVE_OR:
case CSCRIPTCOMPILER_OPERATION_BOOLEAN_AND:
case CSCRIPTCOMPILER_OPERATION_CONDITION_EQUAL:
case CSCRIPTCOMPILER_OPERATION_CONDITION_NOT_EQUAL:
case CSCRIPTCOMPILER_OPERATION_CONDITION_GEQ:
case CSCRIPTCOMPILER_OPERATION_CONDITION_GT:
case CSCRIPTCOMPILER_OPERATION_CONDITION_LT:
case CSCRIPTCOMPILER_OPERATION_CONDITION_LEQ:
case CSCRIPTCOMPILER_OPERATION_SHIFT_LEFT:
case CSCRIPTCOMPILER_OPERATION_SHIFT_RIGHT:
case CSCRIPTCOMPILER_OPERATION_ADD:
case CSCRIPTCOMPILER_OPERATION_SUBTRACT:
case CSCRIPTCOMPILER_OPERATION_MULTIPLY:
case CSCRIPTCOMPILER_OPERATION_DIVIDE:
case CSCRIPTCOMPILER_OPERATION_MODULUS:
case CSCRIPTCOMPILER_OPERATION_NEGATION:
case CSCRIPTCOMPILER_OPERATION_UNSIGNED_SHIFT_RIGHT:
case CSCRIPTCOMPILER_OPERATION_ONES_COMPLEMENT:
case CSCRIPTCOMPILER_OPERATION_BOOLEAN_NOT:
break;
default:
return FALSE;
}

BOOL bUnary = pNode->nOperation == CSCRIPTCOMPILER_OPERATION_BOOLEAN_NOT ||
pNode->nOperation == CSCRIPTCOMPILER_OPERATION_ONES_COMPLEMENT ||
pNode->nOperation == CSCRIPTCOMPILER_OPERATION_NEGATION;
Expand Down
252 changes: 133 additions & 119 deletions neverwinter/nwscript/native/scriptcompfinalcode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -919,160 +919,174 @@ void CScriptCompiler::InitializeSwitchLabelList()
///////////////////////////////////////////////////////////////////////////////
int32_t CScriptCompiler::TraverseTreeForSwitchLabels(CScriptParseTreeNode *pNode)
{
// First, we scan to see if there are multiple labels of the same type.
int nReturnValue;

if (pNode == NULL)
{
return 0;
}
// First of all, if we are about to go into another switch block, abort!
if (pNode->nOperation == CSCRIPTCOMPILER_OPERATION_SWITCH_BLOCK)
{
return 0;
}

nReturnValue = TraverseTreeForSwitchLabels(pNode->pLeft);
if (nReturnValue < 0)
{
return nReturnValue;
}
//
// This function uses the in-order tree traversal, meaning that the left
// subtree is processed first, then the current node, then the right subtree.
// Or, as pseudocode, something like:
// TraverseTreeForSwitchLabels(pNode->pLeft);
// ProcessNode(pNode);
// TraverseTreeForSwitchLabels(pNode->pRight);
// However, recursively processing the tree means we run the risk of a stack
// overflow if the tree depth is too large. What is 'too large' depends on
// the platform, but there will always be a switch/case statement complex
// enough that processing it causes a crash.
// So, instead we use a heap-allocated custom stack to store the nodes and
// process them iteratively instead.
//
std::vector<CScriptParseTreeNode *> nodestack;

if (pNode->nOperation == CSCRIPTCOMPILER_OPERATION_DEFAULT)
while (pNode || !nodestack.empty())
{
if (m_bSwitchLabelDefault == TRUE)
// First of all, if we are about to go into another switch block, abort!
while (pNode && pNode->nOperation != CSCRIPTCOMPILER_OPERATION_SWITCH_BLOCK)
{
return OutputWalkTreeError(STRREF_CSCRIPTCOMPILER_ERROR_MULTIPLE_DEFAULT_STATEMENTS_WITHIN_SWITCH, pNode);
nodestack.push_back(pNode);
pNode = pNode->pLeft;
}
m_bSwitchLabelDefault = TRUE;
}

if (pNode->nOperation == CSCRIPTCOMPILER_OPERATION_CASE)
{
int32_t nCaseValue;
if (nodestack.empty())
break;

ConstantFoldNode(pNode->pLeft, TRUE);
// Evaluate the constant value that is contained.
if (pNode->pLeft != NULL &&
pNode->pLeft->nOperation == CSCRIPTCOMPILER_OPERATION_NEGATION &&
pNode->pLeft->pLeft != NULL &&
pNode->pLeft->pLeft->nOperation == CSCRIPTCOMPILER_OPERATION_CONSTANT_INTEGER)
{
nCaseValue = -pNode->pLeft->pLeft->nIntegerData;
}
else if (pNode->pLeft != NULL &&
pNode->pLeft->nOperation == CSCRIPTCOMPILER_OPERATION_CONSTANT_INTEGER)
{
nCaseValue = pNode->pLeft->nIntegerData;
}
else if (pNode->pLeft != NULL &&
pNode->pLeft->nOperation == CSCRIPTCOMPILER_OPERATION_CONSTANT_STRING)
{
nCaseValue = pNode->pLeft->m_psStringData->GetHash();
}
else
{
return OutputWalkTreeError(STRREF_CSCRIPTCOMPILER_ERROR_CASE_PARAMETER_NOT_A_CONSTANT_INTEGER,pNode);
}
pNode = nodestack.back();
nodestack.pop_back();

// Now, we have to check if any of the previous case statements have the same value.
int nCount;
//
// At this point, we have processed all the nodes to the left of pNode.
//

for (nCount = 0; nCount < m_nSwitchLabelNumber; ++nCount)
if (pNode->nOperation == CSCRIPTCOMPILER_OPERATION_DEFAULT)
{
if (m_pnSwitchLabelStatements[nCount] == nCaseValue)
if (m_bSwitchLabelDefault == TRUE)
{
return OutputWalkTreeError(STRREF_CSCRIPTCOMPILER_ERROR_MULTIPLE_CASE_CONSTANT_STATEMENTS_WITHIN_SWITCH,pNode);
return OutputWalkTreeError(STRREF_CSCRIPTCOMPILER_ERROR_MULTIPLE_DEFAULT_STATEMENTS_WITHIN_SWITCH, pNode);
}
m_bSwitchLabelDefault = TRUE;
}

// Add the case statement to the list.
if (m_nSwitchLabelNumber >= m_nSwitchLabelArraySize)
if (pNode->nOperation == CSCRIPTCOMPILER_OPERATION_CASE)
{
int32_t *pNewIntArray = new int32_t[m_nSwitchLabelArraySize * 2];
int32_t nCaseValue;

ConstantFoldNode(pNode->pLeft, TRUE);
// Evaluate the constant value that is contained.
if (pNode->pLeft != NULL &&
pNode->pLeft->nOperation == CSCRIPTCOMPILER_OPERATION_NEGATION &&
pNode->pLeft->pLeft != NULL &&
pNode->pLeft->pLeft->nOperation == CSCRIPTCOMPILER_OPERATION_CONSTANT_INTEGER)
{
nCaseValue = -pNode->pLeft->pLeft->nIntegerData;
}
else if (pNode->pLeft != NULL &&
pNode->pLeft->nOperation == CSCRIPTCOMPILER_OPERATION_CONSTANT_INTEGER)
{
nCaseValue = pNode->pLeft->nIntegerData;
}
else if (pNode->pLeft != NULL &&
pNode->pLeft->nOperation == CSCRIPTCOMPILER_OPERATION_CONSTANT_STRING)
{
nCaseValue = pNode->pLeft->m_psStringData->GetHash();
}
else
{
return OutputWalkTreeError(STRREF_CSCRIPTCOMPILER_ERROR_CASE_PARAMETER_NOT_A_CONSTANT_INTEGER,pNode);
}

// Now, we have to check if any of the previous case statements have the same value.
int nCount;

for (nCount = 0; nCount < m_nSwitchLabelNumber; ++nCount)
{
pNewIntArray[nCount] = m_pnSwitchLabelStatements[nCount];
if (m_pnSwitchLabelStatements[nCount] == nCaseValue)
{
return OutputWalkTreeError(STRREF_CSCRIPTCOMPILER_ERROR_MULTIPLE_CASE_CONSTANT_STATEMENTS_WITHIN_SWITCH,pNode);
}
}

// Add the case statement to the list.
if (m_nSwitchLabelNumber >= m_nSwitchLabelArraySize)
{
int32_t *pNewIntArray = new int32_t[m_nSwitchLabelArraySize * 2];
for (nCount = 0; nCount < m_nSwitchLabelNumber; ++nCount)
{
pNewIntArray[nCount] = m_pnSwitchLabelStatements[nCount];
}
m_nSwitchLabelArraySize *= 2;
delete[] m_pnSwitchLabelStatements;
m_pnSwitchLabelStatements = pNewIntArray;
}
m_nSwitchLabelArraySize *= 2;
delete[] m_pnSwitchLabelStatements;
m_pnSwitchLabelStatements = pNewIntArray;
}

m_pnSwitchLabelStatements[m_nSwitchLabelNumber] = nCaseValue;
++m_nSwitchLabelNumber;
m_pnSwitchLabelStatements[m_nSwitchLabelNumber] = nCaseValue;
++m_nSwitchLabelNumber;

// Now, we add the pseudocode:
// COPYTOP fffffffc,0004 // copies the switch result so that we can use it.
// CONSTI nCaseValue // adds the constant that we're to compare against.
// EQUALII // compares the two, leaving the result on the stack.
// JNZ _SC_nCaseValue_nSwitchIdentifier // result goes away, jump executed.
// Now, we add the pseudocode:
// COPYTOP fffffffc,0004 // copies the switch result so that we can use it.
// CONSTI nCaseValue // adds the constant that we're to compare against.
// EQUALII // compares the two, leaving the result on the stack.
// JNZ _SC_nCaseValue_nSwitchIdentifier // result goes away, jump executed.

// CODE GENERATION
// Here, we would dump the "appropriate" data from the run-time stack
// on to the top of the stack, making a copy of it ... that's why
// we're adding one to the appropriate run time stack.
// CODE GENERATION
// Here, we would dump the "appropriate" data from the run-time stack
// on to the top of the stack, making a copy of it ... that's why
// we're adding one to the appropriate run time stack.

int32_t nStackElementsDown = -4;
int32_t nSize = 4;
int32_t nStackElementsDown = -4;
int32_t nSize = 4;

m_pchOutputCode[m_nOutputCodeLength + CVIRTUALMACHINE_OPCODE_LOCATION] = CVIRTUALMACHINE_OPCODE_RUNSTACK_COPY;
m_pchOutputCode[m_nOutputCodeLength+CVIRTUALMACHINE_AUXCODE_LOCATION] = CVIRTUALMACHINE_AUXCODE_TYPE_VOID;
m_pchOutputCode[m_nOutputCodeLength + CVIRTUALMACHINE_OPCODE_LOCATION] = CVIRTUALMACHINE_OPCODE_RUNSTACK_COPY;
m_pchOutputCode[m_nOutputCodeLength+CVIRTUALMACHINE_AUXCODE_LOCATION] = CVIRTUALMACHINE_AUXCODE_TYPE_VOID;

m_pchOutputCode[m_nOutputCodeLength + CVIRTUALMACHINE_EXTRA_DATA_LOCATION] = (char) (((nStackElementsDown) >> 24) & 0x0ff);
m_pchOutputCode[m_nOutputCodeLength+CVIRTUALMACHINE_EXTRA_DATA_LOCATION+1] = (char) (((nStackElementsDown) >> 16) & 0x0ff);
m_pchOutputCode[m_nOutputCodeLength+CVIRTUALMACHINE_EXTRA_DATA_LOCATION+2] = (char) (((nStackElementsDown) >> 8) & 0x0ff);
m_pchOutputCode[m_nOutputCodeLength+CVIRTUALMACHINE_EXTRA_DATA_LOCATION+3] = (char) (((nStackElementsDown)) & 0x0ff);
m_pchOutputCode[m_nOutputCodeLength + CVIRTUALMACHINE_EXTRA_DATA_LOCATION] = (char) (((nStackElementsDown) >> 24) & 0x0ff);
m_pchOutputCode[m_nOutputCodeLength+CVIRTUALMACHINE_EXTRA_DATA_LOCATION+1] = (char) (((nStackElementsDown) >> 16) & 0x0ff);
m_pchOutputCode[m_nOutputCodeLength+CVIRTUALMACHINE_EXTRA_DATA_LOCATION+2] = (char) (((nStackElementsDown) >> 8) & 0x0ff);
m_pchOutputCode[m_nOutputCodeLength+CVIRTUALMACHINE_EXTRA_DATA_LOCATION+3] = (char) (((nStackElementsDown)) & 0x0ff);

m_pchOutputCode[m_nOutputCodeLength+CVIRTUALMACHINE_EXTRA_DATA_LOCATION+4] = (char) (((nSize) >> 8) & 0x0ff);
m_pchOutputCode[m_nOutputCodeLength+CVIRTUALMACHINE_EXTRA_DATA_LOCATION+5] = (char) (((nSize)) & 0x0ff);
m_pchOutputCode[m_nOutputCodeLength+CVIRTUALMACHINE_EXTRA_DATA_LOCATION+4] = (char) (((nSize) >> 8) & 0x0ff);
m_pchOutputCode[m_nOutputCodeLength+CVIRTUALMACHINE_EXTRA_DATA_LOCATION+5] = (char) (((nSize)) & 0x0ff);

m_nOutputCodeLength += CVIRTUALMACHINE_OPERATION_BASE_SIZE + 6;
m_aOutputCodeInstructionBoundaries.push_back(m_nOutputCodeLength);
m_nOutputCodeLength += CVIRTUALMACHINE_OPERATION_BASE_SIZE + 6;
m_aOutputCodeInstructionBoundaries.push_back(m_nOutputCodeLength);

// CODE GENERATION
// Here, we have a "constant integer" op-code that would be added.
int32_t nIntegerData = nCaseValue;
// CODE GENERATION
// Here, we have a "constant integer" op-code that would be added.
int32_t nIntegerData = nCaseValue;

m_pchOutputCode[m_nOutputCodeLength + CVIRTUALMACHINE_OPCODE_LOCATION] = CVIRTUALMACHINE_OPCODE_CONSTANT;
m_pchOutputCode[m_nOutputCodeLength + CVIRTUALMACHINE_AUXCODE_LOCATION] = CVIRTUALMACHINE_AUXCODE_TYPE_INTEGER;
m_pchOutputCode[m_nOutputCodeLength + CVIRTUALMACHINE_OPCODE_LOCATION] = CVIRTUALMACHINE_OPCODE_CONSTANT;
m_pchOutputCode[m_nOutputCodeLength + CVIRTUALMACHINE_AUXCODE_LOCATION] = CVIRTUALMACHINE_AUXCODE_TYPE_INTEGER;

// Enter the integer constant.
m_pchOutputCode[m_nOutputCodeLength+CVIRTUALMACHINE_EXTRA_DATA_LOCATION] = (char) (((nIntegerData) >> 24) & 0x0ff);
m_pchOutputCode[m_nOutputCodeLength+CVIRTUALMACHINE_EXTRA_DATA_LOCATION+1] = (char) (((nIntegerData) >> 16) & 0x0ff);
m_pchOutputCode[m_nOutputCodeLength+CVIRTUALMACHINE_EXTRA_DATA_LOCATION+2] = (char) (((nIntegerData) >> 8) & 0x0ff);
m_pchOutputCode[m_nOutputCodeLength+CVIRTUALMACHINE_EXTRA_DATA_LOCATION+3] = (char) (((nIntegerData)) & 0x0ff);
// Enter the integer constant.
m_pchOutputCode[m_nOutputCodeLength+CVIRTUALMACHINE_EXTRA_DATA_LOCATION] = (char) (((nIntegerData) >> 24) & 0x0ff);
m_pchOutputCode[m_nOutputCodeLength+CVIRTUALMACHINE_EXTRA_DATA_LOCATION+1] = (char) (((nIntegerData) >> 16) & 0x0ff);
m_pchOutputCode[m_nOutputCodeLength+CVIRTUALMACHINE_EXTRA_DATA_LOCATION+2] = (char) (((nIntegerData) >> 8) & 0x0ff);
m_pchOutputCode[m_nOutputCodeLength+CVIRTUALMACHINE_EXTRA_DATA_LOCATION+3] = (char) (((nIntegerData)) & 0x0ff);

m_nOutputCodeLength += CVIRTUALMACHINE_OPERATION_BASE_SIZE + 4;
m_aOutputCodeInstructionBoundaries.push_back(m_nOutputCodeLength);
m_nOutputCodeLength += CVIRTUALMACHINE_OPERATION_BASE_SIZE + 4;
m_aOutputCodeInstructionBoundaries.push_back(m_nOutputCodeLength);

// CODE GENERATION
// Write an "condition EQUALII" operation.
m_pchOutputCode[m_nOutputCodeLength + CVIRTUALMACHINE_OPCODE_LOCATION] = CVIRTUALMACHINE_OPCODE_EQUAL;
m_pchOutputCode[m_nOutputCodeLength + CVIRTUALMACHINE_AUXCODE_LOCATION] = CVIRTUALMACHINE_AUXCODE_TYPETYPE_INTEGER_INTEGER;
m_nOutputCodeLength += CVIRTUALMACHINE_OPERATION_BASE_SIZE;
m_aOutputCodeInstructionBoundaries.push_back(m_nOutputCodeLength);
// CODE GENERATION
// Write an "condition EQUALII" operation.
m_pchOutputCode[m_nOutputCodeLength + CVIRTUALMACHINE_OPCODE_LOCATION] = CVIRTUALMACHINE_OPCODE_EQUAL;
m_pchOutputCode[m_nOutputCodeLength + CVIRTUALMACHINE_AUXCODE_LOCATION] = CVIRTUALMACHINE_AUXCODE_TYPETYPE_INTEGER_INTEGER;
m_nOutputCodeLength += CVIRTUALMACHINE_OPERATION_BASE_SIZE;
m_aOutputCodeInstructionBoundaries.push_back(m_nOutputCodeLength);

// CODE GENERATION
// Add the "JNZ _SC_nCaseValue_nSwitchIdentifier" operation.
m_pchOutputCode[m_nOutputCodeLength + CVIRTUALMACHINE_OPCODE_LOCATION] = CVIRTUALMACHINE_OPCODE_JNZ;
m_pchOutputCode[m_nOutputCodeLength + CVIRTUALMACHINE_AUXCODE_LOCATION] = 0;
// CODE GENERATION
// Add the "JNZ _SC_nCaseValue_nSwitchIdentifier" operation.
m_pchOutputCode[m_nOutputCodeLength + CVIRTUALMACHINE_OPCODE_LOCATION] = CVIRTUALMACHINE_OPCODE_JNZ;
m_pchOutputCode[m_nOutputCodeLength + CVIRTUALMACHINE_AUXCODE_LOCATION] = 0;

/* CExoString sSymbolName;
sSymbolName.Format("_SC_%08x_%08x",nCaseValue,m_nSwitchIdentifier); */
AddSymbolToQueryList(m_nOutputCodeLength + CVIRTUALMACHINE_EXTRA_DATA_LOCATION,
CSCRIPTCOMPILER_SYMBOL_TABLE_ENTRY_TYPE_SWITCH_CASE,
nCaseValue,m_nSwitchIdentifier);
/* CExoString sSymbolName;
sSymbolName.Format("_SC_%08x_%08x",nCaseValue,m_nSwitchIdentifier); */
AddSymbolToQueryList(m_nOutputCodeLength + CVIRTUALMACHINE_EXTRA_DATA_LOCATION,
CSCRIPTCOMPILER_SYMBOL_TABLE_ENTRY_TYPE_SWITCH_CASE,
nCaseValue,m_nSwitchIdentifier);

m_nOutputCodeLength += CVIRTUALMACHINE_OPERATION_BASE_SIZE + 4;
m_aOutputCodeInstructionBoundaries.push_back(m_nOutputCodeLength);
m_nOutputCodeLength += CVIRTUALMACHINE_OPERATION_BASE_SIZE + 4;
m_aOutputCodeInstructionBoundaries.push_back(m_nOutputCodeLength);

}
}

nReturnValue = TraverseTreeForSwitchLabels(pNode->pRight);
if (nReturnValue < 0)
{
return nReturnValue;
// Done with current node. Move one to the right, then process left subtree again.
pNode = pNode->pRight;
}

return 0;
Expand Down
26 changes: 26 additions & 0 deletions tests/scriptcomp/corpus/switch.nss
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,30 @@ void main()
case S+S: Assert(FALSE); break;
default: Assert(FALSE); break;
}

// Nested switch
switch (N)
{
case N:
{
switch (N+1)
{
case N+1: Assert(TRUE); break;
case 0: Assert(FALSE); break;
default: Assert(FALSE); break;
}
break;
}
case 0:
{
switch (N+1)
{
case N+1: Assert(TRUE); break;
case 0: Assert(FALSE); break;
default: Assert(FALSE); break;
}
break;
}
default: Assert(FALSE); break;
}
}

0 comments on commit 9bd1d63

Please sign in to comment.