Skip to content

Commit

Permalink
add moreee textttsss
Browse files Browse the repository at this point in the history
  • Loading branch information
yamaguchi1024 committed Oct 13, 2024
1 parent 5a4fdc2 commit 6fabfc1
Showing 1 changed file with 172 additions and 21 deletions.
193 changes: 172 additions & 21 deletions docs/Cursors.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,33 +104,184 @@ s2 <- c

## Cursor inspection

- `StmtCursor`s wrap the underlying IR object which can be inspected
- Ex. check cursor type with `isinstance(c, PC.AlloCursor)`
`StmtCursor`s wrap the underlying Exo IR object and can be inspected.
- Ex. check cursor type with `isinstance(c, PC.AllocCursor)`

## Cursor forwarding
`StmtCursor`s are one of the following types.

- `Procedure.forward(cursor)` applies forwarding to resolve a cursor from a previous procedure
- Each pass returns a fwd function mapping old IR to new IR
- `p.forward(c)` composes fwd functions from `c` to `p`
#### `ArgCursor`

`p.forward(...)` provides some examples.
Represents a cursor pointing to a procedure argument of the form:
```
name : type @ mem
```

Methods:
- `name() -> str`: Returns the name of the argument.
- `mem() -> Memory`: Returns the memory location of the argument.
- `is_tensor() -> bool`: Checks if the argument is a tensor.
- `shape() -> ExprListCursor`: Returns a cursor to the shape expression list.
- `type() -> API.ExoType`: Returns the type of the argument.

#### `AssignCursor`

Represents a cursor pointing to an assignment statement of the form:
```
name[idx] = rhs
```

Methods:
- `name() -> str`: Returns the name of the variable being assigned to.
- `idx() -> ExprListCursor`: Returns a cursor to the index expression list.
- `rhs() -> ExprCursor`: Returns a cursor to the right-hand side expression.
- `type() -> API.ExoType`: Returns the type of the assignment.

#### `ReduceCursor`

Each scheduling primitive returns a forwarding function.
Procedure objects have a pointer to the previous procedure, and the forwarding function from the previous procedure to the current procedure.
When `p.forward(...)` is called, all the forwarding functions up until the Cursor's procedure will get applied.
Represents a cursor pointing to a reduction statement of the form:
```
name[idx] += rhs
```

Methods:
- `name() -> str`: Returns the name of the variable being reduced.
- `idx() -> ExprListCursor`: Returns a cursor to the index expression list.
- `rhs() -> ExprCursor`: Returns a cursor to the right-hand side expression.


#### `AssignConfigCursor`

Represents a cursor pointing to a configuration assignment statement of the form:
```
config.field = rhs
```

Methods:
- `config() -> Config`: Returns the configuration object.
- `field() -> str`: Returns the name of the configuration field being assigned to.
- `rhs() -> ExprCursor`: Returns a cursor to the right-hand side expression.

#### `PassCursor`

Represents a cursor pointing to a no-op statement:
```
pass
```

#### `IfCursor`

Represents a cursor pointing to an if statement of the form:
```
if condition:
body
```
or
```
if condition:
body
else:
orelse
```
Returns an invalid cursor if `orelse` isn't present.

Methods:
- `cond() -> ExprCursor`: Returns a cursor to the if condition expression.
- `body() -> BlockCursor`: Returns a cursor to the if body block.
- `orelse() -> Cursor`: Returns a cursor to the else block (if present).

#### `ForCursor`

Represents a cursor pointing to a loop statement of the form:
```
for name in seq(0, hi):
body
```

Literally the code in src/exo/API.py
Methods:
- `name() -> str`: Returns the loop variable name.
- `lo() -> ExprCursor`: Returns a cursor to the lower bound expression (defaults to 0).
- `hi() -> ExprCursor`: Returns a cursor to the upper bound expression.
- `body() -> BlockCursor`: Returns a cursor to the loop body block.


#### `AllocCursor`

Represents a cursor pointing to a buffer allocation statement of the form:
```
name : type @ mem
```

Methods:
- `name() -> str`: Returns the name of the allocated buffer.
- `mem() -> Memory`: Returns the memory location of the buffer.
- `is_tensor() -> bool`: Checks if the allocated buffer is a tensor.
- `shape() -> ExprListCursor`: Returns a cursor to the shape expression list.
- `type() -> API.ExoType`: Returns the type of the allocated buffer.


#### `CallCursor`

Represents a cursor pointing to a sub-procedure call statement of the form:
```
subproc(args)
```

Methods:
- `subproc()`: Returns the called sub-procedure.
- `args() -> ExprListCursor`: Returns a cursor to the argument expression list.


#### `WindowStmtCursor`

Represents a cursor pointing to a window declaration statement of the form:
```
name = winexpr
```

Methods:
- `name() -> str`: Returns the name of the window.
- `winexpr() -> ExprCursor`: Returns a cursor to the window expression.


## Cursor Forwarding

When a procedure `p` is transformed into a new procedure `p'` by applying scheduling primitives, any cursors pointing into `p` need to be updated to point to the corresponding locations in `p'`. This process is called *cursor forwarding*.

To forward a cursor `c` from `p` to `p'`, you can use the `forward` method on the new procedure:
```python
def forward(self, cur: C.Cursor):
p = self
fwds = []
while p is not None and p is not cur.proc():
fwds.append(p._forward)
p = p._provenance_eq_Procedure
c' = p'.forward(c)
```

### How Forwarding Works

ir = cur._impl
for fn in reversed(fwds):
ir = fn(ir)
Internally, each scheduling primitive returns a *forwarding function* that maps locations in the input procedure to locations in the output procedure.

return C.lift_cursor(ir, self)
When you call `p'.forward(c)`, Exo composes the forwarding functions for all the scheduling steps between `c.proc()` (the procedure `c` points into, in this case `p`) and `p'` (the final procedure). This composition produces a single function that can map `c` from its original procedure to the corresponding location in `p'`.

Here's the actual implementation of forwarding in `src/exo/API.py`:

```python
def forward(self, cur: C.Cursor):
p = self
fwds = []
while p is not None and p is not cur.proc():
fwds.append(p._forward)
p = p._provenance_eq_Procedure

ir = cur._impl
for fn in reversed(fwds):
ir = fn(ir)

return C.lift_cursor(ir, self)
```

The key steps are:

1. Collect the forwarding functions (`p._forward`) for all procedures between `cur.proc()` and `self` (the final procedure).
2. Get the underlying Exo IR for the input cursor (`cur._impl`).
3. Apply the forwarding functions in reverse order to map the IR node to its final location.
4. Lift the mapped IR node back into a cursor in the final procedure.

So in summary, `p.forward(c)` computes and applies the composite forwarding function to map cursor `c` from its original procedure to the corresponding location in procedure `p`.


0 comments on commit 6fabfc1

Please sign in to comment.