Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Language modeling example #101

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -177,4 +177,3 @@ files/
# Generated Config files.
examples/**/*.yaml
**/config.yaml
examples/language_modeling/scripts
21 changes: 21 additions & 0 deletions examples/language_modeling/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Training data attribution with language modeling task
This directory contains the codes for running the training data attribution with large scale language models like LLAMA3. In essence, the code ranks the pretraining data based on the importance of each data point in the generation of a target sentence. The procedure to rank the data points is as follows:
1. **Data Preparation**: Generate the model outputs that we will analyze. This is a simple code that generates the output based on the prompt. We experimented with `Meta-Llama-3-8B-Instruct`, `pythia-1.4b` and `gpt2-xl`. Use `generate_llama3.py` for `Meta-Llama-3-8B-Instruct` and `generate.py` for `pythia-1.4b` and `gpt2-xl`.
```python
python generate_llama3.py
python generate.py
```

2. **Extract Log**: `extract_log.py` extracts training gradients for each pretraining data point, compresses them using LoGra, and saves them in files. Note that by default we use 1B tokens from `openwebtext` data, leveraging data parallelism. An example running command is as follows (the actual command used for the paper could be found in `scripts` folder). This is the most time consuming part of the pipeline.
```python
accelerate launch --num_processes 2 --num_machines 1 --multi_gpu extract_log.py --model_name meta-llama/Meta-Llama-3-8B-Instruct --lora random --hessian raw --mlp_only --data_name openwebtext
```
As a result, the code will generate a folder containing the compressed gradients for each data point and other statistics necessary for running LoGra (e.g. the random initialization of LoGra parameters, the covariance of the gradients, etc.).

3. **Compute Influence function**: `compute_influence.py` computes the influence score for each data point, using the compressed gradient we just generated. The specified query data (`data_name`) is used to compute the query gradient. As we have already saved (preconditioned) the training gradients, this is a relatively fast process.
```python
python compute_influence.py --model_name meta-llama/Meta-Llama-3-8B-Instruct --lora random --hessian raw --split generated --mlp_only --data_name openwebtext --mode cosine
```

4. `Analysis`: Finally, we also include a minimal analysis code that extracts the top-k most influential data points and saves them in a file. This code is `analysis.py`.
```python
34 changes: 34 additions & 0 deletions examples/language_modeling/analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os

import torch


k = 20
experiment = "pythia-1.4b_random_raw_openwebtext_mlp/generated"
mode = "cosine"

scores = torch.load(os.path.join(experiment, f"scores_{mode}.pt"), map_location="cpu")
train_ids = torch.load(
os.path.join(experiment, f"train_ids_{mode}.pt"), map_location="cpu"
)
test_ids = torch.load(
os.path.join(experiment, f"test_ids_{mode}.pt"), map_location="cpu"
)
print(len(train_ids), len(test_ids), scores.shape)
assert len(train_ids) == scores.shape[1]
assert len(test_ids) == scores.shape[0]

out = ""
for idx, test_id in enumerate(test_ids):
out += "==========================================================\n"
out += f"Query: {test_id}\n"
out += "==========================================================\n"
topk_indices = torch.topk(scores[idx], k=k)[1]
for j, topk_idx in enumerate(topk_indices):
score = scores[idx][topk_idx]
train_id = train_ids[topk_idx]
out += f"Top {j + 1} (score: {score})]: {train_id}\n"
out += "==========================================================\n"

with open(os.path.join(experiment, f"top_{mode}.txt"), "w") as file:
file.write(out)
78 changes: 34 additions & 44 deletions examples/language_modeling/compute_influence.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import os
import copy
import argparse
import copy
import os

import logix
import torch
import torch.nn.functional as F
from accelerate import Accelerator
import logix
from logix.analysis import InfluenceFunction
from logix.utils import merge_logs
from tqdm import tqdm

from utils import get_model, get_tokenizer, get_loader, set_seed

from utils import get_loader, get_model, get_tokenizer, set_seed

if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
Expand All @@ -20,55 +18,37 @@
def main():
parser = argparse.ArgumentParser("GPT2 Influence Analysis")
parser.add_argument("--config_path", type=str, default="./config.yaml")
parser.add_argument(
"--cache_dir",
type=str,
default="/data/tir/projects/tir3/users/sangkeuc/huggingface",
)
parser.add_argument(
"--save_dir",
type=str,
default="/data/tir/projects/tir3/users/sangkeuc/gpt/results",
)
parser.add_argument("--model_name", type=str, default="gpt2-xl")
parser.add_argument("--data_path", type=str, default="wikitext")
parser.add_argument("--data_name", type=str, default=None)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--hessian", type=str, default="raw")
parser.add_argument("--cache_dir", type=str, default=None)
parser.add_argument("--model_name", type=str, default="gpt2")
parser.add_argument("--hessian", type=str, default="kfac")
parser.add_argument("--lora", type=str, default="random")
parser.add_argument("--split", type=str, default="train")
parser.add_argument("--split", type=str, default="valid")
parser.add_argument("--mlp_only", action="store_true")
parser.add_argument("--layerwise", action="store_true")
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--damping", type=float, default=1e-5)
parser.add_argument("--data_name", type=str, default="openwebtext")
parser.add_argument("--mode", type=str, default="dot")
args = parser.parse_args()

set_seed(0)
accelerator = Accelerator()
influence_groups = None
if args.layerwise:
layer_id = "h" if args.model_name == "gpt2-xl" else "layers"
layer_num = 48 if args.model_name == "gpt2-xl" else 32
influence_groups = [f".{layer_id}.{i}." for i in range(layer_num)]

# prepare model & data loader
model = get_model(model_name=args.model_name, cache_dir=args.cache_dir)
tokenizer = get_tokenizer(
model_name=args.model_name, cache_dir=args.cache_dir, add_padding_token=True
)
tokenizer = get_tokenizer(model_name=args.model_name, cache_dir=args.cache_dir)
data_loader = get_loader(
model_name=args.model_name,
data_path=args.data_path,
data_name=args.data_name,
tokenizer=tokenizer,
batch_size=args.batch_size,
cache_dir=args.cache_dir,
split=args.split,
data_name=args.data_name,
)
model, data_loader = accelerator.prepare(model, data_loader)

# Set-up LogIX
model_name_strip = args.model_name.split("/")[-1]
project = f"{model_name_strip}_{args.lora}_{args.hessian}"
project = f"{model_name_strip}_{args.lora}_{args.hessian}_{args.data_name}"
name_filter = ["att", "mlp"]
if args.mlp_only:
project += "_mlp"
Expand All @@ -82,6 +62,9 @@ def main():
# Influence analysis
logix.setup({"log": "grad"})
logix.eval()
if_scores = []
train_ids = None
test_ids = []
merged_test_logs = []
for idx, batch in enumerate(tqdm(data_loader)):
data_id = tokenizer.batch_decode(batch["input_ids"], skip_special_tokens=True)
Expand All @@ -103,23 +86,30 @@ def main():
test_log = logix.get_log()
merged_test_logs.append(copy.deepcopy(test_log))

if idx == 12 or idx == len(data_loader) - 1:
if idx == len(data_loader) - 1:
merged_test_log = merge_logs(merged_test_logs)
result = run.influence.compute_influence_all(
merged_test_log, log_loader, influence_groups=influence_groups
if_score, train_ids_batch = run.influence.compute_influence_all(
merged_test_log,
log_loader,
mode=args.mode,
)
if_scores.append(if_score)
if train_ids is None:
train_ids = train_ids_batch
else:
assert train_ids == train_ids_batch
test_ids.extend(merged_test_log[0])
if_scores = torch.cat(if_scores, dim=0)
merged_test_logs = []
break

post_fix = f"{args.split}_{model_name_strip}_{args.lora}_{args.hessian}"
if args.mlp_only:
post_fix += "_mlp"
save_dir = os.path.join(args.save_dir, post_fix)
base_dir = os.path.dirname(os.path.abspath(__file__)) # current file's directory
save_dir = os.path.join(base_dir, project, f"{args.split}")
if not os.path.exists(save_dir):
os.makedirs(save_dir)
torch.save(result["influence"], os.path.join(save_dir, "scores.pt"))
torch.save(result["src_ids"], os.path.join(save_dir, "test_ids.pt"))
torch.save(result["tgt_ids"], os.path.join(save_dir, "train_ids.pt"))
torch.save(if_scores, os.path.join(save_dir, f"scores_{args.mode}.pt"))
torch.save(train_ids, os.path.join(save_dir, f"train_ids_{args.mode}.pt"))
torch.save(test_ids, os.path.join(save_dir, f"test_ids_{args.mode}.pt"))


if __name__ == "__main__":
Expand Down
28 changes: 9 additions & 19 deletions examples/language_modeling/extract_log.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
import os
import argparse

from tqdm import tqdm
import logix
import torch
import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.utils import GradScalerKwargs
import logix
from logix.statistic import Covariance

from utils import get_model, get_tokenizer, get_loader, set_seed
from tqdm import tqdm

from utils import get_loader, get_model, get_tokenizer, set_seed

# Enable TF32 if possible
if torch.cuda.is_available():
Expand All @@ -20,21 +16,15 @@
def main():
parser = argparse.ArgumentParser("GPT2 Influence Analysis")
parser.add_argument("--config_path", type=str, default="./config.yaml")
parser.add_argument(
"--cache_dir",
type=str,
default="/data/tir/projects/tir3/users/sangkeuc/huggingface",
)
parser.add_argument("--cache_dir", type=str, default=None)
parser.add_argument("--model_name", type=str, default="gpt2")
parser.add_argument("--data_path", type=str, default="wikitext")
parser.add_argument("--data_name", type=str, default=None)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--hessian", type=str, default="raw")
parser.add_argument("--hessian", type=str, default="kfac")
parser.add_argument("--lora", type=str, default="random")
parser.add_argument("--save", type=str, default="grad")
parser.add_argument("--mlp_only", action="store_true")
parser.add_argument("--data_name", type=str, default="openwebtext")
args = parser.parse_args()
print(args)

set_seed(0)
accelerator = Accelerator()
Expand All @@ -44,16 +34,15 @@ def main():
tokenizer = get_tokenizer(model_name=args.model_name, cache_dir=args.cache_dir)
data_loader = get_loader(
model_name=args.model_name,
data_path=args.data_path,
data_name=args.data_name,
tokenizer=tokenizer,
batch_size=args.batch_size,
cache_dir=args.cache_dir,
data_name=args.data_name,
)

# LogIX Setup
model_name_strip = args.model_name.split("/")[-1]
project = f"{model_name_strip}_{args.lora}_{args.hessian}"
project = f"{model_name_strip}_{args.lora}_{args.hessian}_{args.data_name}"
name_filter = ["att", "mlp"]
if args.mlp_only:
project += "_mlp"
Expand Down Expand Up @@ -87,6 +76,7 @@ def main():
)
accelerator.backward(loss)
logix.finalize()
print(f"Log saved in {project}")


if __name__ == "__main__":
Expand Down
Loading
Loading