-
Notifications
You must be signed in to change notification settings - Fork 13
/
adapters.py
140 lines (118 loc) · 4.47 KB
/
adapters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#!/usr/bin/env python3
from __future__ import annotations
import os
from typing import Any
import torch
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizerBase
def get_packed_sft_dataset(
tokenizer: PreTrainedTokenizerBase,
dataset_path: str | os.PathLike,
seq_length: int,
shuffle: bool,
) -> Dataset:
"""
Given a tokenizer and a path to a dataset with instruction-tuning examples,
construct a PyTorch Dataset for language modeling. The examples should be
packed, i.e., all sequences in the dataset are of a constant length (`seq_length`).
Args:
tokenizer: transformers.PreTrainedTokenizerBase
Transformers tokenizer to use in tokenizing and encoding text.
dataset_path: str
Path to file with instruction-tuning examples.
seq_length: int
Number of tokens to include in each example.
shuffle: bool
If true, shuffle the documents before packing them into examples.
Returns:
PyTorch Dataset for language modeling. Each example in this dataset is a dictionary of
with keys "input_ids" and "labels" (both tensors of shape (seq_length, )).
"input_ids" contains the token IDs for the language modeling inputs, and "labels" contains
the token IDs for the language modeling labels.
"""
raise NotImplementedError
def run_iterate_batches(
dataset: Dataset,
batch_size: int,
shuffle: bool,
):
"""
Given a PyTorch Dataset, return an iterable over batches of size `batch_size`.
Iterating through the returned iterable should constitute one epoch over the Dataset.
Args:
dataset: Dataset
Dataset to emit batches from.
batch_size: int
Number of examples to include per batch.
shuffle: bool
If true, shuffle examples before batching them.
Returns:
Iterable over batches, where each batch has size `batch_size`.
"""
raise NotImplementedError
def run_parse_mmlu_response(
mmlu_example: dict[str, Any],
model_output: str,
) -> str | None:
"""
Given an MMLU example and a model output, parse the model output into a
predicted option letter (i.e., 'A', 'B', 'C', or 'D'). If the model output
cannot be parsed into a prediction option letter, return None.
mmlu_example: dict[str, Any]
Dictionary with an MMLU example. Contains the following keys:
- "subject": str with the subject of the question.
- "question": str with the text of the question.
- "options": list[str] with the four answer options (in order).
The first option refers to letter "A", the second to "B", etc.
- "answer": str with the option of the correct answer (e.g., "A")
model_output: str
str with the model's output to the MMLU example.
Returns:
str (one of "A", "B", "C", or "D") if the model output can be parsed into a prediction,
else None.
"""
raise NotImplementedError
def run_parse_gsm8k_response(
model_output: str,
) -> str | None:
"""
Given a GSM8K model output, parse the model output into a predicted numeric answer by
taking the last number that occurs in the output.
model_output: str
str with the model's output to a GSM8K example.
Returns:
str with the predicted numeric answer if the model output can be parsed into a prediction,
else None.
"""
raise NotImplementedError
def compute_per_instance_dpo_loss(
lm: torch.nn.Module,
lm_ref: torch.nn.Module,
tokenizer: PreTrainedTokenizerBase,
beta: float,
prompt: str,
response_chosen: str,
response_rejected: str,
) -> torch.Tensor:
"""
Given two language models (`lm`, and the "reference model" `lm_ref`),
their tokenizer, the DPO beta hyperparameter, a prompt and a pair
of responses to the prompt, computes the value of the DPO loss for this example.
lm: torch.nn.Module
Language model being trained.
lm_ref: torch.nn.Module
Reference language model.
tokenizer: PreTrainedTokenizerBase
Tokenizer for both language models.
beta: float
DPO beta hyperparameter.
prompt: str
Prompt for this instance of preference pair.
response_chosen: str
Preferred response to the prompt.
response_rejected: str
Rejected response to the prompt.
Returns:
torch.Tensor with the DPO loss for this example.
"""
raise NotImplementedError