Skip to content

Commit

Permalink
Add support for auxiliary dataset generation
Browse files Browse the repository at this point in the history
This adds support for generating auxiliary datasets during knowledge
data generation. An auxiliary dataset is where we ask the model to
generate some additional data samples with a different prompt than the
standard dataset, along with some extra instruction prompts that will
get matched to the auxiliary generated samples and used during
training.

The auxiliary instructions are a new part of the pipeline config, as
they are tightly coupled to the pipeline config. An example, where
you'll note the `spellcheck` value from the pipeline config has to match
across both the pipeline config and the new auxiliary instructions, so
we just list both in the same config file.

version: "1.0"
blocks:
...
  - name: flatten_auxiliary_columns
    type: FlattenColumnsBlock
    config:
      var_cols:
        - spellcheck
        - base_document
      value_name: corrected_document
      var_name: dataset_type
...
datamixing:
  auxiliary_instructions:
    spellcheck:
      - Correct any spelling errors in the document and output the corrected version.
      - Rewrite the document to remove any spelling errors.

Parts of this are extracted and rebased from
aakankshaduggal#4
aakankshaduggal#21

Refs instructlab#162.

Co-authored-by: shivchander <shivchander.s30@gmail.com>
Co-authored-by: Khaled Sulayman <khaled@thesulaymans.com>
Co-authored-by: abhi1092 <abhi1092@gmail.com>
Co-authored-by: Aakanksha Duggal <aduggal@redhat.com>
Co-authored-by: Mark McLoughlin <markmc@redhat.com>
Signed-off-by: Ben Browning <bbrownin@redhat.com>
  • Loading branch information
6 people committed Jul 29, 2024
1 parent ca30d98 commit 4ccdc30
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 13 deletions.
17 changes: 17 additions & 0 deletions src/instructlab/sdg/configs/knowledge/spellcheck.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
system: You are an AI assistant that is an expert at fixing spelling errors in documents.

introduction: |
Give me a copy of the below document with all spelling errors corrected.
principles: |
Do not add any new information.
Do not leave out any information.
examples: ""

generation: |
Document:
{document}
start_tags: [""]
end_tags: [""]
104 changes: 96 additions & 8 deletions src/instructlab/sdg/datamixing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Standard
from typing import Optional
from typing import Dict, List, Optional
import json
import logging
import os.path
Expand All @@ -12,6 +12,7 @@

# First Party
from instructlab.sdg.utils import GenerateException, pandas
from instructlab.sdg.utils.pandas import dataset_from_pandas_dataframe

ALLOWED_COLS = ["id", "messages", "metadata"]
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -374,7 +375,68 @@ def _conv_pretrain(rec):
return rec


def _create_phase10_ds(generated_dataset: Dataset):
def _create_auxiliary_dataset(
generated_dataset: Dataset, auxiliary_inst: Optional[Dict[str, List[str]]]
):
# Samples that went through the auxiliary generation pipeline will
# have a dataset_type column created by that pipeline. If that's
# not present, then we may be running in a pipeline without any
# auxiliary dataset generation enabled.
if "dataset_type" not in generated_dataset.column_names:
return None
# If we didn't find any auxiliary instructions to load, then
# that's also another sign that we're not running with any
# auxiliary datasets enabled.
if auxiliary_inst is None:
return None
# This "base_document" dataset_type is set in the knowledge
# pipeline config, and represents samples that do not have the
# auxiliary generated document attached, so we filter those out.
auxiliary_ds = generated_dataset.filter(
lambda x: x["dataset_type"] != "base_document"
)
unique_document_auxiliary = auxiliary_ds.to_pandas().drop_duplicates(
subset=["document"]
)
unique_document_auxiliary = dataset_from_pandas_dataframe(unique_document_auxiliary)
unique_document_auxiliary = unique_document_auxiliary.select_columns(
[
"raw_document",
"document_outline",
"domain",
"dataset_type",
"document",
]
)
unique_document_auxiliary = unique_document_auxiliary.rename_columns(
{"raw_document": "context", "document": "response"}
)

def __create_auxiliary_ds(rec):
instruction = random.choice(auxiliary_inst[rec["dataset_type"]])
messages = [
{"role": "user", "content": f"{rec['context']}\n\n{instruction}"},
{"role": "assistant", "content": rec["response"]},
]
metadata = json.dumps(
{
"dataset_type": rec["dataset_type"],
"raw_document": rec["context"],
"dataset": f"document_{rec['dataset_type']}",
"domain": rec["domain"],
}
)
return {"messages": messages, "metadata": metadata, "id": str(uuid.uuid4())}

unique_document_auxiliary = unique_document_auxiliary.map(
__create_auxiliary_ds, remove_columns=unique_document_auxiliary.column_names
)
return unique_document_auxiliary


def _create_phase10_ds(
generated_dataset: Dataset, auxiliary_inst: Optional[Dict[str, List[str]]]
):
"""
Create a dataset for Phase 1.0 of downstream training.
Expand All @@ -387,10 +449,17 @@ def _create_phase10_ds(generated_dataset: Dataset):
)
knowledge_ds = _add_extra_contexts_to_samples(knowledge_ds, p=0.4)

return knowledge_ds
auxiliary_dataset = _create_auxiliary_dataset(generated_dataset, auxiliary_inst)
if auxiliary_dataset is not None:
phase10 = concatenate_datasets([knowledge_ds, auxiliary_dataset])
else:
phase10 = knowledge_ds
return phase10


def _create_phase07_ds(generated_dataset: Dataset):
def _create_phase07_ds(
generated_dataset: Dataset, auxiliary_inst: Optional[Dict[str, List[str]]]
):
"""
Create a dataset for Phase 0.7 of downstream training.
Expand All @@ -404,7 +473,13 @@ def _create_phase07_ds(generated_dataset: Dataset):
)
knowledge_ds = knowledge_ds.map(_conv_pretrain)

return knowledge_ds
auxiliary_dataset = _create_auxiliary_dataset(generated_dataset, auxiliary_inst)
if auxiliary_dataset is not None:
auxiliary_dataset = auxiliary_dataset.map(_conv_pretrain)
phase07 = concatenate_datasets([knowledge_ds, auxiliary_dataset])
else:
phase07 = knowledge_ds
return phase07


def _convert_to_leaf_node_messages(sample: dict, sys_prompt: str):
Expand Down Expand Up @@ -440,12 +515,21 @@ class DataMixer:
# once.
NUM_SYNTH_SKILLS = 30

def __init__(self, data_dirs, output_dir, date_suffix, sys_prompt, num_procs):
def __init__(
self,
data_dirs,
output_dir,
date_suffix,
sys_prompt,
num_procs,
auxiliary_inst=None,
):
self.data_dirs = data_dirs
self.output_dir = output_dir
self.sys_prompt = sys_prompt
self.date_suffix = date_suffix
self.num_procs = num_procs
self.auxiliary_inst = auxiliary_inst

self.knowledge_recipe = self._load_default_recipe("knowledge.yaml")
self.skills_recipe = self._load_default_recipe("skills.yaml")
Expand Down Expand Up @@ -482,7 +566,9 @@ def _gen_leaf_node_data(

def collect(self, leaf_node_path, new_generated_data, is_knowledge):
if is_knowledge:
knowledge_phase_data = _create_phase07_ds(new_generated_data)
knowledge_phase_data = _create_phase07_ds(
new_generated_data, self.auxiliary_inst
)
output_file_leaf_knowledge = (
f"node_datasets_{self.date_suffix}/{leaf_node_path}_p07.jsonl"
)
Expand All @@ -492,7 +578,9 @@ def collect(self, leaf_node_path, new_generated_data, is_knowledge):
output_file_leaf_knowledge,
)

skills_phase_data = _create_phase10_ds(new_generated_data)
skills_phase_data = _create_phase10_ds(
new_generated_data, self.auxiliary_inst
)
output_file_leaf_skills = (
f"node_datasets_{self.date_suffix}/{leaf_node_path}_p10.jsonl"
)
Expand Down
8 changes: 6 additions & 2 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def load_pipeline(yaml_basename):
)


def _mixer_init(ctx, output_dir, date_suffix):
def _mixer_init(ctx, output_dir, date_suffix, knowledge_auxiliary_inst):
pd = platformdirs.PlatformDirs(
appname=os.path.join("instructlab", "sdg"), multipath=True
)
Expand All @@ -258,6 +258,7 @@ def _mixer_init(ctx, output_dir, date_suffix):
date_suffix,
_SYS_PROMPT,
ctx.dataset_num_procs,
knowledge_auxiliary_inst,
)


Expand Down Expand Up @@ -367,7 +368,10 @@ def generate_data(
mmlu_ctx = dataclasses.replace(ctx, checkpoint_dir=None)
mmlu_bench_pipe = mmlubench_pipe_init(mmlu_ctx)

mixer = _mixer_init(ctx, output_dir, date_suffix)
# FIXME: remove SDG https://github.com/instructlab/sdg/pull/64
mixer = _mixer_init(
ctx, output_dir, date_suffix, sdg_knowledge.pipelines[0].auxiliary_inst
)

if console_output:
logger.info(
Expand Down
13 changes: 10 additions & 3 deletions src/instructlab/sdg/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from importlib import resources
from typing import Iterable, Optional
from typing import Dict, Iterable, List, Optional
import logging
import math
import os.path
Expand Down Expand Up @@ -109,6 +109,7 @@ def __init__(
ctx: PipelineContext,
config_path: str,
chained_blocks: list[dict],
auxiliary_inst: Optional[Dict[str, List[str]]] = None,
) -> None:
"""
Initialize the Pipeline class with a configuration dictionary.
Expand All @@ -120,12 +121,14 @@ def __init__(
self.config_path = config_path
# pipeline config is the run configuration that consists of the pipeline steps
self.chained_blocks = chained_blocks
# datamixing instructions for auxiliary data generated by this pipeline
self.auxiliary_inst = auxiliary_inst

@classmethod
def from_file(cls, ctx, pipeline_yaml):
if not os.path.isabs(pipeline_yaml):
pipeline_yaml = os.path.join(resources.files(__package__), pipeline_yaml)
return cls(ctx, pipeline_yaml, _parse_pipeline_config_file(pipeline_yaml))
return cls(ctx, pipeline_yaml, *_parse_pipeline_config_file(pipeline_yaml))

def generate(self, dataset) -> Dataset:
"""
Expand Down Expand Up @@ -296,7 +299,11 @@ def _parse_pipeline_config_file(pipeline_yaml):
"The pipeline config file contains no 'blocks' section"
)

return content["blocks"]
auxiliary_inst = None
if "datamixing" in content and "auxiliary_instructions" in content["datamixing"]:
auxiliary_inst = content["datamixing"]["auxiliary_instructions"]

return content["blocks"], auxiliary_inst


# This is part of the public API.
Expand Down
37 changes: 37 additions & 0 deletions src/instructlab/sdg/pipelines/full/knowledge.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,36 @@
version: "1.0"
blocks:
- name: duplicate_document_col
type: DuplicateColumnsBlock
config:
columns_map:
document: base_document

- name: gen_spellcheck
type: LLMBlock
config:
config_path: ../../configs/knowledge/spellcheck.yaml
output_cols:
- spellcheck
gen_kwargs:
max_tokens: 2048

- name: flatten_auxiliary_columns
type: FlattenColumnsBlock
config:
var_cols:
- spellcheck
- base_document
value_name: corrected_document
var_name: dataset_type

- name: rename_to_document_column
type: RenameColumnsBlock
config:
columns_map:
document: raw_document
corrected_document: document

- name: gen_knowledge
type: LLMBlock
config:
Expand Down Expand Up @@ -73,3 +104,9 @@ blocks:
- explanation
- rating
- __index_level_0__

datamixing:
auxiliary_instructions:
spellcheck:
- Correct any spelling errors in the document and output the corrected version.
- Rewrite the document to remove any spelling errors.
17 changes: 17 additions & 0 deletions src/instructlab/sdg/pipelines/schema/v1.json
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,23 @@
}
}
}
},
"datamixing": {
"type": "object",
"additionalProperties": false,
"properties": {
"auxiliary_instructions": {
"type": "object",
"patternProperties": {
".*": {
"type": "array",
"items": {
"type": "string"
}
}
}
}
}
}
}
}
6 changes: 6 additions & 0 deletions tests/test_default_pipeline_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from instructlab.sdg.pipeline import Pipeline, PipelineContext
from instructlab.sdg.utilblocks import (
CombineColumnsBlock,
DuplicateColumnsBlock,
FlattenColumnsBlock,
RenameColumnsBlock,
SamplePopulatorBlock,
SelectorBlock,
)
Expand All @@ -23,8 +26,11 @@ def _noop_generate(self, samples):

@patch.object(CombineColumnsBlock, "generate", _noop_generate)
@patch.object(ConditionalLLMBlock, "generate", _noop_generate)
@patch.object(DuplicateColumnsBlock, "generate", _noop_generate)
@patch.object(FilterByValueBlock, "generate", _noop_generate)
@patch.object(FlattenColumnsBlock, "generate", _noop_generate)
@patch.object(LLMBlock, "generate", _noop_generate)
@patch.object(RenameColumnsBlock, "generate", _noop_generate)
@patch.object(SamplePopulatorBlock, "generate", _noop_generate)
@patch.object(SelectorBlock, "generate", _noop_generate)
@patch("instructlab.sdg.llmblock.server_supports_batched", lambda c, m: True)
Expand Down

0 comments on commit 4ccdc30

Please sign in to comment.