Skip to content

Commit

Permalink
llama-cpp multi server support
Browse files Browse the repository at this point in the history
llama-cpp does not support batching, concurrent completions requests, or really anything to speed our processes up.

The only clear solution here is to create our own form of paralellism by supporting running multiple servers at once.

via a `--num-servers` flag from the cli, a user can spin up 2,3, or even 4 of the `mistral 7b instruct` models since they only take about 5GB of RAM.

This allows us to split our dataset into batches like we do with vllm and execute threads running each batch in parallel. Each server handles its own batch

Signed-off-by: Charlie Doern <[email protected]>
  • Loading branch information
cdoern committed Oct 22, 2024
1 parent 067e4c1 commit b8f614a
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 42 deletions.
223 changes: 189 additions & 34 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,30 @@
# SPDX-License-Identifier: Apache-2.0

# Standard
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from importlib import resources
from pathlib import Path
from typing import Optional
from threading import Lock

Check warning on line 8 in src/instructlab/sdg/generate_data.py

View workflow job for this annotation

GitHub Actions / pylint

W0611: Unused Lock imported from threading (unused-import)
from typing import List, Optional

Check warning on line 9 in src/instructlab/sdg/generate_data.py

View workflow job for this annotation

GitHub Actions / pylint

W0611: Unused List imported from typing (unused-import)
import dataclasses
import json
import logging
import math
import os
import threading

Check warning on line 15 in src/instructlab/sdg/generate_data.py

View workflow job for this annotation

GitHub Actions / pylint

W0611: Unused import threading (unused-import)
import time

# Third Party
# instructlab - All of these need to go away (other than sdg) - issue #6
from datasets import Dataset
from datasets import Dataset, concatenate_datasets
from tqdm import tqdm

Check warning on line 21 in src/instructlab/sdg/generate_data.py

View workflow job for this annotation

GitHub Actions / pylint

W0611: Unused tqdm imported from tqdm (unused-import)
from xdg_base_dirs import xdg_data_dirs, xdg_data_home
import openai

# First Party
from instructlab.sdg.checkpointing import Checkpointer

# pylint: disable=ungrouped-imports
from instructlab.sdg.datamixing import DataMixer, _get_question_hack, _get_response_hack
from instructlab.sdg.eval_data import generate_eval_task_data, mmlubench_pipe_init
Expand Down Expand Up @@ -262,7 +269,7 @@ def _mixer_init(ctx, output_dir, date_suffix, knowledge_auxiliary_inst):
# TODO - parameter removal needs to be done in sync with a CLI change.
# to be removed: logger, prompt_file_path, rouge_threshold, tls_*
def generate_data(
client: openai.OpenAI,
clients: list[openai.OpenAI],
logger: logging.Logger = logger, # pylint: disable=redefined-outer-name
model_family: Optional[str] = None,
model_name: Optional[str] = None,
Expand Down Expand Up @@ -332,7 +339,7 @@ def generate_data(
model_family = MODEL_FAMILY_MERLINITE

ctx = _context_init(
client,
clients[0],
model_family,
model_name,
num_instructions_to_generate,
Expand All @@ -359,60 +366,79 @@ def generate_data(

generated_data = None
empty_sdg_leaf_nodes = []

for leaf_node in leaf_nodes.values():
is_knowledge = False
leaf_node_path = leaf_node[0]["taxonomy_path"].replace("->", "_")
samples = leaf_node_to_samples(leaf_node, server_ctx_size, chunk_word_count)

if not samples:
raise GenerateException("Error: No samples found in leaf node.")

is_knowledge = False
is_g_skill = False
is_f_skill = False
if samples[0].get("document"):
pipe = knowledge_pipe
is_knowledge = True

elif samples[0].get("seed_context"):
is_g_skill = True
pipe = grounded_skills_pipe

else:
is_f_skill = True
pipe = freeform_skills_pipe

logger.debug("Samples: %s", samples)

ds = Dataset.from_list(samples)
logger.debug("Dataset: %s", ds)
new_generated_data = pipe.generate(ds, leaf_node_path)
if len(new_generated_data) == 0:
empty_sdg_leaf_nodes.append(leaf_node_path)
logger.warning("Empty dataset for qna node: %s", leaf_node_path)
continue
generated_data = (
[new_generated_data]
if generated_data is None
else generated_data + [new_generated_data]
)
logger.info("Generated %d samples", len(generated_data))
logger.debug("Generated data: %s", generated_data)

if is_knowledge:
# generate mmlubench data for the current leaf node
generate_eval_task_data(
mmlu_bench_pipe,
leaf_node_path,
# if we have multiple servers using llama we need to
# 1. split ds into batches per server
# 2. execute a thread running each batch in its own pipeline
# 3. add the data back together
# 4. return that data.
if len(clients) > 1:
new_generated_data = generate_on_multiple_servers(
ds,
clients,
checkpoint_dir,
model_family,
model_name,
num_instructions_to_generate,
output_dir,
date_suffix,
is_knowledge,
is_g_skill,
is_f_skill,
pipeline,
leaf_node_path,
num_cpus,
)
else:
new_generated_data = pipe.generate(ds, leaf_node_path)
if len(new_generated_data) == 0:
empty_sdg_leaf_nodes.append(leaf_node_path)
logger.warning("Empty dataset for qna node: %s", leaf_node_path)
continue
generated_data = (
[new_generated_data]
if generated_data is None
else generated_data + [new_generated_data]
)
logger.info("Generated %d samples", len(generated_data))
logger.debug("Generated data: %s", generated_data)

if is_knowledge:
# generate mmlubench data for the current leaf node
generate_eval_task_data(
mmlu_bench_pipe,
leaf_node_path,
ds,
output_dir,
date_suffix,
)

mixer.collect(leaf_node_path, new_generated_data, is_knowledge)

if generated_data is None:
generated_data = []

_gen_train_data(
generated_data,
os.path.join(output_dir, output_file_train),
os.path.join(output_dir, output_file_messages),
)
mixer.collect(leaf_node_path, new_generated_data, is_knowledge)

mixer.generate()

Expand All @@ -424,3 +450,132 @@ def generate_data(
" ".join(empty_sdg_leaf_nodes)
)
)


def process_llama_server_batch(
ds,
client,
model_family,
model_name,
num_instructions_to_generate,
checkpoint_dir,
batch_size,
num_cpus,
output_dir,
date_suffix,
is_knowledge,
is_g_skill,
is_f_skill,
pipeline,
leaf_node_path,
thread,
):
logger.info(f"Running on client {client.base_url} {model_name} ")
ctx = _context_init(
client,
model_family,
model_name,
num_instructions_to_generate,
checkpoint_dir,
1, # save_freq
batch_size=batch_size,
batch_num_workers=num_cpus,
)

knowledge_pipe, freeform_skills_pipe, grounded_skills_pipe = _sdg_init(
ctx, pipeline
)
mmlu_ctx = dataclasses.replace(ctx, checkpoint_dir=None)
mmlu_bench_pipe = mmlubench_pipe_init(mmlu_ctx)
logger.debug("Dataset: %s", ds)
pipe = None
if is_knowledge:
pipe = knowledge_pipe
elif is_g_skill:
pipe = grounded_skills_pipe
elif is_f_skill:
pipe = freeform_skills_pipe

new_data = pipe.generate(ds, leaf_node_path, thread)

if is_knowledge:
generate_eval_task_data(
mmlu_bench_pipe,
leaf_node_path,
ds,
output_dir,
date_suffix,
)
return new_data


def generate_on_multiple_servers(
ds,
clients,
checkpoint_dir,
model_family,
model_name,
num_instructions_to_generate,
output_dir,
date_suffix,
is_knowledge,
is_g_skill,
is_f_skill,
pipeline,
leaf_node_path,
num_cpus,
):
# num batches == num clients
total_size = len(ds)

batch_size = math.ceil(total_size / len(clients))

# Create a batch for each client using ds.select() and indices
batches = [
ds.select(range(i * batch_size, min((i + 1) * batch_size, total_size)))
for i in range(len(clients))
]
# batches will be the same as the number of clients.

checkpointer = Checkpointer(checkpoint_dir, 1)
output_splits = []

logger.debug(f" batches {len(batches)}, clients {len(clients)}, {clients}")
# Using ThreadPoolExecutor to process each (ds, client) pair in parallel
with ThreadPoolExecutor() as executor:
futures = [
executor.submit(
process_llama_server_batch,
ds,
client,
model_family,
model_name,
num_instructions_to_generate,
checkpoint_dir,
0,
num_cpus / len(clients),
output_dir,
date_suffix,
is_knowledge,
is_g_skill,
is_f_skill,
pipeline,
leaf_node_path,
thread,
)
for thread, (ds, client) in enumerate(zip(batches, clients))
]
for i, future in enumerate(futures):
if future.running():
logger.debug(f"Thread {i} is running")
elif future.done():
logger.debug(f"Thread {i} has completed")
elif future.cancelled():
logger.debug(f"Thread {i} was canceled")
for future in futures:
new_data = future.result()
output_splits.append(new_data)
checkpointer.checkpoint(new_data)
concatenate_datasets(output_splits)
logger.debug("Dataset: %s", output_splits)
return output_splits
11 changes: 7 additions & 4 deletions src/instructlab/sdg/llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# Standard
from collections import ChainMap
from threading import Lock

Check warning on line 5 in src/instructlab/sdg/llmblock.py

View workflow job for this annotation

GitHub Actions / pylint

W0611: Unused Lock imported from threading (unused-import)
from typing import Any, Dict
import logging
import re
Expand Down Expand Up @@ -156,7 +157,7 @@ def _gen_kwargs(self, gen_kwargs, **defaults):
gen_kwargs["temperature"] = float(gen_kwargs["temperature"])
return gen_kwargs

def _generate(self, samples) -> list:
def _generate(self, samples, thread) -> list:
prompts = [self._format_prompt(sample) for sample in samples]
logger.debug(f"STARTING GENERATION FOR LLMBlock USING PROMPTS: {prompts}")
if self.server_supports_batched:
Expand All @@ -167,7 +168,9 @@ def _generate(self, samples) -> list:

results = []
progress_bar = tqdm(
range(len(prompts)), desc=f"{self.block_name} Prompt Generation"
range(len(prompts)),
desc=f"{self.block_name} Prompt Generation Thread {thread}",
position=thread,
)
for prompt in prompts:
logger.debug(f"CREATING COMPLETION FOR PROMPT: {prompt}")
Expand All @@ -179,7 +182,7 @@ def _generate(self, samples) -> list:
progress_bar.update(1)
return results

def generate(self, samples: Dataset) -> Dataset:
def generate(self, samples: Dataset, thread: int) -> Dataset:
"""
Generate the output from the block. This method should first validate the input data,
then generate the output, and finally parse the generated output before returning it.
Expand Down Expand Up @@ -211,7 +214,7 @@ def generate(self, samples: Dataset) -> Dataset:

# generate the output

outputs = self._generate(samples)
outputs = self._generate(samples, thread)
logger.debug("Generated outputs: %s", outputs)

num_parallel_samples = self.gen_kwargs.get("n", 1)
Expand Down
12 changes: 8 additions & 4 deletions src/instructlab/sdg/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from importlib import resources
from threading import Lock

Check warning on line 7 in src/instructlab/sdg/pipeline.py

View workflow job for this annotation

GitHub Actions / pylint

W0611: Unused Lock imported from threading (unused-import)
from typing import Dict, Iterable, List, Optional
import logging
import math
Expand Down Expand Up @@ -131,7 +132,7 @@ def from_file(cls, ctx, pipeline_yaml):
pipeline_yaml = os.path.join(resources.files(__package__), pipeline_yaml)
return cls(ctx, pipeline_yaml, *_parse_pipeline_config_file(pipeline_yaml))

def generate(self, dataset, checkpoint_name=None) -> Dataset:
def generate(self, dataset, checkpoint_name=None, thread: int = 0) -> Dataset:
"""
Generate the dataset by running the pipeline steps.
dataset: the input dataset
Expand All @@ -151,7 +152,7 @@ def generate(self, dataset, checkpoint_name=None) -> Dataset:
# If not batching, simply delegate to _generate_single
if not self.ctx.batching_enabled:
logger.info("Running pipeline single-threaded")
return self._generate_single(dataset)
return self._generate_single(dataset, thread)

# Otherwise, split the dataset into batches and run each batch as a
# future in the thread pool
Expand Down Expand Up @@ -181,7 +182,7 @@ def generate(self, dataset, checkpoint_name=None) -> Dataset:

## Implementation Details ##

def _generate_single(self, dataset) -> Dataset:
def _generate_single(self, dataset, thread: int = None) -> Dataset:
"""Generate a single dataset by running the pipeline steps."""
for block_prop in self.chained_blocks:
# Initialize arguments for error handling to None
Expand All @@ -198,7 +199,10 @@ def _generate_single(self, dataset) -> Dataset:
logger.info(dataset)

# Execute the block and wrap errors with the block name/type
dataset = block.generate(dataset)
if block_type == llmblock.LLMBlock:
dataset = block.generate(dataset, thread)
else:
dataset = block.generate(dataset)
except Exception as err:
raise PipelineBlockError(
exception=err,
Expand Down

0 comments on commit b8f614a

Please sign in to comment.