Skip to content

Commit

Permalink
llama2 70b chat accelerate example (#2494)
Browse files Browse the repository at this point in the history
* llam2 accelerate example

* add readme

* fmt

* fixing the padding and prompt

* update steps

* Updated readme with more details

* changed to inheriting from basehandler

* add model_path

* change to int8

* add download cmd

* update download path

* minor edit for model_path

---------

Co-authored-by: Geeta Chauhan <4461127+chauhang@users.noreply.github.com>
Co-authored-by: Mark Saroufim <marksaroufim@fb.com>
Co-authored-by: Ankith Gunapal <agunapal@ischool.Berkeley.edu>
Co-authored-by: Hamid Shojanazeri <hamid.nazeri2010@gmail.com>
  • Loading branch information
5 people committed Aug 28, 2023
1 parent 683608b commit 04e0b37
Show file tree
Hide file tree
Showing 6 changed files with 224 additions and 0 deletions.
60 changes: 60 additions & 0 deletions examples/large_models/Huggingface_accelerate/llama2/Readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Loading meta-llama/Llama-2-70b-chat-hf on AWS EC2 g5.24xlarge using accelerate

This document briefs on serving large HG models with limited resource using accelerate. This option can be activated with `low_cpu_mem_usage=True`. The model is first created on the Meta device (with empty weights) and the state dict is then loaded inside it (shard by shard in the case of a sharded checkpoint).

### Step 1: Download model Permission

Follow [this instruction](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) to get permission

Login with a Hugging Face account
```
huggingface-cli login
# or using an environment variable
huggingface-cli login --token $HUGGINGFACE_TOKEN
```

```bash
python ../Download_model.py --model_path model --model_name meta-llama/Llama-2-70b-chat-hf
```
Model will be saved in the following path, `model/models--meta-llama--Llama-2-70b-chat-hf`.

### Step 2: Generate MAR file

Add the downloaded path to " model_path:" in `model-config.yaml` and run the following.

```bash
torch-model-archiver --model-name llama2-70b-chat --version 1.0 --handler custom_handler.py --config-file model-config.yaml -r requirements.txt --archive-format no-archive
```

If you are using conda, and notice issues with mpi4py, you would need to install openmpi-mpicc using the following

```
conda install -c conda-forge openmpi-mpicc
```

### Step 3: Add the mar file to model store

```bash
mkdir model_store
mv llama2-70b-chat model_store
mv model model_store/llama2-70b-chat
```

### Step 3: Start torchserve

Update config.properties and start torchserve

```bash
torchserve --start --ncs --ts-config config.properties --model-store model_store --models llama2-70b-chat
```

### Step 4: Run inference

```bash
curl -v "http://localhost:8080/predictions/llama2-70b-chat" -T sample_text.txt
```

results in the following output
```
Mayonnaise is a thick, creamy condiment made from a mixture of egg yolks, oil, vinegar or lemon juice, and seasonings'
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
inference_address=http://0.0.0.0:8080
management_address=http://0.0.0.0:8081
metrics_address=http://0.0.0.0:8082
enable_envvars_config=true
install_py_dep_per_model=true

139 changes: 139 additions & 0 deletions examples/large_models/Huggingface_accelerate/llama2/custom_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import logging
from abc import ABC

import torch
import transformers
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from accelerate import init_empty_weights
from accelerate import load_checkpoint_and_dispatch

from ts.context import Context
from ts.torch_handler.base_handler import BaseHandler

logger = logging.getLogger(__name__)
logger.info("Transformers version %s", transformers.__version__)


class LlamaHandler(BaseHandler, ABC):
"""
Transformers handler class for sequence, token classification and question answering.
"""

def __init__(self):
super(LlamaHandler, self).__init__()
self.max_length = None
self.max_new_tokens = None
self.tokenizer = None
self.initialized = False

def initialize(self, ctx: Context):
"""In this initialize function, the HF large model is loaded and
partitioned using DeepSpeed.
Args:
ctx (context): It is a JSON Object containing information
pertaining to the model artifacts parameters.
"""
model_dir = ctx.system_properties.get("model_dir")
self.max_length = int(ctx.model_yaml_config["handler"]["max_length"])
self.max_new_tokens = int(ctx.model_yaml_config["handler"]["max_new_tokens"])
model_name = ctx.model_yaml_config["handler"]["model_name"]
model_path = f'{model_dir}/{ctx.model_yaml_config["handler"]["model_path"]}'
seed = int(ctx.model_yaml_config["handler"]["manual_seed"])
torch.manual_seed(seed)

logger.info("Model %s loading tokenizer", ctx.model_name)
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="balanced",
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
load_in_8bit=True,
trust_remote_code=True)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenizer.add_special_tokens(
{

"pad_token": "<PAD>",
}
)
self.model.resize_token_embeddings(self.model.config.vocab_size + 1)

logger.info("Model %s loaded successfully", ctx.model_name)
self.initialized = True

def preprocess(self, requests):
"""
Basic text preprocessing, based on the user's choice of application mode.
Args:
requests (list): A list of dictionaries with a "data" or "body" field, each
containing the input text to be processed.
Returns:
tuple: A tuple with two tensors: the batch of input ids and the batch of
attention masks.
"""
input_texts = [data.get("data") or data.get("body") for data in requests]
input_ids_batch, attention_mask_batch = [], []
for input_text in input_texts:
input_ids, attention_mask = self.encode_input_text(input_text)
input_ids_batch.append(input_ids)
attention_mask_batch.append(attention_mask)
input_ids_batch = torch.cat(input_ids_batch, dim=0).to(self.model.device)
attention_mask_batch = torch.cat(attention_mask_batch, dim=0).to(self.device)
return input_ids_batch, attention_mask_batch

def encode_input_text(self, input_text):
"""
Encodes a single input text using the tokenizer.
Args:
input_text (str): The input text to be encoded.
Returns:
tuple: A tuple with two tensors: the encoded input ids and the attention mask.
"""
if isinstance(input_text, (bytes, bytearray)):
input_text = input_text.decode("utf-8")
logger.info("Received text: '%s'", input_text)
inputs = self.tokenizer.encode_plus(
input_text,
max_length=self.max_length,
padding=True,
add_special_tokens=True,
return_tensors="pt",
truncation=True,
)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
return input_ids, attention_mask

def inference(self, input_batch):
"""
Predicts the class (or classes) of the received text using the serialized transformers
checkpoint.
Args:
input_batch (tuple): A tuple with two tensors: the batch of input ids and the batch
of attention masks, as returned by the preprocess function.
Returns:
list: A list of strings with the predicted values for each input text in the batch.
"""
input_ids_batch, attention_mask_batch = input_batch
input_ids_batch = input_ids_batch.to(self.device)
outputs = self.model.generate(
input_ids_batch,
attention_mask=attention_mask_batch,
max_length=self.max_new_tokens,
)

inferences = self.tokenizer.batch_decode(
outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
)

logger.info("Generated text: %s", inferences)
return inferences

def postprocess(self, inference_output):
"""Post Process Function converts the predicted response into Torchserve readable format.
Args:
inference_output (list): It contains the predicted response of the input text.
Returns:
(list): Returns a list of the Predictions and Explanations.
"""
return inference_output
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# TorchServe frontend parameters
minWorkers: 1
maxWorkers: 1
maxBatchDelay: 100
responseTimeout: 1200
deviceType: "gpu"

handler:
model_name: "meta-llama/Llama-2-70b-chat-hf"
model_path: "model/models--meta-llama--Llama-2-70b-chat-hf/snapshots/9ff8b00464fc439a64bb374769dec3dd627be1c2"
max_length: 50
max_new_tokens: 50
manual_seed: 40
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
transformers==4.31.0
accelerate
bitsandbytes
scipy
mpi4py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
what is the recipe of mayonnaise?

0 comments on commit 04e0b37

Please sign in to comment.