{% 卡 %}
免責聲明:本指南僅供參考和教育用途,不能取代專業的醫療建議、診斷或治療。
MedGemma 不應在未經適當驗證、調整和/或開發者針對其特定用例進行有意義的修改的情況下使用。 MedGemma 產生的輸出並非旨在直接用於臨床診斷、患者管理決策、治療建議或任何其他直接的臨床實踐應用。效能基準測試突顯了相關基準測試中的基線能力,但即使對於構成大量訓練資料的圖像和文字領域,模型輸出仍可能不準確。 MedGemma 的所有輸出應視為初步結果,需要透過既定的研發方法進行獨立驗證、臨床關聯分析和進一步研究。
{% endcard %}
人工智慧(AI)正在革新醫療保健產業,但如何讓功能強大的通用AI模型掌握病理學家的專業技能呢?從原型到生產的這段旅程通常始於筆記本,而這正是我們將要開始的地方。
在本指南中,我們將邁出至關重要的第一步。我們將逐步完成 Gemma 3 變體MedGemma 的微調過程。 MedGemma 是谷歌面向醫學界推出的一系列開放模型,用於對乳癌組織病理影像進行分類。我們使用全精度 MedGemma 模型,因為這是在許多臨床任務中獲得最佳表現所必需的。如果您擔心計算成本,可以使用MedGemma 預先配置的微調 notebook進行量化和微調。
為了完成第一步,我們將使用Finetune Notebook 。該 Notebook 提供了所有程式碼以及逐步操作說明,是進行實驗的理想環境。我也會分享我在這個過程中學到的關鍵見解,包括一個至關重要的資料類型選擇,它最終產生了決定性的影響。
在原型設計階段完善模型後,我們就可以進入下一步了。在接下來的文章中,我們將向您展示如何使用Cloud Run 作業將此工作流程遷移到可擴展的、可用於生產環境的環境中。
在深入程式碼之前,我們先來了解一下背景。我們的目標是將乳房組織的顯微鏡影像分類為八類:四類良性(非癌性)和四類惡性(癌性)。這種分類是病理學家為做出準確診斷而執行的眾多關鍵任務之一,而我們擁有一套強大的工具來完成這項工作。
我們將使用MedGemma ,這是Google推出的一系列強大的開放模型,它基於與 Gemini 模型相同的研究和技術建置。 MedGemma 的獨特之處在於,它並非通用模型,而是專門針對醫療領域進行了最佳化。
MedGemma 的視覺元件MedSigLIP已使用大量去辨識化的醫學影像進行預先訓練,其中包括我們正在使用的組織病理切片類型。如果您不需要 MedGemma 的預測能力,可以單獨使用 MedSigLIP,它是一種更經濟高效的預測任務選擇,例如影像分類。您可以使用多個MedSigLIP 教學筆記本進行微調。
MedGemma 語言元件也經過了各種醫學文本的訓練,因此我們使用的google/medgemma-4b-it版本非常適合我們基於文字的提示。 Google 為 MedGemma 提供了一個強大的基礎,但它需要針對特定用例進行微調——而這正是我們即將要做的。
為了訓練我們的模型,我們將使用 乳癌組織病理影像分類(BreakHis)資料集。 BreakHis 資料集是一個公開的資料集,包含數千張乳房腫瘤組織的顯微鏡影像,這些影像來自 82 位患者,並使用了不同的放大倍率(40 倍、100 倍、200 倍和 400 倍)。此資料集可公開用於非商業研究,詳情請參閱論文:FA Spanhol、LS Oliveira、C. Petitjean 和 L. Heudel, 《乳癌組織病理影像分類資料集》 。 <sup> 1</sup>
處理一個包含 40 億個參數的模型需要強大的 GPU,因此我在Vertex AI Workbench中使用了配備40 GB記憶體的NVIDIA A100 。這款 GPU 擁有足夠的效能,並且還配備了NVIDIA Tensor Core ,能夠出色地處理現代資料格式,我們將利用這一點來加快訓練速度。在後續文章中,我們將解釋如何計算微調所需的記憶體。
我第一次嘗試載入模型時為了節省記憶體使用了常見的float16資料類型,結果慘敗。模型的輸出完全是亂碼,快速除錯後發現所有內部值都變成了NaN(非數字) 。
罪魁禍首是經典的數值溢位。
要理解其中的原因,你需要了解這些 16 位元格式之間的關鍵差異:
float16 (FP16) 的數值範圍非常小,無法表示任何大於 65,504 的數字。在轉換器執行數百萬次計算的過程中,中間值很容易超過這個限制,導致溢位並產生 NaN 值。一旦出現 NaN 值,就會污染後續的所有計算。
bfloat16 (BF16) :這種格式由谷歌大腦開發,它做出了一個關鍵的權衡。它犧牲了一些精確度,以保持與完整的 32 位元 float32 格式相同的巨大數值範圍。
bfloat16 的超大值範圍可以防止溢出,從而確保訓練過程的穩定性。修復方法雖然只是簡單地修改了一行程式碼,但卻是基於這個關鍵概念。
成功程式碼:
# The simple, stable solution
model_kwargs = dict(
torch_dtype=torch.bfloat16, # Use bfloat16 for its wide numerical range
device_map="auto",
attn_implementation="sdpa",
)
model = AutoModelForImageTextToText.from_pretrained(MODEL_ID, **model_kwargs)
經驗教訓:在微調大型模型時,始終優先選擇bfloat16 ,因為它更穩定。這個小小的改動可以讓你避免很多與 NaN 值相關的麻煩。
現在,讓我們來看程式碼。我會將我的Finetune Notebook分解成清晰、合乎邏輯的步驟。
首先,您需要從 Hugging Face 生態系統安裝必要的程式庫,然後登入您的帳戶以下載模型。
# Install required packages
!pip install --upgrade --quiet transformers datasets evaluate peft trl scikit-learn
import os
import re
import torch
import gc
from datasets import load_dataset, ClassLabel
from peft import LoraConfig, PeftModel
from transformers import AutoModelForImageTextToText, AutoProcessor
from trl import SFTTrainer, SFTConfig
import evaluate
⚠️ 重要安全提示:切勿將 API 金鑰或令牌等敏感資訊直接硬編碼到程式碼或筆記本中,尤其是在生產環境中。這種做法不安全,會造成嚴重的安全風險。
在 Vertex AI Workbench 中,處理金鑰(例如您的 Hugging Face 令牌)最安全、企業級的方法是使用 Google Cloud 的Secret Manager 。
如果您只是在進行實驗,暫時不想設定 Secret Manager,可以使用互動式登入小工具。該小部件會將令牌暫時保存在實例的檔案系統中。
# Hugging Face authentication using interactive login widget:
from huggingface_hub import notebook_login
notebook_login()
在即將發布的文章中,我們將把這個過程遷移到 Cloud Run Jobs,並向您展示使用 Secret Manager 來處理此令牌的正確和安全方法。
接下來,我們使用kagglehub庫從 Kaggle 下載BreakHis 資料集。該資料集包含一個Folds.csv文件,其中概述了實驗的資料劃分方式。原始研究使用了 5 折交叉驗證,但為了控製本次演示的訓練時間,我們將重點關注第一折,並且僅使用 100 倍放大倍率的圖像。您可以嘗試使用其他折數和放大倍率進行更廣泛的實驗。
! pip install -q kagglehub
import kagglehub
import os
import pandas as pd
from PIL import Image
from datasets import Dataset, Image as HFImage, Features, ClassLabel
# Download the dataset metadata
path = kagglehub.dataset_download("ambarish/breakhis")
print("Path to dataset files:", path)
folds = pd.read_csv('{}/Folds.csv'.format(path))
# Filter for 100X magnification from the first fold
folds_100x = folds[folds['mag']==100]
folds_100x = folds_100x[folds_100x['fold']==1]
# Get the train/test splits
folds_100x_test = folds_100x[folds_100x.grp=='test']
folds_100x_train = folds_100x[folds_100x.grp=='train']
# Define the base path for images
BASE_PATH = "/home/jupyter/.cache/kagglehub/datasets/ambarish/breakhis/versions/4/BreaKHis_v1"
初始的100倍放大倍率訓練集和測試集劃分顯示良性和惡性樣本數量不平衡。為了解決這個問題,我們將對訓練集和測試集中數量較多的樣本進行欠採樣,以建立良惡性樣本數量比例為50/50的平衡資料集。
# --- 1. Create Balanced TRAIN Set ---
train_benign_df = folds_100x_train[folds_100x_train['filename'].str.contains('benign')]
train_malignant_df = folds_100x_train[folds_100x_train['filename'].str.contains('malignant')]
min_train_count = min(len(train_benign_df), len(train_malignant_df))
balanced_train_benign = train_benign_df.sample(n=min_train_count, random_state=42)
balanced_train_malignant = train_malignant_df.sample(n=min_train_count, random_state=42)
balanced_train_df = pd.concat([balanced_train_benign, balanced_train_malignant])
# --- 2. Create Balanced TEST Set ---
test_benign_df = folds_100x_test[folds_100x_test['filename'].str.contains('benign')]
test_malignant_df = folds_100x_test[folds_100x_test['filename'].str.contains('malignant')]
min_test_count = min(len(test_benign_df), len(test_malignant_df))
balanced_test_benign = test_benign_df.sample(n=min_test_count, random_state=42)
balanced_test_malignant = test_malignant_df.sample(n=min_test_count, random_state=42)
balanced_test_df = pd.concat([balanced_test_benign, balanced_test_malignant])
# --- 3. Get the Final Filename Lists ---
train_filenames = balanced_train_df['filename'].values
test_filenames = balanced_test_df['filename'].values
print(f"Balanced Train: {len(train_filenames)} files")
print(f"Balanced Test: {len(test_filenames)} files")
我們將資料轉換為 Hugging Face datasets集格式,因為這是使用其 Transformers 庫中的SFTTrainer最簡單方法。這種格式針對處理大型資料集(尤其是圖像)進行了最佳化,因為它可以在需要時有效地載入資料。此外,它還為我們提供了便捷的預處理工具,例如將我們的格式化函數應用於所有範例。
CLASS_NAMES = [
'benign_adenosis', 'benign_fibroadenoma', 'benign_phyllodes_tumor',
'benign_tubular_adenoma', 'malignant_ductal_carcinoma',
'malignant_lobular_carcinoma', 'malignant_mucinous_carcinoma',
'malignant_papillary_carcinoma'
]
def get_label_from_filename(filename):
filename = filename.replace('\\', '/').lower()
if '/adenosis/' in filename: return 0
if '/fibroadenoma/' in filename: return 1
if '/phyllodes_tumor/' in filename: return 2
if '/tubular_adenoma/' in filename: return 3
if '/ductal_carcinoma/' in filename: return 4
if '/lobular_carcinoma/' in filename: return 5
if '/mucinous_carcinoma/' in filename: return 6
if '/papillary_carcinoma/' in filename: return 7
return -1
train_data_dict = {
'image': [os.path.join(BASE_PATH, f) for f in train_filenames],
'label': [get_label_from_filename(f) for f in train_filenames]
}
test_data_dict = {
'image': [os.path.join(BASE_PATH, f) for f in test_filenames],
'label': [get_label_from_filename(f) for f in test_filenames]
}
features = Features({
'image': HFImage(),
'label': ClassLabel(names=CLASS_NAMES)
})
train_dataset = Dataset.from_dict(train_data_dict, features=features).cast_column("image", HFImage())
eval_dataset = Dataset.from_dict(test_data_dict, features=features).cast_column("image", HFImage())
print(train_dataset)
print(eval_dataset)
在這一步,我們需要告訴模型我們想要它做什麼。我們建立一個清晰、結構化的提示,指示模型分析影像並僅傳回與類別對應的數字。此提示使輸出簡潔易懂。然後,我們將此格式對應到整個資料集。
# Define the instruction prompt
PROMPT = """Analyze this breast tissue histopathology image and classify it.
Classes (0-7):
0: benign_adenosis
1: benign_fibroadenoma
2: benign_phyllodes_tumor
3: benign_tubular_adenoma
4: malignant_ductal_carcinoma
5: malignant_lobular_carcinoma
6: malignant_mucinous_carcinoma
7: malignant_papillary_carcinoma
Answer with only the number (0-7):"""
def format_data(example):
"""Format examples into the chat-style messages MedGemma expects."""
example["messages"] = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": PROMPT},
],
},
{
"role": "assistant",
"content": [
{"type": "text", "text": str(example["label"])},
],
},
]
return example
# Apply formatting
formatted_train = train_dataset.map(format_data, batched=False)
formatted_eval = eval_dataset.map(format_data, batched=False)
print("✓ Data formatted with instruction prompts")
在這裡,我們載入 MedGemma 模型及其關聯的處理器。該處理器是一個便捷的工具,用於準備模型所需的圖像和文字。我們還將選擇兩個關鍵參數以提高效率:
torch_dtype=torch.bfloat16 :正如我們前面提到的,這種格式可以確保數值穩定性。
attn_implementation="sdpa" : 縮放點積注意力機制是 PyTorch 2.0 中提供的一種高度優化的注意力機制。你可以將此機制理解為告訴模型使用超快的內建引擎來執行其最重要的計算。它能夠加速訓練和推理,如果你的硬體支持,它甚至可以自動使用更高級的後端,例如 FlashAttention。
MODEL_ID = "google/medgemma-4b-it"
# Model configuration
model_kwargs = dict(
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="sdpa",
)
model = AutoModelForImageTextToText.from_pretrained(MODEL_ID, **model_kwargs)
processor = AutoProcessor.from_pretrained(MODEL_ID)
# Ensure right padding for training
processor.tokenizer.padding_side = "right"
在投入時間和運算資源進行微調之前,我們先來看看預訓練模型本身的表現如何。這一步驟可以為我們提供一個基準,以便衡量我們改進的效果。
# Helper functions to run evaluation
accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")
def compute_metrics(predictions, references):
return {
**accuracy_metric.compute(predictions=predictions, references=references),
**f1_metric.compute(predictions=predictions, references=references, average="weighted")
}
def postprocess_prediction(text):
"""Extract just the number from the model's text output."""
digit_match = re.search(r'\b([0-7])\b', text.strip())
return int(digit_match.group(1)) if digit_match else -1
def batch_predict(model, processor, prompts, images, batch_size=8, max_new_tokens=40):
"""A function to run inference in batches."""
predictions = []
for i in range(0, len(prompts), batch_size):
batch_texts = prompts[i:i + batch_size]
batch_images = [[img] for img in images[i:i + batch_size]]
inputs = processor(text=batch_texts, images=images, padding=True, return_tensors="pt").to("cuda", torch.bfloat16)
prompt_lengths = inputs["attention_mask"].sum(dim=1)
with torch.inference_mode():
outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False, pad_token_id=processor.tokenizer.pad_token_id)
for seq, length in zip(outputs, prompt_lengths):
generated = processor.decode(seq[length:], skip_special_tokens=True)
predictions.append(postprocess_prediction(generated))
return predictions
# Prepare data for evaluation
eval_prompts = [processor.apply_chat_template([msg[0]], add_generation_prompt=True, tokenize=False) for msg in formatted_eval["messages"]]
eval_images = formatted_eval["image"]
eval_labels = formatted_eval["label"]
# Run baseline evaluation
print("Running baseline evaluation...")
baseline_preds = batch_predict(model, processor, eval_prompts, eval_images)
baseline_metrics = compute_metrics(baseline_preds, eval_labels)
print(f"\n{'BASELINE RESULTS':-^80}")
print(f"Accuracy: {baseline_metrics['accuracy']:.1%}")
print(f"F1 Score: {baseline_metrics['f1']:.3f}")
print("-" * 80)
我們對基線模型在 8 類分類和二元(良性/惡性)分類上的表現進行了評估:
8類分類準確率:32.6%
8 類 F1 得分(加權):0.241
二元準確率:59.6%
二元F1評分(惡性):0.639
此輸出結果表明,該模型的性能優於隨機猜測(12.5%),但仍有很大的改進空間,尤其是在細粒度的 8 類分類方面。
在開始訓練之前,值得一問:微調是唯一的方法嗎?另一種流行的技術是少樣本學習。
小樣本學習就像在考試前給聰明的學生幾個新數學題的例子。你不是在重新教他們代數,而是在題目中直接提供例子,向他們展示你希望他們遵循的特定模式。這是一種強大的技巧,尤其是在使用透過 API 實現的封閉模型,而無法存取內部權重時。
那麼,我們為什麼選擇微調呢?
我們可以託管該模型:由於 MedGemma 是一個開放模型,我們可以直接存取其架構。這種存取權限使我們能夠進行微調,從而建立一個新的、永久更新的模型版本。
我們擁有一個很好的資料集:微調可以讓模型學習數百張訓練影像中深層、潛在的模式,比僅僅在提示中向它展示幾個例子要有效得多。
簡而言之,微調為我們的任務建立了一個真正的專業模型,這正是我們所想要的。
這才是重頭戲!我們將使用低秩自適應(LoRA) ,它比傳統的微調方法速度更快、記憶體效率更高。 LoRA 的工作原理是凍結原始模型權重,僅訓練一小部分新的適配器權重。以下是我們的參數選擇明細:
r=8 :LoRA 等級。等級越低,可訓練參數越少,速度越快,但表達能力越弱。等級越高,容量越大,但在小資料集上容易過度擬合。等級 8 是一個很好的起點,兼顧了性能和效率。
lora_alpha=16 :LoRA權重的縮放因子。通常的經驗法則是將其設定為秩的兩倍(2 × r)。
lora_dropout=0.1 :一種正規化技術。它在訓練過程中隨機停用一些 LoRA 神經元,以防止模型過度專業化而喪失泛化能力。
# LoRA Configuration
peft_config = LoraConfig(
r=8,
lora_alpha=16,
lora_dropout=0.1,
bias="none",
target_modules="all-linear",
task_type="CAUSAL_LM",
)
# Custom data collator to handle images and text
def collate_fn(examples):
texts, images = [], []
for example in examples:
images.append([example["image"]])
texts.append(processor.apply_chat_template(example["messages"], add_generation_prompt=False, tokenize=False).strip())
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100
image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.special_tokens_map["boi_token"])
labels[labels == image_token_id] = -100
labels[labels == 262144] = -100
batch["labels"] = labels
return batch
# Training arguments
training_args = SFTConfig(
output_dir="medgemma-breastcancer-finetuned",
num_train_epochs=5,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
gradient_accumulation_steps=8,
gradient_checkpointing=True,
optim="paged_adamw_8bit",
learning_rate=5e-4,
lr_scheduler_type="cosine",
warmup_ratio=0.03, # Warm up LR for first 3% of training
max_grad_norm=0.3, # Clip gradients to prevent instability
bf16=True, # Use bfloat16 precision
logging_steps=10,
save_strategy="steps",
save_steps=100,
eval_strategy="epoch",
push_to_hub=False,
report_to="none",
gradient_checkpointing_kwargs={"use_reentrant": False},
dataset_kwargs={"skip_prepare_dataset": True},
remove_unused_columns=False,
label_names=["labels"],
)
# Initialize and run the trainer
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=formatted_train,
eval_dataset=formatted_eval,
peft_config=peft_config,
processing_class=processor,
data_collator=collate_fn,
)
print("Starting training...")
trainer.train()
trainer.save_model()
在配備 40 GB 記憶體的 A100 GPU 上進行訓練約耗時80 分鐘。結果看起來很有希望,驗證損失穩定下降。
重要提示(節省時間!) :如果您的訓練因任何原因中斷(例如連接問題或超出資源限制),您可以使用trainer.train()中的resume_from_checkpoint參數從已保存的檢查點恢復訓練過程。檢查點可以節省您寶貴的時間,因為它們會按照TrainingArguments中定義的save_steps間隔進行保存。
訓練完成後,就到了檢驗結果的時刻。我們將載入新的 LoRa 適配器權重,將其與基礎模型合併,然後執行與基準模型相同的評估。
# Clear memory and load the final model
del model
torch.cuda.empty_cache()
gc.collect()
# Load base model again
base_model = AutoModelForImageTextToText.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="sdpa"
)
# Load LoRA adapters and merge them into a single model
finetuned_model = PeftModel.from_pretrained(base_model, training_args.output_dir)
finetuned_model = finetuned_model.merge_and_unload()
# Configure for generation
finetuned_model.generation_config.max_new_tokens = 50
finetuned_model.generation_config.pad_token_id = processor_finetuned.tokenizer.pad_token_id
finetuned_model.config.pad_token_id = processor_finetuned.tokenizer.pad_token_id
# Load the processor and run evaluation
processor_finetuned = AutoProcessor.from_pretrained(training_args.output_dir)
finetuned_preds = batch_predict(finetuned_model, processor_finetuned, eval_prompts, eval_images, batch_size=4)
finetuned_metrics = compute_metrics(finetuned_preds, eval_labels)
那麼,微調對效能有何影響呢?讓我們來看看8級精度和宏F1的資料。
--- 8-Class Classification (0-7) ---
Model Accuracy F1 (Weighted)
-----------------------------------------------
Baseline 32.6% 0.241
Fine-tuned 87.2% 0.865
-----------------------------------------------
--- Binary (Benign/Malignant) Classification ---
Model Accuracy F1 (Malignant)
-----------------------------------------------
Baseline 59.6% 0.639
Fine-tuned 99.0% 0.991
-----------------------------------------------
結果非常棒!經過微調後,我們看到了顯著的提升:
8 類:準確率從 32.6% 躍升至 87.2%(+54.6%),F1 值從 0.241 躍升至 0.865。
二進位:準確率從 59.6% 提高到 99.0%(+39.4%),F1 值從 0.639 提高到 0.991。
這個專案展現了微調現代基礎模型的強大能力。我們選取了一個已經基於相關醫學資料預訓練的通用人工智慧模型,為其提供一個小型、專門的資料集,並以驚人的效率教導它一項新技能。從通用模型到專門分類器的轉變比以往任何時候都更加便捷,這為人工智慧在醫學及其他領域的應用開闢了令人興奮的可能性。
所有資訊都可以在Finetune Notebook中找到。您可以使用Vertex AI Workbench上的 GPU 執行個體來執行它。
想把它投入生產環境嗎?別忘了關注即將發布的文章,文章將向您展示如何將微調和評估功能引入Cloud Run 作業。
希望這篇指南對您有幫助。祝您程式愉快!
特別感謝 MedGemma 團隊的 Fereshteh Mahvar 和 Dave Steiner 對本文的寶貴意見和回饋。
1 IEEE生物醫學工程學報,第63卷,第7期,第1455-1462頁,2016年