4 min read

Using data attribution for AI alignment

This is a post on a recent paper I thought was cool. I give some follow-up project ideas after.

In-Run Data Shapley: Data attribution method efficient enough for pre-training data attribution.

Essentially, it can track how individual data points (or clusters) impact model performance across pre-training. You just need to develop a set of validation examples to continually check the model's performance on those examples during pre-training. Amazingly, you can do this over the course of a single training run; no need to require multiple pre-training runs like other data attribution methods have required.

Other methods, like influence functions, are too computationally expensive to run during pre-training and can only be run post-training.

So, here's why this might be interesting from an alignment perspective:

  • You might be able to set up a bunch of validation examples to test specific behaviour in the models so that we are hyper-aware of which data points contribute the most to that behaviour. For example, self-awareness or self-preservation.
  • Given that this is possible to run during pre-training, you might understand model behaviour at such a granular level that you can construct data mixtures/curriculums that push the model towards internalizing 'human values' much sooner than it develops behaviours or capabilities we wouldn't want. Or, you delay self-awareness and such much further along in the training process.
  • In this Roger Dearnaley post, A "Bitter Lesson" Approach to Aligning AGI and ASI, Roger proposes training an AI on a synthetic dataset where all intelligences are motivated by the collective well-being of humanity. You are trying to bias the model to be as close to the basin of attraction for alignment as possible. In-Run Data Shapley could be used to construct such a dataset and guide the training process so that the training data best exemplifies the desired aligned behaviour.

I think data is underrated among the alignment community (synthetic/transformed data even more). I have been thinking about it from the perspective of pre-training and post-training. My initial look into synthetic data was related to online learning and essentially controlling model behaviour. I was interested in papers like this one by Google, where they significantly reduce sycophancy in an LLM via 1k synthetically generated examples. Data shapes behaviour, and I think many people do not acknowledge this enough (which sometimes leads them to make confused conclusions about model behaviour).

In terms of specific research projects, my current ideas fall into these kinds of buckets:

Pre-training close to the basin of attraction for alignment

  1. How much can we improve "Pretraining Language Models with Human Preferences"? I'd like to transform training in various ways (as mentioned in your posts). For example, I could take fineweb and pre-train a GPT-2 sized model with the original dataset and a transformed version. Unclear so far which things I'd like to measure the most at that model size, though. A downstream experiment: is one model more likely to reward hack over the other? Does shard theory help us come up with useful experiments (pre-training with human feedback is almost like reinforcing behaviour and leveraging some form of shard theory)? Note that Google used a similar pre-training scheme for PaLM 2:

2. How can the "basin of attraction for alignment" be mathematically formalized?

3. Trying to the impact of systematic errors:

Studying reward misspecification: do the reward labels have a systematic effect and bias in pushing the model? How much of the model's behaviour is determined by the data itself vs. the reward model's misspecification? My current reading of the literature on this is a bit unclear. However, there's a paper saying: "We present a novel observation about the behaviour of offline reinforcement learning (RL) algorithms: on many benchmark datasets, offline RL can produce well-performing and safe policies even when trained with "wrong" reward labels, such as those that are zero everywhere or are negatives of the true rewards."

4. How do we design the training curriculum to significantly bias the model's pre-training close to the basin of attraction for alignment?

Studying some form of iterative training where we have a synthetically trained model vs a normally trained model and then measure things like model drift. For example, is the model more likely to drift (in an online setting) in ways we wouldn't want it to if it is pre-trained on normal text, but the process is more safely guided through synthetic pre-training?

5. Part of the alignment challenge (for example, the concern of scheming AIs) is that the order in which the model learns things might matter. For example, you'd want the model to internalize a solid world model of human values before it gains the situational awareness required to manipulate its training process (scheme). So, can we design a training curriculum for specific capabilities s.t. the model learns capabilities in an ideal sequence?

Data attribution project ideas

  1. How to make this approach work in tandem with unlearning?

2. Use data attribution methods to understand how specific data shapes model behaviour and use that information to reconstruct pre-training to shape model behaviour in the way we want. For example, can we side-step the need for unlearning? Can these data attribution methods augment unlearning to work better?

As Roger said in his comment, we can try to manage the dataset to prevent WMB-dangerous capabilities and things like self-replication. It's quite possible that unlearning will not be enough.

Another project would be to fine-tune on a dataset with and without the dangerous capabilities we don't want and use that as a benchmark for unlearning methods (and how easy it is to fine-tune the capability back into the model).

3. Including other methods beyond data attribution (e.g. SAEs) to measure model evolution through training.

4. Is it possible to better understand and predict emergence via data attribution?

5. Studying model generalization via data attribution (doing similar things to the influence functions paper, but through time). Though the most interesting behaviour may only come at scales I wouldn't have the compute for.

6. Would there be value in using an early checkpoint in training and then training on the synthetic data from that point forward? At which point in training does this make sense to do?

If you are interested in this kind of research, let me know! I'd love to brainstorm some potential projects and then apply for funding if there is something promising there.