In Improve RAG retrieval and training with Feast and Kubeflow Trainer, we established the infrastructure by setting up Feast and ingesting our knowledge base into Milvus. Now, we focus on the model itself. This post walks you through preprocessing the training data, fine-tuning a RAG model with our custom Feast retriever, and scaling the training workflow using Kubeflow Trainer on Red Hat OpenShift AI.
Preprocessing the Natural Questions dataset for RAG training
Now that you have a knowledge base in Milvus, the next step is to prepare the training dataset for fine-tuning. We will use Google's Natural Questions dataset, which is derived from Wikipedia. This dataset offers both short and long answers for each question, but we only need the short answers. First, we'll filter the dataset to include only questions with short answers.
Next, because our knowledge base uses only 5% of the Wikipedia dataset, you must confirm that the knowledge base can answer the questions in your preprocessed dataset. To do this, perform an intersection check between the training dataset and the knowledge base using the is_question_answerable() function. This function retrieves relevant passages and validates whether the retrieved context contains the expected answer. This small-scale fine-tuning example uses 3,000 training samples and 300 evaluation samples.
Understanding the dual tokenization strategy
The following code demonstrates how to tokenize the question and answer:
# Tokenize question (uses facebook dpr encoder model)
tokenized_question = question_encoder_tokenizer(
question_text,
truncation=True,
max_length=32,
padding="max_length",
)
# Tokenize answer (uses facebook bart model)
tokenized_answer_for_labels = generator_tokenizer(
text_target=answer_text,
truncation=True,
max_length=32,
padding="max_length",
)RAG preprocessing requires you to handle the dual tokenization requirements inherent to the RAG architecture. Unlike traditional language models that use a single tokenizer, RAG models use a dual-encoder approach. This separates retrieval and generation tasks, and each requires specialized tokenization:
- Question tokenization (DPR question encoder): We use facebook/dpr-question_encoder-single-nq-base to tokenize questions because it's optimized for retrieval tasks. This encoder was trained with contrastive learning to create embeddings that excel at finding semantically similar passages in the knowledge base. The tokenization strategy helps the model distinguish between relevant and irrelevant documents during similarity search.
- Answer tokenization (generator model): For answers, we use the generator tokenizer facebook/bart-large because it's designed for sequence-to-sequence text generation. This tokenizer structures tokens to produce coherent responses and handles the formatting required for generative tasks.
After you preprocess and save the dataset, run a validation check to verify that the dataset questions and answers are tokenized and structured correctly.
Fine-tuning the RAG model
This cell implements the core training pipeline where we create and fine-tune a complete RAG model using RagSequenceForGeneration with the custom FeastRAGRetriever. This is the main part of the implementation, where you optimize the retrieval and generation components together.
Integration with Custom FeastRAGRetriever
The training integrates the custom retriever:
# Initialize custom RAG retriever for training
rag_retriever = FeastRAGRetriever(
question_encoder_tokenizer=question_encoder_tokenizer,
generator_tokenizer=generator_tokenizer,
question_encoder=question_encoder_model,
generator_model=generator_model,
feast_repo_path="/mnt/shared/kfto-sft-feast-rag/feature_repo",
feature_view=wiki_passage_feature_view,
features=features_to_retrieve,
search_type="vector",
config=rag_config,
index=feast_index,
)Training integration: During training, the RAG model calls the retriever for each question. The retriever queries Feast and Milvus to get relevant passages. The generator uses these passages to produce answers, and gradients flow back through both components.
Dynamic retrieval: Unlike static approaches, this enables the question encoder to learn better query representations based on what actually helps the generator produce correct answers.
This joint training approach creates a RAG system where retrieval and generation are optimized together, which performs better than independently trained components.
Understanding RagSequenceForGeneration vs. RagTokenForGeneration
We use RagSequenceForGeneration instead of RagTokenForGeneration for a specific architectural reason:
# Create RAG model for training
rag_model = RagSequenceForGeneration(
config=rag_config,
question_encoder=question_encoder_model,
generator=generator_model,
retriever=rag_retriever
)RagSequenceForGeneration retrieves documents once at the sequence level and conditions the entire generation process on the same retrieved context. This approach is more efficient and suitable for longer responses where consistency across the entire answer is important.
RagTokenForGeneration retrieves documents at every token generation step. This makes it computationally expensive and potentially inconsistent, as the model might generate different tokens based on different retrieved contexts.
For tasks like Natural Questions, sequence-level retrieval provides better coherence and efficiency. The Facebook RAG paper recommends this approach for most practical applications.
Joint training architecture: Following Facebook's RAG paper
The joint training follows the method from Facebook's 2021 RAG paper, where you optimize both retrieval and generation components together:
# Initialize components that will be jointly trained
question_encoder_model = DPRQuestionEncoder.from_pretrained(
question_encoder_model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
torch_dtype=model_args.torch_dtype
)
generator_model = AutoModelForSeq2SeqLM.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
torch_dtype=model_args.torch_dtype
)During training, gradients flow through both the question encoder (for better retrieval) and the generator (for better response generation). The question encoder learns to retrieve passages that are most useful for the generator, while the generator learns to better utilize the retrieved context.
Unlike traditional approaches that train retrieval and generation separately, joint training ensures that the retrieval component optimizes for passages that help generate better answers, not just semantically similar passages.
Why Transformers Seq2SeqTrainer instead of Default Trainer or TRL SFTTrainer?
We use Seq2SeqTrainer because RAG models require specialized handling for sequence-to-sequence generation tasks:
# Setup distributed training with Seq2SeqTrainer
trainer = Seq2SeqTrainer(
model=rag_model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
tokenizer=question_encoder_tokenizer,
data_collator=default_data_collator,
)Seq2SeqTrainer advantages:
- Generation-aware evaluation: Supports predict_with_generate=True for proper evaluation using actual generation rather than just loss computation.
- Sequence-to-sequence metrics: Handles ROUGE, BLEU, and other generation-specific metrics.
- Special token handling: Manages encoder-decoder attention patterns and special tokens required for generation tasks.
Why not SFTTrainer? SFTTrainer is designed for supervised fine-tuning of decoder-only models (like GPT), not encoder-decoder architectures with retrieval components.
Why not default Trainer? The default Trainer lacks the specialized generation evaluation and metrics handling required for sequence-to-sequence tasks.
Custom training arguments configuration
The training configuration is carefully adapted from SFTConfig to Seq2SeqTrainingArguments:
# Convert SFTConfig parameters to standard TrainingArguments
training_args = Seq2SeqTrainingArguments(
output_dir=training_args_trl.output_dir,
num_train_epochs=training_args_trl.num_train_epochs,
per_device_train_batch_size=training_args_trl.per_device_train_batch_size,
gradient_accumulation_steps=training_args_trl.gradient_accumulation_steps,
learning_rate=training_args_trl.learning_rate,
warmup_steps=training_args_trl.warmup_steps,
bf16=training_args_trl.bf16,
tf32=training_args_trl.tf32,
remove_unused_columns=training_args_trl.remove_unused_columns,
predict_with_generate=True, # Critical for RAG evaluation
)Key parameters:
predict_with_generate=True: Enables proper evaluation using generation rather than teacher forcing.remove_unused_columns=False: Preserves all data needed for retrieval during training.FSDP configuration: Enables distributed training across multiple GPUs with memory optimization.
Model saving and checkpoint management
The training process automatically handles comprehensive model saving:
# Training execution
trainer.train()
# Final model save includes all components
trainer.save_model() # Saves complete RAG pipelineWhat gets saved:
- Complete RAG model: The entire
RagSequenceForGenerationarchitecture - Question encoder: Fine-tuned DPR question encoder weights
- Generator model: Fine-tuned BART generator weights
- Tokenizers: Both question encoder and generator tokenizers
- Configuration files: RAG config, tokenizer configs, and training arguments.
You can expect to see this structure in the specified output directory.
Configure the client SDK
In this example, you use the Kubeflow training SDK to create the PyTorchJob resource. The Kubeflow Training Operator uses this resource to configure the PyTorch pods. For the SDK to authenticate to the OpenShift API server and be authorized to create that PyTorchJob resource, you need to provide a valid bearer token by filling the placeholders in the following cell of the notebook:
api_server = "<API_SERVER>"
token = "<TOKEN>"
# Un-comment if your cluster API server uses a self-signed certificate or an un-trusted CA
#configuration.verify_ssl = FalseNote: You can retrieve a valid bearer token and the OpenShift API server URL from the OpenShift web console. Select Copy login command in the drop-down menu in the top-right corner of the navigation bar.
Create the fine-tuning job
You’re almost ready to create the fine-tuning job. Fill the HF_TOKEN environment variable value with a valid user access token from Hugging Face if you are fine-tuning a gated model. You might also need to review the compute resources allocated to the job, like the number of workers and the resources for each, based on your environment's availability:
client.create_job(
job_kind="PyTorchJob",
name="sft",
train_func=main,
num_workers=8,
num_procs_per_worker="1",
resources_per_worker={
"nvidia.com/gpu": 1,
"memory": "64Gi",
"cpu": 4,
},
base_image="quay.io/modh/training:py311-cuda121-torch241",
env_vars={
# HuggingFace
"HF_HOME": "/mnt/shared/.cache",
"HF_TOKEN": "",
# CUDA
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
# NCCL / RCCL
"NCCL_DEBUG": "INFO",
},
parameters=parameters,
volumes=[
V1Volume(name="shared",
persistent_volume_claim=V1PersistentVolumeClaimVolumeSource(claim_name="shared")),
],
volume_mounts=[
V1VolumeMount(name="shared", mount_path="/mnt/shared"),
],
)If you use AMD accelerators, update these fields:
client.create_job(
resources_per_worker={
"amd.com/gpu": 1,
},
base_image="quay.io/modh/training:py311-rocm62-torch241",
env_vars={
# ROCm (HIP)
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
},
)Note: You can find the list of the supported base container images in the Training images section of Red Hat OpenShift AI supported configurations.
After you create the fine-tuning job, you can follow its progress by watching the logs:
client.get_job_logs(
name="sft",
job_kind="PyTorchJob",
follow=True,
)Because HF_HOME is configured to point to the shared persistent storage, the pre-trained model from Hugging Face downloads once and is written into the cache directory. Only one worker acquires the shared file-based lock the first time to download the model, while the other workers wait for the download to complete. During subsequent runs of the fine-tuning job, the system uses the checkpoint stored in the cache instead of redownloading the model. This speeds up the process of experimenting with different hyperparameters.
Test the fine-tuned model
After you run the fine-tuning job, you can run inferences from within the notebook if you attached an accelerator when you created it.
Load the fine-tuned model
The fine-tuned model checkpoint includes all components necessary for RAG inference: the question encoder, generator, and retriever configuration.
# Load fine-tuned RAG model
finetuned_rag_model = RagSequenceForGeneration.from_pretrained(
FINETUNED_RAG_CHECKPOINT_DIR,
retriever=rag_retriever_inference,
)
finetuned_rag_model.to(device)
finetuned_rag_model.eval()The RagSequenceForGeneration.from_pretrained() method loads the complete model from the checkpoint directory. Passing the rag_retriever_inference object reconnects the model to the Feast feature store so it can retrieve context during generation.
Run inference
With the model loaded, you can generate answers for test queries. The RAG model retrieves relevant context from Feast and generates responses based on both the query and retrieved passages.
for test_query in test_queries:
# Tokenize the question
inputs = question_encoder_tokenizer(test_query, return_tensors="pt").to(device)
# Generate answer with retrieval
with torch.no_grad():
generated_ids = finetuned_rag_model.generate(
input_ids=inputs["input_ids"],
max_new_tokens=200,
)
# Decode and print the answer
answer = generator_tokenizer_inference.decode(generated_ids[0], skip_special_tokens=True)
print(f"Question: {test_query}")
print(f"Answer: {answer}\n")During inference:
- The question encoder tokenizes the input query.
- The model retrieves relevant passages from Feast's feature store.
- The generator produces an answer conditioned on both the query and retrieved context.
- The answer is decoded and returned.
The fine-tuned model should produce more accurate and contextually relevant answers compared to the base RAG model because it learned to use the retrieved information better during training.
Conclusion
This walkthrough demonstrated how to fine-tune a RAG model using Feast for feature management and OpenShift AI for distributed training. Combining Feast's feature store capabilities with the RAG architecture lets you build production-ready RAG systems that efficiently manage knowledge bases and scale training across multiple GPUs.
The key advantages of this approach include:
- Dynamic knowledge management: Update embeddings in Feast without retraining the model.
- Distributed training: Scale fine-tuning across multiple GPUs using FSDP.
- Production deployment: Integration between training and inference through Feast.
This pipeline provides a foundation for building RAG systems that can evolve with your organization's needs while maintaining consistent, reproducible deployments.