Gemma

root=pretrained/gemma
mkdir -p ${root} && cd ${root}
git clone git@hf.co:google/gemma-1.1-2b-it
git clone git@hf.co:google/gemma-1.1-7b-it
cd ../..
pretrained/gemma/
├── gemma-1.1-2b-it
└── gemma-1.1-7b-it
 1import torch
 2from transformers import GemmaForCausalLM, GemmaTokenizerFast
 3from transformers.modeling_outputs import CausalLMOutputWithPast
 4
 5import todd
 6
 7assert torch.cuda.device_count() <= 4, (  # yapf: disable
 8    "Please use no more than 4 GPUs, in order to avoid RuntimeError."
 9)
10
11PRETRAINED = 'pretrained/gemma/gemma-1.1-2b-it'
12
13tokenizer: GemmaTokenizerFast = GemmaTokenizerFast.from_pretrained(PRETRAINED)
14model: GemmaForCausalLM = GemmaForCausalLM.from_pretrained(
15    PRETRAINED,
16    device_map='auto',
17    torch_dtype='auto',
18    revision='float16',
19)
20
21WORD_TOKEN = '<word>'  # nosec B105
22TEMPLATE = (
23    "Imagine an instance of the given word and describe its appearance in a "
24    "single sentence. Be creative and avoid describing the scene. The word is"
25    f": {WORD_TOKEN}"
26)
27
28conversation = [dict(role='user', content=TEMPLATE)]
29inputs = tokenizer.apply_chat_template(
30    conversation,
31    add_generation_prompt=True,
32    tokenize=False,
33)
34prefix, suffix = inputs.split(WORD_TOKEN)
35prefix_ids: torch.Tensor = tokenizer.encode(
36    prefix,
37    add_special_tokens=False,
38    return_tensors='pt',
39)
40suffix_ids: torch.Tensor = tokenizer.encode(
41    suffix,
42    add_special_tokens=False,
43    return_tensors='pt',
44)
45if todd.Store.cuda:  # pylint: disable=using-constant-test
46    prefix_ids = prefix_ids.cuda()
47    suffix_ids = suffix_ids.cuda()
48prefix_outputs: CausalLMOutputWithPast = model(prefix_ids, use_cache=True)
49prefix_cache = prefix_outputs.past_key_values
50
51WORD = 'car'
52
53word_ids: torch.Tensor = tokenizer.encode(
54    WORD,
55    add_special_tokens=False,
56    return_tensors='pt',
57)
58if todd.Store.cuda:  # pylint: disable=using-constant-test
59    word_ids = word_ids.cuda()
60
61input_ids = torch.cat([prefix_ids, word_ids, suffix_ids], -1)
62output_ids = model.generate(
63    input_ids,
64    past_key_values=prefix_cache,
65    use_cache=True,
66    max_new_tokens=50,
67    do_sample=True,
68)
69
70_, input_length = input_ids.shape
71output_ids = output_ids[0, input_length:]
72output_text = tokenizer.decode(output_ids, True)
73print(output_text)