Skip to content

Commit

Permalink
Use line_by_line for default gen if present
Browse files Browse the repository at this point in the history
  • Loading branch information
minimaxir committed May 1, 2021
1 parent 34e6cad commit 342a361
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
1 change: 1 addition & 0 deletions aitextgen/TokenDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __init__(
self.tokens = np.load(f)
self.num_subsets = self.tokens.shape[0] - block_size
self.block_size = block_size
self.line_by_line = line_by_line
self.str_suffix = "via cache."

logger.info(
Expand Down
5 changes: 4 additions & 1 deletion aitextgen/aitextgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def generate(
self,
n: int = 1,
prompt: str = "",
prepend_bos: bool = False,
prepend_bos: bool = None,
min_length: int = None,
max_length: int = 256,
temperature: float = 0.7,
Expand Down Expand Up @@ -325,6 +325,9 @@ def generate(
prompt_tensors["input_ids"].to(self.get_device()) if prompt else None
)

if prepend_bos is None:
prepend_bos = getattr(self.model.config, "line_by_line", None)

if prepend_bos:
bos = torch.tensor([[self.tokenizer.bos_token_id]]).to(self.get_device())
if prompt:
Expand Down

0 comments on commit 342a361

Please sign in to comment.