Skip to content

Commit

Permalink
Fix map and apply and fix displaying quoted
Browse files Browse the repository at this point in the history
  • Loading branch information
WyattBlue committed Sep 24, 2023
1 parent feab9ec commit c246f86
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 92 deletions.
75 changes: 17 additions & 58 deletions auto_editor/lang/palet.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,43 +458,15 @@ def __str__(self) -> str:
###############################################################################


def check_args(
o: str,
values: list | tuple,
arity: tuple[int, int | None],
cont: list[Contract] | None,
) -> None:
lower, upper = arity
amount = len(values)

assert not (upper is not None and lower > upper)
base = f"`{o}` has an arity mismatch. Expected "

if lower == upper and len(values) != lower:
raise MyError(f"{base}{lower}, got {amount}")
if upper is None and amount < lower:
raise MyError(f"{base}at least {lower}, got {amount}")
if upper is not None and (amount > upper or amount < lower):
raise MyError(f"{base}between {lower} and {upper}, got {amount}")

if cont is None:
return

for i, val in enumerate(values):
check = cont[-1] if i >= len(cont) else cont[i]
if not check_contract(check, val):
exp = f"{check}" if callable(check) else print_str(check)
raise MyError(f"`{o}` expected a {exp}, got {print_str(val)}")


is_cont = Contract("contract?", is_contract)
is_iterable = Contract(
"iterable?",
lambda v: type(v) in (str, range) or isinstance(v, (list, dict, np.ndarray)),
lambda v: type(v) in (str, range, Quoted)
or isinstance(v, (list, dict, np.ndarray)),
)
is_sequence = Contract(
"sequence?",
lambda v: type(v) in (str, range) or isinstance(v, (list, np.ndarray)),
lambda v: type(v) in (str, range, Quoted) or isinstance(v, (list, np.ndarray)),
)
is_boolarr = Contract(
"bool-array?",
Expand All @@ -504,9 +476,7 @@ def check_args(
"(or/c bool? bool-array?)",
lambda v: type(v) is bool or is_boolarr(v),
)
is_keyw = Contract(
"keyword?", lambda v: type(v) is list and len(v) == 2 and type(v[1]) is Keyword
)
is_keyw = Contract("keyword?", lambda v: type(v) is QuotedKeyword)


def raise_(msg: str) -> None:
Expand Down Expand Up @@ -680,21 +650,14 @@ def vector_extend(vec: list, *more_vecs: list) -> None:
vec.extend(more)


def palet_map(proc: Proc, seq: str | list | range | NDArray) -> Any:
def palet_map(proc: Proc, seq: Any) -> Any:
if type(seq) is str:
return str(map(proc, seq))
if type(seq) is Quoted:
return Quoted(list(map(proc, seq.val)))
if isinstance(seq, (list, range)):
return list(map(proc, seq))

if isinstance(seq, np.ndarray):
if proc.arity[0] != 0:
raise MyError("map: procedure must take at least one arg")
check_args(proc.name, [0], (1, 1), None)
return proc(seq)


def apply(proc: Proc, seq: str | list | range) -> Any:
return reduce(proc, seq)
return proc(seq)


def ref(seq: Any, ref: int) -> Any:
Expand Down Expand Up @@ -1126,10 +1089,12 @@ def syn_for(env: Env, node: list) -> None:
my_eval(env, c)


def syn_quote(env: Env, node: list) -> list:
def syn_quote(env: Env, node: list) -> Any:
guard_term(node, 2, 2)
if type(node[1]) is list or type(node[1]) is Keyword:
return [list, node[1]]
if type(node[1]) is Keyword:
return QuotedKeyword(node[1])
if type(node[1]) is list:
return Quoted(node[1])
return node[1]


Expand Down Expand Up @@ -1382,13 +1347,7 @@ def my_eval(env: Env, node: object) -> Any:
if type(oper) is Syntax:
return oper(env, node)

values = [my_eval(env, c) for c in node[1:]]
if type(oper) is Contract:
check_args(oper.name, values, (1, 1), None)
else:
check_args(oper.name, values, oper.arity, oper.contracts)

return oper(*values)
return oper(*(my_eval(env, c) for c in node[1:]))

return node

Expand Down Expand Up @@ -1526,8 +1485,8 @@ def my_eval(env: Env, node: object) -> Any:
"~v": Proc("~v", lambda *v: " ".join([print_str(a) for a in v]), (0, None)),
# keyword
"keyword?": is_keyw,
"keyword->string": Proc("keyword->string", lambda k: k[1].val, (1, 1), [is_keyw]),
"string->keyword": Proc("string->keyword", lambda s: [list, Keyword(s)], (1, 1), [is_str]),
"keyword->string": Proc("keyword->string", lambda v: v.val.val, (1, 1), [is_keyw]),
"string->keyword": Proc("string->keyword", QuotedKeyword, (1, 1), [is_str]),
# vectors
"vector": Proc("vector", lambda *a: list(a), (0, None)),
"make-vector": Proc(
Expand Down Expand Up @@ -1566,7 +1525,7 @@ def my_eval(env: Env, node: object) -> Any:
"slice": Proc("slice", p_slice, (2, 4), [is_sequence, is_int]),
# procedures
"map": Proc("map", palet_map, (2, 2), [is_proc, is_sequence]),
"apply": Proc("apply", apply, (2, 2), [is_proc, is_sequence]),
"apply": Proc("apply", lambda p, s: p(*s), (2, 2), [is_proc, is_sequence]),
"and/c": Proc("and/c", andc, (1, None), [is_cont]),
"or/c": Proc("or/c", orc, (1, None), [is_cont]),
"not/c": Proc("not/c", notc, (1, 1), [is_cont]),
Expand Down
83 changes: 58 additions & 25 deletions auto_editor/lib/contracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,37 +9,17 @@
from .err import MyError


@dataclass(slots=True)
class Proc:
name: str
proc: Callable
arity: tuple[int, int | None] = (1, None)
contracts: list[Any] | None = None

def __call__(self, *args: Any) -> Any:
return self.proc(*args)

def __str__(self) -> str:
return self.name

def __repr__(self) -> str:
n = "inf" if self.arity[1] is None else f"{self.arity[1]}"

if self.contracts is None:
c = ""
else:
c = " (" + " ".join([f"{c}" for c in self.contracts]) + ")"
return f"#<proc:{self.name} ({self.arity[0]} {n}){c}>"


@dataclass(slots=True)
class Contract:
# Convenient flat contract class
name: str
c: Callable[[object], bool]

def __call__(self, v: object) -> bool:
return self.c(v)
def __call__(self, *v: object) -> bool:
if len(v) != 1:
o = self.name
raise MyError(f"`{o}` has an arity mismatch. Expected 1, got {len(v)}")
return self.c(v[0])

def __str__(self) -> str:
return self.name
Expand All @@ -66,6 +46,59 @@ def check_contract(c: object, val: object) -> bool:
raise MyError(f"Invalid contract, got: {print_str(c)}")


def check_args(
o: str,
values: list | tuple,
arity: tuple[int, int | None],
cont: list[Contract] | None,
) -> None:
lower, upper = arity
amount = len(values)

assert not (upper is not None and lower > upper)
base = f"`{o}` has an arity mismatch. Expected "

if lower == upper and len(values) != lower:
raise MyError(f"{base}{lower}, got {amount}")
if upper is None and amount < lower:
raise MyError(f"{base}at least {lower}, got {amount}")
if upper is not None and (amount > upper or amount < lower):
raise MyError(f"{base}between {lower} and {upper}, got {amount}")

if cont is None:
return

for i, val in enumerate(values):
check = cont[-1] if i >= len(cont) else cont[i]
if not check_contract(check, val):
exp = f"{check}" if callable(check) else print_str(check)
raise MyError(f"`{o}` expected a {exp}, got {print_str(val)}")


@dataclass(slots=True)
class Proc:
name: str
proc: Callable
arity: tuple[int, int | None] = (1, None)
contracts: list[Any] | None = None

def __call__(self, *args: Any) -> Any:
check_args(self.name, args, self.arity, self.contracts)
return self.proc(*args)

def __str__(self) -> str:
return self.name

def __repr__(self) -> str:
n = "inf" if self.arity[1] is None else f"{self.arity[1]}"

if self.contracts is None:
c = ""
else:
c = " (" + " ".join([f"{c}" for c in self.contracts]) + ")"
return f"#<proc:{self.name} ({self.arity[0]} {n}){c}>"


def is_contract(c: object) -> bool:
if type(c) is Contract:
return True
Expand Down
51 changes: 42 additions & 9 deletions auto_editor/lib/data_structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,43 @@ def __eq__(self, obj: object) -> bool:
return type(obj) is Keyword and self.val == obj.val


class QuotedKeyword:
__slots__ = "val"

def __init__(self, val: Keyword | str):
self.val = val if isinstance(val, Keyword) else Keyword(val)

def __str__(self) -> str:
return f"{self.val}"

__repr__ = __str__

def __eq__(self, obj: object) -> bool:
return type(obj) is QuotedKeyword and self.val == obj.val


class Quoted:
__slots__ = "val"

def __init__(self, val: list):
self.val = val

def __len__(self) -> int:
return len(self.val)

def __getitem__(self, index: int) -> object:
return self.val[index]

def __iter__(self) -> list:
return self.val

def __contains__(self, item: object) -> bool:
return item in self.val

def __eq__(self, obj: object) -> bool:
return type(obj) is Quoted and self.val == obj.val


class Char:
__slots__ = "val"

Expand Down Expand Up @@ -141,16 +178,12 @@ def display_str(val: object) -> str:
if type(val) is Fraction:
return f"{val.numerator}/{val.denominator}"

if type(val) is list and val and val[0] is list:
if type(val[1]) is Keyword:
return f"{val[1]}"

if not val[1]:
if type(val) is Quoted:
if not val:
return "()"

result = StringIO()
result.write(f"({display_str(val[1][0])}")
for item in val[1][1:]:
result.write(f"({display_str(val[0])}")
for item in val[1:]:
result.write(f" {display_str(item)}")
result.write(")")
return result.getvalue()
Expand Down Expand Up @@ -213,7 +246,7 @@ def print_str(val: object) -> str:
return f"{val!r}"
if type(val) is Keyword:
return f"'{val}"
if type(val) is Sym or (type(val) is list and val and val[0] is list):
if type(val) in (Sym, Quoted, QuotedKeyword):
return f"'{display_str(val)}"

return display_str(val)
4 changes: 4 additions & 0 deletions resources/scripts/scope.pal
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,7 @@
(let* ([x 1] [y (+ x 1)]) #(y x))
#(2 1)
))

(assert (= (apply add1 #(4)) 5))
(assert (= (apply sub1 #(4)) 3))
(assert (equal? (map add1 '(3 4 5)) '(4 5 6)))

0 comments on commit c246f86

Please sign in to comment.