Skip to content

Commit

Permalink
Fix class detection for namespaced classes (Py)
Browse files Browse the repository at this point in the history
This commit asjusts the Python generated parser to correctly deal with
namespaced classes (e.g., those coming from cwltool extensions).
  • Loading branch information
GlassOfWhiskey committed Dec 4, 2024
1 parent f3518e2 commit daf7967
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 33 deletions.
2 changes: 1 addition & 1 deletion schema_salad/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .typescript_codegen import TypeScriptCodeGen
from .utils import aslist

FIELD_SORT_ORDER = ["id", "class", "name"]
FIELD_SORT_ORDER = ["class", "id", "name"]


def codegen(
Expand Down
28 changes: 28 additions & 0 deletions schema_salad/metaschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,6 +1162,8 @@ class RecordField(Documented):
A field of a record.
"""

class_uri = "https://w3id.org/cwl/salad#RecordField"

def __init__(
self,
name: Any,
Expand Down Expand Up @@ -1428,6 +1430,8 @@ def save(


class RecordSchema(Saveable):
class_uri = "https://w3id.org/cwl/salad#RecordSchema"

def __init__(
self,
type_: Any,
Expand Down Expand Up @@ -1632,6 +1636,8 @@ class EnumSchema(Saveable):
"""

class_uri = "https://w3id.org/cwl/salad#EnumSchema"

def __init__(
self,
symbols: Any,
Expand Down Expand Up @@ -1898,6 +1904,8 @@ def save(


class ArraySchema(Saveable):
class_uri = "https://w3id.org/cwl/salad#ArraySchema"

def __init__(
self,
items: Any,
Expand Down Expand Up @@ -2097,6 +2105,8 @@ def save(


class MapSchema(Saveable):
class_uri = "https://w3id.org/cwl/salad#MapSchema"

def __init__(
self,
type_: Any,
Expand Down Expand Up @@ -2296,6 +2306,8 @@ def save(


class UnionSchema(Saveable):
class_uri = "https://w3id.org/cwl/salad#UnionSchema"

def __init__(
self,
names: Any,
Expand Down Expand Up @@ -2501,6 +2513,8 @@ class JsonldPredicate(Saveable):
"""

class_uri = "https://w3id.org/cwl/salad#JsonldPredicate"

def __init__(
self,
_id: Optional[Any] = None,
Expand Down Expand Up @@ -3239,6 +3253,8 @@ def save(


class SpecializeDef(Saveable):
class_uri = "https://w3id.org/cwl/salad#SpecializeDef"

def __init__(
self,
specializeFrom: Any,
Expand Down Expand Up @@ -3463,6 +3479,8 @@ class SaladRecordField(RecordField):
A field of a record.
"""

class_uri = "https://w3id.org/cwl/salad#SaladRecordField"

def __init__(
self,
name: Any,
Expand Down Expand Up @@ -3844,6 +3862,8 @@ def save(


class SaladRecordSchema(NamedType, RecordSchema, SchemaDefinedType):
class_uri = "https://w3id.org/cwl/salad#SaladRecordSchema"

def __init__(
self,
name: Any,
Expand Down Expand Up @@ -4705,6 +4725,8 @@ class SaladEnumSchema(NamedType, EnumSchema, SchemaDefinedType):
"""

class_uri = "https://w3id.org/cwl/salad#SaladEnumSchema"

def __init__(
self,
symbols: Any,
Expand Down Expand Up @@ -5446,6 +5468,8 @@ class SaladMapSchema(NamedType, MapSchema, SchemaDefinedType):
"""

class_uri = "https://w3id.org/cwl/salad#SaladMapSchema"

def __init__(
self,
name: Any,
Expand Down Expand Up @@ -6131,6 +6155,8 @@ class SaladUnionSchema(NamedType, UnionSchema, DocType):
"""

class_uri = "https://w3id.org/cwl/salad#SaladUnionSchema"

def __init__(
self,
name: Any,
Expand Down Expand Up @@ -6757,6 +6783,8 @@ class Documentation(NamedType, DocType):
"""

class_uri = "https://w3id.org/cwl/salad#Documentation"

def __init__(
self,
name: Any,
Expand Down
51 changes: 22 additions & 29 deletions schema_salad/python_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def begin_class(
idfield: str,
optional_fields: set[str],
) -> None:
class_uri = classname
classname = self.safe_name(classname)

if extends:
Expand All @@ -163,6 +164,8 @@ def begin_class(
self.out.write(" pass\n\n\n")
return

self.out.write(f' class_uri = "{class_uri}"\n\n')

required_field_names = [f for f in field_names if f not in optional_fields]
optional_field_names = [f for f in field_names if f in optional_fields]

Expand Down Expand Up @@ -276,27 +279,6 @@ def save(
"""
)

if "class" in field_names:
self.out.write(
"""
if "class" not in _doc:
raise ValidationException("Missing 'class' field")
if _doc.get("class") != "{class_}":
raise ValidationException("tried `{class_}` but")
""".format(
class_=classname
)
)

self.serializer.write(
"""
r["class"] = "{class_}"
""".format(
class_=classname
)
)

def end_class(self, classname: str, field_names: list[str]) -> None:
"""Signal that we are done with this class."""
if self.current_class_is_abstract:
Expand Down Expand Up @@ -554,9 +536,6 @@ def declare_field(
if self.current_class_is_abstract:
return

if shortname(name) == "class":
return

if optional:
self.out.write(f""" {self.safe_name(name)} = None\n""")
self.out.write(f""" if "{shortname(name)}" in _doc:\n""") # noqa: B907
Expand Down Expand Up @@ -608,8 +587,22 @@ def declare_field(
spc=spc,
)
)
self.out.write(
"""

if shortname(name) == "class":
self.out.write(
"""{spc} if {safename} != cls.__name__ and {safename} != cls.class_uri:
{spc} raise ValidationException(f"tried `{{cls.__name__}}` but")
{spc} except ValidationException as e:
{spc} raise e
""".format(
safename=self.safe_name(name),
spc=spc,
)
)

else:
self.out.write(
"""
{spc} except ValidationException as e:
{spc} error_message, to_print, verb_tensage = parse_errors(str(e))
Expand Down Expand Up @@ -647,10 +640,10 @@ def declare_field(
{spc} )
{spc} )
""".format(
fieldname=shortname(name),
spc=spc,
fieldname=shortname(name),
spc=spc,
)
)
)

if name == self.idfield or not self.idfield:
baseurl = "base_url"
Expand Down
6 changes: 3 additions & 3 deletions schema_salad/tests/test_codegen_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ def test_error_message5(tmp_path: Path) -> None:
def test_error_message6(tmp_path: Path) -> None:
t = "test_schema/test6.cwl"
match = r"""\*\s+tried\s+`CommandLineTool`\s+but
\s+Missing\s+'class'\s+field
\s+missing\s+required\s+field\s+`class`
+\*\s+tried\s+`ExpressionTool`\s+but
\s+Missing\s+'class'\s+field
\s+missing\s+required\s+field\s+`class`
+\*\s+tried\s+`Workflow`\s+but
\s+Missing\s+'class'\s+field"""
\s+missing\s+required\s+field\s+`class`"""
path = get_data("tests/" + t)
assert path
with pytest.raises(ValidationException, match=match):
Expand Down

0 comments on commit daf7967

Please sign in to comment.