-
Notifications
You must be signed in to change notification settings - Fork 531
/
hf_generate.py
400 lines (361 loc) · 13.3 KB
/
hf_generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
import itertools
import random
import time
import warnings
from argparse import ArgumentParser, ArgumentTypeError, Namespace
from contextlib import nullcontext
from typing import Union
import numpy as np
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from llmfoundry.utils import prompt_files as utils
def get_dtype(dtype: str):
if dtype == 'fp32':
return torch.float32
elif dtype == 'fp16':
return torch.float16
elif dtype == 'bf16':
return torch.bfloat16
else:
raise NotImplementedError(
f'dtype {dtype} is not supported. ' +\
f'We only support fp32, fp16, and bf16 currently')
def str2bool(v: Union[str, bool]):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise ArgumentTypeError('Boolean value expected.')
def str_or_bool(v: Union[str, bool]):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
return v
def parse_args() -> Namespace:
"""Parse commandline arguments."""
parser = ArgumentParser(
description='Load a HF CausalLM Model and use it to generate text.',
)
parser.add_argument('-n', '--name_or_path', type=str, required=True)
parser.add_argument(
'-p',
'--prompts',
nargs='+',
default=[
'My name is',
'This is an explanation of deep learning to a five year old. Deep learning is',
],
help='List of generation prompts or list of delimited files. Use syntax ' +\
'"file::/path/to/prompt.txt" to load a prompt(s) contained in a txt file.',
)
parser.add_argument(
'--prompt-delimiter',
default=None,
help=
'Prompt delimiter for txt files. By default, a file is a single prompt',
)
parser.add_argument('--max_seq_len', type=int, default=None)
parser.add_argument('--max_new_tokens', type=int, default=100)
parser.add_argument('--max_batch_size', type=int, default=None)
#####
# Note: Generation config defaults are set to match Hugging Face defaults
parser.add_argument('--temperature', type=float, nargs='+', default=[1.0])
parser.add_argument('--top_k', type=int, nargs='+', default=[50])
parser.add_argument('--top_p', type=float, nargs='+', default=[1.0])
parser.add_argument(
'--repetition_penalty',
type=float,
nargs='+',
default=[1.0],
)
parser.add_argument(
'--no_repeat_ngram_size',
type=int,
nargs='+',
default=[0],
)
#####
parser.add_argument('--seed', type=int, nargs='+', default=[42])
parser.add_argument(
'--do_sample',
type=str2bool,
nargs='?',
const=True,
default=True,
)
parser.add_argument(
'--use_cache',
type=str2bool,
nargs='?',
const=True,
default=True,
)
parser.add_argument('--eos_token_id', type=int, default=None)
parser.add_argument('--pad_token_id', type=int, default=None)
parser.add_argument(
'--model_dtype',
type=str,
choices=['fp32', 'fp16', 'bf16'],
default=None,
)
parser.add_argument(
'--autocast_dtype',
type=str,
choices=['fp32', 'fp16', 'bf16'],
default=None,
)
parser.add_argument(
'--warmup',
type=str2bool,
nargs='?',
const=True,
default=True,
)
parser.add_argument(
'--trust_remote_code',
type=str2bool,
nargs='?',
const=True,
default=True,
)
parser.add_argument(
'--use_auth_token',
type=str_or_bool,
nargs='?',
const=True,
default=None,
)
parser.add_argument('--revision', type=str, default=None)
parser.add_argument('--device', type=str, default=None)
parser.add_argument('--device_map', type=str, default=None)
parser.add_argument('--attn_impl', type=str, default=None)
return parser.parse_args()
def maybe_synchronize():
if torch.cuda.is_available():
torch.cuda.synchronize()
def main(args: Namespace) -> None:
# Set device or device_map
if args.device and args.device_map:
raise ValueError('You can only set one of `device` and `device_map`.')
if args.device is not None:
device = args.device
device_map = None
else:
device = None
device_map = args.device_map or 'auto'
print(f'Using {device=} and {device_map=}')
# Set model_dtype
if args.model_dtype is not None:
model_dtype = get_dtype(args.model_dtype)
else:
model_dtype = torch.float32
print(f'Using {model_dtype=}')
# Load prompts
prompt_strings = utils.load_prompts(args.prompts, args.prompt_delimiter)
# Grab config first
print(f'Loading HF Config...')
from_pretrained_kwargs = {
'use_auth_token': args.use_auth_token,
'trust_remote_code': args.trust_remote_code,
'revision': args.revision,
}
try:
config = AutoConfig.from_pretrained(
args.name_or_path,
**from_pretrained_kwargs,
)
if hasattr(config, 'init_device') and device is not None:
config.init_device = device
if args.attn_impl is not None and hasattr(config, 'attn_config'):
config.attn_config['attn_impl'] = args.attn_impl
if args.max_seq_len is not None and hasattr(config, 'max_seq_len'):
config.max_seq_len = args.max_seq_len
except Exception as e:
raise RuntimeError(
'If you are having auth problems, try logging in via `huggingface-cli login` ' +\
'or by setting the environment variable `export HF_TOKEN=... ' +\
'using your access token from https://huggingface.co/settings/tokens.',
) from e
# Build tokenizer
print('\nLoading HF tokenizer...')
tokenizer = AutoTokenizer.from_pretrained(
args.name_or_path,
**from_pretrained_kwargs,
)
if tokenizer.pad_token_id is None:
warnings.warn(
'pad_token_id is not set for the tokenizer. Using eos_token_id as pad_token_id.',
)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'
# Load HF Model
print(f'Loading HF model with dtype={model_dtype}...')
try:
model = AutoModelForCausalLM.from_pretrained(
args.name_or_path,
config=config,
torch_dtype=model_dtype,
device_map=device_map,
**from_pretrained_kwargs,
)
model.eval()
print(f'n_params={sum(p.numel() for p in model.parameters())}')
if device is not None:
print(f'Placing model on {device=}...')
model.to(device)
except Exception as e:
raise RuntimeError(
'Unable to load HF model. ' +
'If you are having auth problems, try logging in via `huggingface-cli login` '
+ 'or by setting the environment variable `export HF_TOKEN=... ' +
'using your access token from https://huggingface.co/settings/tokens.',
) from e
# Autocast
if args.autocast_dtype is not None:
autocast_dtype = get_dtype(args.autocast_dtype)
autocast_context = torch.autocast(model.device.type, autocast_dtype)
print(f'Using autocast with dtype={autocast_dtype}...')
else:
autocast_context = nullcontext()
print('NOT using autocast...')
done_warmup = False
for temp, topp, topk, repp, nrnz, seed in itertools.product(
args.temperature,
args.top_p,
args.top_k,
args.repetition_penalty,
args.no_repeat_ngram_size,
args.seed,
):
# Seed randomness
random.seed(seed)
torch.manual_seed(seed)
print(f'\nGenerate seed:\n{seed}')
generate_kwargs = {
'max_new_tokens': args.max_new_tokens,
'temperature': temp,
'top_p': topp,
'top_k': topk,
'repetition_penalty': repp,
'no_repeat_ngram_size': nrnz,
'use_cache': args.use_cache,
'do_sample': False if temp == 0 else args.do_sample,
'eos_token_id': args.eos_token_id or tokenizer.eos_token_id,
'pad_token_id': args.pad_token_id or tokenizer.pad_token_id,
}
print(f'\nGenerate kwargs:\n{generate_kwargs}')
# Generate function with correct context managers
def _generate(encoded_inp: dict[str, torch.Tensor]):
with torch.no_grad():
with autocast_context:
return model.generate(
input_ids=encoded_inp['input_ids'],
attention_mask=encoded_inp['attention_mask'],
**generate_kwargs,
)
# Split into prompt batches
batches = []
if args.max_batch_size:
bs = args.max_batch_size
batches = [
prompt_strings[i:i + bs]
for i in range(0, len(prompt_strings), bs)
]
else:
batches = [prompt_strings]
for batch in batches:
print(f'\nTokenizing prompts...')
maybe_synchronize()
encode_start = time.time()
encoded_inp = tokenizer(batch, return_tensors='pt', padding=True)
for key, value in encoded_inp.items():
encoded_inp[key] = value.to(model.device)
maybe_synchronize()
encode_end = time.time()
input_tokens = torch.sum(
encoded_inp['input_ids'] !=
tokenizer.pad_token_id, # type: ignore
axis=1,
).numpy(force=True)
# Warmup
if args.warmup and (not done_warmup):
print('Warming up...')
_ = _generate(encoded_inp)
done_warmup = True
# Run HF generate
print('Generating responses...')
maybe_synchronize()
gen_start = time.time()
encoded_gen = _generate(encoded_inp)
maybe_synchronize()
gen_end = time.time()
decode_start = time.time()
decoded_gen = tokenizer.batch_decode(
encoded_gen,
skip_special_tokens=True,
)
maybe_synchronize()
decode_end = time.time()
gen_tokens = torch.sum(
encoded_gen != tokenizer.pad_token_id,
axis=1,
).numpy(force=True) # type: ignore
# Print generations
delimiter = '#' * 100
# decode the encoded prompt to handle the case when the tokenizer
# trims extra spaces or does other pre-tokenization things
effective_prompts = tokenizer.batch_decode(
encoded_inp['input_ids'],
skip_special_tokens=True,
)
for idx, (effective_prompt, prompt, gen) in enumerate(
zip(effective_prompts, batch, decoded_gen),
):
continuation = gen[len(effective_prompt):]
print(delimiter)
if len(continuation) > 0:
print('\033[92m' + prompt + '\033[0m' + continuation)
else:
print('Warning. No non-special output tokens generated.')
print(
'This can happen if the generation only contains padding/eos tokens.',
)
print('Debug:')
full_generation = tokenizer.batch_decode(
encoded_gen,
skip_special_tokens=False,
)[idx]
print('\033[92m' + 'Prompt:\n' + prompt + '\033[0m')
print('Full generation:\n' + full_generation)
print(delimiter)
# Print timing info
bs = len(batch)
# ensure that gen_tokens >= 1 in case model only generated padding tokens
gen_tokens = np.maximum(gen_tokens, np.ones_like(gen_tokens))
output_tokens = gen_tokens - input_tokens
total_input_tokens = input_tokens.sum()
total_output_tokens = output_tokens.sum()
encode_latency = 1000 * (encode_end - encode_start)
gen_latency = 1000 * (gen_end - gen_start)
decode_latency = 1000 * (decode_end - decode_start)
total_latency = encode_latency + gen_latency + decode_latency
latency_per_output_token = total_latency / total_output_tokens
output_tok_per_sec = 1000 / latency_per_output_token
print(f'{bs=}, {input_tokens=}, {output_tokens=}')
print(f'{total_input_tokens=}, {total_output_tokens=}')
print(
f'{encode_latency=:.2f}ms, {gen_latency=:.2f}ms, {decode_latency=:.2f}ms, {total_latency=:.2f}ms',
)
print(f'{latency_per_output_token=:.2f}ms/tok')
print(f'{output_tok_per_sec=:.2f}tok/sec')
if __name__ == '__main__':
main(parse_args())