1import torch
2from transformers import GemmaForCausalLM, GemmaTokenizerFast
3from transformers.modeling_outputs import CausalLMOutputWithPast
4
5import todd
6
7assert torch.cuda.device_count() <= 4, ( # yapf: disable
8 "Please use no more than 4 GPUs, in order to avoid RuntimeError."
9)
10
11PRETRAINED = 'pretrained/gemma/gemma-1.1-2b-it'
12
13tokenizer: GemmaTokenizerFast = GemmaTokenizerFast.from_pretrained(PRETRAINED)
14model: GemmaForCausalLM = GemmaForCausalLM.from_pretrained(
15 PRETRAINED,
16 device_map='auto',
17 torch_dtype='auto',
18 revision='float16',
19)
20
21WORD_TOKEN = '<word>' # nosec B105
22TEMPLATE = (
23 "Imagine an instance of the given word and describe its appearance in a "
24 "single sentence. Be creative and avoid describing the scene. The word is"
25 f": {WORD_TOKEN}"
26)
27
28conversation = [dict(role='user', content=TEMPLATE)]
29inputs = tokenizer.apply_chat_template(
30 conversation,
31 add_generation_prompt=True,
32 tokenize=False,
33)
34prefix, suffix = inputs.split(WORD_TOKEN)
35prefix_ids: torch.Tensor = tokenizer.encode(
36 prefix,
37 add_special_tokens=False,
38 return_tensors='pt',
39)
40suffix_ids: torch.Tensor = tokenizer.encode(
41 suffix,
42 add_special_tokens=False,
43 return_tensors='pt',
44)
45if todd.Store.cuda: # pylint: disable=using-constant-test
46 prefix_ids = prefix_ids.cuda()
47 suffix_ids = suffix_ids.cuda()
48prefix_outputs: CausalLMOutputWithPast = model(prefix_ids, use_cache=True)
49prefix_cache = prefix_outputs.past_key_values
50
51WORD = 'car'
52
53word_ids: torch.Tensor = tokenizer.encode(
54 WORD,
55 add_special_tokens=False,
56 return_tensors='pt',
57)
58if todd.Store.cuda: # pylint: disable=using-constant-test
59 word_ids = word_ids.cuda()
60
61input_ids = torch.cat([prefix_ids, word_ids, suffix_ids], -1)
62output_ids = model.generate(
63 input_ids,
64 past_key_values=prefix_cache,
65 use_cache=True,
66 max_new_tokens=50,
67 do_sample=True,
68)
69
70_, input_length = input_ids.shape
71output_ids = output_ids[0, input_length:]
72output_text = tokenizer.decode(output_ids, True)
73print(output_text)