écrits/tutorial/2024/12
Tutorial14 déc. 2024·15 min

Fine-tuning de Gemma pour l'appel de fonctions

Un guide complet sur le fine-tuning du modèle Gemma pour l'appel de fonctions en utilisant Torch XLA et le framework TRL de Hugging Face.

Bienvenue dans ce guide complet sur le fine-tuning du modèle Gemma pour l'appel de fonctions. Ce tutoriel vous accompagnera tout au long du processus, de la configuration de votre environnement à l'entraînement et au test du modèle. Nous utiliserons Torch XLA et le framework TRL (Transformer Reinforcement Learning) de Hugging Face.

Introduction à Gemma

Gemma est une famille de modèles ouverts légers et à la pointe de la technologie de Google, conçus pour diverses tâches de génération de texte. Ces modèles sont text-to-text, decoder-only large language models, disponibles en anglais, et bien adaptés aux environnements avec des ressources limitées.

Configuration de l'environnement

Sélection de l'environnement d'exécution

Vous pouvez choisir Google Colab ou Kaggle comme plateforme. Pour ce guide, nous nous concentrerons sur Kaggle.

Configuration Kaggle

  1. Cliquez sur Ouvrir dans Kaggle.
  2. Cliquez sur Paramètres dans la barre latérale droite.
  3. Sous Accélérateur, sélectionnez TPUs.
  4. Enregistrez les paramètres, et le notebook redémarrera avec le support TPU.

Gemma avec Hugging Face

  1. Créer un compte Hugging Face : Inscrivez-vous gratuitement si vous n'en avez pas.
  2. Accéder au modèle Gemma : Visitez la page du modèle Gemma et acceptez les conditions d'utilisation.
  3. Générer un token Hugging Face : Allez dans les paramètres de votre compte Hugging Face et générez un nouveau token d'accès.

Configurer vos identifiants

import os
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
os.environ['HF_TOKEN'] = user_secrets.get_secret("HF_TOKEN")

Installer les dépendances

!pip install transformers==4.46.1 datasets==3.1.0 trl==0.12.0 peft==0.13.2 accelerate==0.34.0 torch~=2.5.0 torch_xla[tpu]~=2.5.0 tpu-info

Fine-tuning de Gemma 2 pour l'appel de fonctions

Initialisation du modèle Gemma 2

from transformers import AutoTokenizer, AutoModelForCausalLM
model_name = "google/gemma-2-2b-it"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
model.config.use_cache = False
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = "right"

Charger un dataset

Utilisez un dataset existant de Hugging Face, comme lilacai/glaive-function-calling-v2-sharegpt.

from datasets import load_dataset
dataset = load_dataset("lilacai/glaive-function-calling-v2-sharegpt", split="train[:15%]")

Créer un template de chat personnalisé

chat_template = "{{ bos_token }}{% if messages[0]['from'] == 'system' %}{{'<start_of_turn>user\n' + messages[0]['value'] | trim + ' ' + messages[1]['value'] | trim + '<end_of_turn>\n'}}{% set messages = messages[2:] %}{% endif %}{% for message in messages %}{% if message['from'] == 'human' %}{{'<start_of_turn>user\n' + message['value'] | trim + '<end_of_turn>\n'}}{% elif message['from'] == 'gpt' %}{{'<start_of_turn>model\n' + message['value'] | trim + '<end_of_turn>\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<start_of_turn>model\n' }}{% endif %}"
tokenizer.chat_template = chat_template

Définir la fonction de formatage

def formatting_prompts_func(examples):
    convos = examples["conversations"]
    texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in convos]
    return { "text" : texts, }
dataset = dataset.map(formatting_prompts_func, batched=True)

Configuration LoRA

from peft import LoraConfig
peft_config = LoraConfig(lora_alpha=16, lora_dropout=0, r=16, bias="none", task_type="CAUSAL_LM", target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"])

Définir la configuration d'entraînement

from trl import SFTTrainer, SFTConfig
training_arguments = SFTConfig(output_dir="./results", max_steps=100, per_device_train_batch_size=32, optim="adafactor", learning_rate=0.0002, bf16=True, max_seq_length=1024, dataset_text_field="text", packing=True, logging_steps=1, seed=42)

Entraîner le modèle

trainer = SFTTrainer(model=model, train_dataset=dataset, peft_config=peft_config, args=training_arguments)
trainer.train()

Sauvegarder le modèle fine-tuné

trainer.model.to('cpu').save_pretrained("gemma-func-ft")

Tester le modèle fine-tuné

Rechargez le modèle fine-tuné et testez-le avec un prompt exemple.

base_model = AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16)
model = PeftModel.from_pretrained(base_model, "gemma-func-ft")
model = model.merge_and_unload()
input_text = "<start_of_turn>user\nVous êtes un assistant utile..."
input_ids = tokenizer(input_text, return_tensors="pt").to("cpu")
outputs = model.generate(**input_ids, max_new_tokens=512)
print(tokenizer.decode(outputs[0]))

Conclusion

Félicitations ! Vous avez réussi à fine-tuner Gemma pour l'appel de fonctions en utilisant Torch XLA et PEFT avec LoRA sur TPUs. Explorez davantage en expérimentant avec différents datasets et en ajustant les hyperparamètres.


Référence : Google Gemma Cookbook par Google LLC.