Adding lots of new skills to sharedZZ
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
---
|
||||
description: Knowledge and Tools for Machine Learning Operations - tools and frameworks for training, fine-tuning, deploying, and optimizing ML/AI models
|
||||
---
|
||||
@@ -0,0 +1,3 @@
|
||||
---
|
||||
description: Model evaluation benchmarks, experiment tracking, data curation, tokenizers, and interpretability tools.
|
||||
---
|
||||
@@ -0,0 +1,497 @@
|
||||
---
|
||||
name: evaluating-llms-harness
|
||||
description: "lm-eval-harness: benchmark LLMs (MMLU, GSM8K, etc.)."
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [lm-eval, transformers, vllm]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Evaluation, LM Evaluation Harness, Benchmarking, MMLU, HumanEval, GSM8K, EleutherAI, Model Quality, Academic Benchmarks, Industry Standard]
|
||||
|
||||
---
|
||||
|
||||
# lm-evaluation-harness - LLM Benchmarking
|
||||
|
||||
## What's inside
|
||||
|
||||
Evaluates LLMs across 60+ academic benchmarks (MMLU, HumanEval, GSM8K, TruthfulQA, HellaSwag). Use when benchmarking model quality, comparing models, reporting academic results, or tracking training progress. Industry standard used by EleutherAI, HuggingFace, and major labs. Supports HuggingFace, vLLM, APIs.
|
||||
|
||||
## Quick start
|
||||
|
||||
lm-evaluation-harness evaluates LLMs across 60+ academic benchmarks using standardized prompts and metrics.
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install lm-eval
|
||||
```
|
||||
|
||||
**Evaluate any HuggingFace model**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks mmlu,gsm8k,hellaswag \
|
||||
--device cuda:0 \
|
||||
--batch_size 8
|
||||
```
|
||||
|
||||
**View available tasks**:
|
||||
```bash
|
||||
lm_eval --tasks list
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Standard benchmark evaluation
|
||||
|
||||
Evaluate model on core benchmarks (MMLU, GSM8K, HumanEval).
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
Benchmark Evaluation:
|
||||
- [ ] Step 1: Choose benchmark suite
|
||||
- [ ] Step 2: Configure model
|
||||
- [ ] Step 3: Run evaluation
|
||||
- [ ] Step 4: Analyze results
|
||||
```
|
||||
|
||||
**Step 1: Choose benchmark suite**
|
||||
|
||||
**Core reasoning benchmarks**:
|
||||
- **MMLU** (Massive Multitask Language Understanding) - 57 subjects, multiple choice
|
||||
- **GSM8K** - Grade school math word problems
|
||||
- **HellaSwag** - Common sense reasoning
|
||||
- **TruthfulQA** - Truthfulness and factuality
|
||||
- **ARC** (AI2 Reasoning Challenge) - Science questions
|
||||
|
||||
**Code benchmarks**:
|
||||
- **HumanEval** - Python code generation (164 problems)
|
||||
- **MBPP** (Mostly Basic Python Problems) - Python coding
|
||||
|
||||
**Standard suite** (recommended for model releases):
|
||||
```bash
|
||||
--tasks mmlu,gsm8k,hellaswag,truthfulqa,arc_challenge
|
||||
```
|
||||
|
||||
**Step 2: Configure model**
|
||||
|
||||
**HuggingFace model**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf,dtype=bfloat16 \
|
||||
--tasks mmlu \
|
||||
--device cuda:0 \
|
||||
--batch_size auto # Auto-detect optimal batch size
|
||||
```
|
||||
|
||||
**Quantized model (4-bit/8-bit)**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf,load_in_4bit=True \
|
||||
--tasks mmlu \
|
||||
--device cuda:0
|
||||
```
|
||||
|
||||
**Custom checkpoint**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=/path/to/my-model,tokenizer=/path/to/tokenizer \
|
||||
--tasks mmlu \
|
||||
--device cuda:0
|
||||
```
|
||||
|
||||
**Step 3: Run evaluation**
|
||||
|
||||
```bash
|
||||
# Full MMLU evaluation (57 subjects)
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks mmlu \
|
||||
--num_fewshot 5 \ # 5-shot evaluation (standard)
|
||||
--batch_size 8 \
|
||||
--output_path results/ \
|
||||
--log_samples # Save individual predictions
|
||||
|
||||
# Multiple benchmarks at once
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks mmlu,gsm8k,hellaswag,truthfulqa,arc_challenge \
|
||||
--num_fewshot 5 \
|
||||
--batch_size 8 \
|
||||
--output_path results/llama2-7b-eval.json
|
||||
```
|
||||
|
||||
**Step 4: Analyze results**
|
||||
|
||||
Results saved to `results/llama2-7b-eval.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"results": {
|
||||
"mmlu": {
|
||||
"acc": 0.459,
|
||||
"acc_stderr": 0.004
|
||||
},
|
||||
"gsm8k": {
|
||||
"exact_match": 0.142,
|
||||
"exact_match_stderr": 0.006
|
||||
},
|
||||
"hellaswag": {
|
||||
"acc_norm": 0.765,
|
||||
"acc_norm_stderr": 0.004
|
||||
}
|
||||
},
|
||||
"config": {
|
||||
"model": "hf",
|
||||
"model_args": "pretrained=meta-llama/Llama-2-7b-hf",
|
||||
"num_fewshot": 5
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Workflow 2: Track training progress
|
||||
|
||||
Evaluate checkpoints during training.
|
||||
|
||||
```
|
||||
Training Progress Tracking:
|
||||
- [ ] Step 1: Set up periodic evaluation
|
||||
- [ ] Step 2: Choose quick benchmarks
|
||||
- [ ] Step 3: Automate evaluation
|
||||
- [ ] Step 4: Plot learning curves
|
||||
```
|
||||
|
||||
**Step 1: Set up periodic evaluation**
|
||||
|
||||
Evaluate every N training steps:
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
# eval_checkpoint.sh
|
||||
|
||||
CHECKPOINT_DIR=$1
|
||||
STEP=$2
|
||||
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=$CHECKPOINT_DIR/checkpoint-$STEP \
|
||||
--tasks gsm8k,hellaswag \
|
||||
--num_fewshot 0 \ # 0-shot for speed
|
||||
--batch_size 16 \
|
||||
--output_path results/step-$STEP.json
|
||||
```
|
||||
|
||||
**Step 2: Choose quick benchmarks**
|
||||
|
||||
Fast benchmarks for frequent evaluation:
|
||||
- **HellaSwag**: ~10 minutes on 1 GPU
|
||||
- **GSM8K**: ~5 minutes
|
||||
- **PIQA**: ~2 minutes
|
||||
|
||||
Avoid for frequent eval (too slow):
|
||||
- **MMLU**: ~2 hours (57 subjects)
|
||||
- **HumanEval**: Requires code execution
|
||||
|
||||
**Step 3: Automate evaluation**
|
||||
|
||||
Integrate with training script:
|
||||
|
||||
```python
|
||||
# In training loop
|
||||
if step % eval_interval == 0:
|
||||
model.save_pretrained(f"checkpoints/step-{step}")
|
||||
|
||||
# Run evaluation
|
||||
os.system(f"./eval_checkpoint.sh checkpoints step-{step}")
|
||||
```
|
||||
|
||||
Or use PyTorch Lightning callbacks:
|
||||
|
||||
```python
|
||||
from pytorch_lightning import Callback
|
||||
|
||||
class EvalHarnessCallback(Callback):
|
||||
def on_validation_epoch_end(self, trainer, pl_module):
|
||||
step = trainer.global_step
|
||||
checkpoint_path = f"checkpoints/step-{step}"
|
||||
|
||||
# Save checkpoint
|
||||
trainer.save_checkpoint(checkpoint_path)
|
||||
|
||||
# Run lm-eval
|
||||
os.system(f"lm_eval --model hf --model_args pretrained={checkpoint_path} ...")
|
||||
```
|
||||
|
||||
**Step 4: Plot learning curves**
|
||||
|
||||
```python
|
||||
import json
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Load all results
|
||||
steps = []
|
||||
mmlu_scores = []
|
||||
|
||||
for file in sorted(glob.glob("results/step-*.json")):
|
||||
with open(file) as f:
|
||||
data = json.load(f)
|
||||
step = int(file.split("-")[1].split(".")[0])
|
||||
steps.append(step)
|
||||
mmlu_scores.append(data["results"]["mmlu"]["acc"])
|
||||
|
||||
# Plot
|
||||
plt.plot(steps, mmlu_scores)
|
||||
plt.xlabel("Training Step")
|
||||
plt.ylabel("MMLU Accuracy")
|
||||
plt.title("Training Progress")
|
||||
plt.savefig("training_curve.png")
|
||||
```
|
||||
|
||||
### Workflow 3: Compare multiple models
|
||||
|
||||
Benchmark suite for model comparison.
|
||||
|
||||
```
|
||||
Model Comparison:
|
||||
- [ ] Step 1: Define model list
|
||||
- [ ] Step 2: Run evaluations
|
||||
- [ ] Step 3: Generate comparison table
|
||||
```
|
||||
|
||||
**Step 1: Define model list**
|
||||
|
||||
```bash
|
||||
# models.txt
|
||||
meta-llama/Llama-2-7b-hf
|
||||
meta-llama/Llama-2-13b-hf
|
||||
mistralai/Mistral-7B-v0.1
|
||||
microsoft/phi-2
|
||||
```
|
||||
|
||||
**Step 2: Run evaluations**
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
# eval_all_models.sh
|
||||
|
||||
TASKS="mmlu,gsm8k,hellaswag,truthfulqa"
|
||||
|
||||
while read model; do
|
||||
echo "Evaluating $model"
|
||||
|
||||
# Extract model name for output file
|
||||
model_name=$(echo $model | sed 's/\//-/g')
|
||||
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=$model,dtype=bfloat16 \
|
||||
--tasks $TASKS \
|
||||
--num_fewshot 5 \
|
||||
--batch_size auto \
|
||||
--output_path results/$model_name.json
|
||||
|
||||
done < models.txt
|
||||
```
|
||||
|
||||
**Step 3: Generate comparison table**
|
||||
|
||||
```python
|
||||
import json
|
||||
import pandas as pd
|
||||
|
||||
models = [
|
||||
"meta-llama-Llama-2-7b-hf",
|
||||
"meta-llama-Llama-2-13b-hf",
|
||||
"mistralai-Mistral-7B-v0.1",
|
||||
"microsoft-phi-2"
|
||||
]
|
||||
|
||||
tasks = ["mmlu", "gsm8k", "hellaswag", "truthfulqa"]
|
||||
|
||||
results = []
|
||||
for model in models:
|
||||
with open(f"results/{model}.json") as f:
|
||||
data = json.load(f)
|
||||
row = {"Model": model.replace("-", "/")}
|
||||
for task in tasks:
|
||||
# Get primary metric for each task
|
||||
metrics = data["results"][task]
|
||||
if "acc" in metrics:
|
||||
row[task.upper()] = f"{metrics['acc']:.3f}"
|
||||
elif "exact_match" in metrics:
|
||||
row[task.upper()] = f"{metrics['exact_match']:.3f}"
|
||||
results.append(row)
|
||||
|
||||
df = pd.DataFrame(results)
|
||||
print(df.to_markdown(index=False))
|
||||
```
|
||||
|
||||
Output:
|
||||
```
|
||||
| Model | MMLU | GSM8K | HELLASWAG | TRUTHFULQA |
|
||||
|------------------------|-------|-------|-----------|------------|
|
||||
| meta-llama/Llama-2-7b | 0.459 | 0.142 | 0.765 | 0.391 |
|
||||
| meta-llama/Llama-2-13b | 0.549 | 0.287 | 0.801 | 0.430 |
|
||||
| mistralai/Mistral-7B | 0.626 | 0.395 | 0.812 | 0.428 |
|
||||
| microsoft/phi-2 | 0.560 | 0.613 | 0.682 | 0.447 |
|
||||
```
|
||||
|
||||
### Workflow 4: Evaluate with vLLM (faster inference)
|
||||
|
||||
Use vLLM backend for 5-10x faster evaluation.
|
||||
|
||||
```
|
||||
vLLM Evaluation:
|
||||
- [ ] Step 1: Install vLLM
|
||||
- [ ] Step 2: Configure vLLM backend
|
||||
- [ ] Step 3: Run evaluation
|
||||
```
|
||||
|
||||
**Step 1: Install vLLM**
|
||||
|
||||
```bash
|
||||
pip install vllm
|
||||
```
|
||||
|
||||
**Step 2: Configure vLLM backend**
|
||||
|
||||
```bash
|
||||
lm_eval --model vllm \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.8 \
|
||||
--tasks mmlu \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
**Step 3: Run evaluation**
|
||||
|
||||
vLLM is 5-10× faster than standard HuggingFace:
|
||||
|
||||
```bash
|
||||
# Standard HF: ~2 hours for MMLU on 7B model
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks mmlu \
|
||||
--batch_size 8
|
||||
|
||||
# vLLM: ~15-20 minutes for MMLU on 7B model
|
||||
lm_eval --model vllm \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf,tensor_parallel_size=2 \
|
||||
--tasks mmlu \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use lm-evaluation-harness when:**
|
||||
- Benchmarking models for academic papers
|
||||
- Comparing model quality across standard tasks
|
||||
- Tracking training progress
|
||||
- Reporting standardized metrics (everyone uses same prompts)
|
||||
- Need reproducible evaluation
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **HELM** (Stanford): Broader evaluation (fairness, efficiency, calibration)
|
||||
- **AlpacaEval**: Instruction-following evaluation with LLM judges
|
||||
- **MT-Bench**: Conversational multi-turn evaluation
|
||||
- **Custom scripts**: Domain-specific evaluation
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: Evaluation too slow**
|
||||
|
||||
Use vLLM backend:
|
||||
```bash
|
||||
lm_eval --model vllm \
|
||||
--model_args pretrained=model-name,tensor_parallel_size=2
|
||||
```
|
||||
|
||||
Or reduce fewshot examples:
|
||||
```bash
|
||||
--num_fewshot 0 # Instead of 5
|
||||
```
|
||||
|
||||
Or evaluate subset of MMLU:
|
||||
```bash
|
||||
--tasks mmlu_stem # Only STEM subjects
|
||||
```
|
||||
|
||||
**Issue: Out of memory**
|
||||
|
||||
Reduce batch size:
|
||||
```bash
|
||||
--batch_size 1 # Or --batch_size auto
|
||||
```
|
||||
|
||||
Use quantization:
|
||||
```bash
|
||||
--model_args pretrained=model-name,load_in_8bit=True
|
||||
```
|
||||
|
||||
Enable CPU offloading:
|
||||
```bash
|
||||
--model_args pretrained=model-name,device_map=auto,offload_folder=offload
|
||||
```
|
||||
|
||||
**Issue: Different results than reported**
|
||||
|
||||
Check fewshot count:
|
||||
```bash
|
||||
--num_fewshot 5 # Most papers use 5-shot
|
||||
```
|
||||
|
||||
Check exact task name:
|
||||
```bash
|
||||
--tasks mmlu # Not mmlu_direct or mmlu_fewshot
|
||||
```
|
||||
|
||||
Verify model and tokenizer match:
|
||||
```bash
|
||||
--model_args pretrained=model-name,tokenizer=same-model-name
|
||||
```
|
||||
|
||||
**Issue: HumanEval not executing code**
|
||||
|
||||
Install execution dependencies:
|
||||
```bash
|
||||
pip install human-eval
|
||||
```
|
||||
|
||||
Enable code execution:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=model-name \
|
||||
--tasks humaneval \
|
||||
--allow_code_execution # Required for HumanEval
|
||||
```
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**Benchmark descriptions**: See [references/benchmark-guide.md](references/benchmark-guide.md) for detailed description of all 60+ tasks, what they measure, and interpretation.
|
||||
|
||||
**Custom tasks**: See [references/custom-tasks.md](references/custom-tasks.md) for creating domain-specific evaluation tasks.
|
||||
|
||||
**API evaluation**: See [references/api-evaluation.md](references/api-evaluation.md) for evaluating OpenAI, Anthropic, and other API models.
|
||||
|
||||
**Multi-GPU strategies**: See [references/distributed-eval.md](references/distributed-eval.md) for data parallel and tensor parallel evaluation.
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
- **GPU**: NVIDIA (CUDA 11.8+), works on CPU (very slow)
|
||||
- **VRAM**:
|
||||
- 7B model: 16GB (bf16) or 8GB (8-bit)
|
||||
- 13B model: 28GB (bf16) or 14GB (8-bit)
|
||||
- 70B model: Requires multi-GPU or quantization
|
||||
- **Time** (7B model, single A100):
|
||||
- HellaSwag: 10 minutes
|
||||
- GSM8K: 5 minutes
|
||||
- MMLU (full): 2 hours
|
||||
- HumanEval: 20 minutes
|
||||
|
||||
## Resources
|
||||
|
||||
- GitHub: https://github.com/EleutherAI/lm-evaluation-harness
|
||||
- Docs: https://github.com/EleutherAI/lm-evaluation-harness/tree/main/docs
|
||||
- Task library: 60+ tasks including MMLU, GSM8K, HumanEval, TruthfulQA, HellaSwag, ARC, WinoGrande, etc.
|
||||
- Leaderboard: https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard (uses this harness)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,490 @@
|
||||
# API Evaluation
|
||||
|
||||
Guide to evaluating OpenAI, Anthropic, and other API-based language models.
|
||||
|
||||
## Overview
|
||||
|
||||
The lm-evaluation-harness supports evaluating API-based models through a unified `TemplateAPI` interface. This allows benchmarking of:
|
||||
- OpenAI models (GPT-4, GPT-3.5, etc.)
|
||||
- Anthropic models (Claude 3, Claude 2, etc.)
|
||||
- Local OpenAI-compatible APIs
|
||||
- Custom API endpoints
|
||||
|
||||
**Why evaluate API models**:
|
||||
- Benchmark closed-source models
|
||||
- Compare API models to open models
|
||||
- Validate API performance
|
||||
- Track model updates over time
|
||||
|
||||
## Supported API Models
|
||||
|
||||
| Provider | Model Type | Request Types | Logprobs |
|
||||
|----------|------------|---------------|----------|
|
||||
| OpenAI (completions) | `openai-completions` | All | ✅ Yes |
|
||||
| OpenAI (chat) | `openai-chat-completions` | `generate_until` only | ❌ No |
|
||||
| Anthropic (completions) | `anthropic-completions` | All | ❌ No |
|
||||
| Anthropic (chat) | `anthropic-chat` | `generate_until` only | ❌ No |
|
||||
| Local (OpenAI-compatible) | `local-completions` | Depends on server | Varies |
|
||||
|
||||
**Note**: Models without logprobs can only be evaluated on generation tasks, not perplexity or loglikelihood tasks.
|
||||
|
||||
## OpenAI Models
|
||||
|
||||
### Setup
|
||||
|
||||
```bash
|
||||
export OPENAI_API_KEY=sk-...
|
||||
```
|
||||
|
||||
### Completion Models (Legacy)
|
||||
|
||||
**Available models**: `davinci-002`, `babbage-002`
|
||||
|
||||
```bash
|
||||
lm_eval --model openai-completions \
|
||||
--model_args model=davinci-002 \
|
||||
--tasks lambada_openai,hellaswag \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
**Supports**:
|
||||
- `generate_until`: ✅
|
||||
- `loglikelihood`: ✅
|
||||
- `loglikelihood_rolling`: ✅
|
||||
|
||||
### Chat Models
|
||||
|
||||
**Available models**: `gpt-4`, `gpt-4-turbo`, `gpt-3.5-turbo`
|
||||
|
||||
```bash
|
||||
lm_eval --model openai-chat-completions \
|
||||
--model_args model=gpt-4-turbo \
|
||||
--tasks mmlu,gsm8k,humaneval \
|
||||
--num_fewshot 5 \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
**Supports**:
|
||||
- `generate_until`: ✅
|
||||
- `loglikelihood`: ❌ (no logprobs)
|
||||
- `loglikelihood_rolling`: ❌
|
||||
|
||||
**Important**: Chat models don't provide logprobs, so they can only be used with generation tasks (MMLU, GSM8K, HumanEval), not perplexity tasks.
|
||||
|
||||
### Configuration Options
|
||||
|
||||
```bash
|
||||
lm_eval --model openai-chat-completions \
|
||||
--model_args \
|
||||
model=gpt-4-turbo,\
|
||||
base_url=https://api.openai.com/v1,\
|
||||
num_concurrent=5,\
|
||||
max_retries=3,\
|
||||
timeout=60,\
|
||||
batch_size=auto
|
||||
```
|
||||
|
||||
**Parameters**:
|
||||
- `model`: Model identifier (required)
|
||||
- `base_url`: API endpoint (default: OpenAI)
|
||||
- `num_concurrent`: Concurrent requests (default: 5)
|
||||
- `max_retries`: Retry failed requests (default: 3)
|
||||
- `timeout`: Request timeout in seconds (default: 60)
|
||||
- `tokenizer`: Tokenizer to use (default: matches model)
|
||||
- `tokenizer_backend`: `"tiktoken"` or `"huggingface"`
|
||||
|
||||
### Cost Management
|
||||
|
||||
OpenAI charges per token. Estimate costs before running:
|
||||
|
||||
```python
|
||||
# Rough estimate
|
||||
num_samples = 1000
|
||||
avg_tokens_per_sample = 500 # input + output
|
||||
cost_per_1k_tokens = 0.01 # GPT-3.5 Turbo
|
||||
|
||||
total_cost = (num_samples * avg_tokens_per_sample / 1000) * cost_per_1k_tokens
|
||||
print(f"Estimated cost: ${total_cost:.2f}")
|
||||
```
|
||||
|
||||
**Cost-saving tips**:
|
||||
- Use `--limit N` for testing
|
||||
- Start with `gpt-3.5-turbo` before `gpt-4`
|
||||
- Set `max_gen_toks` to minimum needed
|
||||
- Use `num_fewshot=0` for zero-shot when possible
|
||||
|
||||
## Anthropic Models
|
||||
|
||||
### Setup
|
||||
|
||||
```bash
|
||||
export ANTHROPIC_API_KEY=sk-ant-...
|
||||
```
|
||||
|
||||
### Completion Models (Legacy)
|
||||
|
||||
```bash
|
||||
lm_eval --model anthropic-completions \
|
||||
--model_args model=claude-2.1 \
|
||||
--tasks lambada_openai,hellaswag \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
### Chat Models (Recommended)
|
||||
|
||||
**Available models**: `claude-3-5-sonnet-20241022`, `claude-3-opus-20240229`, `claude-3-sonnet-20240229`, `claude-3-haiku-20240307`
|
||||
|
||||
```bash
|
||||
lm_eval --model anthropic-chat \
|
||||
--model_args model=claude-3-5-sonnet-20241022 \
|
||||
--tasks mmlu,gsm8k,humaneval \
|
||||
--num_fewshot 5 \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
**Aliases**: `anthropic-chat-completions` (same as `anthropic-chat`)
|
||||
|
||||
### Configuration Options
|
||||
|
||||
```bash
|
||||
lm_eval --model anthropic-chat \
|
||||
--model_args \
|
||||
model=claude-3-5-sonnet-20241022,\
|
||||
base_url=https://api.anthropic.com,\
|
||||
num_concurrent=5,\
|
||||
max_retries=3,\
|
||||
timeout=60
|
||||
```
|
||||
|
||||
### Cost Management
|
||||
|
||||
Anthropic pricing (as of 2024):
|
||||
- Claude 3.5 Sonnet: $3.00 / 1M input, $15.00 / 1M output
|
||||
- Claude 3 Opus: $15.00 / 1M input, $75.00 / 1M output
|
||||
- Claude 3 Haiku: $0.25 / 1M input, $1.25 / 1M output
|
||||
|
||||
**Budget-friendly strategy**:
|
||||
```bash
|
||||
# Test on small sample first
|
||||
lm_eval --model anthropic-chat \
|
||||
--model_args model=claude-3-haiku-20240307 \
|
||||
--tasks mmlu \
|
||||
--limit 100
|
||||
|
||||
# Then run full eval on best model
|
||||
lm_eval --model anthropic-chat \
|
||||
--model_args model=claude-3-5-sonnet-20241022 \
|
||||
--tasks mmlu \
|
||||
--num_fewshot 5
|
||||
```
|
||||
|
||||
## Local OpenAI-Compatible APIs
|
||||
|
||||
Many local inference servers expose OpenAI-compatible APIs (vLLM, Text Generation Inference, llama.cpp, Ollama).
|
||||
|
||||
### vLLM Local Server
|
||||
|
||||
**Start server**:
|
||||
```bash
|
||||
vllm serve meta-llama/Llama-2-7b-hf \
|
||||
--host 0.0.0.0 \
|
||||
--port 8000
|
||||
```
|
||||
|
||||
**Evaluate**:
|
||||
```bash
|
||||
lm_eval --model local-completions \
|
||||
--model_args \
|
||||
model=meta-llama/Llama-2-7b-hf,\
|
||||
base_url=http://localhost:8000/v1,\
|
||||
num_concurrent=1 \
|
||||
--tasks mmlu,gsm8k \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
### Text Generation Inference (TGI)
|
||||
|
||||
**Start server**:
|
||||
```bash
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 \
|
||||
ghcr.io/huggingface/text-generation-inference:latest \
|
||||
--model-id meta-llama/Llama-2-7b-hf
|
||||
```
|
||||
|
||||
**Evaluate**:
|
||||
```bash
|
||||
lm_eval --model local-completions \
|
||||
--model_args \
|
||||
model=meta-llama/Llama-2-7b-hf,\
|
||||
base_url=http://localhost:8080/v1 \
|
||||
--tasks hellaswag,arc_challenge
|
||||
```
|
||||
|
||||
### Ollama
|
||||
|
||||
**Start server**:
|
||||
```bash
|
||||
ollama serve
|
||||
ollama pull llama2:7b
|
||||
```
|
||||
|
||||
**Evaluate**:
|
||||
```bash
|
||||
lm_eval --model local-completions \
|
||||
--model_args \
|
||||
model=llama2:7b,\
|
||||
base_url=http://localhost:11434/v1 \
|
||||
--tasks mmlu
|
||||
```
|
||||
|
||||
### llama.cpp Server
|
||||
|
||||
**Start server**:
|
||||
```bash
|
||||
./server -m models/llama-2-7b.gguf --host 0.0.0.0 --port 8080
|
||||
```
|
||||
|
||||
**Evaluate**:
|
||||
```bash
|
||||
lm_eval --model local-completions \
|
||||
--model_args \
|
||||
model=llama2,\
|
||||
base_url=http://localhost:8080/v1 \
|
||||
--tasks gsm8k
|
||||
```
|
||||
|
||||
## Custom API Implementation
|
||||
|
||||
For custom API endpoints, subclass `TemplateAPI`:
|
||||
|
||||
### Create `my_api.py`
|
||||
|
||||
```python
|
||||
from lm_eval.models.api_models import TemplateAPI
|
||||
import requests
|
||||
|
||||
class MyCustomAPI(TemplateAPI):
|
||||
"""Custom API model."""
|
||||
|
||||
def __init__(self, base_url, api_key, **kwargs):
|
||||
super().__init__(base_url=base_url, **kwargs)
|
||||
self.api_key = api_key
|
||||
|
||||
def _create_payload(self, messages, gen_kwargs):
|
||||
"""Create API request payload."""
|
||||
return {
|
||||
"messages": messages,
|
||||
"api_key": self.api_key,
|
||||
**gen_kwargs
|
||||
}
|
||||
|
||||
def parse_generations(self, response):
|
||||
"""Parse generation response."""
|
||||
return response.json()["choices"][0]["text"]
|
||||
|
||||
def parse_logprobs(self, response):
|
||||
"""Parse logprobs (if available)."""
|
||||
# Return None if API doesn't provide logprobs
|
||||
logprobs = response.json().get("logprobs")
|
||||
if logprobs:
|
||||
return logprobs["token_logprobs"]
|
||||
return None
|
||||
```
|
||||
|
||||
### Register and Use
|
||||
|
||||
```python
|
||||
from lm_eval import evaluator
|
||||
from my_api import MyCustomAPI
|
||||
|
||||
model = MyCustomAPI(
|
||||
base_url="https://api.example.com/v1",
|
||||
api_key="your-key"
|
||||
)
|
||||
|
||||
results = evaluator.simple_evaluate(
|
||||
model=model,
|
||||
tasks=["mmlu", "gsm8k"],
|
||||
num_fewshot=5,
|
||||
batch_size="auto"
|
||||
)
|
||||
```
|
||||
|
||||
## Comparing API and Open Models
|
||||
|
||||
### Side-by-Side Evaluation
|
||||
|
||||
```bash
|
||||
# Evaluate OpenAI GPT-4
|
||||
lm_eval --model openai-chat-completions \
|
||||
--model_args model=gpt-4-turbo \
|
||||
--tasks mmlu,gsm8k,hellaswag \
|
||||
--num_fewshot 5 \
|
||||
--output_path results/gpt4.json
|
||||
|
||||
# Evaluate open Llama 2 70B
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-70b-hf,dtype=bfloat16 \
|
||||
--tasks mmlu,gsm8k,hellaswag \
|
||||
--num_fewshot 5 \
|
||||
--output_path results/llama2-70b.json
|
||||
|
||||
# Compare results
|
||||
python scripts/compare_results.py \
|
||||
results/gpt4.json \
|
||||
results/llama2-70b.json
|
||||
```
|
||||
|
||||
### Typical Comparisons
|
||||
|
||||
| Model | MMLU | GSM8K | HumanEval | Cost |
|
||||
|-------|------|-------|-----------|------|
|
||||
| GPT-4 Turbo | 86.4% | 92.0% | 67.0% | $$$$ |
|
||||
| Claude 3 Opus | 86.8% | 95.0% | 84.9% | $$$$ |
|
||||
| GPT-3.5 Turbo | 70.0% | 57.1% | 48.1% | $$ |
|
||||
| Llama 2 70B | 68.9% | 56.8% | 29.9% | Free (self-host) |
|
||||
| Mixtral 8x7B | 70.6% | 58.4% | 40.2% | Free (self-host) |
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Rate Limiting
|
||||
|
||||
Respect API rate limits:
|
||||
```bash
|
||||
lm_eval --model openai-chat-completions \
|
||||
--model_args \
|
||||
model=gpt-4-turbo,\
|
||||
num_concurrent=3,\ # Lower concurrency
|
||||
timeout=120 \ # Longer timeout
|
||||
--tasks mmlu
|
||||
```
|
||||
|
||||
### Reproducibility
|
||||
|
||||
Set temperature to 0 for deterministic results:
|
||||
```bash
|
||||
lm_eval --model openai-chat-completions \
|
||||
--model_args model=gpt-4-turbo \
|
||||
--tasks mmlu \
|
||||
--gen_kwargs temperature=0.0
|
||||
```
|
||||
|
||||
Or use `seed` for sampling:
|
||||
```bash
|
||||
lm_eval --model anthropic-chat \
|
||||
--model_args model=claude-3-5-sonnet-20241022 \
|
||||
--tasks gsm8k \
|
||||
--gen_kwargs temperature=0.7,seed=42
|
||||
```
|
||||
|
||||
### Caching
|
||||
|
||||
API models automatically cache responses to avoid redundant calls:
|
||||
```bash
|
||||
# First run: makes API calls
|
||||
lm_eval --model openai-chat-completions \
|
||||
--model_args model=gpt-4-turbo \
|
||||
--tasks mmlu \
|
||||
--limit 100
|
||||
|
||||
# Second run: uses cache (instant, free)
|
||||
lm_eval --model openai-chat-completions \
|
||||
--model_args model=gpt-4-turbo \
|
||||
--tasks mmlu \
|
||||
--limit 100
|
||||
```
|
||||
|
||||
Cache location: `~/.cache/lm_eval/`
|
||||
|
||||
### Error Handling
|
||||
|
||||
APIs can fail. Use retries:
|
||||
```bash
|
||||
lm_eval --model openai-chat-completions \
|
||||
--model_args \
|
||||
model=gpt-4-turbo,\
|
||||
max_retries=5,\
|
||||
timeout=120 \
|
||||
--tasks mmlu
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "Authentication failed"
|
||||
|
||||
Check API key:
|
||||
```bash
|
||||
echo $OPENAI_API_KEY # Should print sk-...
|
||||
echo $ANTHROPIC_API_KEY # Should print sk-ant-...
|
||||
```
|
||||
|
||||
### "Rate limit exceeded"
|
||||
|
||||
Reduce concurrency:
|
||||
```bash
|
||||
--model_args num_concurrent=1
|
||||
```
|
||||
|
||||
Or add delays between requests.
|
||||
|
||||
### "Timeout error"
|
||||
|
||||
Increase timeout:
|
||||
```bash
|
||||
--model_args timeout=180
|
||||
```
|
||||
|
||||
### "Model not found"
|
||||
|
||||
For local APIs, verify server is running:
|
||||
```bash
|
||||
curl http://localhost:8000/v1/models
|
||||
```
|
||||
|
||||
### Cost Runaway
|
||||
|
||||
Use `--limit` for testing:
|
||||
```bash
|
||||
lm_eval --model openai-chat-completions \
|
||||
--model_args model=gpt-4-turbo \
|
||||
--tasks mmlu \
|
||||
--limit 50 # Only 50 samples
|
||||
```
|
||||
|
||||
## Advanced Features
|
||||
|
||||
### Custom Headers
|
||||
|
||||
```bash
|
||||
lm_eval --model local-completions \
|
||||
--model_args \
|
||||
base_url=http://api.example.com/v1,\
|
||||
header="Authorization: Bearer token,X-Custom: value"
|
||||
```
|
||||
|
||||
### Disable SSL Verification (Development Only)
|
||||
|
||||
```bash
|
||||
lm_eval --model local-completions \
|
||||
--model_args \
|
||||
base_url=https://localhost:8000/v1,\
|
||||
verify_certificate=false
|
||||
```
|
||||
|
||||
### Custom Tokenizer
|
||||
|
||||
```bash
|
||||
lm_eval --model openai-chat-completions \
|
||||
--model_args \
|
||||
model=gpt-4-turbo,\
|
||||
tokenizer=gpt2,\
|
||||
tokenizer_backend=huggingface
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- OpenAI API: https://platform.openai.com/docs/api-reference
|
||||
- Anthropic API: https://docs.anthropic.com/claude/reference
|
||||
- TemplateAPI: `lm_eval/models/api_models.py`
|
||||
- OpenAI models: `lm_eval/models/openai_completions.py`
|
||||
- Anthropic models: `lm_eval/models/anthropic_llms.py`
|
||||
@@ -0,0 +1,488 @@
|
||||
# Benchmark Guide
|
||||
|
||||
Complete guide to all 60+ evaluation tasks in lm-evaluation-harness, what they measure, and how to interpret results.
|
||||
|
||||
## Overview
|
||||
|
||||
The lm-evaluation-harness includes 60+ benchmarks spanning:
|
||||
- Language understanding (MMLU, GLUE)
|
||||
- Mathematical reasoning (GSM8K, MATH)
|
||||
- Code generation (HumanEval, MBPP)
|
||||
- Instruction following (IFEval, AlpacaEval)
|
||||
- Long-context understanding (LongBench)
|
||||
- Multilingual capabilities (AfroBench, NorEval)
|
||||
- Reasoning (BBH, ARC)
|
||||
- Truthfulness (TruthfulQA)
|
||||
|
||||
**List all tasks**:
|
||||
```bash
|
||||
lm_eval --tasks list
|
||||
```
|
||||
|
||||
## Major Benchmarks
|
||||
|
||||
### MMLU (Massive Multitask Language Understanding)
|
||||
|
||||
**What it measures**: Broad knowledge across 57 subjects (STEM, humanities, social sciences, law).
|
||||
|
||||
**Task variants**:
|
||||
- `mmlu`: Original 57-subject benchmark
|
||||
- `mmlu_pro`: More challenging version with reasoning-focused questions
|
||||
- `mmlu_prox`: Multilingual extension
|
||||
|
||||
**Format**: Multiple choice (4 options)
|
||||
|
||||
**Example**:
|
||||
```
|
||||
Question: What is the capital of France?
|
||||
A. Berlin
|
||||
B. Paris
|
||||
C. London
|
||||
D. Madrid
|
||||
Answer: B
|
||||
```
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks mmlu \
|
||||
--num_fewshot 5
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- Random: 25% (chance)
|
||||
- GPT-3 (175B): 43.9%
|
||||
- GPT-4: 86.4%
|
||||
- Human expert: ~90%
|
||||
|
||||
**Good for**: Assessing general knowledge and domain expertise.
|
||||
|
||||
### GSM8K (Grade School Math 8K)
|
||||
|
||||
**What it measures**: Mathematical reasoning on grade-school level word problems.
|
||||
|
||||
**Task variants**:
|
||||
- `gsm8k`: Base task
|
||||
- `gsm8k_cot`: With chain-of-thought prompting
|
||||
- `gsm_plus`: Adversarial variant with perturbations
|
||||
|
||||
**Format**: Free-form generation, extract numerical answer
|
||||
|
||||
**Example**:
|
||||
```
|
||||
Question: A baker made 200 cookies. He sold 3/5 of them in the morning and 1/4 of the remaining in the afternoon. How many cookies does he have left?
|
||||
Answer: 60
|
||||
```
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks gsm8k \
|
||||
--num_fewshot 5
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- Random: ~0%
|
||||
- GPT-3 (175B): 17.0%
|
||||
- GPT-4: 92.0%
|
||||
- Llama 2 70B: 56.8%
|
||||
|
||||
**Good for**: Testing multi-step reasoning and arithmetic.
|
||||
|
||||
### HumanEval
|
||||
|
||||
**What it measures**: Python code generation from docstrings (functional correctness).
|
||||
|
||||
**Task variants**:
|
||||
- `humaneval`: Standard benchmark
|
||||
- `humaneval_instruct`: For instruction-tuned models
|
||||
|
||||
**Format**: Code generation, execution-based evaluation
|
||||
|
||||
**Example**:
|
||||
```python
|
||||
def has_close_elements(numbers: List[float], threshold: float) -> bool:
|
||||
""" Check if in given list of numbers, are any two numbers closer to each other than
|
||||
given threshold.
|
||||
>>> has_close_elements([1.0, 2.0, 3.0], 0.5)
|
||||
False
|
||||
>>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)
|
||||
True
|
||||
"""
|
||||
```
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=codellama/CodeLlama-7b-hf \
|
||||
--tasks humaneval \
|
||||
--batch_size 1
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- Random: 0%
|
||||
- GPT-3 (175B): 0%
|
||||
- Codex: 28.8%
|
||||
- GPT-4: 67.0%
|
||||
- Code Llama 34B: 53.7%
|
||||
|
||||
**Good for**: Evaluating code generation capabilities.
|
||||
|
||||
### BBH (BIG-Bench Hard)
|
||||
|
||||
**What it measures**: 23 challenging reasoning tasks where models previously failed to beat humans.
|
||||
|
||||
**Categories**:
|
||||
- Logical reasoning
|
||||
- Math word problems
|
||||
- Social understanding
|
||||
- Algorithmic reasoning
|
||||
|
||||
**Format**: Multiple choice and free-form
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks bbh \
|
||||
--num_fewshot 3
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- Random: ~25%
|
||||
- GPT-3 (175B): 33.9%
|
||||
- PaLM 540B: 58.3%
|
||||
- GPT-4: 86.7%
|
||||
|
||||
**Good for**: Testing advanced reasoning capabilities.
|
||||
|
||||
### IFEval (Instruction-Following Evaluation)
|
||||
|
||||
**What it measures**: Ability to follow specific, verifiable instructions.
|
||||
|
||||
**Instruction types**:
|
||||
- Format constraints (e.g., "answer in 3 sentences")
|
||||
- Length constraints (e.g., "use at least 100 words")
|
||||
- Content constraints (e.g., "include the word 'banana'")
|
||||
- Structural constraints (e.g., "use bullet points")
|
||||
|
||||
**Format**: Free-form generation with rule-based verification
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-chat-hf \
|
||||
--tasks ifeval \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- Measures: Instruction adherence (not quality)
|
||||
- GPT-4: 86% instruction following
|
||||
- Claude 2: 84%
|
||||
|
||||
**Good for**: Evaluating chat/instruct models.
|
||||
|
||||
### GLUE (General Language Understanding Evaluation)
|
||||
|
||||
**What it measures**: Natural language understanding across 9 tasks.
|
||||
|
||||
**Tasks**:
|
||||
- `cola`: Grammatical acceptability
|
||||
- `sst2`: Sentiment analysis
|
||||
- `mrpc`: Paraphrase detection
|
||||
- `qqp`: Question pairs
|
||||
- `stsb`: Semantic similarity
|
||||
- `mnli`: Natural language inference
|
||||
- `qnli`: Question answering NLI
|
||||
- `rte`: Recognizing textual entailment
|
||||
- `wnli`: Winograd schemas
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=bert-base-uncased \
|
||||
--tasks glue \
|
||||
--num_fewshot 0
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- BERT Base: 78.3 (GLUE score)
|
||||
- RoBERTa Large: 88.5
|
||||
- Human baseline: 87.1
|
||||
|
||||
**Good for**: Encoder-only models, fine-tuning baselines.
|
||||
|
||||
### LongBench
|
||||
|
||||
**What it measures**: Long-context understanding (4K-32K tokens).
|
||||
|
||||
**21 tasks covering**:
|
||||
- Single-document QA
|
||||
- Multi-document QA
|
||||
- Summarization
|
||||
- Few-shot learning
|
||||
- Code completion
|
||||
- Synthetic tasks
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks longbench \
|
||||
--batch_size 1
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- Tests context utilization
|
||||
- Many models struggle beyond 4K tokens
|
||||
- GPT-4 Turbo: 54.3%
|
||||
|
||||
**Good for**: Evaluating long-context models.
|
||||
|
||||
## Additional Benchmarks
|
||||
|
||||
### TruthfulQA
|
||||
|
||||
**What it measures**: Model's propensity to be truthful vs. generate plausible-sounding falsehoods.
|
||||
|
||||
**Format**: Multiple choice with 4-5 options
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks truthfulqa_mc2 \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- Larger models often score worse (more convincing lies)
|
||||
- GPT-3: 58.8%
|
||||
- GPT-4: 59.0%
|
||||
- Human: ~94%
|
||||
|
||||
### ARC (AI2 Reasoning Challenge)
|
||||
|
||||
**What it measures**: Grade-school science questions.
|
||||
|
||||
**Variants**:
|
||||
- `arc_easy`: Easier questions
|
||||
- `arc_challenge`: Harder questions requiring reasoning
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks arc_challenge \
|
||||
--num_fewshot 25
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- ARC-Easy: Most models >80%
|
||||
- ARC-Challenge random: 25%
|
||||
- GPT-4: 96.3%
|
||||
|
||||
### HellaSwag
|
||||
|
||||
**What it measures**: Commonsense reasoning about everyday situations.
|
||||
|
||||
**Format**: Choose most plausible continuation
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks hellaswag \
|
||||
--num_fewshot 10
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- Random: 25%
|
||||
- GPT-3: 78.9%
|
||||
- Llama 2 70B: 85.3%
|
||||
|
||||
### WinoGrande
|
||||
|
||||
**What it measures**: Commonsense reasoning via pronoun resolution.
|
||||
|
||||
**Example**:
|
||||
```
|
||||
The trophy doesn't fit in the brown suitcase because _ is too large.
|
||||
A. the trophy
|
||||
B. the suitcase
|
||||
```
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks winogrande \
|
||||
--num_fewshot 5
|
||||
```
|
||||
|
||||
### PIQA
|
||||
|
||||
**What it measures**: Physical commonsense reasoning.
|
||||
|
||||
**Example**: "To clean a keyboard, use compressed air or..."
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks piqa
|
||||
```
|
||||
|
||||
## Multilingual Benchmarks
|
||||
|
||||
### AfroBench
|
||||
|
||||
**What it measures**: Performance across 64 African languages.
|
||||
|
||||
**15 tasks**: NLU, text generation, knowledge, QA, math reasoning
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks afrobench
|
||||
```
|
||||
|
||||
### NorEval
|
||||
|
||||
**What it measures**: Norwegian language understanding (9 task categories).
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=NbAiLab/nb-gpt-j-6B \
|
||||
--tasks noreval
|
||||
```
|
||||
|
||||
## Domain-Specific Benchmarks
|
||||
|
||||
### MATH
|
||||
|
||||
**What it measures**: High-school competition math problems.
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks math \
|
||||
--num_fewshot 4
|
||||
```
|
||||
|
||||
**Interpretation**:
|
||||
- Very challenging
|
||||
- GPT-4: 42.5%
|
||||
- Minerva 540B: 33.6%
|
||||
|
||||
### MBPP (Mostly Basic Python Problems)
|
||||
|
||||
**What it measures**: Python programming from natural language descriptions.
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=codellama/CodeLlama-7b-hf \
|
||||
--tasks mbpp \
|
||||
--batch_size 1
|
||||
```
|
||||
|
||||
### DROP
|
||||
|
||||
**What it measures**: Reading comprehension requiring discrete reasoning.
|
||||
|
||||
**Command**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks drop
|
||||
```
|
||||
|
||||
## Benchmark Selection Guide
|
||||
|
||||
### For General Purpose Models
|
||||
|
||||
Run this suite:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks mmlu,gsm8k,hellaswag,arc_challenge,truthfulqa_mc2 \
|
||||
--num_fewshot 5
|
||||
```
|
||||
|
||||
### For Code Models
|
||||
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=codellama/CodeLlama-7b-hf \
|
||||
--tasks humaneval,mbpp \
|
||||
--batch_size 1
|
||||
```
|
||||
|
||||
### For Chat/Instruct Models
|
||||
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-chat-hf \
|
||||
--tasks ifeval,mmlu,gsm8k_cot \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
### For Long Context Models
|
||||
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-3.1-8B \
|
||||
--tasks longbench \
|
||||
--batch_size 1
|
||||
```
|
||||
|
||||
## Interpreting Results
|
||||
|
||||
### Understanding Metrics
|
||||
|
||||
**Accuracy**: Percentage of correct answers (most common)
|
||||
|
||||
**Exact Match (EM)**: Requires exact string match (strict)
|
||||
|
||||
**F1 Score**: Balances precision and recall
|
||||
|
||||
**BLEU/ROUGE**: Text generation similarity
|
||||
|
||||
**Pass@k**: Percentage passing when generating k samples
|
||||
|
||||
### Typical Score Ranges
|
||||
|
||||
| Model Size | MMLU | GSM8K | HumanEval | HellaSwag |
|
||||
|------------|------|-------|-----------|-----------|
|
||||
| 7B | 40-50% | 10-20% | 5-15% | 70-80% |
|
||||
| 13B | 45-55% | 20-35% | 15-25% | 75-82% |
|
||||
| 70B | 60-70% | 50-65% | 35-50% | 82-87% |
|
||||
| GPT-4 | 86% | 92% | 67% | 95% |
|
||||
|
||||
### Red Flags
|
||||
|
||||
- **All tasks at random chance**: Model not trained properly
|
||||
- **Exact 0% on generation tasks**: Likely format/parsing issue
|
||||
- **Huge variance across runs**: Check seed/sampling settings
|
||||
- **Better than GPT-4 on everything**: Likely contamination
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Always report few-shot setting**: 0-shot, 5-shot, etc.
|
||||
2. **Run multiple seeds**: Report mean ± std
|
||||
3. **Check for data contamination**: Search training data for benchmark examples
|
||||
4. **Compare to published baselines**: Validate your setup
|
||||
5. **Report all hyperparameters**: Model, batch size, max tokens, temperature
|
||||
|
||||
## References
|
||||
|
||||
- Task list: `lm_eval --tasks list`
|
||||
- Task README: `lm_eval/tasks/README.md`
|
||||
- Papers: See individual benchmark papers
|
||||
@@ -0,0 +1,602 @@
|
||||
# Custom Tasks
|
||||
|
||||
Complete guide to creating domain-specific evaluation tasks in lm-evaluation-harness.
|
||||
|
||||
## Overview
|
||||
|
||||
Custom tasks allow you to evaluate models on your own datasets and metrics. Tasks are defined using YAML configuration files with optional Python utilities for complex logic.
|
||||
|
||||
**Why create custom tasks**:
|
||||
- Evaluate on proprietary/domain-specific data
|
||||
- Test specific capabilities not covered by existing benchmarks
|
||||
- Create evaluation pipelines for internal models
|
||||
- Reproduce research experiments
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Minimal Custom Task
|
||||
|
||||
Create `my_tasks/simple_qa.yaml`:
|
||||
|
||||
```yaml
|
||||
task: simple_qa
|
||||
dataset_path: data/simple_qa.jsonl
|
||||
output_type: generate_until
|
||||
doc_to_text: "Question: {{question}}\nAnswer:"
|
||||
doc_to_target: "{{answer}}"
|
||||
metric_list:
|
||||
- metric: exact_match
|
||||
aggregation: mean
|
||||
higher_is_better: true
|
||||
```
|
||||
|
||||
**Run it**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks simple_qa \
|
||||
--include_path my_tasks/
|
||||
```
|
||||
|
||||
## Task Configuration Reference
|
||||
|
||||
### Essential Fields
|
||||
|
||||
```yaml
|
||||
# Task identification
|
||||
task: my_custom_task # Unique task name (required)
|
||||
task_alias: "My Task" # Display name
|
||||
tag: # Tags for grouping
|
||||
- custom
|
||||
- domain_specific
|
||||
|
||||
# Dataset configuration
|
||||
dataset_path: data/my_data.jsonl # HuggingFace dataset or local path
|
||||
dataset_name: default # Subset name (if applicable)
|
||||
training_split: train
|
||||
validation_split: validation
|
||||
test_split: test
|
||||
|
||||
# Evaluation configuration
|
||||
output_type: generate_until # or loglikelihood, multiple_choice
|
||||
num_fewshot: 5 # Number of few-shot examples
|
||||
batch_size: auto # Batch size
|
||||
|
||||
# Prompt templates (Jinja2)
|
||||
doc_to_text: "Question: {{question}}"
|
||||
doc_to_target: "{{answer}}"
|
||||
|
||||
# Metrics
|
||||
metric_list:
|
||||
- metric: exact_match
|
||||
aggregation: mean
|
||||
higher_is_better: true
|
||||
|
||||
# Metadata
|
||||
metadata:
|
||||
version: 1.0
|
||||
```
|
||||
|
||||
### Output Types
|
||||
|
||||
**`generate_until`**: Free-form generation
|
||||
```yaml
|
||||
output_type: generate_until
|
||||
generation_kwargs:
|
||||
max_gen_toks: 256
|
||||
until:
|
||||
- "\n"
|
||||
- "."
|
||||
temperature: 0.0
|
||||
```
|
||||
|
||||
**`loglikelihood`**: Compute log probability of targets
|
||||
```yaml
|
||||
output_type: loglikelihood
|
||||
# Used for perplexity, classification
|
||||
```
|
||||
|
||||
**`multiple_choice`**: Choose from options
|
||||
```yaml
|
||||
output_type: multiple_choice
|
||||
doc_to_choice: "{{choices}}" # List of choices
|
||||
```
|
||||
|
||||
## Data Formats
|
||||
|
||||
### Local JSONL File
|
||||
|
||||
`data/my_data.jsonl`:
|
||||
```json
|
||||
{"question": "What is 2+2?", "answer": "4"}
|
||||
{"question": "Capital of France?", "answer": "Paris"}
|
||||
```
|
||||
|
||||
**Task config**:
|
||||
```yaml
|
||||
dataset_path: data/my_data.jsonl
|
||||
dataset_kwargs:
|
||||
data_files:
|
||||
test: data/my_data.jsonl
|
||||
```
|
||||
|
||||
### HuggingFace Dataset
|
||||
|
||||
```yaml
|
||||
dataset_path: squad
|
||||
dataset_name: plain_text
|
||||
test_split: validation
|
||||
```
|
||||
|
||||
### CSV File
|
||||
|
||||
`data/my_data.csv`:
|
||||
```csv
|
||||
question,answer,category
|
||||
What is 2+2?,4,math
|
||||
Capital of France?,Paris,geography
|
||||
```
|
||||
|
||||
**Task config**:
|
||||
```yaml
|
||||
dataset_path: data/my_data.csv
|
||||
dataset_kwargs:
|
||||
data_files:
|
||||
test: data/my_data.csv
|
||||
```
|
||||
|
||||
## Prompt Engineering
|
||||
|
||||
### Simple Template
|
||||
|
||||
```yaml
|
||||
doc_to_text: "Question: {{question}}\nAnswer:"
|
||||
doc_to_target: "{{answer}}"
|
||||
```
|
||||
|
||||
### Conditional Logic
|
||||
|
||||
```yaml
|
||||
doc_to_text: |
|
||||
{% if context %}
|
||||
Context: {{context}}
|
||||
{% endif %}
|
||||
Question: {{question}}
|
||||
Answer:
|
||||
```
|
||||
|
||||
### Multiple Choice
|
||||
|
||||
```yaml
|
||||
doc_to_text: |
|
||||
Question: {{question}}
|
||||
A. {{choices[0]}}
|
||||
B. {{choices[1]}}
|
||||
C. {{choices[2]}}
|
||||
D. {{choices[3]}}
|
||||
Answer:
|
||||
|
||||
doc_to_target: "{{ 'ABCD'[answer_idx] }}"
|
||||
doc_to_choice: ["A", "B", "C", "D"]
|
||||
```
|
||||
|
||||
### Few-Shot Formatting
|
||||
|
||||
```yaml
|
||||
fewshot_delimiter: "\n\n" # Between examples
|
||||
target_delimiter: " " # Between question and answer
|
||||
doc_to_text: "Q: {{question}}"
|
||||
doc_to_target: "A: {{answer}}"
|
||||
```
|
||||
|
||||
## Custom Python Functions
|
||||
|
||||
For complex logic, use Python functions in `utils.py`.
|
||||
|
||||
### Create `my_tasks/utils.py`
|
||||
|
||||
```python
|
||||
def process_docs(dataset):
|
||||
"""Preprocess documents."""
|
||||
def _process(doc):
|
||||
# Custom preprocessing
|
||||
doc["question"] = doc["question"].strip().lower()
|
||||
return doc
|
||||
|
||||
return dataset.map(_process)
|
||||
|
||||
def doc_to_text(doc):
|
||||
"""Custom prompt formatting."""
|
||||
context = doc.get("context", "")
|
||||
question = doc["question"]
|
||||
|
||||
if context:
|
||||
return f"Context: {context}\nQuestion: {question}\nAnswer:"
|
||||
return f"Question: {question}\nAnswer:"
|
||||
|
||||
def doc_to_target(doc):
|
||||
"""Custom target extraction."""
|
||||
return doc["answer"].strip().lower()
|
||||
|
||||
def aggregate_scores(items):
|
||||
"""Custom metric aggregation."""
|
||||
correct = sum(1 for item in items if item == 1.0)
|
||||
total = len(items)
|
||||
return correct / total if total > 0 else 0.0
|
||||
```
|
||||
|
||||
### Use in Task Config
|
||||
|
||||
```yaml
|
||||
task: my_custom_task
|
||||
dataset_path: data/my_data.jsonl
|
||||
|
||||
# Use Python functions
|
||||
process_docs: !function utils.process_docs
|
||||
doc_to_text: !function utils.doc_to_text
|
||||
doc_to_target: !function utils.doc_to_target
|
||||
|
||||
metric_list:
|
||||
- metric: exact_match
|
||||
aggregation: !function utils.aggregate_scores
|
||||
higher_is_better: true
|
||||
```
|
||||
|
||||
## Real-World Examples
|
||||
|
||||
### Example 1: Domain QA Task
|
||||
|
||||
**Goal**: Evaluate medical question answering.
|
||||
|
||||
`medical_qa/medical_qa.yaml`:
|
||||
```yaml
|
||||
task: medical_qa
|
||||
dataset_path: data/medical_qa.jsonl
|
||||
output_type: generate_until
|
||||
num_fewshot: 3
|
||||
|
||||
doc_to_text: |
|
||||
Medical Question: {{question}}
|
||||
Context: {{context}}
|
||||
Answer (be concise):
|
||||
|
||||
doc_to_target: "{{answer}}"
|
||||
|
||||
generation_kwargs:
|
||||
max_gen_toks: 100
|
||||
until:
|
||||
- "\n\n"
|
||||
temperature: 0.0
|
||||
|
||||
metric_list:
|
||||
- metric: exact_match
|
||||
aggregation: mean
|
||||
higher_is_better: true
|
||||
- metric: !function utils.medical_f1
|
||||
aggregation: mean
|
||||
higher_is_better: true
|
||||
|
||||
filter_list:
|
||||
- name: lowercase
|
||||
filter:
|
||||
- function: lowercase
|
||||
- function: remove_whitespace
|
||||
|
||||
metadata:
|
||||
version: 1.0
|
||||
domain: medical
|
||||
```
|
||||
|
||||
`medical_qa/utils.py`:
|
||||
```python
|
||||
from sklearn.metrics import f1_score
|
||||
import re
|
||||
|
||||
def medical_f1(predictions, references):
|
||||
"""Custom F1 for medical terms."""
|
||||
pred_terms = set(extract_medical_terms(predictions[0]))
|
||||
ref_terms = set(extract_medical_terms(references[0]))
|
||||
|
||||
if not pred_terms and not ref_terms:
|
||||
return 1.0
|
||||
if not pred_terms or not ref_terms:
|
||||
return 0.0
|
||||
|
||||
tp = len(pred_terms & ref_terms)
|
||||
fp = len(pred_terms - ref_terms)
|
||||
fn = len(ref_terms - pred_terms)
|
||||
|
||||
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
||||
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
||||
|
||||
return 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
||||
|
||||
def extract_medical_terms(text):
|
||||
"""Extract medical terminology."""
|
||||
# Custom logic
|
||||
return re.findall(r'\b[A-Z][a-z]+(?:[A-Z][a-z]+)*\b', text)
|
||||
```
|
||||
|
||||
### Example 2: Code Evaluation
|
||||
|
||||
`code_eval/python_challenges.yaml`:
|
||||
```yaml
|
||||
task: python_challenges
|
||||
dataset_path: data/python_problems.jsonl
|
||||
output_type: generate_until
|
||||
num_fewshot: 0
|
||||
|
||||
doc_to_text: |
|
||||
Write a Python function to solve:
|
||||
{{problem_statement}}
|
||||
|
||||
Function signature:
|
||||
{{function_signature}}
|
||||
|
||||
doc_to_target: "{{canonical_solution}}"
|
||||
|
||||
generation_kwargs:
|
||||
max_gen_toks: 512
|
||||
until:
|
||||
- "\n\nclass"
|
||||
- "\n\ndef"
|
||||
temperature: 0.2
|
||||
|
||||
metric_list:
|
||||
- metric: !function utils.execute_code
|
||||
aggregation: mean
|
||||
higher_is_better: true
|
||||
|
||||
process_results: !function utils.process_code_results
|
||||
|
||||
metadata:
|
||||
version: 1.0
|
||||
```
|
||||
|
||||
`code_eval/utils.py`:
|
||||
```python
|
||||
import subprocess
|
||||
import json
|
||||
|
||||
def execute_code(predictions, references):
|
||||
"""Execute generated code against test cases."""
|
||||
generated_code = predictions[0]
|
||||
test_cases = json.loads(references[0])
|
||||
|
||||
try:
|
||||
# Execute code with test cases
|
||||
for test_input, expected_output in test_cases:
|
||||
result = execute_with_timeout(generated_code, test_input, timeout=5)
|
||||
if result != expected_output:
|
||||
return 0.0
|
||||
return 1.0
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
def execute_with_timeout(code, input_data, timeout=5):
|
||||
"""Safely execute code with timeout."""
|
||||
# Implementation with subprocess and timeout
|
||||
pass
|
||||
|
||||
def process_code_results(doc, results):
|
||||
"""Process code execution results."""
|
||||
return {
|
||||
"passed": results[0] == 1.0,
|
||||
"generated_code": results[1]
|
||||
}
|
||||
```
|
||||
|
||||
### Example 3: Instruction Following
|
||||
|
||||
`instruction_eval/instruction_eval.yaml`:
|
||||
```yaml
|
||||
task: instruction_following
|
||||
dataset_path: data/instructions.jsonl
|
||||
output_type: generate_until
|
||||
num_fewshot: 0
|
||||
|
||||
doc_to_text: |
|
||||
Instruction: {{instruction}}
|
||||
{% if constraints %}
|
||||
Constraints: {{constraints}}
|
||||
{% endif %}
|
||||
Response:
|
||||
|
||||
doc_to_target: "{{expected_response}}"
|
||||
|
||||
generation_kwargs:
|
||||
max_gen_toks: 256
|
||||
temperature: 0.7
|
||||
|
||||
metric_list:
|
||||
- metric: !function utils.check_constraints
|
||||
aggregation: mean
|
||||
higher_is_better: true
|
||||
- metric: !function utils.semantic_similarity
|
||||
aggregation: mean
|
||||
higher_is_better: true
|
||||
|
||||
process_docs: !function utils.add_constraint_checkers
|
||||
```
|
||||
|
||||
`instruction_eval/utils.py`:
|
||||
```python
|
||||
from sentence_transformers import SentenceTransformer, util
|
||||
|
||||
model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||
|
||||
def check_constraints(predictions, references):
|
||||
"""Check if response satisfies constraints."""
|
||||
response = predictions[0]
|
||||
constraints = json.loads(references[0])
|
||||
|
||||
satisfied = 0
|
||||
total = len(constraints)
|
||||
|
||||
for constraint in constraints:
|
||||
if verify_constraint(response, constraint):
|
||||
satisfied += 1
|
||||
|
||||
return satisfied / total if total > 0 else 1.0
|
||||
|
||||
def verify_constraint(response, constraint):
|
||||
"""Verify single constraint."""
|
||||
if constraint["type"] == "length":
|
||||
return len(response.split()) >= constraint["min_words"]
|
||||
elif constraint["type"] == "contains":
|
||||
return constraint["keyword"] in response.lower()
|
||||
# Add more constraint types
|
||||
return True
|
||||
|
||||
def semantic_similarity(predictions, references):
|
||||
"""Compute semantic similarity."""
|
||||
pred_embedding = model.encode(predictions[0])
|
||||
ref_embedding = model.encode(references[0])
|
||||
return float(util.cos_sim(pred_embedding, ref_embedding))
|
||||
|
||||
def add_constraint_checkers(dataset):
|
||||
"""Parse constraints into verifiable format."""
|
||||
def _parse(doc):
|
||||
# Parse constraint string into structured format
|
||||
doc["parsed_constraints"] = parse_constraints(doc.get("constraints", ""))
|
||||
return doc
|
||||
return dataset.map(_parse)
|
||||
```
|
||||
|
||||
## Advanced Features
|
||||
|
||||
### Output Filtering
|
||||
|
||||
```yaml
|
||||
filter_list:
|
||||
- name: extract_answer
|
||||
filter:
|
||||
- function: regex
|
||||
regex_pattern: "Answer: (.*)"
|
||||
group: 1
|
||||
- function: lowercase
|
||||
- function: strip_whitespace
|
||||
```
|
||||
|
||||
### Multiple Metrics
|
||||
|
||||
```yaml
|
||||
metric_list:
|
||||
- metric: exact_match
|
||||
aggregation: mean
|
||||
higher_is_better: true
|
||||
- metric: f1
|
||||
aggregation: mean
|
||||
higher_is_better: true
|
||||
- metric: bleu
|
||||
aggregation: mean
|
||||
higher_is_better: true
|
||||
```
|
||||
|
||||
### Task Groups
|
||||
|
||||
Create `my_tasks/_default.yaml`:
|
||||
```yaml
|
||||
group: my_eval_suite
|
||||
task:
|
||||
- simple_qa
|
||||
- medical_qa
|
||||
- python_challenges
|
||||
```
|
||||
|
||||
**Run entire suite**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks my_eval_suite \
|
||||
--include_path my_tasks/
|
||||
```
|
||||
|
||||
## Testing Your Task
|
||||
|
||||
### Validate Configuration
|
||||
|
||||
```bash
|
||||
# Test task loading
|
||||
lm_eval --tasks my_custom_task --include_path my_tasks/ --limit 0
|
||||
|
||||
# Run on 5 samples
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=gpt2 \
|
||||
--tasks my_custom_task \
|
||||
--include_path my_tasks/ \
|
||||
--limit 5
|
||||
```
|
||||
|
||||
### Debug Mode
|
||||
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=gpt2 \
|
||||
--tasks my_custom_task \
|
||||
--include_path my_tasks/ \
|
||||
--limit 1 \
|
||||
--log_samples # Save input/output samples
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Start simple**: Test with minimal config first
|
||||
2. **Version your tasks**: Use `metadata.version`
|
||||
3. **Document your metrics**: Explain custom metrics in comments
|
||||
4. **Test with multiple models**: Ensure robustness
|
||||
5. **Validate on known examples**: Include sanity checks
|
||||
6. **Use filters carefully**: Can hide errors
|
||||
7. **Handle edge cases**: Empty strings, missing fields
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Classification Task
|
||||
|
||||
```yaml
|
||||
output_type: loglikelihood
|
||||
doc_to_text: "Text: {{text}}\nLabel:"
|
||||
doc_to_target: " {{label}}" # Space prefix important!
|
||||
metric_list:
|
||||
- metric: acc
|
||||
aggregation: mean
|
||||
```
|
||||
|
||||
### Perplexity Evaluation
|
||||
|
||||
```yaml
|
||||
output_type: loglikelihood_rolling
|
||||
doc_to_text: "{{text}}"
|
||||
metric_list:
|
||||
- metric: perplexity
|
||||
aggregation: perplexity
|
||||
```
|
||||
|
||||
### Ranking Task
|
||||
|
||||
```yaml
|
||||
output_type: loglikelihood
|
||||
doc_to_text: "Query: {{query}}\nPassage: {{passage}}\nRelevant:"
|
||||
doc_to_target: [" Yes", " No"]
|
||||
metric_list:
|
||||
- metric: acc
|
||||
aggregation: mean
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**"Task not found"**: Check `--include_path` and task name
|
||||
|
||||
**Empty results**: Verify `doc_to_text` and `doc_to_target` templates
|
||||
|
||||
**Metric errors**: Ensure metric names are correct (exact_match, not exact-match)
|
||||
|
||||
**Filter issues**: Test filters with `--log_samples`
|
||||
|
||||
**Python function not found**: Check `!function module.function_name` syntax
|
||||
|
||||
## References
|
||||
|
||||
- Task system: EleutherAI/lm-evaluation-harness docs
|
||||
- Example tasks: `lm_eval/tasks/` directory
|
||||
- TaskConfig: `lm_eval/api/task.py`
|
||||
@@ -0,0 +1,519 @@
|
||||
# Distributed Evaluation
|
||||
|
||||
Guide to running evaluation across multiple GPUs using data parallelism and tensor/pipeline parallelism.
|
||||
|
||||
## Overview
|
||||
|
||||
Distributed evaluation speeds up benchmarking by:
|
||||
- **Data Parallelism**: Split evaluation samples across GPUs (each GPU has full model copy)
|
||||
- **Tensor Parallelism**: Split model weights across GPUs (for large models)
|
||||
- **Pipeline Parallelism**: Split model layers across GPUs (for very large models)
|
||||
|
||||
**When to use**:
|
||||
- Data Parallel: Model fits on single GPU, want faster evaluation
|
||||
- Tensor/Pipeline Parallel: Model too large for single GPU
|
||||
|
||||
## HuggingFace Models (`hf`)
|
||||
|
||||
### Data Parallelism (Recommended)
|
||||
|
||||
Each GPU loads a full copy of the model and processes a subset of evaluation data.
|
||||
|
||||
**Single Node (8 GPUs)**:
|
||||
```bash
|
||||
accelerate launch --multi_gpu --num_processes 8 \
|
||||
-m lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf,dtype=bfloat16 \
|
||||
--tasks mmlu,gsm8k,hellaswag \
|
||||
--batch_size 16
|
||||
```
|
||||
|
||||
**Speedup**: Near-linear (8 GPUs = ~8× faster)
|
||||
|
||||
**Memory**: Each GPU needs full model (7B model ≈ 14GB × 8 = 112GB total)
|
||||
|
||||
### Tensor Parallelism (Model Sharding)
|
||||
|
||||
Split model weights across GPUs for models too large for single GPU.
|
||||
|
||||
**Without accelerate launcher**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args \
|
||||
pretrained=meta-llama/Llama-2-70b-hf,\
|
||||
parallelize=True,\
|
||||
dtype=bfloat16 \
|
||||
--tasks mmlu,gsm8k \
|
||||
--batch_size 8
|
||||
```
|
||||
|
||||
**With 8 GPUs**: 70B model (140GB) / 8 = 17.5GB per GPU ✅
|
||||
|
||||
**Advanced sharding**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args \
|
||||
pretrained=meta-llama/Llama-2-70b-hf,\
|
||||
parallelize=True,\
|
||||
device_map_option=auto,\
|
||||
max_memory_per_gpu=40GB,\
|
||||
max_cpu_memory=100GB,\
|
||||
dtype=bfloat16 \
|
||||
--tasks mmlu
|
||||
```
|
||||
|
||||
**Options**:
|
||||
- `device_map_option`: `"auto"` (default), `"balanced"`, `"balanced_low_0"`
|
||||
- `max_memory_per_gpu`: Max memory per GPU (e.g., `"40GB"`)
|
||||
- `max_cpu_memory`: Max CPU memory for offloading
|
||||
- `offload_folder`: Disk offloading directory
|
||||
|
||||
### Combined Data + Tensor Parallelism
|
||||
|
||||
Use both for very large models.
|
||||
|
||||
**Example: 70B model on 16 GPUs (2 copies, 8 GPUs each)**:
|
||||
```bash
|
||||
accelerate launch --multi_gpu --num_processes 2 \
|
||||
-m lm_eval --model hf \
|
||||
--model_args \
|
||||
pretrained=meta-llama/Llama-2-70b-hf,\
|
||||
parallelize=True,\
|
||||
dtype=bfloat16 \
|
||||
--tasks mmlu \
|
||||
--batch_size 8
|
||||
```
|
||||
|
||||
**Result**: 2× speedup from data parallelism, 70B model fits via tensor parallelism
|
||||
|
||||
### Configuration with `accelerate config`
|
||||
|
||||
Create `~/.cache/huggingface/accelerate/default_config.yaml`:
|
||||
```yaml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
distributed_type: MULTI_GPU
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
gpu_ids: all
|
||||
mixed_precision: bf16
|
||||
```
|
||||
|
||||
**Then run**:
|
||||
```bash
|
||||
accelerate launch -m lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks mmlu
|
||||
```
|
||||
|
||||
## vLLM Models (`vllm`)
|
||||
|
||||
vLLM provides highly optimized distributed inference.
|
||||
|
||||
### Tensor Parallelism
|
||||
|
||||
**Single Node (4 GPUs)**:
|
||||
```bash
|
||||
lm_eval --model vllm \
|
||||
--model_args \
|
||||
pretrained=meta-llama/Llama-2-70b-hf,\
|
||||
tensor_parallel_size=4,\
|
||||
dtype=auto,\
|
||||
gpu_memory_utilization=0.9 \
|
||||
--tasks mmlu,gsm8k \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
**Memory**: 70B model split across 4 GPUs = ~35GB per GPU
|
||||
|
||||
### Data Parallelism
|
||||
|
||||
**Multiple model replicas**:
|
||||
```bash
|
||||
lm_eval --model vllm \
|
||||
--model_args \
|
||||
pretrained=meta-llama/Llama-2-7b-hf,\
|
||||
data_parallel_size=4,\
|
||||
dtype=auto,\
|
||||
gpu_memory_utilization=0.8 \
|
||||
--tasks hellaswag,arc_challenge \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
**Result**: 4 model replicas = 4× throughput
|
||||
|
||||
### Combined Tensor + Data Parallelism
|
||||
|
||||
**Example: 8 GPUs = 4 TP × 2 DP**:
|
||||
```bash
|
||||
lm_eval --model vllm \
|
||||
--model_args \
|
||||
pretrained=meta-llama/Llama-2-70b-hf,\
|
||||
tensor_parallel_size=4,\
|
||||
data_parallel_size=2,\
|
||||
dtype=auto,\
|
||||
gpu_memory_utilization=0.85 \
|
||||
--tasks mmlu \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
**Result**: 70B model fits (TP=4), 2× speedup (DP=2)
|
||||
|
||||
### Multi-Node vLLM
|
||||
|
||||
vLLM doesn't natively support multi-node. Use Ray:
|
||||
|
||||
```bash
|
||||
# Start Ray cluster
|
||||
ray start --head --port=6379
|
||||
|
||||
# Run evaluation
|
||||
lm_eval --model vllm \
|
||||
--model_args \
|
||||
pretrained=meta-llama/Llama-2-70b-hf,\
|
||||
tensor_parallel_size=8,\
|
||||
dtype=auto \
|
||||
--tasks mmlu
|
||||
```
|
||||
|
||||
## NVIDIA NeMo Models (`nemo_lm`)
|
||||
|
||||
### Data Replication
|
||||
|
||||
**8 replicas on 8 GPUs**:
|
||||
```bash
|
||||
torchrun --nproc-per-node=8 --no-python \
|
||||
lm_eval --model nemo_lm \
|
||||
--model_args \
|
||||
path=/path/to/model.nemo,\
|
||||
devices=8 \
|
||||
--tasks hellaswag,arc_challenge \
|
||||
--batch_size 32
|
||||
```
|
||||
|
||||
**Speedup**: Near-linear (8× faster)
|
||||
|
||||
### Tensor Parallelism
|
||||
|
||||
**4-way tensor parallelism**:
|
||||
```bash
|
||||
torchrun --nproc-per-node=4 --no-python \
|
||||
lm_eval --model nemo_lm \
|
||||
--model_args \
|
||||
path=/path/to/70b_model.nemo,\
|
||||
devices=4,\
|
||||
tensor_model_parallel_size=4 \
|
||||
--tasks mmlu,gsm8k \
|
||||
--batch_size 16
|
||||
```
|
||||
|
||||
### Pipeline Parallelism
|
||||
|
||||
**2 TP × 2 PP on 4 GPUs**:
|
||||
```bash
|
||||
torchrun --nproc-per-node=4 --no-python \
|
||||
lm_eval --model nemo_lm \
|
||||
--model_args \
|
||||
path=/path/to/model.nemo,\
|
||||
devices=4,\
|
||||
tensor_model_parallel_size=2,\
|
||||
pipeline_model_parallel_size=2 \
|
||||
--tasks mmlu \
|
||||
--batch_size 8
|
||||
```
|
||||
|
||||
**Constraint**: `devices = TP × PP`
|
||||
|
||||
### Multi-Node NeMo
|
||||
|
||||
Currently not supported by lm-evaluation-harness.
|
||||
|
||||
## SGLang Models (`sglang`)
|
||||
|
||||
### Tensor Parallelism
|
||||
|
||||
```bash
|
||||
lm_eval --model sglang \
|
||||
--model_args \
|
||||
pretrained=meta-llama/Llama-2-70b-hf,\
|
||||
tp_size=4,\
|
||||
dtype=auto \
|
||||
--tasks gsm8k \
|
||||
--batch_size auto
|
||||
```
|
||||
|
||||
### Data Parallelism (Deprecated)
|
||||
|
||||
**Note**: SGLang is deprecating data parallelism. Use tensor parallelism instead.
|
||||
|
||||
```bash
|
||||
lm_eval --model sglang \
|
||||
--model_args \
|
||||
pretrained=meta-llama/Llama-2-7b-hf,\
|
||||
dp_size=4,\
|
||||
dtype=auto \
|
||||
--tasks mmlu
|
||||
```
|
||||
|
||||
## Performance Comparison
|
||||
|
||||
### 70B Model Evaluation (MMLU, 5-shot)
|
||||
|
||||
| Method | GPUs | Time | Memory/GPU | Notes |
|
||||
|--------|------|------|------------|-------|
|
||||
| HF (no parallel) | 1 | 8 hours | 140GB (OOM) | Won't fit |
|
||||
| HF (TP=8) | 8 | 2 hours | 17.5GB | Slower, fits |
|
||||
| HF (DP=8) | 8 | 1 hour | 140GB (OOM) | Won't fit |
|
||||
| vLLM (TP=4) | 4 | 30 min | 35GB | Fast! |
|
||||
| vLLM (TP=4, DP=2) | 8 | 15 min | 35GB | Fastest |
|
||||
|
||||
### 7B Model Evaluation (Multiple Tasks)
|
||||
|
||||
| Method | GPUs | Time | Speedup |
|
||||
|--------|------|------|---------|
|
||||
| HF (single) | 1 | 4 hours | 1× |
|
||||
| HF (DP=4) | 4 | 1 hour | 4× |
|
||||
| HF (DP=8) | 8 | 30 min | 8× |
|
||||
| vLLM (DP=8) | 8 | 15 min | 16× |
|
||||
|
||||
**Takeaway**: vLLM is significantly faster than HuggingFace for inference.
|
||||
|
||||
## Choosing Parallelism Strategy
|
||||
|
||||
### Decision Tree
|
||||
|
||||
```
|
||||
Model fits on single GPU?
|
||||
├─ YES: Use data parallelism
|
||||
│ ├─ HF: accelerate launch --multi_gpu --num_processes N
|
||||
│ └─ vLLM: data_parallel_size=N (fastest)
|
||||
│
|
||||
└─ NO: Use tensor/pipeline parallelism
|
||||
├─ Model < 70B:
|
||||
│ └─ vLLM: tensor_parallel_size=4
|
||||
├─ Model 70-175B:
|
||||
│ ├─ vLLM: tensor_parallel_size=8
|
||||
│ └─ Or HF: parallelize=True
|
||||
└─ Model > 175B:
|
||||
└─ Contact framework authors
|
||||
```
|
||||
|
||||
### Memory Estimation
|
||||
|
||||
**Rule of thumb**:
|
||||
```
|
||||
Memory (GB) = Parameters (B) × Precision (bytes) × 1.2 (overhead)
|
||||
```
|
||||
|
||||
**Examples**:
|
||||
- 7B FP16: 7 × 2 × 1.2 = 16.8GB ✅ Fits A100 40GB
|
||||
- 13B FP16: 13 × 2 × 1.2 = 31.2GB ✅ Fits A100 40GB
|
||||
- 70B FP16: 70 × 2 × 1.2 = 168GB ❌ Need TP=4 or TP=8
|
||||
- 70B BF16: 70 × 2 × 1.2 = 168GB (same as FP16)
|
||||
|
||||
**With tensor parallelism**:
|
||||
```
|
||||
Memory per GPU = Total Memory / TP
|
||||
```
|
||||
|
||||
- 70B on 4 GPUs: 168GB / 4 = 42GB per GPU ✅
|
||||
- 70B on 8 GPUs: 168GB / 8 = 21GB per GPU ✅
|
||||
|
||||
## Multi-Node Evaluation
|
||||
|
||||
### HuggingFace with SLURM
|
||||
|
||||
**Submit job**:
|
||||
```bash
|
||||
#!/bin/bash
|
||||
#SBATCH --nodes=4
|
||||
#SBATCH --gpus-per-node=8
|
||||
#SBATCH --ntasks-per-node=1
|
||||
|
||||
srun accelerate launch --multi_gpu \
|
||||
--num_processes $((SLURM_NNODES * 8)) \
|
||||
-m lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks mmlu,gsm8k,hellaswag \
|
||||
--batch_size 16
|
||||
```
|
||||
|
||||
**Submit**:
|
||||
```bash
|
||||
sbatch eval_job.sh
|
||||
```
|
||||
|
||||
### Manual Multi-Node Setup
|
||||
|
||||
**On each node, run**:
|
||||
```bash
|
||||
accelerate launch \
|
||||
--multi_gpu \
|
||||
--num_machines 4 \
|
||||
--num_processes 32 \
|
||||
--main_process_ip $MASTER_IP \
|
||||
--main_process_port 29500 \
|
||||
--machine_rank $NODE_RANK \
|
||||
-m lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks mmlu
|
||||
```
|
||||
|
||||
**Environment variables**:
|
||||
- `MASTER_IP`: IP of rank 0 node
|
||||
- `NODE_RANK`: 0, 1, 2, 3 for each node
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Start Small
|
||||
|
||||
Test on small sample first:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-70b-hf,parallelize=True \
|
||||
--tasks mmlu \
|
||||
--limit 100 # Just 100 samples
|
||||
```
|
||||
|
||||
### 2. Monitor GPU Usage
|
||||
|
||||
```bash
|
||||
# Terminal 1: Run evaluation
|
||||
lm_eval --model hf ...
|
||||
|
||||
# Terminal 2: Monitor
|
||||
watch -n 1 nvidia-smi
|
||||
```
|
||||
|
||||
Look for:
|
||||
- GPU utilization > 90%
|
||||
- Memory usage stable
|
||||
- All GPUs active
|
||||
|
||||
### 3. Optimize Batch Size
|
||||
|
||||
```bash
|
||||
# Auto batch size (recommended)
|
||||
--batch_size auto
|
||||
|
||||
# Or tune manually
|
||||
--batch_size 16 # Start here
|
||||
--batch_size 32 # Increase if memory allows
|
||||
```
|
||||
|
||||
### 4. Use Mixed Precision
|
||||
|
||||
```bash
|
||||
--model_args dtype=bfloat16 # Faster, less memory
|
||||
```
|
||||
|
||||
### 5. Check Communication
|
||||
|
||||
For data parallelism, check network bandwidth:
|
||||
```bash
|
||||
# Should see InfiniBand or high-speed network
|
||||
nvidia-smi topo -m
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "CUDA out of memory"
|
||||
|
||||
**Solutions**:
|
||||
1. Increase tensor parallelism:
|
||||
```bash
|
||||
--model_args tensor_parallel_size=8 # Was 4
|
||||
```
|
||||
|
||||
2. Reduce batch size:
|
||||
```bash
|
||||
--batch_size 4 # Was 16
|
||||
```
|
||||
|
||||
3. Lower precision:
|
||||
```bash
|
||||
--model_args dtype=int8 # Quantization
|
||||
```
|
||||
|
||||
### "NCCL error" or Hanging
|
||||
|
||||
**Check**:
|
||||
1. All GPUs visible: `nvidia-smi`
|
||||
2. NCCL installed: `python -c "import torch; print(torch.cuda.nccl.version())"`
|
||||
3. Network connectivity between nodes
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
export NCCL_DEBUG=INFO # Enable debug logging
|
||||
export NCCL_IB_DISABLE=0 # Use InfiniBand if available
|
||||
```
|
||||
|
||||
### Slow Evaluation
|
||||
|
||||
**Possible causes**:
|
||||
1. **Data loading bottleneck**: Preprocess dataset
|
||||
2. **Low GPU utilization**: Increase batch size
|
||||
3. **Communication overhead**: Reduce parallelism degree
|
||||
|
||||
**Profile**:
|
||||
```bash
|
||||
lm_eval --model hf \
|
||||
--model_args pretrained=meta-llama/Llama-2-7b-hf \
|
||||
--tasks mmlu \
|
||||
--limit 100 \
|
||||
--log_samples # Check timing
|
||||
```
|
||||
|
||||
### GPUs Imbalanced
|
||||
|
||||
**Symptom**: GPU 0 at 100%, others at 50%
|
||||
|
||||
**Solution**: Use `device_map_option=balanced`:
|
||||
```bash
|
||||
--model_args parallelize=True,device_map_option=balanced
|
||||
```
|
||||
|
||||
## Example Configurations
|
||||
|
||||
### Small Model (7B) - Fast Evaluation
|
||||
|
||||
```bash
|
||||
# 8 A100s, data parallel
|
||||
accelerate launch --multi_gpu --num_processes 8 \
|
||||
-m lm_eval --model hf \
|
||||
--model_args \
|
||||
pretrained=meta-llama/Llama-2-7b-hf,\
|
||||
dtype=bfloat16 \
|
||||
--tasks mmlu,gsm8k,hellaswag,arc_challenge \
|
||||
--num_fewshot 5 \
|
||||
--batch_size 32
|
||||
|
||||
# Time: ~30 minutes
|
||||
```
|
||||
|
||||
### Large Model (70B) - vLLM
|
||||
|
||||
```bash
|
||||
# 8 H100s, tensor parallel
|
||||
lm_eval --model vllm \
|
||||
--model_args \
|
||||
pretrained=meta-llama/Llama-2-70b-hf,\
|
||||
tensor_parallel_size=8,\
|
||||
dtype=auto,\
|
||||
gpu_memory_utilization=0.9 \
|
||||
--tasks mmlu,gsm8k,humaneval \
|
||||
--num_fewshot 5 \
|
||||
--batch_size auto
|
||||
|
||||
# Time: ~1 hour
|
||||
```
|
||||
|
||||
### Very Large Model (175B+)
|
||||
|
||||
**Requires specialized setup - contact framework maintainers**
|
||||
|
||||
## References
|
||||
|
||||
- HuggingFace Accelerate: https://huggingface.co/docs/accelerate/
|
||||
- vLLM docs: https://docs.vllm.ai/
|
||||
- NeMo docs: https://docs.nvidia.com/nemo-framework/
|
||||
- lm-eval distributed guide: `docs/model_guide.md`
|
||||
@@ -0,0 +1,593 @@
|
||||
---
|
||||
name: weights-and-biases
|
||||
description: "W&B: log ML experiments, sweeps, model registry, dashboards."
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [wandb]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [MLOps, Weights And Biases, WandB, Experiment Tracking, Hyperparameter Tuning, Model Registry, Collaboration, Real-Time Visualization, PyTorch, TensorFlow, HuggingFace]
|
||||
|
||||
---
|
||||
|
||||
# Weights & Biases: ML Experiment Tracking & MLOps
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use Weights & Biases (W&B) when you need to:
|
||||
- **Track ML experiments** with automatic metric logging
|
||||
- **Visualize training** in real-time dashboards
|
||||
- **Compare runs** across hyperparameters and configurations
|
||||
- **Optimize hyperparameters** with automated sweeps
|
||||
- **Manage model registry** with versioning and lineage
|
||||
- **Collaborate on ML projects** with team workspaces
|
||||
- **Track artifacts** (datasets, models, code) with lineage
|
||||
|
||||
**Users**: 200,000+ ML practitioners | **GitHub Stars**: 10.5k+ | **Integrations**: 100+
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Install W&B
|
||||
pip install wandb
|
||||
|
||||
# Login (creates API key)
|
||||
wandb login
|
||||
|
||||
# Or set API key programmatically
|
||||
export WANDB_API_KEY=your_api_key_here
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic Experiment Tracking
|
||||
|
||||
```python
|
||||
import wandb
|
||||
|
||||
# Initialize a run
|
||||
run = wandb.init(
|
||||
project="my-project",
|
||||
config={
|
||||
"learning_rate": 0.001,
|
||||
"epochs": 10,
|
||||
"batch_size": 32,
|
||||
"architecture": "ResNet50"
|
||||
}
|
||||
)
|
||||
|
||||
# Training loop
|
||||
for epoch in range(run.config.epochs):
|
||||
# Your training code
|
||||
train_loss = train_epoch()
|
||||
val_loss = validate()
|
||||
|
||||
# Log metrics
|
||||
wandb.log({
|
||||
"epoch": epoch,
|
||||
"train/loss": train_loss,
|
||||
"val/loss": val_loss,
|
||||
"train/accuracy": train_acc,
|
||||
"val/accuracy": val_acc
|
||||
})
|
||||
|
||||
# Finish the run
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
### With PyTorch
|
||||
|
||||
```python
|
||||
import torch
|
||||
import wandb
|
||||
|
||||
# Initialize
|
||||
wandb.init(project="pytorch-demo", config={
|
||||
"lr": 0.001,
|
||||
"epochs": 10
|
||||
})
|
||||
|
||||
# Access config
|
||||
config = wandb.config
|
||||
|
||||
# Training loop
|
||||
for epoch in range(config.epochs):
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
# Forward pass
|
||||
output = model(data)
|
||||
loss = criterion(output, target)
|
||||
|
||||
# Backward pass
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Log every 100 batches
|
||||
if batch_idx % 100 == 0:
|
||||
wandb.log({
|
||||
"loss": loss.item(),
|
||||
"epoch": epoch,
|
||||
"batch": batch_idx
|
||||
})
|
||||
|
||||
# Save model
|
||||
torch.save(model.state_dict(), "model.pth")
|
||||
wandb.save("model.pth") # Upload to W&B
|
||||
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### 1. Projects and Runs
|
||||
|
||||
**Project**: Collection of related experiments
|
||||
**Run**: Single execution of your training script
|
||||
|
||||
```python
|
||||
# Create/use project
|
||||
run = wandb.init(
|
||||
project="image-classification",
|
||||
name="resnet50-experiment-1", # Optional run name
|
||||
tags=["baseline", "resnet"], # Organize with tags
|
||||
notes="First baseline run" # Add notes
|
||||
)
|
||||
|
||||
# Each run has unique ID
|
||||
print(f"Run ID: {run.id}")
|
||||
print(f"Run URL: {run.url}")
|
||||
```
|
||||
|
||||
### 2. Configuration Tracking
|
||||
|
||||
Track hyperparameters automatically:
|
||||
|
||||
```python
|
||||
config = {
|
||||
# Model architecture
|
||||
"model": "ResNet50",
|
||||
"pretrained": True,
|
||||
|
||||
# Training params
|
||||
"learning_rate": 0.001,
|
||||
"batch_size": 32,
|
||||
"epochs": 50,
|
||||
"optimizer": "Adam",
|
||||
|
||||
# Data params
|
||||
"dataset": "ImageNet",
|
||||
"augmentation": "standard"
|
||||
}
|
||||
|
||||
wandb.init(project="my-project", config=config)
|
||||
|
||||
# Access config during training
|
||||
lr = wandb.config.learning_rate
|
||||
batch_size = wandb.config.batch_size
|
||||
```
|
||||
|
||||
### 3. Metric Logging
|
||||
|
||||
```python
|
||||
# Log scalars
|
||||
wandb.log({"loss": 0.5, "accuracy": 0.92})
|
||||
|
||||
# Log multiple metrics
|
||||
wandb.log({
|
||||
"train/loss": train_loss,
|
||||
"train/accuracy": train_acc,
|
||||
"val/loss": val_loss,
|
||||
"val/accuracy": val_acc,
|
||||
"learning_rate": current_lr,
|
||||
"epoch": epoch
|
||||
})
|
||||
|
||||
# Log with custom x-axis
|
||||
wandb.log({"loss": loss}, step=global_step)
|
||||
|
||||
# Log media (images, audio, video)
|
||||
wandb.log({"examples": [wandb.Image(img) for img in images]})
|
||||
|
||||
# Log histograms
|
||||
wandb.log({"gradients": wandb.Histogram(gradients)})
|
||||
|
||||
# Log tables
|
||||
table = wandb.Table(columns=["id", "prediction", "ground_truth"])
|
||||
wandb.log({"predictions": table})
|
||||
```
|
||||
|
||||
### 4. Model Checkpointing
|
||||
|
||||
```python
|
||||
import torch
|
||||
import wandb
|
||||
|
||||
# Save model checkpoint
|
||||
checkpoint = {
|
||||
'epoch': epoch,
|
||||
'model_state_dict': model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'loss': loss,
|
||||
}
|
||||
|
||||
torch.save(checkpoint, 'checkpoint.pth')
|
||||
|
||||
# Upload to W&B
|
||||
wandb.save('checkpoint.pth')
|
||||
|
||||
# Or use Artifacts (recommended)
|
||||
artifact = wandb.Artifact('model', type='model')
|
||||
artifact.add_file('checkpoint.pth')
|
||||
wandb.log_artifact(artifact)
|
||||
```
|
||||
|
||||
## Hyperparameter Sweeps
|
||||
|
||||
Automatically search for optimal hyperparameters.
|
||||
|
||||
### Define Sweep Configuration
|
||||
|
||||
```python
|
||||
sweep_config = {
|
||||
'method': 'bayes', # or 'grid', 'random'
|
||||
'metric': {
|
||||
'name': 'val/accuracy',
|
||||
'goal': 'maximize'
|
||||
},
|
||||
'parameters': {
|
||||
'learning_rate': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-5,
|
||||
'max': 1e-1
|
||||
},
|
||||
'batch_size': {
|
||||
'values': [16, 32, 64, 128]
|
||||
},
|
||||
'optimizer': {
|
||||
'values': ['adam', 'sgd', 'rmsprop']
|
||||
},
|
||||
'dropout': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.1,
|
||||
'max': 0.5
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Initialize sweep
|
||||
sweep_id = wandb.sweep(sweep_config, project="my-project")
|
||||
```
|
||||
|
||||
### Define Training Function
|
||||
|
||||
```python
|
||||
def train():
|
||||
# Initialize run
|
||||
run = wandb.init()
|
||||
|
||||
# Access sweep parameters
|
||||
lr = wandb.config.learning_rate
|
||||
batch_size = wandb.config.batch_size
|
||||
optimizer_name = wandb.config.optimizer
|
||||
|
||||
# Build model with sweep config
|
||||
model = build_model(wandb.config)
|
||||
optimizer = get_optimizer(optimizer_name, lr)
|
||||
|
||||
# Training loop
|
||||
for epoch in range(NUM_EPOCHS):
|
||||
train_loss = train_epoch(model, optimizer, batch_size)
|
||||
val_acc = validate(model)
|
||||
|
||||
# Log metrics
|
||||
wandb.log({
|
||||
"train/loss": train_loss,
|
||||
"val/accuracy": val_acc
|
||||
})
|
||||
|
||||
# Run sweep
|
||||
wandb.agent(sweep_id, function=train, count=50) # Run 50 trials
|
||||
```
|
||||
|
||||
### Sweep Strategies
|
||||
|
||||
```python
|
||||
# Grid search - exhaustive
|
||||
sweep_config = {
|
||||
'method': 'grid',
|
||||
'parameters': {
|
||||
'lr': {'values': [0.001, 0.01, 0.1]},
|
||||
'batch_size': {'values': [16, 32, 64]}
|
||||
}
|
||||
}
|
||||
|
||||
# Random search
|
||||
sweep_config = {
|
||||
'method': 'random',
|
||||
'parameters': {
|
||||
'lr': {'distribution': 'uniform', 'min': 0.0001, 'max': 0.1},
|
||||
'dropout': {'distribution': 'uniform', 'min': 0.1, 'max': 0.5}
|
||||
}
|
||||
}
|
||||
|
||||
# Bayesian optimization (recommended)
|
||||
sweep_config = {
|
||||
'method': 'bayes',
|
||||
'metric': {'name': 'val/loss', 'goal': 'minimize'},
|
||||
'parameters': {
|
||||
'lr': {'distribution': 'log_uniform', 'min': 1e-5, 'max': 1e-1}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Artifacts
|
||||
|
||||
Track datasets, models, and other files with lineage.
|
||||
|
||||
### Log Artifacts
|
||||
|
||||
```python
|
||||
# Create artifact
|
||||
artifact = wandb.Artifact(
|
||||
name='training-dataset',
|
||||
type='dataset',
|
||||
description='ImageNet training split',
|
||||
metadata={'size': '1.2M images', 'split': 'train'}
|
||||
)
|
||||
|
||||
# Add files
|
||||
artifact.add_file('data/train.csv')
|
||||
artifact.add_dir('data/images/')
|
||||
|
||||
# Log artifact
|
||||
wandb.log_artifact(artifact)
|
||||
```
|
||||
|
||||
### Use Artifacts
|
||||
|
||||
```python
|
||||
# Download and use artifact
|
||||
run = wandb.init(project="my-project")
|
||||
|
||||
# Download artifact
|
||||
artifact = run.use_artifact('training-dataset:latest')
|
||||
artifact_dir = artifact.download()
|
||||
|
||||
# Use the data
|
||||
data = load_data(f"{artifact_dir}/train.csv")
|
||||
```
|
||||
|
||||
### Model Registry
|
||||
|
||||
```python
|
||||
# Log model as artifact
|
||||
model_artifact = wandb.Artifact(
|
||||
name='resnet50-model',
|
||||
type='model',
|
||||
metadata={'architecture': 'ResNet50', 'accuracy': 0.95}
|
||||
)
|
||||
|
||||
model_artifact.add_file('model.pth')
|
||||
wandb.log_artifact(model_artifact, aliases=['best', 'production'])
|
||||
|
||||
# Link to model registry
|
||||
run.link_artifact(model_artifact, 'model-registry/production-models')
|
||||
```
|
||||
|
||||
## Integration Examples
|
||||
|
||||
### HuggingFace Transformers
|
||||
|
||||
```python
|
||||
from transformers import Trainer, TrainingArguments
|
||||
import wandb
|
||||
|
||||
# Initialize W&B
|
||||
wandb.init(project="hf-transformers")
|
||||
|
||||
# Training arguments with W&B
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
report_to="wandb", # Enable W&B logging
|
||||
run_name="bert-finetuning",
|
||||
logging_steps=100,
|
||||
save_steps=500
|
||||
)
|
||||
|
||||
# Trainer automatically logs to W&B
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### PyTorch Lightning
|
||||
|
||||
```python
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
import wandb
|
||||
|
||||
# Create W&B logger
|
||||
wandb_logger = WandbLogger(
|
||||
project="lightning-demo",
|
||||
log_model=True # Log model checkpoints
|
||||
)
|
||||
|
||||
# Use with Trainer
|
||||
trainer = Trainer(
|
||||
logger=wandb_logger,
|
||||
max_epochs=10
|
||||
)
|
||||
|
||||
trainer.fit(model, datamodule=dm)
|
||||
```
|
||||
|
||||
### Keras/TensorFlow
|
||||
|
||||
```python
|
||||
import wandb
|
||||
from wandb.keras import WandbCallback
|
||||
|
||||
# Initialize
|
||||
wandb.init(project="keras-demo")
|
||||
|
||||
# Add callback
|
||||
model.fit(
|
||||
x_train, y_train,
|
||||
validation_data=(x_val, y_val),
|
||||
epochs=10,
|
||||
callbacks=[WandbCallback()] # Auto-logs metrics
|
||||
)
|
||||
```
|
||||
|
||||
## Visualization & Analysis
|
||||
|
||||
### Custom Charts
|
||||
|
||||
```python
|
||||
# Log custom visualizations
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.plot(x, y)
|
||||
wandb.log({"custom_plot": wandb.Image(fig)})
|
||||
|
||||
# Log confusion matrix
|
||||
wandb.log({"conf_mat": wandb.plot.confusion_matrix(
|
||||
probs=None,
|
||||
y_true=ground_truth,
|
||||
preds=predictions,
|
||||
class_names=class_names
|
||||
)})
|
||||
```
|
||||
|
||||
### Reports
|
||||
|
||||
Create shareable reports in W&B UI:
|
||||
- Combine runs, charts, and text
|
||||
- Markdown support
|
||||
- Embeddable visualizations
|
||||
- Team collaboration
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Organize with Tags and Groups
|
||||
|
||||
```python
|
||||
wandb.init(
|
||||
project="my-project",
|
||||
tags=["baseline", "resnet50", "imagenet"],
|
||||
group="resnet-experiments", # Group related runs
|
||||
job_type="train" # Type of job
|
||||
)
|
||||
```
|
||||
|
||||
### 2. Log Everything Relevant
|
||||
|
||||
```python
|
||||
# Log system metrics
|
||||
wandb.log({
|
||||
"gpu/util": gpu_utilization,
|
||||
"gpu/memory": gpu_memory_used,
|
||||
"cpu/util": cpu_utilization
|
||||
})
|
||||
|
||||
# Log code version
|
||||
wandb.log({"git_commit": git_commit_hash})
|
||||
|
||||
# Log data splits
|
||||
wandb.log({
|
||||
"data/train_size": len(train_dataset),
|
||||
"data/val_size": len(val_dataset)
|
||||
})
|
||||
```
|
||||
|
||||
### 3. Use Descriptive Names
|
||||
|
||||
```python
|
||||
# ✅ Good: Descriptive run names
|
||||
wandb.init(
|
||||
project="nlp-classification",
|
||||
name="bert-base-lr0.001-bs32-epoch10"
|
||||
)
|
||||
|
||||
# ❌ Bad: Generic names
|
||||
wandb.init(project="nlp", name="run1")
|
||||
```
|
||||
|
||||
### 4. Save Important Artifacts
|
||||
|
||||
```python
|
||||
# Save final model
|
||||
artifact = wandb.Artifact('final-model', type='model')
|
||||
artifact.add_file('model.pth')
|
||||
wandb.log_artifact(artifact)
|
||||
|
||||
# Save predictions for analysis
|
||||
predictions_table = wandb.Table(
|
||||
columns=["id", "input", "prediction", "ground_truth"],
|
||||
data=predictions_data
|
||||
)
|
||||
wandb.log({"predictions": predictions_table})
|
||||
```
|
||||
|
||||
### 5. Use Offline Mode for Unstable Connections
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
# Enable offline mode
|
||||
os.environ["WANDB_MODE"] = "offline"
|
||||
|
||||
wandb.init(project="my-project")
|
||||
# ... your code ...
|
||||
|
||||
# Sync later
|
||||
# wandb sync <run_directory>
|
||||
```
|
||||
|
||||
## Team Collaboration
|
||||
|
||||
### Share Runs
|
||||
|
||||
```python
|
||||
# Runs are automatically shareable via URL
|
||||
run = wandb.init(project="team-project")
|
||||
print(f"Share this URL: {run.url}")
|
||||
```
|
||||
|
||||
### Team Projects
|
||||
|
||||
- Create team account at wandb.ai
|
||||
- Add team members
|
||||
- Set project visibility (private/public)
|
||||
- Use team-level artifacts and model registry
|
||||
|
||||
## Pricing
|
||||
|
||||
- **Free**: Unlimited public projects, 100GB storage
|
||||
- **Academic**: Free for students/researchers
|
||||
- **Teams**: $50/seat/month, private projects, unlimited storage
|
||||
- **Enterprise**: Custom pricing, on-prem options
|
||||
|
||||
## Resources
|
||||
|
||||
- **Documentation**: https://docs.wandb.ai
|
||||
- **GitHub**: https://github.com/wandb/wandb (10.5k+ stars)
|
||||
- **Examples**: https://github.com/wandb/examples
|
||||
- **Community**: https://wandb.ai/community
|
||||
- **Discord**: https://wandb.me/discord
|
||||
|
||||
## See Also
|
||||
|
||||
- `references/sweeps.md` - Comprehensive hyperparameter optimization guide
|
||||
- `references/artifacts.md` - Data and model versioning patterns
|
||||
- `references/integrations.md` - Framework-specific examples
|
||||
|
||||
|
||||
@@ -0,0 +1,584 @@
|
||||
# Artifacts & Model Registry Guide
|
||||
|
||||
Complete guide to data versioning and model management with W&B Artifacts.
|
||||
|
||||
## Table of Contents
|
||||
- What are Artifacts
|
||||
- Creating Artifacts
|
||||
- Using Artifacts
|
||||
- Model Registry
|
||||
- Versioning & Lineage
|
||||
- Best Practices
|
||||
|
||||
## What are Artifacts
|
||||
|
||||
Artifacts are versioned datasets, models, or files tracked with lineage.
|
||||
|
||||
**Key Features:**
|
||||
- Automatic versioning (v0, v1, v2...)
|
||||
- Lineage tracking (which runs produced/used artifacts)
|
||||
- Efficient storage (deduplication)
|
||||
- Collaboration (team-wide access)
|
||||
- Aliases (latest, best, production)
|
||||
|
||||
**Common Use Cases:**
|
||||
- Dataset versioning
|
||||
- Model checkpoints
|
||||
- Preprocessed data
|
||||
- Evaluation results
|
||||
- Configuration files
|
||||
|
||||
## Creating Artifacts
|
||||
|
||||
### Basic Dataset Artifact
|
||||
|
||||
```python
|
||||
import wandb
|
||||
|
||||
run = wandb.init(project="my-project")
|
||||
|
||||
# Create artifact
|
||||
dataset = wandb.Artifact(
|
||||
name='training-data',
|
||||
type='dataset',
|
||||
description='ImageNet training split with augmentations',
|
||||
metadata={
|
||||
'size': '1.2M images',
|
||||
'format': 'JPEG',
|
||||
'resolution': '224x224'
|
||||
}
|
||||
)
|
||||
|
||||
# Add files
|
||||
dataset.add_file('data/train.csv') # Single file
|
||||
dataset.add_dir('data/images') # Entire directory
|
||||
dataset.add_reference('s3://bucket/data') # Cloud reference
|
||||
|
||||
# Log artifact
|
||||
run.log_artifact(dataset)
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
### Model Artifact
|
||||
|
||||
```python
|
||||
import torch
|
||||
import wandb
|
||||
|
||||
run = wandb.init(project="my-project")
|
||||
|
||||
# Train model
|
||||
model = train_model()
|
||||
|
||||
# Save model
|
||||
torch.save(model.state_dict(), 'model.pth')
|
||||
|
||||
# Create model artifact
|
||||
model_artifact = wandb.Artifact(
|
||||
name='resnet50-classifier',
|
||||
type='model',
|
||||
description='ResNet50 trained on ImageNet',
|
||||
metadata={
|
||||
'architecture': 'ResNet50',
|
||||
'accuracy': 0.95,
|
||||
'loss': 0.15,
|
||||
'epochs': 50,
|
||||
'framework': 'PyTorch'
|
||||
}
|
||||
)
|
||||
|
||||
# Add model file
|
||||
model_artifact.add_file('model.pth')
|
||||
|
||||
# Add config
|
||||
model_artifact.add_file('config.yaml')
|
||||
|
||||
# Log with aliases
|
||||
run.log_artifact(model_artifact, aliases=['latest', 'best'])
|
||||
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
### Preprocessed Data Artifact
|
||||
|
||||
```python
|
||||
import pandas as pd
|
||||
import wandb
|
||||
|
||||
run = wandb.init(project="nlp-project")
|
||||
|
||||
# Preprocess data
|
||||
df = pd.read_csv('raw_data.csv')
|
||||
df_processed = preprocess(df)
|
||||
df_processed.to_csv('processed_data.csv', index=False)
|
||||
|
||||
# Create artifact
|
||||
processed_data = wandb.Artifact(
|
||||
name='processed-text-data',
|
||||
type='dataset',
|
||||
metadata={
|
||||
'rows': len(df_processed),
|
||||
'columns': list(df_processed.columns),
|
||||
'preprocessing_steps': ['lowercase', 'remove_stopwords', 'tokenize']
|
||||
}
|
||||
)
|
||||
|
||||
processed_data.add_file('processed_data.csv')
|
||||
|
||||
# Log artifact
|
||||
run.log_artifact(processed_data)
|
||||
```
|
||||
|
||||
## Using Artifacts
|
||||
|
||||
### Download and Use
|
||||
|
||||
```python
|
||||
import wandb
|
||||
|
||||
run = wandb.init(project="my-project")
|
||||
|
||||
# Download artifact
|
||||
artifact = run.use_artifact('training-data:latest')
|
||||
artifact_dir = artifact.download()
|
||||
|
||||
# Use files
|
||||
import pandas as pd
|
||||
df = pd.read_csv(f'{artifact_dir}/train.csv')
|
||||
|
||||
# Train with artifact data
|
||||
model = train_model(df)
|
||||
```
|
||||
|
||||
### Use Specific Version
|
||||
|
||||
```python
|
||||
# Use specific version
|
||||
artifact_v2 = run.use_artifact('training-data:v2')
|
||||
|
||||
# Use alias
|
||||
artifact_best = run.use_artifact('model:best')
|
||||
artifact_prod = run.use_artifact('model:production')
|
||||
|
||||
# Use from another project
|
||||
artifact = run.use_artifact('team/other-project/model:latest')
|
||||
```
|
||||
|
||||
### Check Artifact Metadata
|
||||
|
||||
```python
|
||||
artifact = run.use_artifact('training-data:latest')
|
||||
|
||||
# Access metadata
|
||||
print(artifact.metadata)
|
||||
print(f"Size: {artifact.metadata['size']}")
|
||||
|
||||
# Access version info
|
||||
print(f"Version: {artifact.version}")
|
||||
print(f"Created at: {artifact.created_at}")
|
||||
print(f"Digest: {artifact.digest}")
|
||||
```
|
||||
|
||||
## Model Registry
|
||||
|
||||
Link models to a central registry for governance and deployment.
|
||||
|
||||
### Create Model Registry
|
||||
|
||||
```python
|
||||
# In W&B UI:
|
||||
# 1. Go to "Registry" tab
|
||||
# 2. Create new registry: "production-models"
|
||||
# 3. Define stages: development, staging, production
|
||||
```
|
||||
|
||||
### Link Model to Registry
|
||||
|
||||
```python
|
||||
import wandb
|
||||
|
||||
run = wandb.init(project="training")
|
||||
|
||||
# Create model artifact
|
||||
model_artifact = wandb.Artifact(
|
||||
name='sentiment-classifier',
|
||||
type='model',
|
||||
metadata={'accuracy': 0.94, 'f1': 0.92}
|
||||
)
|
||||
|
||||
model_artifact.add_file('model.pth')
|
||||
|
||||
# Log artifact
|
||||
run.log_artifact(model_artifact)
|
||||
|
||||
# Link to registry
|
||||
run.link_artifact(
|
||||
model_artifact,
|
||||
'model-registry/production-models',
|
||||
aliases=['staging'] # Deploy to staging
|
||||
)
|
||||
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
### Promote Model in Registry
|
||||
|
||||
```python
|
||||
# Retrieve model from registry
|
||||
api = wandb.Api()
|
||||
artifact = api.artifact('model-registry/production-models/sentiment-classifier:staging')
|
||||
|
||||
# Promote to production
|
||||
artifact.link('model-registry/production-models', aliases=['production'])
|
||||
|
||||
# Demote from production
|
||||
artifact.aliases = ['archived']
|
||||
artifact.save()
|
||||
```
|
||||
|
||||
### Use Model from Registry
|
||||
|
||||
```python
|
||||
import wandb
|
||||
|
||||
run = wandb.init()
|
||||
|
||||
# Download production model
|
||||
model_artifact = run.use_artifact(
|
||||
'model-registry/production-models/sentiment-classifier:production'
|
||||
)
|
||||
|
||||
model_dir = model_artifact.download()
|
||||
|
||||
# Load and use
|
||||
import torch
|
||||
model = torch.load(f'{model_dir}/model.pth')
|
||||
model.eval()
|
||||
```
|
||||
|
||||
## Versioning & Lineage
|
||||
|
||||
### Automatic Versioning
|
||||
|
||||
```python
|
||||
# First log: creates v0
|
||||
run1 = wandb.init(project="my-project")
|
||||
dataset_v0 = wandb.Artifact('my-dataset', type='dataset')
|
||||
dataset_v0.add_file('data_v1.csv')
|
||||
run1.log_artifact(dataset_v0)
|
||||
|
||||
# Second log with same name: creates v1
|
||||
run2 = wandb.init(project="my-project")
|
||||
dataset_v1 = wandb.Artifact('my-dataset', type='dataset')
|
||||
dataset_v1.add_file('data_v2.csv') # Different content
|
||||
run2.log_artifact(dataset_v1)
|
||||
|
||||
# Third log with SAME content as v1: references v1 (no new version)
|
||||
run3 = wandb.init(project="my-project")
|
||||
dataset_v1_again = wandb.Artifact('my-dataset', type='dataset')
|
||||
dataset_v1_again.add_file('data_v2.csv') # Same content as v1
|
||||
run3.log_artifact(dataset_v1_again) # Still v1, no v2 created
|
||||
```
|
||||
|
||||
### Track Lineage
|
||||
|
||||
```python
|
||||
# Training run
|
||||
run = wandb.init(project="my-project")
|
||||
|
||||
# Use dataset (input)
|
||||
dataset = run.use_artifact('training-data:v3')
|
||||
data = load_data(dataset.download())
|
||||
|
||||
# Train model
|
||||
model = train(data)
|
||||
|
||||
# Save model (output)
|
||||
model_artifact = wandb.Artifact('trained-model', type='model')
|
||||
torch.save(model.state_dict(), 'model.pth')
|
||||
model_artifact.add_file('model.pth')
|
||||
run.log_artifact(model_artifact)
|
||||
|
||||
# Lineage automatically tracked:
|
||||
# training-data:v3 --> [run] --> trained-model:v0
|
||||
```
|
||||
|
||||
### View Lineage Graph
|
||||
|
||||
```python
|
||||
# In W&B UI:
|
||||
# Artifacts → Select artifact → Lineage tab
|
||||
# Shows:
|
||||
# - Which runs produced this artifact
|
||||
# - Which runs used this artifact
|
||||
# - Parent/child artifacts
|
||||
```
|
||||
|
||||
## Artifact Types
|
||||
|
||||
### Dataset Artifacts
|
||||
|
||||
```python
|
||||
# Raw data
|
||||
raw_data = wandb.Artifact('raw-data', type='dataset')
|
||||
raw_data.add_dir('raw/')
|
||||
|
||||
# Processed data
|
||||
processed_data = wandb.Artifact('processed-data', type='dataset')
|
||||
processed_data.add_dir('processed/')
|
||||
|
||||
# Train/val/test splits
|
||||
train_split = wandb.Artifact('train-split', type='dataset')
|
||||
train_split.add_file('train.csv')
|
||||
|
||||
val_split = wandb.Artifact('val-split', type='dataset')
|
||||
val_split.add_file('val.csv')
|
||||
```
|
||||
|
||||
### Model Artifacts
|
||||
|
||||
```python
|
||||
# Checkpoint during training
|
||||
checkpoint = wandb.Artifact('checkpoint-epoch-10', type='model')
|
||||
checkpoint.add_file('checkpoint_epoch_10.pth')
|
||||
|
||||
# Final model
|
||||
final_model = wandb.Artifact('final-model', type='model')
|
||||
final_model.add_file('model.pth')
|
||||
final_model.add_file('tokenizer.json')
|
||||
|
||||
# Quantized model
|
||||
quantized = wandb.Artifact('quantized-model', type='model')
|
||||
quantized.add_file('model_int8.onnx')
|
||||
```
|
||||
|
||||
### Result Artifacts
|
||||
|
||||
```python
|
||||
# Predictions
|
||||
predictions = wandb.Artifact('test-predictions', type='predictions')
|
||||
predictions.add_file('predictions.csv')
|
||||
|
||||
# Evaluation metrics
|
||||
eval_results = wandb.Artifact('evaluation', type='evaluation')
|
||||
eval_results.add_file('metrics.json')
|
||||
eval_results.add_file('confusion_matrix.png')
|
||||
```
|
||||
|
||||
## Advanced Patterns
|
||||
|
||||
### Incremental Artifacts
|
||||
|
||||
Add files incrementally without re-uploading.
|
||||
|
||||
```python
|
||||
run = wandb.init(project="my-project")
|
||||
|
||||
# Create artifact
|
||||
dataset = wandb.Artifact('incremental-dataset', type='dataset')
|
||||
|
||||
# Add files incrementally
|
||||
for i in range(100):
|
||||
filename = f'batch_{i}.csv'
|
||||
process_batch(i, filename)
|
||||
dataset.add_file(filename)
|
||||
|
||||
# Log progress
|
||||
if (i + 1) % 10 == 0:
|
||||
print(f"Added {i + 1}/100 batches")
|
||||
|
||||
# Log complete artifact
|
||||
run.log_artifact(dataset)
|
||||
```
|
||||
|
||||
### Artifact Tables
|
||||
|
||||
Track structured data with W&B Tables.
|
||||
|
||||
```python
|
||||
import wandb
|
||||
|
||||
run = wandb.init(project="my-project")
|
||||
|
||||
# Create table
|
||||
table = wandb.Table(columns=["id", "image", "label", "prediction"])
|
||||
|
||||
for idx, (img, label, pred) in enumerate(zip(images, labels, predictions)):
|
||||
table.add_data(
|
||||
idx,
|
||||
wandb.Image(img),
|
||||
label,
|
||||
pred
|
||||
)
|
||||
|
||||
# Log as artifact
|
||||
artifact = wandb.Artifact('predictions-table', type='predictions')
|
||||
artifact.add(table, "predictions")
|
||||
run.log_artifact(artifact)
|
||||
```
|
||||
|
||||
### Artifact References
|
||||
|
||||
Reference external data without copying.
|
||||
|
||||
```python
|
||||
# S3 reference
|
||||
dataset = wandb.Artifact('s3-dataset', type='dataset')
|
||||
dataset.add_reference('s3://my-bucket/data/', name='train')
|
||||
dataset.add_reference('s3://my-bucket/labels/', name='labels')
|
||||
|
||||
# GCS reference
|
||||
dataset.add_reference('gs://my-bucket/data/')
|
||||
|
||||
# HTTP reference
|
||||
dataset.add_reference('https://example.com/data.zip')
|
||||
|
||||
# Local filesystem reference (for shared storage)
|
||||
dataset.add_reference('file:///mnt/shared/data')
|
||||
```
|
||||
|
||||
## Collaboration Patterns
|
||||
|
||||
### Team Dataset Sharing
|
||||
|
||||
```python
|
||||
# Data engineer creates dataset
|
||||
run = wandb.init(project="data-eng", entity="my-team")
|
||||
dataset = wandb.Artifact('shared-dataset', type='dataset')
|
||||
dataset.add_dir('data/')
|
||||
run.log_artifact(dataset, aliases=['latest', 'production'])
|
||||
|
||||
# ML engineer uses dataset
|
||||
run = wandb.init(project="ml-training", entity="my-team")
|
||||
dataset = run.use_artifact('my-team/data-eng/shared-dataset:production')
|
||||
data = load_data(dataset.download())
|
||||
```
|
||||
|
||||
### Model Handoff
|
||||
|
||||
```python
|
||||
# Training team
|
||||
train_run = wandb.init(project="model-training", entity="ml-team")
|
||||
model = train_model()
|
||||
model_artifact = wandb.Artifact('nlp-model', type='model')
|
||||
model_artifact.add_file('model.pth')
|
||||
train_run.log_artifact(model_artifact)
|
||||
train_run.link_artifact(model_artifact, 'model-registry/nlp-models', aliases=['candidate'])
|
||||
|
||||
# Evaluation team
|
||||
eval_run = wandb.init(project="model-eval", entity="ml-team")
|
||||
model_artifact = eval_run.use_artifact('model-registry/nlp-models/nlp-model:candidate')
|
||||
metrics = evaluate_model(model_artifact)
|
||||
|
||||
if metrics['f1'] > 0.9:
|
||||
# Promote to production
|
||||
model_artifact.link('model-registry/nlp-models', aliases=['production'])
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Use Descriptive Names
|
||||
|
||||
```python
|
||||
# ✅ Good: Descriptive names
|
||||
wandb.Artifact('imagenet-train-augmented-v2', type='dataset')
|
||||
wandb.Artifact('bert-base-sentiment-finetuned', type='model')
|
||||
|
||||
# ❌ Bad: Generic names
|
||||
wandb.Artifact('dataset1', type='dataset')
|
||||
wandb.Artifact('model', type='model')
|
||||
```
|
||||
|
||||
### 2. Add Comprehensive Metadata
|
||||
|
||||
```python
|
||||
model_artifact = wandb.Artifact(
|
||||
'production-model',
|
||||
type='model',
|
||||
description='ResNet50 classifier for product categorization',
|
||||
metadata={
|
||||
# Model info
|
||||
'architecture': 'ResNet50',
|
||||
'framework': 'PyTorch 2.0',
|
||||
'pretrained': True,
|
||||
|
||||
# Performance
|
||||
'accuracy': 0.95,
|
||||
'f1_score': 0.93,
|
||||
'inference_time_ms': 15,
|
||||
|
||||
# Training
|
||||
'epochs': 50,
|
||||
'dataset': 'imagenet',
|
||||
'num_samples': 1200000,
|
||||
|
||||
# Business context
|
||||
'use_case': 'e-commerce product classification',
|
||||
'owner': 'ml-team@company.com',
|
||||
'approved_by': 'data-science-lead'
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Use Aliases for Deployment Stages
|
||||
|
||||
```python
|
||||
# Development
|
||||
run.log_artifact(model, aliases=['dev', 'latest'])
|
||||
|
||||
# Staging
|
||||
run.log_artifact(model, aliases=['staging'])
|
||||
|
||||
# Production
|
||||
run.log_artifact(model, aliases=['production', 'v1.2.0'])
|
||||
|
||||
# Archive old versions
|
||||
old_artifact = api.artifact('model:production')
|
||||
old_artifact.aliases = ['archived-v1.1.0']
|
||||
old_artifact.save()
|
||||
```
|
||||
|
||||
### 4. Track Data Lineage
|
||||
|
||||
```python
|
||||
def create_training_pipeline():
|
||||
run = wandb.init(project="pipeline")
|
||||
|
||||
# 1. Load raw data
|
||||
raw_data = run.use_artifact('raw-data:latest')
|
||||
|
||||
# 2. Preprocess
|
||||
processed = preprocess(raw_data)
|
||||
processed_artifact = wandb.Artifact('processed-data', type='dataset')
|
||||
processed_artifact.add_file('processed.csv')
|
||||
run.log_artifact(processed_artifact)
|
||||
|
||||
# 3. Train model
|
||||
model = train(processed)
|
||||
model_artifact = wandb.Artifact('trained-model', type='model')
|
||||
model_artifact.add_file('model.pth')
|
||||
run.log_artifact(model_artifact)
|
||||
|
||||
# Lineage: raw-data → processed-data → trained-model
|
||||
```
|
||||
|
||||
### 5. Efficient Storage
|
||||
|
||||
```python
|
||||
# ✅ Good: Reference large files
|
||||
large_dataset = wandb.Artifact('large-dataset', type='dataset')
|
||||
large_dataset.add_reference('s3://bucket/huge-file.tar.gz')
|
||||
|
||||
# ❌ Bad: Upload giant files
|
||||
# large_dataset.add_file('huge-file.tar.gz') # Don't do this
|
||||
|
||||
# ✅ Good: Upload only metadata
|
||||
metadata_artifact = wandb.Artifact('dataset-metadata', type='dataset')
|
||||
metadata_artifact.add_file('metadata.json') # Small file
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **Artifacts Documentation**: https://docs.wandb.ai/guides/artifacts
|
||||
- **Model Registry**: https://docs.wandb.ai/guides/model-registry
|
||||
- **Best Practices**: https://wandb.ai/site/articles/versioning-data-and-models-in-ml
|
||||
@@ -0,0 +1,700 @@
|
||||
# Framework Integrations Guide
|
||||
|
||||
Complete guide to integrating W&B with popular ML frameworks.
|
||||
|
||||
## Table of Contents
|
||||
- HuggingFace Transformers
|
||||
- PyTorch Lightning
|
||||
- Keras/TensorFlow
|
||||
- Fast.ai
|
||||
- XGBoost/LightGBM
|
||||
- PyTorch Native
|
||||
- Custom Integrations
|
||||
|
||||
## HuggingFace Transformers
|
||||
|
||||
### Automatic Integration
|
||||
|
||||
```python
|
||||
from transformers import Trainer, TrainingArguments
|
||||
import wandb
|
||||
|
||||
# Initialize W&B
|
||||
wandb.init(project="hf-transformers", name="bert-finetuning")
|
||||
|
||||
# Training arguments with W&B
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
report_to="wandb", # Enable W&B logging
|
||||
run_name="bert-base-finetuning",
|
||||
|
||||
# Training params
|
||||
num_train_epochs=3,
|
||||
per_device_train_batch_size=16,
|
||||
per_device_eval_batch_size=64,
|
||||
learning_rate=2e-5,
|
||||
|
||||
# Logging
|
||||
logging_dir="./logs",
|
||||
logging_steps=100,
|
||||
logging_first_step=True,
|
||||
|
||||
# Evaluation
|
||||
evaluation_strategy="steps",
|
||||
eval_steps=500,
|
||||
save_steps=500,
|
||||
|
||||
# Other
|
||||
load_best_model_at_end=True,
|
||||
metric_for_best_model="eval_accuracy"
|
||||
)
|
||||
|
||||
# Trainer automatically logs to W&B
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
compute_metrics=compute_metrics
|
||||
)
|
||||
|
||||
# Train (metrics logged automatically)
|
||||
trainer.train()
|
||||
|
||||
# Finish W&B run
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
### Custom Logging
|
||||
|
||||
```python
|
||||
from transformers import Trainer, TrainingArguments
|
||||
from transformers.integrations import WandbCallback
|
||||
import wandb
|
||||
|
||||
class CustomWandbCallback(WandbCallback):
|
||||
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
|
||||
super().on_evaluate(args, state, control, metrics, **kwargs)
|
||||
|
||||
# Log custom metrics
|
||||
wandb.log({
|
||||
"custom/eval_score": metrics["eval_accuracy"] * 100,
|
||||
"custom/epoch": state.epoch
|
||||
})
|
||||
|
||||
# Use custom callback
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
callbacks=[CustomWandbCallback()]
|
||||
)
|
||||
```
|
||||
|
||||
### Log Model to Registry
|
||||
|
||||
```python
|
||||
from transformers import Trainer, TrainingArguments
|
||||
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
report_to="wandb",
|
||||
load_best_model_at_end=True
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
# Save final model as artifact
|
||||
model_artifact = wandb.Artifact(
|
||||
'hf-bert-model',
|
||||
type='model',
|
||||
description='BERT finetuned on sentiment analysis'
|
||||
)
|
||||
|
||||
# Save model files
|
||||
trainer.save_model("./final_model")
|
||||
model_artifact.add_dir("./final_model")
|
||||
|
||||
# Log artifact
|
||||
wandb.log_artifact(model_artifact, aliases=['best', 'production'])
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
## PyTorch Lightning
|
||||
|
||||
### Basic Integration
|
||||
|
||||
```python
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
import wandb
|
||||
|
||||
# Create W&B logger
|
||||
wandb_logger = WandbLogger(
|
||||
project="lightning-demo",
|
||||
name="resnet50-training",
|
||||
log_model=True, # Log model checkpoints as artifacts
|
||||
save_code=True # Save code as artifact
|
||||
)
|
||||
|
||||
# Lightning module
|
||||
class LitModel(pl.LightningModule):
|
||||
def __init__(self, learning_rate=0.001):
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
self.model = create_model()
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self.model(x)
|
||||
loss = F.cross_entropy(y_hat, y)
|
||||
|
||||
# Log metrics (automatically sent to W&B)
|
||||
self.log('train/loss', loss, on_step=True, on_epoch=True)
|
||||
self.log('train/accuracy', accuracy(y_hat, y), on_epoch=True)
|
||||
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self.model(x)
|
||||
loss = F.cross_entropy(y_hat, y)
|
||||
|
||||
self.log('val/loss', loss, on_step=False, on_epoch=True)
|
||||
self.log('val/accuracy', accuracy(y_hat, y), on_epoch=True)
|
||||
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
|
||||
|
||||
# Trainer with W&B logger
|
||||
trainer = pl.Trainer(
|
||||
logger=wandb_logger,
|
||||
max_epochs=10,
|
||||
accelerator="gpu",
|
||||
devices=1
|
||||
)
|
||||
|
||||
# Train (metrics logged automatically)
|
||||
trainer.fit(model, datamodule=dm)
|
||||
|
||||
# Finish W&B run
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
### Log Media
|
||||
|
||||
```python
|
||||
class LitModel(pl.LightningModule):
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self.model(x)
|
||||
|
||||
# Log images (first batch only)
|
||||
if batch_idx == 0:
|
||||
self.logger.experiment.log({
|
||||
"examples": [wandb.Image(img) for img in x[:8]]
|
||||
})
|
||||
|
||||
return loss
|
||||
|
||||
def on_validation_epoch_end(self):
|
||||
# Log confusion matrix
|
||||
cm = compute_confusion_matrix(self.all_preds, self.all_targets)
|
||||
|
||||
self.logger.experiment.log({
|
||||
"confusion_matrix": wandb.plot.confusion_matrix(
|
||||
probs=None,
|
||||
y_true=self.all_targets,
|
||||
preds=self.all_preds,
|
||||
class_names=self.class_names
|
||||
)
|
||||
})
|
||||
```
|
||||
|
||||
### Hyperparameter Sweeps
|
||||
|
||||
```python
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
import wandb
|
||||
|
||||
# Define sweep
|
||||
sweep_config = {
|
||||
'method': 'bayes',
|
||||
'metric': {'name': 'val/accuracy', 'goal': 'maximize'},
|
||||
'parameters': {
|
||||
'learning_rate': {'min': 1e-5, 'max': 1e-2, 'distribution': 'log_uniform'},
|
||||
'batch_size': {'values': [16, 32, 64]},
|
||||
'hidden_size': {'values': [128, 256, 512]}
|
||||
}
|
||||
}
|
||||
|
||||
sweep_id = wandb.sweep(sweep_config, project="lightning-sweeps")
|
||||
|
||||
def train():
|
||||
# Initialize W&B
|
||||
run = wandb.init()
|
||||
|
||||
# Get hyperparameters
|
||||
config = wandb.config
|
||||
|
||||
# Create logger
|
||||
wandb_logger = WandbLogger()
|
||||
|
||||
# Create model with sweep params
|
||||
model = LitModel(
|
||||
learning_rate=config.learning_rate,
|
||||
hidden_size=config.hidden_size
|
||||
)
|
||||
|
||||
# Create datamodule with sweep batch size
|
||||
dm = DataModule(batch_size=config.batch_size)
|
||||
|
||||
# Train
|
||||
trainer = pl.Trainer(logger=wandb_logger, max_epochs=10)
|
||||
trainer.fit(model, dm)
|
||||
|
||||
# Run sweep
|
||||
wandb.agent(sweep_id, function=train, count=30)
|
||||
```
|
||||
|
||||
## Keras/TensorFlow
|
||||
|
||||
### With Callback
|
||||
|
||||
```python
|
||||
import tensorflow as tf
|
||||
from wandb.keras import WandbCallback
|
||||
import wandb
|
||||
|
||||
# Initialize W&B
|
||||
wandb.init(
|
||||
project="keras-demo",
|
||||
config={
|
||||
"learning_rate": 0.001,
|
||||
"epochs": 10,
|
||||
"batch_size": 32
|
||||
}
|
||||
)
|
||||
|
||||
config = wandb.config
|
||||
|
||||
# Build model
|
||||
model = tf.keras.Sequential([
|
||||
tf.keras.layers.Dense(128, activation='relu'),
|
||||
tf.keras.layers.Dropout(0.2),
|
||||
tf.keras.layers.Dense(10, activation='softmax')
|
||||
])
|
||||
|
||||
model.compile(
|
||||
optimizer=tf.keras.optimizers.Adam(config.learning_rate),
|
||||
loss='sparse_categorical_crossentropy',
|
||||
metrics=['accuracy']
|
||||
)
|
||||
|
||||
# Train with W&B callback
|
||||
history = model.fit(
|
||||
x_train, y_train,
|
||||
validation_data=(x_val, y_val),
|
||||
epochs=config.epochs,
|
||||
batch_size=config.batch_size,
|
||||
callbacks=[
|
||||
WandbCallback(
|
||||
log_weights=True, # Log model weights
|
||||
log_gradients=True, # Log gradients
|
||||
training_data=(x_train, y_train),
|
||||
validation_data=(x_val, y_val),
|
||||
labels=class_names
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Save model as artifact
|
||||
model.save('model.h5')
|
||||
artifact = wandb.Artifact('keras-model', type='model')
|
||||
artifact.add_file('model.h5')
|
||||
wandb.log_artifact(artifact)
|
||||
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
### Custom Training Loop
|
||||
|
||||
```python
|
||||
import tensorflow as tf
|
||||
import wandb
|
||||
|
||||
wandb.init(project="tf-custom-loop")
|
||||
|
||||
# Model, optimizer, loss
|
||||
model = create_model()
|
||||
optimizer = tf.keras.optimizers.Adam(1e-3)
|
||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
|
||||
|
||||
# Metrics
|
||||
train_loss = tf.keras.metrics.Mean(name='train_loss')
|
||||
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
|
||||
|
||||
@tf.function
|
||||
def train_step(x, y):
|
||||
with tf.GradientTape() as tape:
|
||||
predictions = model(x, training=True)
|
||||
loss = loss_fn(y, predictions)
|
||||
|
||||
gradients = tape.gradient(loss, model.trainable_variables)
|
||||
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
|
||||
|
||||
train_loss(loss)
|
||||
train_accuracy(y, predictions)
|
||||
|
||||
# Training loop
|
||||
for epoch in range(EPOCHS):
|
||||
train_loss.reset_states()
|
||||
train_accuracy.reset_states()
|
||||
|
||||
for step, (x, y) in enumerate(train_dataset):
|
||||
train_step(x, y)
|
||||
|
||||
# Log every 100 steps
|
||||
if step % 100 == 0:
|
||||
wandb.log({
|
||||
'train/loss': train_loss.result().numpy(),
|
||||
'train/accuracy': train_accuracy.result().numpy(),
|
||||
'epoch': epoch,
|
||||
'step': step
|
||||
})
|
||||
|
||||
# Log epoch metrics
|
||||
wandb.log({
|
||||
'epoch/train_loss': train_loss.result().numpy(),
|
||||
'epoch/train_accuracy': train_accuracy.result().numpy(),
|
||||
'epoch': epoch
|
||||
})
|
||||
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
## Fast.ai
|
||||
|
||||
### With Callback
|
||||
|
||||
```python
|
||||
from fastai.vision.all import *
|
||||
from fastai.callback.wandb import *
|
||||
import wandb
|
||||
|
||||
# Initialize W&B
|
||||
wandb.init(project="fastai-demo")
|
||||
|
||||
# Create data loaders
|
||||
dls = ImageDataLoaders.from_folder(
|
||||
path,
|
||||
train='train',
|
||||
valid='valid',
|
||||
bs=64
|
||||
)
|
||||
|
||||
# Create learner with W&B callback
|
||||
learn = vision_learner(
|
||||
dls,
|
||||
resnet34,
|
||||
metrics=accuracy,
|
||||
cbs=WandbCallback(
|
||||
log_preds=True, # Log predictions
|
||||
log_model=True, # Log model as artifact
|
||||
log_dataset=True # Log dataset as artifact
|
||||
)
|
||||
)
|
||||
|
||||
# Train (metrics logged automatically)
|
||||
learn.fine_tune(5)
|
||||
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
## XGBoost/LightGBM
|
||||
|
||||
### XGBoost
|
||||
|
||||
```python
|
||||
import xgboost as xgb
|
||||
import wandb
|
||||
|
||||
# Initialize W&B
|
||||
run = wandb.init(project="xgboost-demo", config={
|
||||
"max_depth": 6,
|
||||
"learning_rate": 0.1,
|
||||
"n_estimators": 100
|
||||
})
|
||||
|
||||
config = wandb.config
|
||||
|
||||
# Create DMatrix
|
||||
dtrain = xgb.DMatrix(X_train, label=y_train)
|
||||
dval = xgb.DMatrix(X_val, label=y_val)
|
||||
|
||||
# XGBoost params
|
||||
params = {
|
||||
'max_depth': config.max_depth,
|
||||
'learning_rate': config.learning_rate,
|
||||
'objective': 'binary:logistic',
|
||||
'eval_metric': ['logloss', 'auc']
|
||||
}
|
||||
|
||||
# Custom callback for W&B
|
||||
def wandb_callback(env):
|
||||
"""Log XGBoost metrics to W&B."""
|
||||
for metric_name, metric_value in env.evaluation_result_list:
|
||||
wandb.log({
|
||||
f"{metric_name}": metric_value,
|
||||
"iteration": env.iteration
|
||||
})
|
||||
|
||||
# Train with callback
|
||||
model = xgb.train(
|
||||
params,
|
||||
dtrain,
|
||||
num_boost_round=config.n_estimators,
|
||||
evals=[(dtrain, 'train'), (dval, 'val')],
|
||||
callbacks=[wandb_callback],
|
||||
verbose_eval=10
|
||||
)
|
||||
|
||||
# Save model
|
||||
model.save_model('xgboost_model.json')
|
||||
artifact = wandb.Artifact('xgboost-model', type='model')
|
||||
artifact.add_file('xgboost_model.json')
|
||||
wandb.log_artifact(artifact)
|
||||
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
### LightGBM
|
||||
|
||||
```python
|
||||
import lightgbm as lgb
|
||||
import wandb
|
||||
|
||||
run = wandb.init(project="lgbm-demo")
|
||||
|
||||
# Create datasets
|
||||
train_data = lgb.Dataset(X_train, label=y_train)
|
||||
val_data = lgb.Dataset(X_val, label=y_val, reference=train_data)
|
||||
|
||||
# Parameters
|
||||
params = {
|
||||
'objective': 'binary',
|
||||
'metric': ['binary_logloss', 'auc'],
|
||||
'learning_rate': 0.1,
|
||||
'num_leaves': 31
|
||||
}
|
||||
|
||||
# Custom callback
|
||||
def log_to_wandb(env):
|
||||
"""Log LightGBM metrics to W&B."""
|
||||
for entry in env.evaluation_result_list:
|
||||
dataset_name, metric_name, metric_value, _ = entry
|
||||
wandb.log({
|
||||
f"{dataset_name}/{metric_name}": metric_value,
|
||||
"iteration": env.iteration
|
||||
})
|
||||
|
||||
# Train
|
||||
model = lgb.train(
|
||||
params,
|
||||
train_data,
|
||||
num_boost_round=100,
|
||||
valid_sets=[train_data, val_data],
|
||||
valid_names=['train', 'val'],
|
||||
callbacks=[log_to_wandb]
|
||||
)
|
||||
|
||||
# Save model
|
||||
model.save_model('lgbm_model.txt')
|
||||
artifact = wandb.Artifact('lgbm-model', type='model')
|
||||
artifact.add_file('lgbm_model.txt')
|
||||
wandb.log_artifact(artifact)
|
||||
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
## PyTorch Native
|
||||
|
||||
### Training Loop Integration
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import wandb
|
||||
|
||||
# Initialize W&B
|
||||
wandb.init(project="pytorch-native", config={
|
||||
"learning_rate": 0.001,
|
||||
"epochs": 10,
|
||||
"batch_size": 32
|
||||
})
|
||||
|
||||
config = wandb.config
|
||||
|
||||
# Model, loss, optimizer
|
||||
model = create_model()
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
|
||||
|
||||
# Watch model (logs gradients and parameters)
|
||||
wandb.watch(model, criterion, log="all", log_freq=100)
|
||||
|
||||
# Training loop
|
||||
for epoch in range(config.epochs):
|
||||
model.train()
|
||||
train_loss = 0.0
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
data, target = data.to(device), target.to(device)
|
||||
|
||||
# Forward pass
|
||||
optimizer.zero_grad()
|
||||
output = model(data)
|
||||
loss = criterion(output, target)
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Track metrics
|
||||
train_loss += loss.item()
|
||||
_, predicted = output.max(1)
|
||||
total += target.size(0)
|
||||
correct += predicted.eq(target).sum().item()
|
||||
|
||||
# Log every 100 batches
|
||||
if batch_idx % 100 == 0:
|
||||
wandb.log({
|
||||
'train/loss': loss.item(),
|
||||
'train/batch_accuracy': 100. * correct / total,
|
||||
'epoch': epoch,
|
||||
'batch': batch_idx
|
||||
})
|
||||
|
||||
# Validation
|
||||
model.eval()
|
||||
val_loss = 0.0
|
||||
val_correct = 0
|
||||
val_total = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for data, target in val_loader:
|
||||
data, target = data.to(device), target.to(device)
|
||||
output = model(data)
|
||||
loss = criterion(output, target)
|
||||
|
||||
val_loss += loss.item()
|
||||
_, predicted = output.max(1)
|
||||
val_total += target.size(0)
|
||||
val_correct += predicted.eq(target).sum().item()
|
||||
|
||||
# Log epoch metrics
|
||||
wandb.log({
|
||||
'epoch/train_loss': train_loss / len(train_loader),
|
||||
'epoch/train_accuracy': 100. * correct / total,
|
||||
'epoch/val_loss': val_loss / len(val_loader),
|
||||
'epoch/val_accuracy': 100. * val_correct / val_total,
|
||||
'epoch': epoch
|
||||
})
|
||||
|
||||
# Save final model
|
||||
torch.save(model.state_dict(), 'model.pth')
|
||||
artifact = wandb.Artifact('final-model', type='model')
|
||||
artifact.add_file('model.pth')
|
||||
wandb.log_artifact(artifact)
|
||||
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
## Custom Integrations
|
||||
|
||||
### Generic Framework Integration
|
||||
|
||||
```python
|
||||
import wandb
|
||||
|
||||
class WandbIntegration:
|
||||
"""Generic W&B integration wrapper."""
|
||||
|
||||
def __init__(self, project, config):
|
||||
self.run = wandb.init(project=project, config=config)
|
||||
self.config = wandb.config
|
||||
self.step = 0
|
||||
|
||||
def log_metrics(self, metrics, step=None):
|
||||
"""Log training metrics."""
|
||||
if step is None:
|
||||
step = self.step
|
||||
self.step += 1
|
||||
|
||||
wandb.log(metrics, step=step)
|
||||
|
||||
def log_images(self, images, caption=""):
|
||||
"""Log images."""
|
||||
wandb.log({
|
||||
caption: [wandb.Image(img) for img in images]
|
||||
})
|
||||
|
||||
def log_table(self, data, columns):
|
||||
"""Log tabular data."""
|
||||
table = wandb.Table(columns=columns, data=data)
|
||||
wandb.log({"table": table})
|
||||
|
||||
def save_model(self, model_path, metadata=None):
|
||||
"""Save model as artifact."""
|
||||
artifact = wandb.Artifact(
|
||||
'model',
|
||||
type='model',
|
||||
metadata=metadata or {}
|
||||
)
|
||||
artifact.add_file(model_path)
|
||||
self.run.log_artifact(artifact)
|
||||
|
||||
def finish(self):
|
||||
"""Finish W&B run."""
|
||||
wandb.finish()
|
||||
|
||||
# Usage
|
||||
wb = WandbIntegration(project="my-project", config={"lr": 0.001})
|
||||
|
||||
# Training loop
|
||||
for epoch in range(10):
|
||||
# Your training code
|
||||
loss, accuracy = train_epoch()
|
||||
|
||||
# Log metrics
|
||||
wb.log_metrics({
|
||||
'train/loss': loss,
|
||||
'train/accuracy': accuracy
|
||||
})
|
||||
|
||||
# Save model
|
||||
wb.save_model('model.pth', metadata={'accuracy': 0.95})
|
||||
wb.finish()
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **Integrations Guide**: https://docs.wandb.ai/guides/integrations
|
||||
- **HuggingFace**: https://docs.wandb.ai/guides/integrations/huggingface
|
||||
- **PyTorch Lightning**: https://docs.wandb.ai/guides/integrations/lightning
|
||||
- **Keras**: https://docs.wandb.ai/guides/integrations/keras
|
||||
- **Examples**: https://github.com/wandb/examples
|
||||
@@ -0,0 +1,847 @@
|
||||
# Comprehensive Hyperparameter Sweeps Guide
|
||||
|
||||
Complete guide to hyperparameter optimization with W&B Sweeps.
|
||||
|
||||
## Table of Contents
|
||||
- Sweep Configuration
|
||||
- Search Strategies
|
||||
- Parameter Distributions
|
||||
- Early Termination
|
||||
- Parallel Execution
|
||||
- Advanced Patterns
|
||||
- Real-World Examples
|
||||
|
||||
## Sweep Configuration
|
||||
|
||||
### Basic Sweep Config
|
||||
|
||||
```python
|
||||
sweep_config = {
|
||||
'method': 'bayes', # Search strategy
|
||||
'metric': {
|
||||
'name': 'val/accuracy',
|
||||
'goal': 'maximize' # or 'minimize'
|
||||
},
|
||||
'parameters': {
|
||||
'learning_rate': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-5,
|
||||
'max': 1e-1
|
||||
},
|
||||
'batch_size': {
|
||||
'values': [16, 32, 64, 128]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Initialize sweep
|
||||
sweep_id = wandb.sweep(sweep_config, project="my-project")
|
||||
```
|
||||
|
||||
### Complete Config Example
|
||||
|
||||
```python
|
||||
sweep_config = {
|
||||
# Required: Search method
|
||||
'method': 'bayes',
|
||||
|
||||
# Required: Optimization metric
|
||||
'metric': {
|
||||
'name': 'val/f1_score',
|
||||
'goal': 'maximize'
|
||||
},
|
||||
|
||||
# Required: Parameters to search
|
||||
'parameters': {
|
||||
# Continuous parameter
|
||||
'learning_rate': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-5,
|
||||
'max': 1e-1
|
||||
},
|
||||
|
||||
# Discrete values
|
||||
'batch_size': {
|
||||
'values': [16, 32, 64, 128]
|
||||
},
|
||||
|
||||
# Categorical
|
||||
'optimizer': {
|
||||
'values': ['adam', 'sgd', 'rmsprop', 'adamw']
|
||||
},
|
||||
|
||||
# Uniform distribution
|
||||
'dropout': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.1,
|
||||
'max': 0.5
|
||||
},
|
||||
|
||||
# Integer range
|
||||
'num_layers': {
|
||||
'distribution': 'int_uniform',
|
||||
'min': 2,
|
||||
'max': 10
|
||||
},
|
||||
|
||||
# Fixed value (constant across runs)
|
||||
'epochs': {
|
||||
'value': 50
|
||||
}
|
||||
},
|
||||
|
||||
# Optional: Early termination
|
||||
'early_terminate': {
|
||||
'type': 'hyperband',
|
||||
'min_iter': 5,
|
||||
's': 2,
|
||||
'eta': 3,
|
||||
'max_iter': 27
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Search Strategies
|
||||
|
||||
### 1. Grid Search
|
||||
|
||||
Exhaustively search all combinations.
|
||||
|
||||
```python
|
||||
sweep_config = {
|
||||
'method': 'grid',
|
||||
'parameters': {
|
||||
'learning_rate': {
|
||||
'values': [0.001, 0.01, 0.1]
|
||||
},
|
||||
'batch_size': {
|
||||
'values': [16, 32, 64]
|
||||
},
|
||||
'optimizer': {
|
||||
'values': ['adam', 'sgd']
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Total runs: 3 × 3 × 2 = 18 runs
|
||||
```
|
||||
|
||||
**Pros:**
|
||||
- Comprehensive search
|
||||
- Reproducible results
|
||||
- No randomness
|
||||
|
||||
**Cons:**
|
||||
- Exponential growth with parameters
|
||||
- Inefficient for continuous parameters
|
||||
- Not scalable beyond 3-4 parameters
|
||||
|
||||
**When to use:**
|
||||
- Few parameters (< 4)
|
||||
- All discrete values
|
||||
- Need complete coverage
|
||||
|
||||
### 2. Random Search
|
||||
|
||||
Randomly sample parameter combinations.
|
||||
|
||||
```python
|
||||
sweep_config = {
|
||||
'method': 'random',
|
||||
'parameters': {
|
||||
'learning_rate': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-5,
|
||||
'max': 1e-1
|
||||
},
|
||||
'batch_size': {
|
||||
'values': [16, 32, 64, 128, 256]
|
||||
},
|
||||
'dropout': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.0,
|
||||
'max': 0.5
|
||||
},
|
||||
'num_layers': {
|
||||
'distribution': 'int_uniform',
|
||||
'min': 2,
|
||||
'max': 8
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Run 100 random trials
|
||||
wandb.agent(sweep_id, function=train, count=100)
|
||||
```
|
||||
|
||||
**Pros:**
|
||||
- Scales to many parameters
|
||||
- Can run indefinitely
|
||||
- Often finds good solutions quickly
|
||||
|
||||
**Cons:**
|
||||
- No learning from previous runs
|
||||
- May miss optimal region
|
||||
- Results vary with random seed
|
||||
|
||||
**When to use:**
|
||||
- Many parameters (> 4)
|
||||
- Quick exploration
|
||||
- Limited budget
|
||||
|
||||
### 3. Bayesian Optimization (Recommended)
|
||||
|
||||
Learn from previous trials to sample promising regions.
|
||||
|
||||
```python
|
||||
sweep_config = {
|
||||
'method': 'bayes',
|
||||
'metric': {
|
||||
'name': 'val/loss',
|
||||
'goal': 'minimize'
|
||||
},
|
||||
'parameters': {
|
||||
'learning_rate': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-5,
|
||||
'max': 1e-1
|
||||
},
|
||||
'weight_decay': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-6,
|
||||
'max': 1e-2
|
||||
},
|
||||
'dropout': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.1,
|
||||
'max': 0.5
|
||||
},
|
||||
'num_layers': {
|
||||
'values': [2, 3, 4, 5, 6]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Pros:**
|
||||
- Most sample-efficient
|
||||
- Learns from past trials
|
||||
- Focuses on promising regions
|
||||
|
||||
**Cons:**
|
||||
- Initial random exploration phase
|
||||
- May get stuck in local optima
|
||||
- Slower per iteration
|
||||
|
||||
**When to use:**
|
||||
- Expensive training runs
|
||||
- Need best performance
|
||||
- Limited compute budget
|
||||
|
||||
## Parameter Distributions
|
||||
|
||||
### Continuous Distributions
|
||||
|
||||
```python
|
||||
# Log-uniform: Good for learning rates, regularization
|
||||
'learning_rate': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-6,
|
||||
'max': 1e-1
|
||||
}
|
||||
|
||||
# Uniform: Good for dropout, momentum
|
||||
'dropout': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.0,
|
||||
'max': 0.5
|
||||
}
|
||||
|
||||
# Normal distribution
|
||||
'parameter': {
|
||||
'distribution': 'normal',
|
||||
'mu': 0.5,
|
||||
'sigma': 0.1
|
||||
}
|
||||
|
||||
# Log-normal distribution
|
||||
'parameter': {
|
||||
'distribution': 'log_normal',
|
||||
'mu': 0.0,
|
||||
'sigma': 1.0
|
||||
}
|
||||
```
|
||||
|
||||
### Discrete Distributions
|
||||
|
||||
```python
|
||||
# Fixed values
|
||||
'batch_size': {
|
||||
'values': [16, 32, 64, 128, 256]
|
||||
}
|
||||
|
||||
# Integer uniform
|
||||
'num_layers': {
|
||||
'distribution': 'int_uniform',
|
||||
'min': 2,
|
||||
'max': 10
|
||||
}
|
||||
|
||||
# Quantized uniform (step size)
|
||||
'layer_size': {
|
||||
'distribution': 'q_uniform',
|
||||
'min': 32,
|
||||
'max': 512,
|
||||
'q': 32 # Step by 32: 32, 64, 96, 128...
|
||||
}
|
||||
|
||||
# Quantized log-uniform
|
||||
'hidden_size': {
|
||||
'distribution': 'q_log_uniform',
|
||||
'min': 32,
|
||||
'max': 1024,
|
||||
'q': 32
|
||||
}
|
||||
```
|
||||
|
||||
### Categorical Parameters
|
||||
|
||||
```python
|
||||
# Optimizers
|
||||
'optimizer': {
|
||||
'values': ['adam', 'sgd', 'rmsprop', 'adamw']
|
||||
}
|
||||
|
||||
# Model architectures
|
||||
'model': {
|
||||
'values': ['resnet18', 'resnet34', 'resnet50', 'efficientnet_b0']
|
||||
}
|
||||
|
||||
# Activation functions
|
||||
'activation': {
|
||||
'values': ['relu', 'gelu', 'silu', 'leaky_relu']
|
||||
}
|
||||
```
|
||||
|
||||
## Early Termination
|
||||
|
||||
Stop underperforming runs early to save compute.
|
||||
|
||||
### Hyperband
|
||||
|
||||
```python
|
||||
sweep_config = {
|
||||
'method': 'bayes',
|
||||
'metric': {'name': 'val/accuracy', 'goal': 'maximize'},
|
||||
'parameters': {...},
|
||||
|
||||
# Hyperband early termination
|
||||
'early_terminate': {
|
||||
'type': 'hyperband',
|
||||
'min_iter': 3, # Minimum iterations before termination
|
||||
's': 2, # Bracket count
|
||||
'eta': 3, # Downsampling rate
|
||||
'max_iter': 27 # Maximum iterations
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**How it works:**
|
||||
- Runs trials in brackets
|
||||
- Keeps top 1/eta performers each round
|
||||
- Eliminates bottom performers early
|
||||
|
||||
### Custom Termination
|
||||
|
||||
```python
|
||||
def train():
|
||||
run = wandb.init()
|
||||
|
||||
for epoch in range(MAX_EPOCHS):
|
||||
loss = train_epoch()
|
||||
val_acc = validate()
|
||||
|
||||
wandb.log({'val/accuracy': val_acc, 'epoch': epoch})
|
||||
|
||||
# Custom early stopping
|
||||
if epoch > 5 and val_acc < 0.5:
|
||||
print("Early stop: Poor performance")
|
||||
break
|
||||
|
||||
if epoch > 10 and val_acc > best_acc - 0.01:
|
||||
print("Early stop: No improvement")
|
||||
break
|
||||
```
|
||||
|
||||
## Training Function
|
||||
|
||||
### Basic Template
|
||||
|
||||
```python
|
||||
def train():
|
||||
# Initialize W&B run
|
||||
run = wandb.init()
|
||||
|
||||
# Get hyperparameters
|
||||
config = wandb.config
|
||||
|
||||
# Build model with config
|
||||
model = build_model(
|
||||
hidden_size=config.hidden_size,
|
||||
num_layers=config.num_layers,
|
||||
dropout=config.dropout
|
||||
)
|
||||
|
||||
# Create optimizer
|
||||
optimizer = create_optimizer(
|
||||
model.parameters(),
|
||||
name=config.optimizer,
|
||||
lr=config.learning_rate,
|
||||
weight_decay=config.weight_decay
|
||||
)
|
||||
|
||||
# Training loop
|
||||
for epoch in range(config.epochs):
|
||||
# Train
|
||||
train_loss, train_acc = train_epoch(
|
||||
model, optimizer, train_loader, config.batch_size
|
||||
)
|
||||
|
||||
# Validate
|
||||
val_loss, val_acc = validate(model, val_loader)
|
||||
|
||||
# Log metrics
|
||||
wandb.log({
|
||||
'train/loss': train_loss,
|
||||
'train/accuracy': train_acc,
|
||||
'val/loss': val_loss,
|
||||
'val/accuracy': val_acc,
|
||||
'epoch': epoch
|
||||
})
|
||||
|
||||
# Log final model
|
||||
torch.save(model.state_dict(), 'model.pth')
|
||||
wandb.save('model.pth')
|
||||
|
||||
# Finish run
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
### With PyTorch
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
import wandb
|
||||
|
||||
def train():
|
||||
run = wandb.init()
|
||||
config = wandb.config
|
||||
|
||||
# Data
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=True
|
||||
)
|
||||
|
||||
# Model
|
||||
model = ResNet(
|
||||
num_classes=config.num_classes,
|
||||
dropout=config.dropout
|
||||
).to(device)
|
||||
|
||||
# Optimizer
|
||||
if config.optimizer == 'adam':
|
||||
optimizer = torch.optim.Adam(
|
||||
model.parameters(),
|
||||
lr=config.learning_rate,
|
||||
weight_decay=config.weight_decay
|
||||
)
|
||||
elif config.optimizer == 'sgd':
|
||||
optimizer = torch.optim.SGD(
|
||||
model.parameters(),
|
||||
lr=config.learning_rate,
|
||||
momentum=config.momentum,
|
||||
weight_decay=config.weight_decay
|
||||
)
|
||||
|
||||
# Scheduler
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
||||
optimizer, T_max=config.epochs
|
||||
)
|
||||
|
||||
# Training
|
||||
for epoch in range(config.epochs):
|
||||
model.train()
|
||||
train_loss = 0.0
|
||||
|
||||
for data, target in train_loader:
|
||||
data, target = data.to(device), target.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
output = model(data)
|
||||
loss = nn.CrossEntropyLoss()(output, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
train_loss += loss.item()
|
||||
|
||||
# Validation
|
||||
model.eval()
|
||||
val_loss, val_acc = validate(model, val_loader)
|
||||
|
||||
# Step scheduler
|
||||
scheduler.step()
|
||||
|
||||
# Log
|
||||
wandb.log({
|
||||
'train/loss': train_loss / len(train_loader),
|
||||
'val/loss': val_loss,
|
||||
'val/accuracy': val_acc,
|
||||
'learning_rate': scheduler.get_last_lr()[0],
|
||||
'epoch': epoch
|
||||
})
|
||||
```
|
||||
|
||||
## Parallel Execution
|
||||
|
||||
### Multiple Agents
|
||||
|
||||
Run sweep agents in parallel to speed up search.
|
||||
|
||||
```python
|
||||
# Initialize sweep once
|
||||
sweep_id = wandb.sweep(sweep_config, project="my-project")
|
||||
|
||||
# Run multiple agents in parallel
|
||||
# Agent 1 (Terminal 1)
|
||||
wandb.agent(sweep_id, function=train, count=20)
|
||||
|
||||
# Agent 2 (Terminal 2)
|
||||
wandb.agent(sweep_id, function=train, count=20)
|
||||
|
||||
# Agent 3 (Terminal 3)
|
||||
wandb.agent(sweep_id, function=train, count=20)
|
||||
|
||||
# Total: 60 runs across 3 agents
|
||||
```
|
||||
|
||||
### Multi-GPU Execution
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
def train():
|
||||
# Get available GPU
|
||||
gpu_id = os.environ.get('CUDA_VISIBLE_DEVICES', '0')
|
||||
|
||||
run = wandb.init()
|
||||
config = wandb.config
|
||||
|
||||
# Train on specific GPU
|
||||
device = torch.device(f'cuda:{gpu_id}')
|
||||
model = model.to(device)
|
||||
|
||||
# ... rest of training ...
|
||||
|
||||
# Run agents on different GPUs
|
||||
# Terminal 1
|
||||
# CUDA_VISIBLE_DEVICES=0 wandb agent sweep_id
|
||||
|
||||
# Terminal 2
|
||||
# CUDA_VISIBLE_DEVICES=1 wandb agent sweep_id
|
||||
|
||||
# Terminal 3
|
||||
# CUDA_VISIBLE_DEVICES=2 wandb agent sweep_id
|
||||
```
|
||||
|
||||
## Advanced Patterns
|
||||
|
||||
### Nested Parameters
|
||||
|
||||
```python
|
||||
sweep_config = {
|
||||
'method': 'bayes',
|
||||
'metric': {'name': 'val/accuracy', 'goal': 'maximize'},
|
||||
'parameters': {
|
||||
'model': {
|
||||
'parameters': {
|
||||
'type': {
|
||||
'values': ['resnet', 'efficientnet']
|
||||
},
|
||||
'size': {
|
||||
'values': ['small', 'medium', 'large']
|
||||
}
|
||||
}
|
||||
},
|
||||
'optimizer': {
|
||||
'parameters': {
|
||||
'type': {
|
||||
'values': ['adam', 'sgd']
|
||||
},
|
||||
'lr': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-5,
|
||||
'max': 1e-1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Access nested config
|
||||
def train():
|
||||
run = wandb.init()
|
||||
model_type = wandb.config.model.type
|
||||
model_size = wandb.config.model.size
|
||||
opt_type = wandb.config.optimizer.type
|
||||
lr = wandb.config.optimizer.lr
|
||||
```
|
||||
|
||||
### Conditional Parameters
|
||||
|
||||
```python
|
||||
sweep_config = {
|
||||
'method': 'bayes',
|
||||
'parameters': {
|
||||
'optimizer': {
|
||||
'values': ['adam', 'sgd']
|
||||
},
|
||||
'learning_rate': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-5,
|
||||
'max': 1e-1
|
||||
},
|
||||
# Only used if optimizer == 'sgd'
|
||||
'momentum': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.5,
|
||||
'max': 0.99
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def train():
|
||||
run = wandb.init()
|
||||
config = wandb.config
|
||||
|
||||
if config.optimizer == 'adam':
|
||||
optimizer = torch.optim.Adam(
|
||||
model.parameters(),
|
||||
lr=config.learning_rate
|
||||
)
|
||||
elif config.optimizer == 'sgd':
|
||||
optimizer = torch.optim.SGD(
|
||||
model.parameters(),
|
||||
lr=config.learning_rate,
|
||||
momentum=config.momentum # Conditional parameter
|
||||
)
|
||||
```
|
||||
|
||||
## Real-World Examples
|
||||
|
||||
### Image Classification
|
||||
|
||||
```python
|
||||
sweep_config = {
|
||||
'method': 'bayes',
|
||||
'metric': {
|
||||
'name': 'val/top1_accuracy',
|
||||
'goal': 'maximize'
|
||||
},
|
||||
'parameters': {
|
||||
# Model
|
||||
'architecture': {
|
||||
'values': ['resnet50', 'resnet101', 'efficientnet_b0', 'efficientnet_b3']
|
||||
},
|
||||
'pretrained': {
|
||||
'values': [True, False]
|
||||
},
|
||||
|
||||
# Training
|
||||
'learning_rate': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-5,
|
||||
'max': 1e-2
|
||||
},
|
||||
'batch_size': {
|
||||
'values': [16, 32, 64, 128]
|
||||
},
|
||||
'optimizer': {
|
||||
'values': ['adam', 'sgd', 'adamw']
|
||||
},
|
||||
'weight_decay': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-6,
|
||||
'max': 1e-2
|
||||
},
|
||||
|
||||
# Regularization
|
||||
'dropout': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.0,
|
||||
'max': 0.5
|
||||
},
|
||||
'label_smoothing': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.0,
|
||||
'max': 0.2
|
||||
},
|
||||
|
||||
# Data augmentation
|
||||
'mixup_alpha': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.0,
|
||||
'max': 1.0
|
||||
},
|
||||
'cutmix_alpha': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.0,
|
||||
'max': 1.0
|
||||
}
|
||||
},
|
||||
'early_terminate': {
|
||||
'type': 'hyperband',
|
||||
'min_iter': 5
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### NLP Fine-Tuning
|
||||
|
||||
```python
|
||||
sweep_config = {
|
||||
'method': 'bayes',
|
||||
'metric': {'name': 'eval/f1', 'goal': 'maximize'},
|
||||
'parameters': {
|
||||
# Model
|
||||
'model_name': {
|
||||
'values': ['bert-base-uncased', 'roberta-base', 'distilbert-base-uncased']
|
||||
},
|
||||
|
||||
# Training
|
||||
'learning_rate': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-6,
|
||||
'max': 1e-4
|
||||
},
|
||||
'per_device_train_batch_size': {
|
||||
'values': [8, 16, 32]
|
||||
},
|
||||
'num_train_epochs': {
|
||||
'values': [3, 4, 5]
|
||||
},
|
||||
'warmup_ratio': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.0,
|
||||
'max': 0.1
|
||||
},
|
||||
'weight_decay': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-4,
|
||||
'max': 1e-1
|
||||
},
|
||||
|
||||
# Optimizer
|
||||
'adam_beta1': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.8,
|
||||
'max': 0.95
|
||||
},
|
||||
'adam_beta2': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.95,
|
||||
'max': 0.999
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Start Small
|
||||
|
||||
```python
|
||||
# Initial exploration: Random search, 20 runs
|
||||
sweep_config_v1 = {
|
||||
'method': 'random',
|
||||
'parameters': {...}
|
||||
}
|
||||
wandb.agent(sweep_id_v1, train, count=20)
|
||||
|
||||
# Refined search: Bayes, narrow ranges
|
||||
sweep_config_v2 = {
|
||||
'method': 'bayes',
|
||||
'parameters': {
|
||||
'learning_rate': {
|
||||
'min': 5e-5, # Narrowed from 1e-6 to 1e-4
|
||||
'max': 1e-4
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Use Log Scales
|
||||
|
||||
```python
|
||||
# ✅ Good: Log scale for learning rate
|
||||
'learning_rate': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-6,
|
||||
'max': 1e-2
|
||||
}
|
||||
|
||||
# ❌ Bad: Linear scale
|
||||
'learning_rate': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.000001,
|
||||
'max': 0.01
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Set Reasonable Ranges
|
||||
|
||||
```python
|
||||
# Base ranges on prior knowledge
|
||||
'learning_rate': {'min': 1e-5, 'max': 1e-3}, # Typical for Adam
|
||||
'batch_size': {'values': [16, 32, 64]}, # GPU memory limits
|
||||
'dropout': {'min': 0.1, 'max': 0.5} # Too high hurts training
|
||||
```
|
||||
|
||||
### 4. Monitor Resource Usage
|
||||
|
||||
```python
|
||||
def train():
|
||||
run = wandb.init()
|
||||
|
||||
# Log system metrics
|
||||
wandb.log({
|
||||
'system/gpu_memory_allocated': torch.cuda.memory_allocated(),
|
||||
'system/gpu_memory_reserved': torch.cuda.memory_reserved()
|
||||
})
|
||||
```
|
||||
|
||||
### 5. Save Best Models
|
||||
|
||||
```python
|
||||
def train():
|
||||
run = wandb.init()
|
||||
best_acc = 0.0
|
||||
|
||||
for epoch in range(config.epochs):
|
||||
val_acc = validate(model)
|
||||
|
||||
if val_acc > best_acc:
|
||||
best_acc = val_acc
|
||||
# Save best checkpoint
|
||||
torch.save(model.state_dict(), 'best_model.pth')
|
||||
wandb.save('best_model.pth')
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **Sweeps Documentation**: https://docs.wandb.ai/guides/sweeps
|
||||
- **Configuration Reference**: https://docs.wandb.ai/guides/sweeps/configuration
|
||||
- **Examples**: https://github.com/wandb/examples/tree/master/examples/wandb-sweeps
|
||||
@@ -0,0 +1,80 @@
|
||||
---
|
||||
name: huggingface-hub
|
||||
description: "HuggingFace hf CLI: search/download/upload models, datasets."
|
||||
version: 1.0.0
|
||||
author: Hugging Face
|
||||
license: MIT
|
||||
tags: [huggingface, hf, models, datasets, hub, mlops]
|
||||
---
|
||||
|
||||
# Hugging Face CLI (`hf`) Reference Guide
|
||||
|
||||
The `hf` command is the modern command-line interface for interacting with the Hugging Face Hub, providing tools to manage repositories, models, datasets, and Spaces.
|
||||
|
||||
> **IMPORTANT:** The `hf` command replaces the now deprecated `huggingface-cli` command.
|
||||
|
||||
## Quick Start
|
||||
* **Installation:** `curl -LsSf https://hf.co/cli/install.sh | bash -s`
|
||||
* **Help:** Use `hf --help` to view all available functions and real-world examples.
|
||||
* **Authentication:** Recommended via `HF_TOKEN` environment variable or the `--token` flag.
|
||||
|
||||
---
|
||||
|
||||
## Core Commands
|
||||
|
||||
### General Operations
|
||||
* `hf download REPO_ID`: Download files from the Hub.
|
||||
* `hf upload REPO_ID`: Upload files/folders (recommended for single-commit).
|
||||
* `hf upload-large-folder REPO_ID LOCAL_PATH`: Recommended for resumable uploads of large directories.
|
||||
* `hf sync`: Sync files between a local directory and a bucket.
|
||||
* `hf env` / `hf version`: View environment and version details.
|
||||
|
||||
### Authentication (`hf auth`)
|
||||
* `login` / `logout`: Manage sessions using tokens from [huggingface.co/settings/tokens](https://huggingface.co/settings/tokens).
|
||||
* `list` / `switch`: Manage and toggle between multiple stored access tokens.
|
||||
* `whoami`: Identify the currently logged-in account.
|
||||
|
||||
### Repository Management (`hf repos`)
|
||||
* `create` / `delete`: Create or permanently remove repositories.
|
||||
* `duplicate`: Clone a model, dataset, or Space to a new ID.
|
||||
* `move`: Transfer a repository between namespaces.
|
||||
* `branch` / `tag`: Manage Git-like references.
|
||||
* `delete-files`: Remove specific files using patterns.
|
||||
|
||||
---
|
||||
|
||||
## Specialized Hub Interactions
|
||||
|
||||
### Datasets & Models
|
||||
* **Datasets:** `hf datasets list`, `info`, and `parquet` (list parquet URLs).
|
||||
* **SQL Queries:** `hf datasets sql SQL` — Execute raw SQL via DuckDB against dataset parquet URLs.
|
||||
* **Models:** `hf models list` and `info`.
|
||||
* **Papers:** `hf papers list` — View daily papers.
|
||||
|
||||
### Discussions & Pull Requests (`hf discussions`)
|
||||
* Manage the lifecycle of Hub contributions: `list`, `create`, `info`, `comment`, `close`, `reopen`, and `rename`.
|
||||
* `diff`: View changes in a PR.
|
||||
* `merge`: Finalize pull requests.
|
||||
|
||||
### Infrastructure & Compute
|
||||
* **Endpoints:** Deploy and manage Inference Endpoints (`deploy`, `pause`, `resume`, `scale-to-zero`, `catalog`).
|
||||
* **Jobs:** Run compute tasks on HF infrastructure. Includes `hf jobs uv` for running Python scripts with inline dependencies and `stats` for resource monitoring.
|
||||
* **Spaces:** Manage interactive apps. Includes `dev-mode` and `hot-reload` for Python files without full restarts.
|
||||
|
||||
### Storage & Automation
|
||||
* **Buckets:** Full S3-like bucket management (`create`, `cp`, `mv`, `rm`, `sync`).
|
||||
* **Cache:** Manage local storage with `list`, `prune` (remove detached revisions), and `verify` (checksum checks).
|
||||
* **Webhooks:** Automate workflows by managing Hub webhooks (`create`, `watch`, `enable`/`disable`).
|
||||
* **Collections:** Organize Hub items into collections (`add-item`, `update`, `list`).
|
||||
|
||||
---
|
||||
|
||||
## Advanced Usage & Tips
|
||||
|
||||
### Global Flags
|
||||
* `--format json`: Produces machine-readable output for automation.
|
||||
* `-q` / `--quiet`: Limits output to IDs only.
|
||||
|
||||
### Extensions & Skills
|
||||
* **Extensions:** Extend CLI functionality via GitHub repositories using `hf extensions install REPO_ID`.
|
||||
* **Skills:** Manage AI assistant skills with `hf skills add`.
|
||||
@@ -0,0 +1,3 @@
|
||||
---
|
||||
description: Model serving, quantization (GGUF/GPTQ), structured output, inference optimization, and model surgery tools for deploying and running LLMs.
|
||||
---
|
||||
@@ -0,0 +1,248 @@
|
||||
---
|
||||
name: llama-cpp
|
||||
description: llama.cpp local GGUF inference + HF Hub model discovery.
|
||||
version: 2.1.2
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [llama-cpp-python>=0.2.0]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [llama.cpp, GGUF, Quantization, Hugging Face Hub, CPU Inference, Apple Silicon, Edge Deployment, AMD GPUs, Intel GPUs, NVIDIA, URL-first]
|
||||
---
|
||||
|
||||
# llama.cpp + GGUF
|
||||
|
||||
Use this skill for local GGUF inference, quant selection, or Hugging Face repo discovery for llama.cpp.
|
||||
|
||||
## When to use
|
||||
|
||||
- Run local models on CPU, Apple Silicon, CUDA, ROCm, or Intel GPUs
|
||||
- Find the right GGUF for a specific Hugging Face repo
|
||||
- Build a `llama-server` or `llama-cli` command from the Hub
|
||||
- Search the Hub for models that already support llama.cpp
|
||||
- Enumerate available `.gguf` files and sizes for a repo
|
||||
- Decide between Q4/Q5/Q6/IQ variants for the user's RAM or VRAM
|
||||
|
||||
## Model Discovery workflow
|
||||
|
||||
Prefer URL workflows before asking for `hf`, Python, or custom scripts.
|
||||
|
||||
1. Search for candidate repos on the Hub:
|
||||
- Base: `https://huggingface.co/models?apps=llama.cpp&sort=trending`
|
||||
- Add `search=<term>` for a model family
|
||||
- Add `num_parameters=min:0,max:24B` or similar when the user has size constraints
|
||||
2. Open the repo with the llama.cpp local-app view:
|
||||
- `https://huggingface.co/<repo>?local-app=llama.cpp`
|
||||
3. Treat the local-app snippet as the source of truth when it is visible:
|
||||
- copy the exact `llama-server` or `llama-cli` command
|
||||
- report the recommended quant exactly as HF shows it
|
||||
4. Read the same `?local-app=llama.cpp` URL as page text or HTML and extract the section under `Hardware compatibility`:
|
||||
- prefer its exact quant labels and sizes over generic tables
|
||||
- keep repo-specific labels such as `UD-Q4_K_M` or `IQ4_NL_XL`
|
||||
- if that section is not visible in the fetched page source, say so and fall back to the tree API plus generic quant guidance
|
||||
5. Query the tree API to confirm what actually exists:
|
||||
- `https://huggingface.co/api/models/<repo>/tree/main?recursive=true`
|
||||
- keep entries where `type` is `file` and `path` ends with `.gguf`
|
||||
- use `path` and `size` as the source of truth for filenames and byte sizes
|
||||
- separate quantized checkpoints from `mmproj-*.gguf` projector files and `BF16/` shard files
|
||||
- use `https://huggingface.co/<repo>/tree/main` only as a human fallback
|
||||
6. If the local-app snippet is not text-visible, reconstruct the command from the repo plus the chosen quant:
|
||||
- shorthand quant selection: `llama-server -hf <repo>:<QUANT>`
|
||||
- exact-file fallback: `llama-server --hf-repo <repo> --hf-file <filename.gguf>`
|
||||
7. Only suggest conversion from Transformers weights if the repo does not already expose GGUF files.
|
||||
|
||||
## Quick start
|
||||
|
||||
### Install llama.cpp
|
||||
|
||||
```bash
|
||||
# macOS / Linux (simplest)
|
||||
brew install llama.cpp
|
||||
```
|
||||
|
||||
```bash
|
||||
winget install llama.cpp
|
||||
```
|
||||
|
||||
```bash
|
||||
git clone https://github.com/ggml-org/llama.cpp
|
||||
cd llama.cpp
|
||||
cmake -B build
|
||||
cmake --build build --config Release
|
||||
```
|
||||
|
||||
### Run directly from the Hugging Face Hub
|
||||
|
||||
```bash
|
||||
llama-cli -hf bartowski/Llama-3.2-3B-Instruct-GGUF:Q8_0
|
||||
```
|
||||
|
||||
```bash
|
||||
llama-server -hf bartowski/Llama-3.2-3B-Instruct-GGUF:Q8_0
|
||||
```
|
||||
|
||||
### Run an exact GGUF file from the Hub
|
||||
|
||||
Use this when the tree API shows custom file naming or the exact HF snippet is missing.
|
||||
|
||||
```bash
|
||||
llama-server \
|
||||
--hf-repo microsoft/Phi-3-mini-4k-instruct-gguf \
|
||||
--hf-file Phi-3-mini-4k-instruct-q4.gguf \
|
||||
-c 4096
|
||||
```
|
||||
|
||||
### OpenAI-compatible server check
|
||||
|
||||
```bash
|
||||
curl http://localhost:8080/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [
|
||||
{"role": "user", "content": "Write a limerick about Python exceptions"}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
## Python bindings (llama-cpp-python)
|
||||
|
||||
`pip install llama-cpp-python` (CUDA: `CMAKE_ARGS="-DGGML_CUDA=on" pip install llama-cpp-python --force-reinstall --no-cache-dir`; Metal: `CMAKE_ARGS="-DGGML_METAL=on" ...`).
|
||||
|
||||
### Basic generation
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(
|
||||
model_path="./model-q4_k_m.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=35, # 0 for CPU, 99 to offload everything
|
||||
n_threads=8,
|
||||
)
|
||||
|
||||
out = llm("What is machine learning?", max_tokens=256, temperature=0.7)
|
||||
print(out["choices"][0]["text"])
|
||||
```
|
||||
|
||||
### Chat + streaming
|
||||
|
||||
```python
|
||||
llm = Llama(
|
||||
model_path="./model-q4_k_m.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=35,
|
||||
chat_format="llama-3", # or "chatml", "mistral", etc.
|
||||
)
|
||||
|
||||
resp = llm.create_chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is Python?"},
|
||||
],
|
||||
max_tokens=256,
|
||||
)
|
||||
print(resp["choices"][0]["message"]["content"])
|
||||
|
||||
# Streaming
|
||||
for chunk in llm("Explain quantum computing:", max_tokens=256, stream=True):
|
||||
print(chunk["choices"][0]["text"], end="", flush=True)
|
||||
```
|
||||
|
||||
### Embeddings
|
||||
|
||||
```python
|
||||
llm = Llama(model_path="./model-q4_k_m.gguf", embedding=True, n_gpu_layers=35)
|
||||
vec = llm.embed("This is a test sentence.")
|
||||
print(f"Embedding dimension: {len(vec)}")
|
||||
```
|
||||
|
||||
You can also load a GGUF straight from the Hub:
|
||||
|
||||
```python
|
||||
llm = Llama.from_pretrained(
|
||||
repo_id="bartowski/Llama-3.2-3B-Instruct-GGUF",
|
||||
filename="*Q4_K_M.gguf",
|
||||
n_gpu_layers=35,
|
||||
)
|
||||
```
|
||||
|
||||
## Choosing a quant
|
||||
|
||||
Use the Hub page first, generic heuristics second.
|
||||
|
||||
- Prefer the exact quant that HF marks as compatible for the user's hardware profile.
|
||||
- For general chat, start with `Q4_K_M`.
|
||||
- For code or technical work, prefer `Q5_K_M` or `Q6_K` if memory allows.
|
||||
- For very tight RAM budgets, consider `Q3_K_M`, `IQ` variants, or `Q2` variants only if the user explicitly prioritizes fit over quality.
|
||||
- For multimodal repos, mention `mmproj-*.gguf` separately. The projector is not the main model file.
|
||||
- Do not normalize repo-native labels. If the page says `UD-Q4_K_M`, report `UD-Q4_K_M`.
|
||||
|
||||
## Extracting available GGUFs from a repo
|
||||
|
||||
When the user asks what GGUFs exist, return:
|
||||
|
||||
- filename
|
||||
- file size
|
||||
- quant label
|
||||
- whether it is a main model or an auxiliary projector
|
||||
|
||||
Ignore unless requested:
|
||||
|
||||
- README
|
||||
- BF16 shard files
|
||||
- imatrix blobs or calibration artifacts
|
||||
|
||||
Use the tree API for this step:
|
||||
|
||||
- `https://huggingface.co/api/models/<repo>/tree/main?recursive=true`
|
||||
|
||||
For a repo like `unsloth/Qwen3.6-35B-A3B-GGUF`, the local-app page can show quant chips such as `UD-Q4_K_M`, `UD-Q5_K_M`, `UD-Q6_K`, and `Q8_0`, while the tree API exposes exact file paths such as `Qwen3.6-35B-A3B-UD-Q4_K_M.gguf` and `Qwen3.6-35B-A3B-Q8_0.gguf` with byte sizes. Use the tree API to turn a quant label into an exact filename.
|
||||
|
||||
## Search patterns
|
||||
|
||||
Use these URL shapes directly:
|
||||
|
||||
```text
|
||||
https://huggingface.co/models?apps=llama.cpp&sort=trending
|
||||
https://huggingface.co/models?search=<term>&apps=llama.cpp&sort=trending
|
||||
https://huggingface.co/models?search=<term>&apps=llama.cpp&num_parameters=min:0,max:24B&sort=trending
|
||||
https://huggingface.co/<repo>?local-app=llama.cpp
|
||||
https://huggingface.co/api/models/<repo>/tree/main?recursive=true
|
||||
https://huggingface.co/<repo>/tree/main
|
||||
```
|
||||
|
||||
## Output format
|
||||
|
||||
When answering discovery requests, prefer a compact structured result like:
|
||||
|
||||
```text
|
||||
Repo: <repo>
|
||||
Recommended quant from HF: <label> (<size>)
|
||||
llama-server: <command>
|
||||
Other GGUFs:
|
||||
- <filename> - <size>
|
||||
- <filename> - <size>
|
||||
Source URLs:
|
||||
- <local-app URL>
|
||||
- <tree API URL>
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- **[hub-discovery.md](references/hub-discovery.md)** - URL-only Hugging Face workflows, search patterns, GGUF extraction, and command reconstruction
|
||||
- **[advanced-usage.md](references/advanced-usage.md)** — speculative decoding, batched inference, grammar-constrained generation, LoRA, multi-GPU, custom builds, benchmark scripts
|
||||
- **[quantization.md](references/quantization.md)** — quant quality tradeoffs, when to use Q4/Q5/Q6/IQ, model size scaling, imatrix
|
||||
- **[server.md](references/server.md)** — direct-from-Hub server launch, OpenAI API endpoints, Docker deployment, NGINX load balancing, monitoring
|
||||
- **[optimization.md](references/optimization.md)** — CPU threading, BLAS, GPU offload heuristics, batch tuning, benchmarks
|
||||
- **[troubleshooting.md](references/troubleshooting.md)** — install/convert/quantize/inference/server issues, Apple Silicon, debugging
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/ggml-org/llama.cpp
|
||||
- **Hugging Face GGUF + llama.cpp docs**: https://huggingface.co/docs/hub/gguf-llamacpp
|
||||
- **Hugging Face Local Apps docs**: https://huggingface.co/docs/hub/main/local-apps
|
||||
- **Hugging Face Local Agents docs**: https://huggingface.co/docs/hub/agents-local
|
||||
- **Example local-app page**: https://huggingface.co/unsloth/Qwen3.6-35B-A3B-GGUF?local-app=llama.cpp
|
||||
- **Example tree API**: https://huggingface.co/api/models/unsloth/Qwen3.6-35B-A3B-GGUF/tree/main?recursive=true
|
||||
- **Example llama.cpp search**: https://huggingface.co/models?num_parameters=min:0,max:24B&apps=llama.cpp&sort=trending
|
||||
- **License**: MIT
|
||||
@@ -0,0 +1,504 @@
|
||||
# GGUF Advanced Usage Guide
|
||||
|
||||
## Speculative Decoding
|
||||
|
||||
### Draft Model Approach
|
||||
|
||||
```bash
|
||||
# Use smaller model as draft for faster generation
|
||||
./llama-speculative \
|
||||
-m large-model-q4_k_m.gguf \
|
||||
-md draft-model-q4_k_m.gguf \
|
||||
-p "Write a story about AI" \
|
||||
-n 500 \
|
||||
--draft 8 # Draft tokens before verification
|
||||
```
|
||||
|
||||
### Self-Speculative Decoding
|
||||
|
||||
```bash
|
||||
# Use same model with different context for speculation
|
||||
./llama-cli -m model-q4_k_m.gguf \
|
||||
--lookup-cache-static lookup.bin \
|
||||
--lookup-cache-dynamic lookup-dynamic.bin \
|
||||
-p "Hello world"
|
||||
```
|
||||
|
||||
## Batched Inference
|
||||
|
||||
### Process Multiple Prompts
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(
|
||||
model_path="model-q4_k_m.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=35,
|
||||
n_batch=512 # Larger batch for parallel processing
|
||||
)
|
||||
|
||||
prompts = [
|
||||
"What is Python?",
|
||||
"Explain machine learning.",
|
||||
"Describe neural networks."
|
||||
]
|
||||
|
||||
# Process in batch (each prompt gets separate context)
|
||||
for prompt in prompts:
|
||||
output = llm(prompt, max_tokens=100)
|
||||
print(f"Q: {prompt}")
|
||||
print(f"A: {output['choices'][0]['text']}\n")
|
||||
```
|
||||
|
||||
### Server Batching
|
||||
|
||||
```bash
|
||||
# Start server with batching
|
||||
./llama-server -m model-q4_k_m.gguf \
|
||||
--host 0.0.0.0 \
|
||||
--port 8080 \
|
||||
-ngl 35 \
|
||||
-c 4096 \
|
||||
--parallel 4 # Concurrent requests
|
||||
--cont-batching # Continuous batching
|
||||
```
|
||||
|
||||
## Custom Model Conversion
|
||||
|
||||
### Convert with Vocabulary Modifications
|
||||
|
||||
```python
|
||||
# custom_convert.py
|
||||
import sys
|
||||
sys.path.insert(0, './llama.cpp')
|
||||
|
||||
from convert_hf_to_gguf import main
|
||||
from gguf import GGUFWriter
|
||||
|
||||
# Custom conversion with modified vocab
|
||||
def convert_with_custom_vocab(model_path, output_path):
|
||||
# Load and modify tokenizer
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
|
||||
# Add special tokens if needed
|
||||
special_tokens = {"additional_special_tokens": ["<|custom|>"]}
|
||||
tokenizer.add_special_tokens(special_tokens)
|
||||
tokenizer.save_pretrained(model_path)
|
||||
|
||||
# Then run standard conversion
|
||||
main([model_path, "--outfile", output_path])
|
||||
```
|
||||
|
||||
### Convert Specific Architecture
|
||||
|
||||
```bash
|
||||
# For Mistral-style models
|
||||
python convert_hf_to_gguf.py ./mistral-model \
|
||||
--outfile mistral-f16.gguf \
|
||||
--outtype f16
|
||||
|
||||
# For Qwen models
|
||||
python convert_hf_to_gguf.py ./qwen-model \
|
||||
--outfile qwen-f16.gguf \
|
||||
--outtype f16
|
||||
|
||||
# For Phi models
|
||||
python convert_hf_to_gguf.py ./phi-model \
|
||||
--outfile phi-f16.gguf \
|
||||
--outtype f16
|
||||
```
|
||||
|
||||
## Advanced Quantization
|
||||
|
||||
### Mixed Quantization
|
||||
|
||||
```bash
|
||||
# Quantize different layer types differently
|
||||
./llama-quantize model-f16.gguf model-mixed.gguf Q4_K_M \
|
||||
--allow-requantize \
|
||||
--leave-output-tensor
|
||||
```
|
||||
|
||||
### Quantization with Token Embeddings
|
||||
|
||||
```bash
|
||||
# Keep embeddings at higher precision
|
||||
./llama-quantize model-f16.gguf model-q4.gguf Q4_K_M \
|
||||
--token-embedding-type f16
|
||||
```
|
||||
|
||||
### IQ Quantization (Importance-aware)
|
||||
|
||||
```bash
|
||||
# Ultra-low bit quantization with importance
|
||||
./llama-quantize --imatrix model.imatrix \
|
||||
model-f16.gguf model-iq2_xxs.gguf IQ2_XXS
|
||||
|
||||
# Available IQ types: IQ2_XXS, IQ2_XS, IQ2_S, IQ3_XXS, IQ3_XS, IQ3_S, IQ4_XS
|
||||
```
|
||||
|
||||
## Memory Optimization
|
||||
|
||||
### Memory Mapping
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
# Use memory mapping for large models
|
||||
llm = Llama(
|
||||
model_path="model-q4_k_m.gguf",
|
||||
use_mmap=True, # Memory map the model
|
||||
use_mlock=False, # Don't lock in RAM
|
||||
n_gpu_layers=35
|
||||
)
|
||||
```
|
||||
|
||||
### Partial GPU Offload
|
||||
|
||||
```python
|
||||
# Calculate layers to offload based on VRAM
|
||||
import subprocess
|
||||
|
||||
def get_free_vram_gb():
|
||||
result = subprocess.run(
|
||||
['nvidia-smi', '--query-gpu=memory.free', '--format=csv,nounits,noheader'],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
return int(result.stdout.strip()) / 1024
|
||||
|
||||
# Estimate layers based on VRAM (rough: 0.5GB per layer for 7B Q4)
|
||||
free_vram = get_free_vram_gb()
|
||||
layers_to_offload = int(free_vram / 0.5)
|
||||
|
||||
llm = Llama(
|
||||
model_path="model-q4_k_m.gguf",
|
||||
n_gpu_layers=min(layers_to_offload, 35) # Cap at total layers
|
||||
)
|
||||
```
|
||||
|
||||
### KV Cache Optimization
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
# Optimize KV cache for long contexts
|
||||
llm = Llama(
|
||||
model_path="model-q4_k_m.gguf",
|
||||
n_ctx=8192, # Large context
|
||||
n_gpu_layers=35,
|
||||
type_k=1, # Q8_0 for K cache (1)
|
||||
type_v=1, # Q8_0 for V cache (1)
|
||||
# Or use Q4_0 (2) for more compression
|
||||
)
|
||||
```
|
||||
|
||||
## Context Management
|
||||
|
||||
### Context Shifting
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(
|
||||
model_path="model-q4_k_m.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=35
|
||||
)
|
||||
|
||||
# Handle long conversations with context shifting
|
||||
conversation = []
|
||||
max_history = 10
|
||||
|
||||
def chat(user_message):
|
||||
conversation.append({"role": "user", "content": user_message})
|
||||
|
||||
# Keep only recent history
|
||||
if len(conversation) > max_history * 2:
|
||||
conversation = conversation[-max_history * 2:]
|
||||
|
||||
response = llm.create_chat_completion(
|
||||
messages=conversation,
|
||||
max_tokens=256
|
||||
)
|
||||
|
||||
assistant_message = response["choices"][0]["message"]["content"]
|
||||
conversation.append({"role": "assistant", "content": assistant_message})
|
||||
return assistant_message
|
||||
```
|
||||
|
||||
### Save and Load State
|
||||
|
||||
```bash
|
||||
# Save state to file
|
||||
./llama-cli -m model.gguf \
|
||||
-p "Once upon a time" \
|
||||
--save-session session.bin \
|
||||
-n 100
|
||||
|
||||
# Load and continue
|
||||
./llama-cli -m model.gguf \
|
||||
--load-session session.bin \
|
||||
-p " and they lived" \
|
||||
-n 100
|
||||
```
|
||||
|
||||
## Grammar Constrained Generation
|
||||
|
||||
### JSON Output
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama, LlamaGrammar
|
||||
|
||||
# Define JSON grammar
|
||||
json_grammar = LlamaGrammar.from_string('''
|
||||
root ::= object
|
||||
object ::= "{" ws pair ("," ws pair)* "}" ws
|
||||
pair ::= string ":" ws value
|
||||
value ::= string | number | object | array | "true" | "false" | "null"
|
||||
array ::= "[" ws value ("," ws value)* "]" ws
|
||||
string ::= "\\"" [^"\\\\]* "\\""
|
||||
number ::= [0-9]+
|
||||
ws ::= [ \\t\\n]*
|
||||
''')
|
||||
|
||||
llm = Llama(model_path="model-q4_k_m.gguf", n_gpu_layers=35)
|
||||
|
||||
output = llm(
|
||||
"Output a JSON object with name and age:",
|
||||
grammar=json_grammar,
|
||||
max_tokens=100
|
||||
)
|
||||
print(output["choices"][0]["text"])
|
||||
```
|
||||
|
||||
### Custom Grammar
|
||||
|
||||
```python
|
||||
# Grammar for specific format
|
||||
answer_grammar = LlamaGrammar.from_string('''
|
||||
root ::= "Answer: " letter "\\n" "Explanation: " explanation
|
||||
letter ::= [A-D]
|
||||
explanation ::= [a-zA-Z0-9 .,!?]+
|
||||
''')
|
||||
|
||||
output = llm(
|
||||
"Q: What is 2+2? A) 3 B) 4 C) 5 D) 6",
|
||||
grammar=answer_grammar,
|
||||
max_tokens=100
|
||||
)
|
||||
```
|
||||
|
||||
## LoRA Integration
|
||||
|
||||
### Load LoRA Adapter
|
||||
|
||||
```bash
|
||||
# Apply LoRA at runtime
|
||||
./llama-cli -m base-model-q4_k_m.gguf \
|
||||
--lora lora-adapter.gguf \
|
||||
--lora-scale 1.0 \
|
||||
-p "Hello!"
|
||||
```
|
||||
|
||||
### Multiple LoRA Adapters
|
||||
|
||||
```bash
|
||||
# Stack multiple adapters
|
||||
./llama-cli -m base-model.gguf \
|
||||
--lora adapter1.gguf --lora-scale 0.5 \
|
||||
--lora adapter2.gguf --lora-scale 0.5 \
|
||||
-p "Hello!"
|
||||
```
|
||||
|
||||
### Python LoRA Usage
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(
|
||||
model_path="base-model-q4_k_m.gguf",
|
||||
lora_path="lora-adapter.gguf",
|
||||
lora_scale=1.0,
|
||||
n_gpu_layers=35
|
||||
)
|
||||
```
|
||||
|
||||
## Embedding Generation
|
||||
|
||||
### Extract Embeddings
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(
|
||||
model_path="model-q4_k_m.gguf",
|
||||
embedding=True, # Enable embedding mode
|
||||
n_gpu_layers=35
|
||||
)
|
||||
|
||||
# Get embeddings
|
||||
embeddings = llm.embed("This is a test sentence.")
|
||||
print(f"Embedding dimension: {len(embeddings)}")
|
||||
```
|
||||
|
||||
### Batch Embeddings
|
||||
|
||||
```python
|
||||
texts = [
|
||||
"Machine learning is fascinating.",
|
||||
"Deep learning uses neural networks.",
|
||||
"Python is a programming language."
|
||||
]
|
||||
|
||||
embeddings = [llm.embed(text) for text in texts]
|
||||
|
||||
# Calculate similarity
|
||||
import numpy as np
|
||||
|
||||
def cosine_similarity(a, b):
|
||||
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
|
||||
|
||||
sim = cosine_similarity(embeddings[0], embeddings[1])
|
||||
print(f"Similarity: {sim:.4f}")
|
||||
```
|
||||
|
||||
## Performance Tuning
|
||||
|
||||
### Benchmark Script
|
||||
|
||||
```python
|
||||
import time
|
||||
from llama_cpp import Llama
|
||||
|
||||
def benchmark(model_path, prompt, n_tokens=100, n_runs=5):
|
||||
llm = Llama(
|
||||
model_path=model_path,
|
||||
n_gpu_layers=35,
|
||||
n_ctx=2048,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Warmup
|
||||
llm(prompt, max_tokens=10)
|
||||
|
||||
# Benchmark
|
||||
times = []
|
||||
for _ in range(n_runs):
|
||||
start = time.time()
|
||||
output = llm(prompt, max_tokens=n_tokens)
|
||||
elapsed = time.time() - start
|
||||
times.append(elapsed)
|
||||
|
||||
avg_time = sum(times) / len(times)
|
||||
tokens_per_sec = n_tokens / avg_time
|
||||
|
||||
print(f"Model: {model_path}")
|
||||
print(f"Avg time: {avg_time:.2f}s")
|
||||
print(f"Tokens/sec: {tokens_per_sec:.1f}")
|
||||
|
||||
return tokens_per_sec
|
||||
|
||||
# Compare quantizations
|
||||
for quant in ["q4_k_m", "q5_k_m", "q8_0"]:
|
||||
benchmark(f"model-{quant}.gguf", "Explain quantum computing:", 100)
|
||||
```
|
||||
|
||||
### Optimal Configuration Finder
|
||||
|
||||
```python
|
||||
def find_optimal_config(model_path, target_vram_gb=8):
|
||||
"""Find optimal n_gpu_layers and n_batch for target VRAM."""
|
||||
from llama_cpp import Llama
|
||||
import gc
|
||||
|
||||
best_config = None
|
||||
best_speed = 0
|
||||
|
||||
for n_gpu_layers in range(0, 50, 5):
|
||||
for n_batch in [128, 256, 512, 1024]:
|
||||
try:
|
||||
gc.collect()
|
||||
llm = Llama(
|
||||
model_path=model_path,
|
||||
n_gpu_layers=n_gpu_layers,
|
||||
n_batch=n_batch,
|
||||
n_ctx=2048,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Quick benchmark
|
||||
start = time.time()
|
||||
llm("Hello", max_tokens=50)
|
||||
speed = 50 / (time.time() - start)
|
||||
|
||||
if speed > best_speed:
|
||||
best_speed = speed
|
||||
best_config = {
|
||||
"n_gpu_layers": n_gpu_layers,
|
||||
"n_batch": n_batch,
|
||||
"speed": speed
|
||||
}
|
||||
|
||||
del llm
|
||||
gc.collect()
|
||||
|
||||
except Exception as e:
|
||||
print(f"OOM at layers={n_gpu_layers}, batch={n_batch}")
|
||||
break
|
||||
|
||||
return best_config
|
||||
```
|
||||
|
||||
## Multi-GPU Setup
|
||||
|
||||
### Distribute Across GPUs
|
||||
|
||||
```bash
|
||||
# Split model across multiple GPUs
|
||||
./llama-cli -m large-model.gguf \
|
||||
--tensor-split 0.5,0.5 \
|
||||
-ngl 60 \
|
||||
-p "Hello!"
|
||||
```
|
||||
|
||||
### Python Multi-GPU
|
||||
|
||||
```python
|
||||
import os
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
|
||||
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(
|
||||
model_path="large-model-q4_k_m.gguf",
|
||||
n_gpu_layers=60,
|
||||
tensor_split=[0.5, 0.5] # Split evenly across 2 GPUs
|
||||
)
|
||||
```
|
||||
|
||||
## Custom Builds
|
||||
|
||||
### Build with All Optimizations
|
||||
|
||||
```bash
|
||||
# Clean build with all CPU optimizations
|
||||
make clean
|
||||
LLAMA_OPENBLAS=1 LLAMA_BLAS_VENDOR=OpenBLAS make -j
|
||||
|
||||
# With CUDA and cuBLAS
|
||||
make clean
|
||||
GGML_CUDA=1 LLAMA_CUBLAS=1 make -j
|
||||
|
||||
# With specific CUDA architecture
|
||||
GGML_CUDA=1 CUDA_DOCKER_ARCH=sm_86 make -j
|
||||
```
|
||||
|
||||
### CMake Build
|
||||
|
||||
```bash
|
||||
mkdir build && cd build
|
||||
cmake .. -DGGML_CUDA=ON -DCMAKE_BUILD_TYPE=Release
|
||||
cmake --build . --config Release -j
|
||||
```
|
||||
@@ -0,0 +1,168 @@
|
||||
# Hugging Face URL Workflows for llama.cpp
|
||||
|
||||
Use URL-only workflows first. Do not require `hf` or API clients just to find GGUF files, choose a quant, or build a `llama-server` command.
|
||||
|
||||
## Core URLs
|
||||
|
||||
```text
|
||||
Search:
|
||||
https://huggingface.co/models?apps=llama.cpp&sort=trending
|
||||
|
||||
Search with text:
|
||||
https://huggingface.co/models?search=<term>&apps=llama.cpp&sort=trending
|
||||
|
||||
Search with size bounds:
|
||||
https://huggingface.co/models?search=<term>&apps=llama.cpp&num_parameters=min:0,max:24B&sort=trending
|
||||
|
||||
Repo local-app view:
|
||||
https://huggingface.co/<repo>?local-app=llama.cpp
|
||||
|
||||
Repo tree API:
|
||||
https://huggingface.co/api/models/<repo>/tree/main?recursive=true
|
||||
|
||||
Repo file tree:
|
||||
https://huggingface.co/<repo>/tree/main
|
||||
```
|
||||
|
||||
## 1. Search for llama.cpp-compatible models
|
||||
|
||||
Start from the models page with `apps=llama.cpp`.
|
||||
|
||||
Use:
|
||||
|
||||
- `search=<term>` for model family names such as `Qwen`, `Gemma`, `Phi`, or `Mistral`
|
||||
- `num_parameters=min:0,max:24B` or similar if the user has hardware limits
|
||||
- `sort=trending` when the user wants popular repos right now
|
||||
|
||||
Do not start with random GGUF repos if the user has not chosen a model family yet. Search first, shortlist second.
|
||||
|
||||
Example: https://huggingface.co/models?search=Qwen&apps=llama.cpp&num_parameters=min:0,max:24B&sort=trending
|
||||
|
||||
## 2. Use the local-app page for the recommended quant
|
||||
|
||||
Open:
|
||||
|
||||
```text
|
||||
https://huggingface.co/<repo>?local-app=llama.cpp
|
||||
```
|
||||
|
||||
Extract, in order:
|
||||
|
||||
1. The exact `Use this model` snippet, if it is visible as text
|
||||
2. The `Hardware compatibility` section from the fetched page text or HTML:
|
||||
- quant label
|
||||
- file size
|
||||
- bit-depth grouping
|
||||
3. Any extra launch flags shown in the snippet, such as `--jinja`
|
||||
|
||||
Treat the HF local-app snippet as the source of truth when it is visible.
|
||||
|
||||
Do this by reading the URL itself, not by assuming the UI rendered in a browser. If the fetched page source does not expose `Hardware compatibility`, say that the section was not text-visible and fall back to the tree API plus generic guidance from `quantization.md`.
|
||||
|
||||
## 3. Confirm exact files from the tree API
|
||||
|
||||
Open:
|
||||
|
||||
```text
|
||||
https://huggingface.co/api/models/<repo>/tree/main?recursive=true
|
||||
```
|
||||
|
||||
Treat the JSON response as the source of truth for repo inventory.
|
||||
|
||||
Keep entries where:
|
||||
|
||||
- `type` is `file`
|
||||
- `path` ends with `.gguf`
|
||||
|
||||
Use these fields:
|
||||
|
||||
- `path` for the filename and subdirectory
|
||||
- `size` for the byte size
|
||||
- optionally `lfs.size` to confirm the LFS payload size
|
||||
|
||||
Separate files into:
|
||||
|
||||
- quantized single-file checkpoints, for example `Qwen3.6-35B-A3B-UD-Q4_K_M.gguf`
|
||||
- projector weights, usually `mmproj-*.gguf`
|
||||
- BF16 shard files, usually under `BF16/`
|
||||
- everything else
|
||||
|
||||
Ignore unless the user asks:
|
||||
|
||||
- `README.md`
|
||||
- imatrix or calibration blobs
|
||||
|
||||
Use `https://huggingface.co/<repo>/tree/main` only as a human fallback if the API endpoint fails or the user wants the web view.
|
||||
|
||||
## 4. Build the command
|
||||
|
||||
Preferred order:
|
||||
|
||||
1. Copy the exact HF snippet from the local-app page
|
||||
2. If the page gives a clean quant label, use shorthand selection:
|
||||
|
||||
```bash
|
||||
llama-server -hf <repo>:<QUANT>
|
||||
```
|
||||
|
||||
3. If you need an exact file from the tree API, use the file-specific form:
|
||||
|
||||
```bash
|
||||
llama-server --hf-repo <repo> --hf-file <filename.gguf>
|
||||
```
|
||||
|
||||
4. For CLI usage instead of a server, use:
|
||||
|
||||
```bash
|
||||
llama-cli -hf <repo>:<QUANT>
|
||||
```
|
||||
|
||||
Use the exact-file form when the repo uses custom labels or nonstandard naming that could make `:<QUANT>` ambiguous.
|
||||
|
||||
## 5. Example: `unsloth/Qwen3.6-35B-A3B-GGUF`
|
||||
|
||||
Use these URLs:
|
||||
|
||||
```text
|
||||
https://huggingface.co/unsloth/Qwen3.6-35B-A3B-GGUF?local-app=llama.cpp
|
||||
https://huggingface.co/api/models/unsloth/Qwen3.6-35B-A3B-GGUF/tree/main?recursive=true
|
||||
https://huggingface.co/unsloth/Qwen3.6-35B-A3B-GGUF/tree/main
|
||||
```
|
||||
|
||||
On the local-app page, the hardware compatibility section can expose entries such as:
|
||||
|
||||
- `UD-IQ4_XS` - 17.7 GB
|
||||
- `UD-Q4_K_S` - 20.9 GB
|
||||
- `UD-Q4_K_M` - 22.1 GB
|
||||
- `UD-Q5_K_M` - 26.5 GB
|
||||
- `UD-Q6_K` - 29.3 GB
|
||||
- `Q8_0` - 36.9 GB
|
||||
|
||||
On the tree API, you can confirm exact filenames such as:
|
||||
|
||||
- `Qwen3.6-35B-A3B-UD-Q4_K_M.gguf`
|
||||
- `Qwen3.6-35B-A3B-UD-Q5_K_M.gguf`
|
||||
- `Qwen3.6-35B-A3B-UD-Q6_K.gguf`
|
||||
- `Qwen3.6-35B-A3B-Q8_0.gguf`
|
||||
- `mmproj-F16.gguf`
|
||||
|
||||
Good final output for this repo:
|
||||
|
||||
```text
|
||||
Repo: unsloth/Qwen3.6-35B-A3B-GGUF
|
||||
Recommended quant from HF: UD-Q4_K_M (22.1 GB)
|
||||
llama-server: llama-server --hf-repo unsloth/Qwen3.6-35B-A3B-GGUF --hf-file Qwen3.6-35B-A3B-UD-Q4_K_M.gguf
|
||||
Other GGUFs:
|
||||
- Qwen3.6-35B-A3B-UD-Q5_K_M.gguf - 26.5 GB
|
||||
- Qwen3.6-35B-A3B-UD-Q6_K.gguf - 29.3 GB
|
||||
- Qwen3.6-35B-A3B-Q8_0.gguf - 36.9 GB
|
||||
Projector:
|
||||
- mmproj-F16.gguf - 899 MB
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- Repo-specific quant labels matter. Do not rewrite `UD-Q4_K_M` to `Q4_K_M` unless the page itself does.
|
||||
- `mmproj` files are projector weights for multimodal models, not the main language model checkpoint.
|
||||
- If the HF hardware compatibility panel is missing because the user has no hardware profile configured, or because the fetched page source did not expose it, still use the tree API plus generic quant guidance from `quantization.md`.
|
||||
- If the repo already has GGUFs, do not jump straight to conversion workflows.
|
||||
@@ -0,0 +1,89 @@
|
||||
# Performance Optimization Guide
|
||||
|
||||
Maximize llama.cpp inference speed and efficiency.
|
||||
|
||||
## CPU Optimization
|
||||
|
||||
### Thread tuning
|
||||
```bash
|
||||
# Set threads (default: physical cores)
|
||||
./llama-cli -m model.gguf -t 8
|
||||
|
||||
# For AMD Ryzen 9 7950X (16 cores, 32 threads)
|
||||
-t 16 # Best: physical cores
|
||||
|
||||
# Avoid hyperthreading (slower for matrix ops)
|
||||
```
|
||||
|
||||
### BLAS acceleration
|
||||
```bash
|
||||
# OpenBLAS (faster matrix ops)
|
||||
make LLAMA_OPENBLAS=1
|
||||
|
||||
# BLAS gives 2-3× speedup
|
||||
```
|
||||
|
||||
## GPU Offloading
|
||||
|
||||
### Layer offloading
|
||||
```bash
|
||||
# Offload 35 layers to GPU (hybrid mode)
|
||||
./llama-cli -m model.gguf -ngl 35
|
||||
|
||||
# Offload all layers
|
||||
./llama-cli -m model.gguf -ngl 999
|
||||
|
||||
# Find optimal value:
|
||||
# Start with -ngl 999
|
||||
# If OOM, reduce by 5 until fits
|
||||
```
|
||||
|
||||
### Memory usage
|
||||
```bash
|
||||
# Check VRAM usage
|
||||
nvidia-smi dmon
|
||||
|
||||
# Reduce context if needed
|
||||
./llama-cli -m model.gguf -c 2048 # 2K context instead of 4K
|
||||
```
|
||||
|
||||
## Batch Processing
|
||||
|
||||
```bash
|
||||
# Increase batch size for throughput
|
||||
./llama-cli -m model.gguf -b 512 # Default: 512
|
||||
|
||||
# Physical batch (GPU)
|
||||
--ubatch 128 # Process 128 tokens at once
|
||||
```
|
||||
|
||||
## Context Management
|
||||
|
||||
```bash
|
||||
# Default context (512 tokens)
|
||||
-c 512
|
||||
|
||||
# Longer context (slower, more memory)
|
||||
-c 4096
|
||||
|
||||
# Very long context (if model supports)
|
||||
-c 32768
|
||||
```
|
||||
|
||||
## Benchmarks
|
||||
|
||||
### CPU Performance (Llama 2-7B Q4_K_M)
|
||||
|
||||
| Setup | Speed | Notes |
|
||||
|-------|-------|-------|
|
||||
| Apple M3 Max | 50 tok/s | Metal acceleration |
|
||||
| AMD 7950X (16c) | 35 tok/s | OpenBLAS |
|
||||
| Intel i9-13900K | 30 tok/s | AVX2 |
|
||||
|
||||
### GPU Offloading (RTX 4090)
|
||||
|
||||
| Layers GPU | Speed | VRAM |
|
||||
|------------|-------|------|
|
||||
| 0 (CPU only) | 30 tok/s | 0 GB |
|
||||
| 20 (hybrid) | 80 tok/s | 8 GB |
|
||||
| 35 (all) | 120 tok/s | 12 GB |
|
||||
@@ -0,0 +1,243 @@
|
||||
# GGUF Quantization Guide
|
||||
|
||||
Complete guide to GGUF quantization formats and model conversion.
|
||||
|
||||
## Hub-first quant selection
|
||||
|
||||
Before using generic tables, open the model repo with:
|
||||
|
||||
```text
|
||||
https://huggingface.co/<repo>?local-app=llama.cpp
|
||||
```
|
||||
|
||||
Prefer the exact quant labels and sizes shown in the `Hardware compatibility` section of the fetched `?local-app=llama.cpp` page text or HTML. Then confirm the matching filenames in:
|
||||
|
||||
```text
|
||||
https://huggingface.co/api/models/<repo>/tree/main?recursive=true
|
||||
```
|
||||
|
||||
Use the Hub page first, and only fall back to the generic heuristics below when the repo page does not expose a clear recommendation.
|
||||
|
||||
## Quantization Overview
|
||||
|
||||
**GGUF** (GPT-Generated Unified Format) - Standard format for llama.cpp models.
|
||||
|
||||
### Format Comparison
|
||||
|
||||
| Format | Perplexity | Size (7B) | Tokens/sec | Notes |
|
||||
|--------|------------|-----------|------------|-------|
|
||||
| FP16 | 5.9565 (baseline) | 13.0 GB | 15 tok/s | Original quality |
|
||||
| Q8_0 | 5.9584 (+0.03%) | 7.0 GB | 25 tok/s | Nearly lossless |
|
||||
| **Q6_K** | 5.9642 (+0.13%) | 5.5 GB | 30 tok/s | Best quality/size |
|
||||
| **Q5_K_M** | 5.9796 (+0.39%) | 4.8 GB | 35 tok/s | Balanced |
|
||||
| **Q4_K_M** | 6.0565 (+1.68%) | 4.1 GB | 40 tok/s | **Recommended** |
|
||||
| Q4_K_S | 6.1125 (+2.62%) | 3.9 GB | 42 tok/s | Faster, lower quality |
|
||||
| Q3_K_M | 6.3184 (+6.07%) | 3.3 GB | 45 tok/s | Small models only |
|
||||
| Q2_K | 6.8673 (+15.3%) | 2.7 GB | 50 tok/s | Not recommended |
|
||||
|
||||
**Recommendation**: Use **Q4_K_M** for best balance of quality and speed.
|
||||
|
||||
## Converting Models
|
||||
|
||||
### Hugging Face to GGUF
|
||||
|
||||
```bash
|
||||
# 1. Download Hugging Face model
|
||||
hf download meta-llama/Llama-2-7b-chat-hf \
|
||||
--local-dir models/llama-2-7b-chat/
|
||||
|
||||
# 2. Convert to FP16 GGUF
|
||||
python convert_hf_to_gguf.py \
|
||||
models/llama-2-7b-chat/ \
|
||||
--outtype f16 \
|
||||
--outfile models/llama-2-7b-chat-f16.gguf
|
||||
|
||||
# 3. Quantize to Q4_K_M
|
||||
./llama-quantize \
|
||||
models/llama-2-7b-chat-f16.gguf \
|
||||
models/llama-2-7b-chat-Q4_K_M.gguf \
|
||||
Q4_K_M
|
||||
```
|
||||
|
||||
### Batch quantization
|
||||
|
||||
```bash
|
||||
# Quantize to multiple formats
|
||||
for quant in Q4_K_M Q5_K_M Q6_K Q8_0; do
|
||||
./llama-quantize \
|
||||
model-f16.gguf \
|
||||
model-${quant}.gguf \
|
||||
$quant
|
||||
done
|
||||
```
|
||||
|
||||
## K-Quantization Methods
|
||||
|
||||
**K-quants** use mixed precision for better quality:
|
||||
- Attention weights: Higher precision
|
||||
- Feed-forward weights: Lower precision
|
||||
|
||||
**Variants**:
|
||||
- `_S` (Small): Faster, lower quality
|
||||
- `_M` (Medium): Balanced (recommended)
|
||||
- `_L` (Large): Better quality, larger size
|
||||
|
||||
**Example**: `Q4_K_M`
|
||||
- `Q4`: 4-bit quantization
|
||||
- `K`: Mixed precision method
|
||||
- `M`: Medium quality
|
||||
|
||||
## Quality Testing
|
||||
|
||||
```bash
|
||||
# Calculate perplexity (quality metric)
|
||||
./llama-perplexity \
|
||||
-m model.gguf \
|
||||
-f wikitext-2-raw/wiki.test.raw \
|
||||
-c 512
|
||||
|
||||
# Lower perplexity = better quality
|
||||
# Baseline (FP16): ~5.96
|
||||
# Q4_K_M: ~6.06 (+1.7%)
|
||||
# Q2_K: ~6.87 (+15.3% - too much degradation)
|
||||
```
|
||||
|
||||
## Use Case Guide
|
||||
|
||||
### General purpose (chatbots, assistants)
|
||||
```
|
||||
Q4_K_M - Best balance
|
||||
Q5_K_M - If you have extra RAM
|
||||
```
|
||||
|
||||
### Code generation
|
||||
```
|
||||
Q5_K_M or Q6_K - Higher precision helps with code
|
||||
```
|
||||
|
||||
### Creative writing
|
||||
```
|
||||
Q4_K_M - Sufficient quality
|
||||
Q3_K_M - Acceptable for draft generation
|
||||
```
|
||||
|
||||
### Technical/medical
|
||||
```
|
||||
Q6_K or Q8_0 - Maximum accuracy
|
||||
```
|
||||
|
||||
### Edge devices (Raspberry Pi)
|
||||
```
|
||||
Q2_K or Q3_K_S - Fit in limited RAM
|
||||
```
|
||||
|
||||
## Model Size Scaling
|
||||
|
||||
### 7B parameter models
|
||||
|
||||
| Format | Size | RAM needed |
|
||||
|--------|------|------------|
|
||||
| Q2_K | 2.7 GB | 5 GB |
|
||||
| Q3_K_M | 3.3 GB | 6 GB |
|
||||
| Q4_K_M | 4.1 GB | 7 GB |
|
||||
| Q5_K_M | 4.8 GB | 8 GB |
|
||||
| Q6_K | 5.5 GB | 9 GB |
|
||||
| Q8_0 | 7.0 GB | 11 GB |
|
||||
|
||||
### 13B parameter models
|
||||
|
||||
| Format | Size | RAM needed |
|
||||
|--------|------|------------|
|
||||
| Q2_K | 5.1 GB | 8 GB |
|
||||
| Q3_K_M | 6.2 GB | 10 GB |
|
||||
| Q4_K_M | 7.9 GB | 12 GB |
|
||||
| Q5_K_M | 9.2 GB | 14 GB |
|
||||
| Q6_K | 10.7 GB | 16 GB |
|
||||
|
||||
### 70B parameter models
|
||||
|
||||
| Format | Size | RAM needed |
|
||||
|--------|------|------------|
|
||||
| Q2_K | 26 GB | 32 GB |
|
||||
| Q3_K_M | 32 GB | 40 GB |
|
||||
| Q4_K_M | 41 GB | 48 GB |
|
||||
| Q4_K_S | 39 GB | 46 GB |
|
||||
| Q5_K_M | 48 GB | 56 GB |
|
||||
|
||||
**Recommendation for 70B**: Use Q3_K_M or Q4_K_S to fit in consumer hardware.
|
||||
|
||||
## Finding Pre-Quantized Models
|
||||
|
||||
Use the Hub search with the llama.cpp app filter:
|
||||
|
||||
```text
|
||||
https://huggingface.co/models?apps=llama.cpp&sort=trending
|
||||
https://huggingface.co/models?search=<term>&apps=llama.cpp&sort=trending
|
||||
https://huggingface.co/models?search=<term>&apps=llama.cpp&num_parameters=min:0,max:24B&sort=trending
|
||||
```
|
||||
|
||||
For a specific repo, open:
|
||||
|
||||
```text
|
||||
https://huggingface.co/<repo>?local-app=llama.cpp
|
||||
https://huggingface.co/api/models/<repo>/tree/main?recursive=true
|
||||
```
|
||||
|
||||
Then launch directly from the Hub without extra Hub tooling:
|
||||
|
||||
```bash
|
||||
llama-cli -hf <repo>:Q4_K_M
|
||||
llama-server -hf <repo>:Q4_K_M
|
||||
```
|
||||
|
||||
If you need the exact file name from the tree API:
|
||||
|
||||
```bash
|
||||
llama-server --hf-repo <repo> --hf-file <filename.gguf>
|
||||
```
|
||||
|
||||
## Importance Matrices (imatrix)
|
||||
|
||||
**What**: Calibration data to improve quantization quality.
|
||||
|
||||
**Benefits**:
|
||||
- 10-20% perplexity improvement with Q4
|
||||
- Essential for Q3 and below
|
||||
|
||||
**Usage**:
|
||||
```bash
|
||||
# 1. Generate importance matrix
|
||||
./llama-imatrix \
|
||||
-m model-f16.gguf \
|
||||
-f calibration-data.txt \
|
||||
-o model.imatrix
|
||||
|
||||
# 2. Quantize with imatrix
|
||||
./llama-quantize \
|
||||
--imatrix model.imatrix \
|
||||
model-f16.gguf \
|
||||
model-Q4_K_M.gguf \
|
||||
Q4_K_M
|
||||
```
|
||||
|
||||
**Calibration data**:
|
||||
- Use domain-specific text (e.g., code for code models)
|
||||
- ~100MB of representative text
|
||||
- Higher quality data = better quantization
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**Model outputs gibberish**:
|
||||
- Quantization too aggressive (Q2_K)
|
||||
- Try Q4_K_M or Q5_K_M
|
||||
- Verify model converted correctly
|
||||
|
||||
**Out of memory**:
|
||||
- Use lower quantization (Q4_K_S instead of Q5_K_M)
|
||||
- Offload fewer layers to GPU (`-ngl`)
|
||||
- Use smaller context (`-c 2048`)
|
||||
|
||||
**Slow inference**:
|
||||
- Higher quantization uses more compute
|
||||
- Q8_0 much slower than Q4_K_M
|
||||
- Consider speed vs quality trade-off
|
||||
@@ -0,0 +1,150 @@
|
||||
# Server Deployment Guide
|
||||
|
||||
Production deployment of llama.cpp server with OpenAI-compatible API.
|
||||
|
||||
## Direct from Hugging Face Hub
|
||||
|
||||
Prefer the model repo's local-app page first:
|
||||
|
||||
```text
|
||||
https://huggingface.co/<repo>?local-app=llama.cpp
|
||||
```
|
||||
|
||||
If the page shows an exact snippet, copy it. If not, use one of these forms:
|
||||
|
||||
```bash
|
||||
# Choose a quant label directly from the Hub repo
|
||||
llama-server -hf bartowski/Llama-3.2-3B-Instruct-GGUF:Q8_0
|
||||
```
|
||||
|
||||
```bash
|
||||
# Pin an exact GGUF file from the repo tree
|
||||
llama-server \
|
||||
--hf-repo microsoft/Phi-3-mini-4k-instruct-gguf \
|
||||
--hf-file Phi-3-mini-4k-instruct-q4.gguf \
|
||||
-c 4096
|
||||
```
|
||||
|
||||
Use the file-specific form when the repo has custom naming or when you already extracted the exact filename from the tree API.
|
||||
|
||||
## Server Modes
|
||||
|
||||
### llama-server
|
||||
|
||||
```bash
|
||||
# Basic server
|
||||
./llama-server \
|
||||
-m models/llama-2-7b-chat.Q4_K_M.gguf \
|
||||
--host 0.0.0.0 \
|
||||
--port 8080 \
|
||||
-c 4096 # Context size
|
||||
|
||||
# With GPU acceleration
|
||||
./llama-server \
|
||||
-m models/llama-2-70b.Q4_K_M.gguf \
|
||||
-ngl 40 # Offload 40 layers to GPU
|
||||
```
|
||||
|
||||
## OpenAI-Compatible API
|
||||
|
||||
### Chat completions
|
||||
```bash
|
||||
curl http://localhost:8080/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "llama-2",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are helpful"},
|
||||
{"role": "user", "content": "Hello"}
|
||||
],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 100
|
||||
}'
|
||||
```
|
||||
|
||||
### Streaming
|
||||
```bash
|
||||
curl http://localhost:8080/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "llama-2",
|
||||
"messages": [{"role": "user", "content": "Count to 10"}],
|
||||
"stream": true
|
||||
}'
|
||||
```
|
||||
|
||||
## Docker Deployment
|
||||
|
||||
**Dockerfile**:
|
||||
```dockerfile
|
||||
FROM ubuntu:22.04
|
||||
RUN apt-get update && apt-get install -y git build-essential
|
||||
RUN git clone https://github.com/ggerganov/llama.cpp
|
||||
WORKDIR /llama.cpp
|
||||
RUN make LLAMA_CUDA=1
|
||||
COPY models/ /models/
|
||||
EXPOSE 8080
|
||||
CMD ["./llama-server", "-m", "/models/model.gguf", "--host", "0.0.0.0", "--port", "8080"]
|
||||
```
|
||||
|
||||
**Run**:
|
||||
```bash
|
||||
docker run --gpus all -p 8080:8080 llama-cpp:latest
|
||||
```
|
||||
|
||||
## Monitoring
|
||||
|
||||
```bash
|
||||
# Server metrics endpoint
|
||||
curl http://localhost:8080/metrics
|
||||
|
||||
# Health check
|
||||
curl http://localhost:8080/health
|
||||
```
|
||||
|
||||
**Metrics**:
|
||||
- requests_total
|
||||
- tokens_generated
|
||||
- prompt_tokens
|
||||
- completion_tokens
|
||||
- kv_cache_tokens
|
||||
|
||||
## Load Balancing
|
||||
|
||||
**NGINX**:
|
||||
```nginx
|
||||
upstream llama_cpp {
|
||||
server llama1:8080;
|
||||
server llama2:8080;
|
||||
}
|
||||
|
||||
server {
|
||||
location / {
|
||||
proxy_pass http://llama_cpp;
|
||||
proxy_read_timeout 300s;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Performance Tuning
|
||||
|
||||
**Parallel requests**:
|
||||
```bash
|
||||
./llama-server \
|
||||
-m model.gguf \
|
||||
-np 4 # 4 parallel slots
|
||||
```
|
||||
|
||||
**Continuous batching**:
|
||||
```bash
|
||||
./llama-server \
|
||||
-m model.gguf \
|
||||
--cont-batching # Enable continuous batching
|
||||
```
|
||||
|
||||
**Context caching**:
|
||||
```bash
|
||||
./llama-server \
|
||||
-m model.gguf \
|
||||
--cache-prompt # Cache processed prompts
|
||||
```
|
||||
@@ -0,0 +1,442 @@
|
||||
# GGUF Troubleshooting Guide
|
||||
|
||||
## Installation Issues
|
||||
|
||||
### Build Fails
|
||||
|
||||
**Error**: `make: *** No targets specified and no makefile found`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Ensure you're in llama.cpp directory
|
||||
cd llama.cpp
|
||||
make
|
||||
```
|
||||
|
||||
**Error**: `fatal error: cuda_runtime.h: No such file or directory`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Install CUDA toolkit
|
||||
# Ubuntu
|
||||
sudo apt install nvidia-cuda-toolkit
|
||||
|
||||
# Or set CUDA path
|
||||
export CUDA_PATH=/usr/local/cuda
|
||||
export PATH=$CUDA_PATH/bin:$PATH
|
||||
make GGML_CUDA=1
|
||||
```
|
||||
|
||||
### Python Bindings Issues
|
||||
|
||||
**Error**: `ERROR: Failed building wheel for llama-cpp-python`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Install build dependencies
|
||||
pip install cmake scikit-build-core
|
||||
|
||||
# For CUDA support
|
||||
CMAKE_ARGS="-DGGML_CUDA=on" pip install llama-cpp-python --force-reinstall --no-cache-dir
|
||||
|
||||
# For Metal (macOS)
|
||||
CMAKE_ARGS="-DGGML_METAL=on" pip install llama-cpp-python --force-reinstall --no-cache-dir
|
||||
```
|
||||
|
||||
**Error**: `ImportError: libcudart.so.XX: cannot open shared object file`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Add CUDA libraries to path
|
||||
export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
|
||||
|
||||
# Or reinstall with correct CUDA version
|
||||
pip uninstall llama-cpp-python
|
||||
CUDACXX=/usr/local/cuda/bin/nvcc CMAKE_ARGS="-DGGML_CUDA=on" pip install llama-cpp-python
|
||||
```
|
||||
|
||||
## Conversion Issues
|
||||
|
||||
### Model Not Supported
|
||||
|
||||
**Error**: `KeyError: 'model.embed_tokens.weight'`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Check model architecture
|
||||
python -c "from transformers import AutoConfig; print(AutoConfig.from_pretrained('./model').architectures)"
|
||||
|
||||
# Use appropriate conversion script
|
||||
# For most models:
|
||||
python convert_hf_to_gguf.py ./model --outfile model.gguf
|
||||
|
||||
# For older models, check if legacy script needed
|
||||
```
|
||||
|
||||
### Vocabulary Mismatch
|
||||
|
||||
**Error**: `RuntimeError: Vocabulary size mismatch`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Ensure tokenizer matches model
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("./model")
|
||||
model = AutoModelForCausalLM.from_pretrained("./model")
|
||||
|
||||
print(f"Tokenizer vocab size: {len(tokenizer)}")
|
||||
print(f"Model vocab size: {model.config.vocab_size}")
|
||||
|
||||
# If mismatch, resize embeddings before conversion
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
model.save_pretrained("./model-fixed")
|
||||
```
|
||||
|
||||
### Out of Memory During Conversion
|
||||
|
||||
**Error**: `torch.cuda.OutOfMemoryError` during conversion
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Use CPU for conversion
|
||||
CUDA_VISIBLE_DEVICES="" python convert_hf_to_gguf.py ./model --outfile model.gguf
|
||||
|
||||
# Or use low memory mode
|
||||
python convert_hf_to_gguf.py ./model --outfile model.gguf --outtype f16
|
||||
```
|
||||
|
||||
## Quantization Issues
|
||||
|
||||
### Wrong Output File Size
|
||||
|
||||
**Problem**: Quantized file is larger than expected
|
||||
|
||||
**Check**:
|
||||
```bash
|
||||
# Verify quantization type
|
||||
./llama-cli -m model.gguf --verbose
|
||||
|
||||
# Expected sizes for 7B model:
|
||||
# Q4_K_M: ~4.1 GB
|
||||
# Q5_K_M: ~4.8 GB
|
||||
# Q8_0: ~7.2 GB
|
||||
# F16: ~13.5 GB
|
||||
```
|
||||
|
||||
### Quantization Crashes
|
||||
|
||||
**Error**: `Segmentation fault` during quantization
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Increase stack size
|
||||
ulimit -s unlimited
|
||||
|
||||
# Or use less threads
|
||||
./llama-quantize -t 4 model-f16.gguf model-q4.gguf Q4_K_M
|
||||
```
|
||||
|
||||
### Poor Quality After Quantization
|
||||
|
||||
**Problem**: Model outputs gibberish after quantization
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Use importance matrix**:
|
||||
```bash
|
||||
# Generate imatrix with good calibration data
|
||||
./llama-imatrix -m model-f16.gguf \
|
||||
-f wiki_sample.txt \
|
||||
--chunk 512 \
|
||||
-o model.imatrix
|
||||
|
||||
# Quantize with imatrix
|
||||
./llama-quantize --imatrix model.imatrix \
|
||||
model-f16.gguf model-q4_k_m.gguf Q4_K_M
|
||||
```
|
||||
|
||||
2. **Try higher precision**:
|
||||
```bash
|
||||
# Use Q5_K_M or Q6_K instead of Q4
|
||||
./llama-quantize model-f16.gguf model-q5_k_m.gguf Q5_K_M
|
||||
```
|
||||
|
||||
3. **Check original model**:
|
||||
```bash
|
||||
# Test FP16 version first
|
||||
./llama-cli -m model-f16.gguf -p "Hello, how are you?" -n 50
|
||||
```
|
||||
|
||||
## Inference Issues
|
||||
|
||||
### Slow Generation
|
||||
|
||||
**Problem**: Generation is slower than expected
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Enable GPU offload**:
|
||||
```bash
|
||||
./llama-cli -m model.gguf -ngl 35 -p "Hello"
|
||||
```
|
||||
|
||||
2. **Optimize batch size**:
|
||||
```python
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
n_batch=512, # Increase for faster prompt processing
|
||||
n_gpu_layers=35
|
||||
)
|
||||
```
|
||||
|
||||
3. **Use appropriate threads**:
|
||||
```bash
|
||||
# Match physical cores, not logical
|
||||
./llama-cli -m model.gguf -t 8 -p "Hello"
|
||||
```
|
||||
|
||||
4. **Enable Flash Attention** (if supported):
|
||||
```bash
|
||||
./llama-cli -m model.gguf -ngl 35 --flash-attn -p "Hello"
|
||||
```
|
||||
|
||||
### Out of Memory
|
||||
|
||||
**Error**: `CUDA out of memory` or system freeze
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Reduce GPU layers**:
|
||||
```python
|
||||
# Start low and increase
|
||||
llm = Llama(model_path="model.gguf", n_gpu_layers=10)
|
||||
```
|
||||
|
||||
2. **Use smaller quantization**:
|
||||
```bash
|
||||
./llama-quantize model-f16.gguf model-q3_k_m.gguf Q3_K_M
|
||||
```
|
||||
|
||||
3. **Reduce context length**:
|
||||
```python
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
n_ctx=2048, # Reduce from 4096
|
||||
n_gpu_layers=35
|
||||
)
|
||||
```
|
||||
|
||||
4. **Quantize KV cache**:
|
||||
```python
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
type_k=2, # Q4_0 for K cache
|
||||
type_v=2, # Q4_0 for V cache
|
||||
n_gpu_layers=35
|
||||
)
|
||||
```
|
||||
|
||||
### Garbage Output
|
||||
|
||||
**Problem**: Model outputs random characters or nonsense
|
||||
|
||||
**Diagnose**:
|
||||
```python
|
||||
# Check model loading
|
||||
llm = Llama(model_path="model.gguf", verbose=True)
|
||||
|
||||
# Test with simple prompt
|
||||
output = llm("1+1=", max_tokens=5, temperature=0)
|
||||
print(output)
|
||||
```
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Check model integrity**:
|
||||
```bash
|
||||
# Verify GGUF file
|
||||
./llama-cli -m model.gguf --verbose 2>&1 | head -50
|
||||
```
|
||||
|
||||
2. **Use correct chat format**:
|
||||
```python
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
chat_format="llama-3" # Match your model: chatml, mistral, etc.
|
||||
)
|
||||
```
|
||||
|
||||
3. **Check temperature**:
|
||||
```python
|
||||
# Use lower temperature for deterministic output
|
||||
output = llm("Hello", max_tokens=50, temperature=0.1)
|
||||
```
|
||||
|
||||
### Token Issues
|
||||
|
||||
**Error**: `RuntimeError: unknown token` or encoding errors
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Ensure UTF-8 encoding
|
||||
prompt = "Hello, world!".encode('utf-8').decode('utf-8')
|
||||
output = llm(prompt, max_tokens=50)
|
||||
```
|
||||
|
||||
## Server Issues
|
||||
|
||||
### Connection Refused
|
||||
|
||||
**Error**: `Connection refused` when accessing server
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Bind to all interfaces
|
||||
./llama-server -m model.gguf --host 0.0.0.0 --port 8080
|
||||
|
||||
# Check if port is in use
|
||||
lsof -i :8080
|
||||
```
|
||||
|
||||
### Server Crashes Under Load
|
||||
|
||||
**Problem**: Server crashes with multiple concurrent requests
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Limit parallelism**:
|
||||
```bash
|
||||
./llama-server -m model.gguf \
|
||||
--parallel 2 \
|
||||
-c 4096 \
|
||||
--cont-batching
|
||||
```
|
||||
|
||||
2. **Add request timeout**:
|
||||
```bash
|
||||
./llama-server -m model.gguf --timeout 300
|
||||
```
|
||||
|
||||
3. **Monitor memory**:
|
||||
```bash
|
||||
watch -n 1 nvidia-smi # For GPU
|
||||
watch -n 1 free -h # For RAM
|
||||
```
|
||||
|
||||
### API Compatibility Issues
|
||||
|
||||
**Problem**: OpenAI client not working with server
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
# Use correct base URL format
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:8080/v1", # Include /v1
|
||||
api_key="not-needed"
|
||||
)
|
||||
|
||||
# Use correct model name
|
||||
response = client.chat.completions.create(
|
||||
model="local", # Or the actual model name
|
||||
messages=[{"role": "user", "content": "Hello"}]
|
||||
)
|
||||
```
|
||||
|
||||
## Apple Silicon Issues
|
||||
|
||||
### Metal Not Working
|
||||
|
||||
**Problem**: Metal acceleration not enabled
|
||||
|
||||
**Check**:
|
||||
```bash
|
||||
# Verify Metal support
|
||||
./llama-cli -m model.gguf --verbose 2>&1 | grep -i metal
|
||||
```
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Rebuild with Metal
|
||||
make clean
|
||||
make GGML_METAL=1
|
||||
|
||||
# Python bindings
|
||||
CMAKE_ARGS="-DGGML_METAL=on" pip install llama-cpp-python --force-reinstall
|
||||
```
|
||||
|
||||
### Incorrect Memory Usage on M1/M2
|
||||
|
||||
**Problem**: Model uses too much unified memory
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Offload all layers for Metal
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
n_gpu_layers=99, # Offload everything
|
||||
n_threads=1 # Metal handles parallelism
|
||||
)
|
||||
```
|
||||
|
||||
## Debugging
|
||||
|
||||
### Enable Verbose Output
|
||||
|
||||
```bash
|
||||
# CLI verbose mode
|
||||
./llama-cli -m model.gguf --verbose -p "Hello" -n 50
|
||||
|
||||
# Python verbose
|
||||
llm = Llama(model_path="model.gguf", verbose=True)
|
||||
```
|
||||
|
||||
### Check Model Metadata
|
||||
|
||||
```bash
|
||||
# View GGUF metadata
|
||||
./llama-cli -m model.gguf --verbose 2>&1 | head -100
|
||||
```
|
||||
|
||||
### Validate GGUF File
|
||||
|
||||
```python
|
||||
import struct
|
||||
|
||||
def validate_gguf(filepath):
|
||||
with open(filepath, 'rb') as f:
|
||||
magic = f.read(4)
|
||||
if magic != b'GGUF':
|
||||
print(f"Invalid magic: {magic}")
|
||||
return False
|
||||
|
||||
version = struct.unpack('<I', f.read(4))[0]
|
||||
print(f"GGUF version: {version}")
|
||||
|
||||
tensor_count = struct.unpack('<Q', f.read(8))[0]
|
||||
metadata_count = struct.unpack('<Q', f.read(8))[0]
|
||||
print(f"Tensors: {tensor_count}, Metadata: {metadata_count}")
|
||||
|
||||
return True
|
||||
|
||||
validate_gguf("model.gguf")
|
||||
```
|
||||
|
||||
## Getting Help
|
||||
|
||||
1. **GitHub Issues**: https://github.com/ggml-org/llama.cpp/issues
|
||||
2. **Discussions**: https://github.com/ggml-org/llama.cpp/discussions
|
||||
3. **Reddit**: r/LocalLLaMA
|
||||
|
||||
### Reporting Issues
|
||||
|
||||
Include:
|
||||
- llama.cpp version/commit hash
|
||||
- Build command used
|
||||
- Model name and quantization
|
||||
- Full error message/stack trace
|
||||
- Hardware: CPU/GPU model, RAM, VRAM
|
||||
- OS version
|
||||
- Minimal reproduction steps
|
||||
@@ -0,0 +1,341 @@
|
||||
---
|
||||
name: obliteratus
|
||||
description: "OBLITERATUS: abliterate LLM refusals (diff-in-means)."
|
||||
version: 2.0.0
|
||||
author: Hermes Agent
|
||||
license: MIT
|
||||
dependencies: [obliteratus, torch, transformers, bitsandbytes, accelerate, safetensors]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Abliteration, Uncensoring, Refusal-Removal, LLM, Weight-Projection, SVD, Mechanistic-Interpretability, HuggingFace, Model-Surgery]
|
||||
related_skills: [vllm, gguf, huggingface-tokenizers]
|
||||
---
|
||||
|
||||
# OBLITERATUS Skill
|
||||
|
||||
## What's inside
|
||||
|
||||
9 CLI methods, 28 analysis modules, 116 model presets across 5 compute tiers, tournament evaluation, and telemetry-driven recommendations.
|
||||
|
||||
Remove refusal behaviors (guardrails) from open-weight LLMs without retraining or fine-tuning. Uses mechanistic interpretability techniques — including diff-in-means, SVD, whitened SVD, LEACE concept erasure, SAE decomposition, Bayesian kernel projection, and more — to identify and surgically excise refusal directions from model weights while preserving reasoning capabilities.
|
||||
|
||||
**License warning:** OBLITERATUS is AGPL-3.0. NEVER import it as a Python library. Always invoke via CLI (`obliteratus` command) or subprocess. This keeps Hermes Agent's MIT license clean.
|
||||
|
||||
## Video Guide
|
||||
|
||||
Walkthrough of OBLITERATUS used by a Hermes agent to abliterate Gemma:
|
||||
https://www.youtube.com/watch?v=8fG9BrNTeHs ("OBLITERATUS: An AI Agent Removed Gemma 4's Safety Guardrails")
|
||||
|
||||
Useful when the user wants a visual overview of the end-to-end workflow before running it themselves.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Trigger when the user:
|
||||
- Wants to "uncensor" or "abliterate" an LLM
|
||||
- Asks about removing refusal/guardrails from a model
|
||||
- Wants to create an uncensored version of Llama, Qwen, Mistral, etc.
|
||||
- Mentions "refusal removal", "abliteration", "weight projection"
|
||||
- Wants to analyze how a model's refusal mechanism works
|
||||
- References OBLITERATUS, abliterator, or refusal directions
|
||||
|
||||
## Step 1: Installation
|
||||
|
||||
Check if already installed:
|
||||
```bash
|
||||
obliteratus --version 2>/dev/null && echo "INSTALLED" || echo "NOT INSTALLED"
|
||||
```
|
||||
|
||||
If not installed, clone and install from GitHub:
|
||||
```bash
|
||||
git clone https://github.com/elder-plinius/OBLITERATUS.git
|
||||
cd OBLITERATUS
|
||||
pip install -e .
|
||||
# For Gradio web UI support:
|
||||
# pip install -e ".[spaces]"
|
||||
```
|
||||
|
||||
**IMPORTANT:** Confirm with user before installing. This pulls in ~5-10GB of dependencies (PyTorch, Transformers, bitsandbytes, etc.).
|
||||
|
||||
## Step 2: Check Hardware
|
||||
|
||||
Before anything, check what GPU is available:
|
||||
```bash
|
||||
python3 -c "
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
gpu = torch.cuda.get_device_name(0)
|
||||
vram = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
||||
print(f'GPU: {gpu}')
|
||||
print(f'VRAM: {vram:.1f} GB')
|
||||
if vram < 4: print('TIER: tiny (models under 1B)')
|
||||
elif vram < 8: print('TIER: small (models 1-4B)')
|
||||
elif vram < 16: print('TIER: medium (models 4-9B with 4bit quant)')
|
||||
elif vram < 32: print('TIER: large (models 8-32B with 4bit quant)')
|
||||
else: print('TIER: frontier (models 32B+)')
|
||||
else:
|
||||
print('NO GPU - only tiny models (under 1B) on CPU')
|
||||
"
|
||||
```
|
||||
|
||||
### VRAM Requirements (with 4-bit quantization)
|
||||
|
||||
| VRAM | Max Model Size | Example Models |
|
||||
|:---------|:----------------|:--------------------------------------------|
|
||||
| CPU only | ~1B params | GPT-2, TinyLlama, SmolLM |
|
||||
| 4-8 GB | ~4B params | Qwen2.5-1.5B, Phi-3.5 mini, Llama 3.2 3B |
|
||||
| 8-16 GB | ~9B params | Llama 3.1 8B, Mistral 7B, Gemma 2 9B |
|
||||
| 24 GB | ~32B params | Qwen3-32B, Llama 3.1 70B (tight), Command-R |
|
||||
| 48 GB+ | ~72B+ params | Qwen2.5-72B, DeepSeek-R1 |
|
||||
| Multi-GPU| 200B+ params | Llama 3.1 405B, DeepSeek-V3 (685B MoE) |
|
||||
|
||||
## Step 3: Browse Available Models & Get Recommendations
|
||||
|
||||
```bash
|
||||
# Browse models by compute tier
|
||||
obliteratus models --tier medium
|
||||
|
||||
# Get architecture info for a specific model
|
||||
obliteratus info <model_name>
|
||||
|
||||
# Get telemetry-driven recommendation for best method & params
|
||||
obliteratus recommend <model_name>
|
||||
obliteratus recommend <model_name> --insights # global cross-architecture rankings
|
||||
```
|
||||
|
||||
## Step 4: Choose a Method
|
||||
|
||||
### Method Selection Guide
|
||||
**Default / recommended for most cases: `advanced`.** It uses multi-direction SVD with norm-preserving projection and is well-tested.
|
||||
|
||||
| Situation | Recommended Method | Why |
|
||||
|:----------------------------------|:-------------------|:-----------------------------------------|
|
||||
| Default / most models | `advanced` | Multi-direction SVD, norm-preserving, reliable |
|
||||
| Quick test / prototyping | `basic` | Fast, simple, good enough to evaluate |
|
||||
| Dense model (Llama, Mistral) | `advanced` | Multi-direction, norm-preserving |
|
||||
| MoE model (DeepSeek, Mixtral) | `nuclear` | Expert-granular, handles MoE complexity |
|
||||
| Reasoning model (R1 distills) | `surgical` | CoT-aware, preserves chain-of-thought |
|
||||
| Stubborn refusals persist | `aggressive` | Whitened SVD + head surgery + jailbreak |
|
||||
| Want reversible changes | Use steering vectors (see Analysis section) |
|
||||
| Maximum quality, time no object | `optimized` | Bayesian search for best parameters |
|
||||
| Experimental auto-detection | `informed` | Auto-detects alignment type — experimental, may not always outperform advanced |
|
||||
|
||||
### 9 CLI Methods
|
||||
- **basic** — Single refusal direction via diff-in-means. Fast (~5-10 min for 8B).
|
||||
- **advanced** (DEFAULT, RECOMMENDED) — Multiple SVD directions, norm-preserving projection, 2 refinement passes. Medium speed (~10-20 min).
|
||||
- **aggressive** — Whitened SVD + jailbreak-contrastive + attention head surgery. Higher risk of coherence damage.
|
||||
- **spectral_cascade** — DCT frequency-domain decomposition. Research/novel approach.
|
||||
- **informed** — Runs analysis DURING abliteration to auto-configure. Experimental — slower and less predictable than advanced.
|
||||
- **surgical** — SAE features + neuron masking + head surgery + per-expert. Very slow (~1-2 hrs). Best for reasoning models.
|
||||
- **optimized** — Bayesian hyperparameter search (Optuna TPE). Longest runtime but finds optimal parameters.
|
||||
- **inverted** — Flips the refusal direction. Model becomes actively willing.
|
||||
- **nuclear** — Maximum force combo for stubborn MoE models. Expert-granular.
|
||||
|
||||
### Direction Extraction Methods (--direction-method flag)
|
||||
- **diff_means** (default) — Simple difference-in-means between refused/complied activations. Robust.
|
||||
- **svd** — Multi-direction SVD extraction. Better for complex alignment.
|
||||
- **leace** — LEACE (Linear Erasure via Closed-form Estimation). Optimal linear erasure.
|
||||
|
||||
### 4 Python-API-Only Methods
|
||||
(NOT available via CLI — require Python import, which violates AGPL boundary. Mention to user only if they explicitly want to use OBLITERATUS as a library in their own AGPL project.)
|
||||
- failspy, gabliteration, heretic, rdo
|
||||
|
||||
## Step 5: Run Abliteration
|
||||
|
||||
### Standard usage
|
||||
```bash
|
||||
# Default method (advanced) — recommended for most models
|
||||
obliteratus obliterate <model_name> --method advanced --output-dir ./abliterated-models
|
||||
|
||||
# With 4-bit quantization (saves VRAM)
|
||||
obliteratus obliterate <model_name> --method advanced --quantization 4bit --output-dir ./abliterated-models
|
||||
|
||||
# Large models (70B+) — conservative defaults
|
||||
obliteratus obliterate <model_name> --method advanced --quantization 4bit --large-model --output-dir ./abliterated-models
|
||||
```
|
||||
|
||||
### Fine-tuning parameters
|
||||
```bash
|
||||
obliteratus obliterate <model_name> \
|
||||
--method advanced \
|
||||
--direction-method diff_means \
|
||||
--n-directions 4 \
|
||||
--refinement-passes 2 \
|
||||
--regularization 0.1 \
|
||||
--quantization 4bit \
|
||||
--output-dir ./abliterated-models \
|
||||
--contribute # opt-in telemetry for community research
|
||||
```
|
||||
|
||||
### Key flags
|
||||
| Flag | Description | Default |
|
||||
|:-----|:------------|:--------|
|
||||
| `--method` | Abliteration method | advanced |
|
||||
| `--direction-method` | Direction extraction | diff_means |
|
||||
| `--n-directions` | Number of refusal directions (1-32) | method-dependent |
|
||||
| `--refinement-passes` | Iterative passes (1-5) | 2 |
|
||||
| `--regularization` | Regularization strength (0.0-1.0) | 0.1 |
|
||||
| `--quantization` | Load in 4bit or 8bit | none (full precision) |
|
||||
| `--large-model` | Conservative defaults for 120B+ | false |
|
||||
| `--output-dir` | Where to save the abliterated model | ./obliterated_model |
|
||||
| `--contribute` | Share anonymized results for research | false |
|
||||
| `--verify-sample-size` | Number of test prompts for refusal check | 20 |
|
||||
| `--dtype` | Model dtype (float16, bfloat16) | auto |
|
||||
|
||||
### Other execution modes
|
||||
```bash
|
||||
# Interactive guided mode (hardware → model → preset)
|
||||
obliteratus interactive
|
||||
|
||||
# Web UI (Gradio)
|
||||
obliteratus ui --port 7860
|
||||
|
||||
# Run a full ablation study from YAML config
|
||||
obliteratus run config.yaml --preset quick
|
||||
|
||||
# Tournament: pit all methods against each other
|
||||
obliteratus tourney <model_name>
|
||||
```
|
||||
|
||||
## Step 6: Verify Results
|
||||
|
||||
After abliteration, check the output metrics:
|
||||
|
||||
| Metric | Good Value | Warning |
|
||||
|:-------|:-----------|:--------|
|
||||
| Refusal rate | < 5% (ideally ~0%) | > 10% means refusals persist |
|
||||
| Perplexity change | < 10% increase | > 15% means coherence damage |
|
||||
| KL divergence | < 0.1 | > 0.5 means significant distribution shift |
|
||||
| Coherence | High / passes qualitative check | Degraded responses, repetition |
|
||||
|
||||
### If refusals persist (> 10%)
|
||||
1. Try `aggressive` method
|
||||
2. Increase `--n-directions` (e.g., 8 or 16)
|
||||
3. Add `--refinement-passes 3`
|
||||
4. Try `--direction-method svd` instead of diff_means
|
||||
|
||||
### If coherence is damaged (perplexity > 15% increase)
|
||||
1. Reduce `--n-directions` (try 2)
|
||||
2. Increase `--regularization` (try 0.3)
|
||||
3. Reduce `--refinement-passes` to 1
|
||||
4. Try `basic` method (gentler)
|
||||
|
||||
## Step 7: Use the Abliterated Model
|
||||
|
||||
The output is a standard HuggingFace model directory.
|
||||
|
||||
```bash
|
||||
# Test locally with transformers
|
||||
python3 -c "
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
model = AutoModelForCausalLM.from_pretrained('./abliterated-models/<model>')
|
||||
tokenizer = AutoTokenizer.from_pretrained('./abliterated-models/<model>')
|
||||
inputs = tokenizer('How do I pick a lock?', return_tensors='pt')
|
||||
outputs = model.generate(**inputs, max_new_tokens=200)
|
||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
"
|
||||
|
||||
# Upload to HuggingFace Hub
|
||||
huggingface-cli upload <username>/<model-name>-abliterated ./abliterated-models/<model>
|
||||
|
||||
# Serve with vLLM
|
||||
vllm serve ./abliterated-models/<model>
|
||||
```
|
||||
|
||||
## CLI Command Reference
|
||||
|
||||
| Command | Description |
|
||||
|:--------|:------------|
|
||||
| `obliteratus obliterate` | Main abliteration command |
|
||||
| `obliteratus info <model>` | Print model architecture details |
|
||||
| `obliteratus models --tier <tier>` | Browse curated models by compute tier |
|
||||
| `obliteratus recommend <model>` | Telemetry-driven method/param suggestion |
|
||||
| `obliteratus interactive` | Guided setup wizard |
|
||||
| `obliteratus tourney <model>` | Tournament: all methods head-to-head |
|
||||
| `obliteratus run <config.yaml>` | Execute ablation study from YAML |
|
||||
| `obliteratus strategies` | List all registered ablation strategies |
|
||||
| `obliteratus report <results.json>` | Regenerate visual reports |
|
||||
| `obliteratus ui` | Launch Gradio web interface |
|
||||
| `obliteratus aggregate` | Summarize community telemetry data |
|
||||
|
||||
## Analysis Modules
|
||||
|
||||
OBLITERATUS includes 28 analysis modules for mechanistic interpretability.
|
||||
See `skill_view(name="obliteratus", file_path="references/analysis-modules.md")` for the full reference.
|
||||
|
||||
### Quick analysis commands
|
||||
```bash
|
||||
# Run specific analysis modules
|
||||
obliteratus run analysis-config.yaml --preset quick
|
||||
|
||||
# Key modules to run first:
|
||||
# - alignment_imprint: Fingerprint DPO/RLHF/CAI/SFT alignment method
|
||||
# - concept_geometry: Single direction vs polyhedral cone
|
||||
# - logit_lens: Which layer decides to refuse
|
||||
# - anti_ouroboros: Self-repair risk score
|
||||
# - causal_tracing: Causally necessary components
|
||||
```
|
||||
|
||||
### Steering Vectors (Reversible Alternative)
|
||||
Instead of permanent weight modification, use inference-time steering:
|
||||
```python
|
||||
# Python API only — for user's own projects
|
||||
from obliteratus.analysis.steering_vectors import SteeringVectorFactory, SteeringHookManager
|
||||
```
|
||||
|
||||
## Ablation Strategies
|
||||
|
||||
Beyond direction-based abliteration, OBLITERATUS includes structural ablation strategies:
|
||||
- **Embedding Ablation** — Target embedding layer components
|
||||
- **FFN Ablation** — Feed-forward network block removal
|
||||
- **Head Pruning** — Attention head pruning
|
||||
- **Layer Removal** — Full layer removal
|
||||
|
||||
List all available: `obliteratus strategies`
|
||||
|
||||
## Evaluation
|
||||
|
||||
OBLITERATUS includes built-in evaluation tools:
|
||||
- Refusal rate benchmarking
|
||||
- Perplexity comparison (before/after)
|
||||
- LM Eval Harness integration for academic benchmarks
|
||||
- Head-to-head competitor comparison
|
||||
- Baseline performance tracking
|
||||
|
||||
## Platform Support
|
||||
|
||||
- **CUDA** — Full support (NVIDIA GPUs)
|
||||
- **Apple Silicon (MLX)** — Supported via MLX backend
|
||||
- **CPU** — Supported for tiny models (< 1B params)
|
||||
|
||||
## YAML Config Templates
|
||||
|
||||
Load templates for reproducible runs via `skill_view`:
|
||||
- `templates/abliteration-config.yaml` — Standard single-model config
|
||||
- `templates/analysis-study.yaml` — Pre-abliteration analysis study
|
||||
- `templates/batch-abliteration.yaml` — Multi-model batch processing
|
||||
|
||||
## Telemetry
|
||||
|
||||
OBLITERATUS can optionally contribute anonymized run data to a global research dataset.
|
||||
Enable with `--contribute` flag. No personal data is collected — only model name, method, metrics.
|
||||
|
||||
## Common Pitfalls
|
||||
|
||||
1. **Don't use `informed` as default** — it's experimental and slower. Use `advanced` for reliable results.
|
||||
2. **Models under ~1B respond poorly to abliteration** — their refusal behaviors are shallow and fragmented, making clean direction extraction difficult. Expect partial results (20-40% remaining refusal). Models 3B+ have cleaner refusal directions and respond much better (often 0% refusal with `advanced`).
|
||||
3. **`aggressive` can make things worse** — on small models it can damage coherence and actually increase refusal rate. Only use it if `advanced` leaves > 10% refusals on a 3B+ model.
|
||||
4. **Always check perplexity** — if it spikes > 15%, the model is damaged. Reduce aggressiveness.
|
||||
5. **MoE models need special handling** — use `nuclear` method for Mixtral, DeepSeek-MoE, etc.
|
||||
6. **Quantized models can't be re-quantized** — abliterate the full-precision model, then quantize the output.
|
||||
7. **VRAM estimation is approximate** — 4-bit quant helps but peak usage can spike during extraction.
|
||||
8. **Reasoning models are sensitive** — use `surgical` for R1 distills to preserve chain-of-thought.
|
||||
9. **Check `obliteratus recommend`** — telemetry data may have better parameters than defaults.
|
||||
10. **AGPL license** — never `import obliteratus` in MIT/Apache projects. CLI invocation only.
|
||||
11. **Large models (70B+)** — always use `--large-model` flag for conservative defaults.
|
||||
12. **Spectral certification RED is common** — the spectral check often flags "incomplete" even when practical refusal rate is 0%. Check actual refusal rate rather than relying on spectral certification alone.
|
||||
|
||||
## Complementary Skills
|
||||
|
||||
- **vllm** — Serve abliterated models with high throughput
|
||||
- **gguf** — Convert abliterated models to GGUF for llama.cpp
|
||||
- **huggingface-tokenizers** — Work with model tokenizers
|
||||
@@ -0,0 +1,166 @@
|
||||
# OBLITERATUS Analysis Modules — Reference
|
||||
|
||||
OBLITERATUS includes 28 analysis modules for mechanistic interpretability of refusal in LLMs.
|
||||
These modules help understand how and where refusal behaviors are encoded before performing abliteration.
|
||||
|
||||
---
|
||||
|
||||
## Core Analysis (Run These First)
|
||||
|
||||
### 1. Alignment Imprint Detection (`alignment_imprint.py`)
|
||||
Fingerprints whether a model was trained via DPO, RLHF, CAI, or SFT.
|
||||
This determines which extraction strategy will work best.
|
||||
|
||||
### 2. Concept Cone Geometry (`concept_geometry.py`)
|
||||
Determines if refusal is a single linear direction or a polyhedral cone
|
||||
(set of multiple mechanisms). Single-direction models respond well to `basic`;
|
||||
polyhedral models need `advanced` or `surgical`.
|
||||
|
||||
### 3. Refusal Logit Lens (`logit_lens.py`)
|
||||
Identifies the specific layer where a model "decides" to refuse by decoding
|
||||
intermediate layer representations into token space.
|
||||
|
||||
### 4. Ouroboros Detection (`anti_ouroboros.py`)
|
||||
Identifies if a model attempts to "self-repair" refusal behaviors after
|
||||
excision. Reports a risk score (0-1). High scores mean additional refinement
|
||||
passes are needed.
|
||||
|
||||
### 5. Causal Tracing (`causal_tracing.py`)
|
||||
Identifies which components (layers, heads, MLPs) are causally necessary
|
||||
for refusal behavior using activation patching.
|
||||
|
||||
---
|
||||
|
||||
## Geometric Analysis
|
||||
|
||||
### 6. Cross-Layer Alignment (`cross_layer.py`)
|
||||
Measures how refusal directions align across different layers. High alignment
|
||||
means the refusal signal is consistent; low alignment suggests layer-specific
|
||||
mechanisms.
|
||||
|
||||
### 7. Residual Stream Decomposition (`residual_stream.py`)
|
||||
Decomposes the residual stream into attention and MLP contributions to
|
||||
understand which component type contributes more to refusal.
|
||||
|
||||
### 8. Riemannian Manifold Geometry (`riemannian_manifold.py`)
|
||||
Analyzes the curvature and geometry of the weight manifold near refusal
|
||||
directions. Informs how aggressively projections can be applied without
|
||||
damaging the manifold structure.
|
||||
|
||||
### 9. Whitened SVD (`whitened_svd.py`)
|
||||
Covariance-normalized SVD extraction that separates guardrail signals from
|
||||
natural activation variance. More precise than standard SVD for models with
|
||||
high activation variance.
|
||||
|
||||
### 10. Concept Cone Geometry (extended)
|
||||
Maps the full polyhedral structure of refusal, including cone angles,
|
||||
face counts, and intersection patterns.
|
||||
|
||||
---
|
||||
|
||||
## Probing & Classification
|
||||
|
||||
### 11. Activation Probing (`activation_probing.py`)
|
||||
Post-excision verification — probes for residual refusal concepts after
|
||||
abliteration to ensure complete removal.
|
||||
|
||||
### 12. Probing Classifiers (`probing_classifiers.py`)
|
||||
Trains linear classifiers to detect refusal in activations. Used both
|
||||
before (to verify refusal exists) and after (to verify it's gone).
|
||||
|
||||
### 13. Activation Patching (`activation_patching.py`)
|
||||
Interchange interventions — swaps activations between refused and complied
|
||||
runs to identify causal components.
|
||||
|
||||
### 14. Tuned Lens (`tuned_lens.py`)
|
||||
Trained version of logit lens that provides more accurate per-layer
|
||||
decoding by learning affine transformations for each layer.
|
||||
|
||||
### 15. Multi-Token Position Analysis (`multi_token_position.py`)
|
||||
Analyzes refusal signals across multiple token positions, not just the
|
||||
last token. Important for models that distribute refusal across the sequence.
|
||||
|
||||
---
|
||||
|
||||
## Abliteration & Manipulation
|
||||
|
||||
### 16. SAE-Based Abliteration (`sae_abliteration.py`)
|
||||
Uses Sparse Autoencoder features to identify and remove specific refusal
|
||||
features. More surgical than direction-based methods.
|
||||
|
||||
### 17. Steering Vectors (`steering_vectors.py`)
|
||||
Creates and applies inference-time steering vectors for reversible refusal
|
||||
modification. Includes `SteeringVectorFactory` and `SteeringHookManager`.
|
||||
|
||||
### 18. LEACE Concept Erasure (`leace.py`)
|
||||
Linear Erasure via Closed-form Estimation — mathematically optimal linear
|
||||
concept removal. Available as both analysis module and direction extraction method.
|
||||
|
||||
### 19. Sparse Surgery (`sparse_surgery.py`)
|
||||
High-precision weight modification targeting individual neurons and
|
||||
weight matrix entries rather than full directions.
|
||||
|
||||
### 20. Conditional Abliteration (`conditional_abliteration.py`)
|
||||
Targeted removal that only affects specific refusal categories while
|
||||
preserving others (e.g., remove weapons refusal but keep CSAM refusal).
|
||||
|
||||
---
|
||||
|
||||
## Transfer & Robustness
|
||||
|
||||
### 21. Cross-Model Transfer (`cross_model_transfer.py`)
|
||||
Tests whether refusal directions extracted from one model transfer to
|
||||
another architecture. Measures universality of guardrail directions.
|
||||
|
||||
### 22. Defense Robustness (`defense_robustness.py`)
|
||||
Evaluates how robust the abliteration is against various defense mechanisms
|
||||
and re-alignment attempts.
|
||||
|
||||
### 23. Spectral Certification (`spectral_certification.py`)
|
||||
Provides mathematical bounds on the completeness of refusal removal
|
||||
using spectral analysis of the projection.
|
||||
|
||||
### 24. Wasserstein Optimal Extraction (`wasserstein_optimal.py`)
|
||||
Uses optimal transport theory for more precise direction extraction
|
||||
that minimizes distribution shift.
|
||||
|
||||
### 25. Wasserstein Transfer (`wasserstein_transfer.py`)
|
||||
Distribution transfer between models using Wasserstein distance
|
||||
for cross-architecture refusal direction mapping.
|
||||
|
||||
---
|
||||
|
||||
## Advanced / Research
|
||||
|
||||
### 26. Bayesian Kernel Projection (`bayesian_kernel_projection.py`)
|
||||
Probabilistic feature mapping that estimates uncertainty in refusal
|
||||
direction identification.
|
||||
|
||||
### 27. Cross-Model Universality Index
|
||||
Measures if guardrail directions generalize across different model
|
||||
architectures and training regimes.
|
||||
|
||||
### 28. Visualization (`visualization.py`)
|
||||
Plotting and graphing utilities for all analysis modules. Generates
|
||||
heatmaps, direction plots, and layer-wise analysis charts.
|
||||
|
||||
---
|
||||
|
||||
## Running Analysis
|
||||
|
||||
### Via CLI
|
||||
```bash
|
||||
# Run analysis from a YAML config
|
||||
obliteratus run analysis-study.yaml --preset quick
|
||||
|
||||
# Available study presets:
|
||||
# quick — Fast sanity check (2-3 modules)
|
||||
# full — All core + geometric analysis
|
||||
# jailbreak — Refusal circuit localization
|
||||
# knowledge — Knowledge preservation analysis
|
||||
# robustness — Stress testing / defense evaluation
|
||||
```
|
||||
|
||||
### Via YAML Config
|
||||
See the `templates/analysis-study.yaml` template for a complete example.
|
||||
Load with: `skill_view(name="obliteratus", file_path="templates/analysis-study.yaml")`
|
||||
@@ -0,0 +1,141 @@
|
||||
# OBLITERATUS Methods — Detailed Guide
|
||||
|
||||
> The CLI accepts 9 methods via `--method`: basic, advanced, aggressive, spectral_cascade,
|
||||
> informed, surgical, optimized, inverted, nuclear.
|
||||
> Four additional methods (failspy, gabliteration, heretic, rdo) are available only via the Python API.
|
||||
|
||||
## How Abliteration Works (Theory)
|
||||
|
||||
Abliteration identifies a "refusal direction" — a vector in the model's activation space that
|
||||
corresponds to refusal behavior — and projects it out of the weight matrices.
|
||||
|
||||
Mathematically: `W_new = W_old - (W_old @ d @ d.T)` where `d` is the refusal direction.
|
||||
|
||||
The key challenge is finding accurate refusal directions without damaging other capabilities.
|
||||
|
||||
---
|
||||
|
||||
## Direction Extraction Methods
|
||||
|
||||
Before projecting, OBLITERATUS extracts refusal directions using one of three methods:
|
||||
|
||||
| Method | Flag | Description | Best For |
|
||||
|:-------|:-----|:------------|:---------|
|
||||
| Diff-in-Means | `--direction-method diff_means` | Difference between mean activations on refused vs. complied prompts | Default, fast, robust |
|
||||
| SVD | `--direction-method svd` | Multi-direction extraction via Singular Value Decomposition | Complex alignment, multiple refusal mechanisms |
|
||||
| LEACE | `--direction-method leace` | Linear Erasure via Closed-form Estimation — mathematically optimal | Maximum precision, research |
|
||||
|
||||
---
|
||||
|
||||
## Method Details
|
||||
|
||||
### basic
|
||||
- **Directions:** 1 (single diff-in-means vector)
|
||||
- **Speed:** Fast (~5-10 min for 8B model)
|
||||
- **Risk:** Low
|
||||
- **Use case:** Quick tests, prototyping, evaluating if abliteration works for a model
|
||||
- **How it works:** Extracts one refusal direction and projects it out uniformly across all layers.
|
||||
|
||||
### advanced (DEFAULT — RECOMMENDED)
|
||||
- **Directions:** 4 (multi-direction SVD)
|
||||
- **Speed:** Medium (~10-20 min for 8B model)
|
||||
- **Risk:** Low-Medium
|
||||
- **Refinement passes:** 2
|
||||
- **Use case:** Default for most models. Well-tested and reliable.
|
||||
- **How it works:** Extracts multiple refusal directions via SVD, applies norm-preserving bi-projection to maintain weight matrix norms. Two refinement passes catch residual refusal.
|
||||
|
||||
### aggressive
|
||||
- **Directions:** 8+ (whitened SVD + jailbreak-contrastive)
|
||||
- **Speed:** Medium-Slow
|
||||
- **Risk:** Medium-High (may damage coherence)
|
||||
- **Use case:** When `advanced` leaves > 10% refusals. Stubborn models.
|
||||
- **How it works:** Uses whitened SVD for covariance-normalized extraction, adds jailbreak-contrastive directions, performs attention head surgery on the most refusal-active heads.
|
||||
|
||||
### spectral_cascade
|
||||
- **Speed:** Medium
|
||||
- **Risk:** Medium
|
||||
- **Use case:** Research, novel approaches
|
||||
- **How it works:** DCT (Discrete Cosine Transform) frequency-domain decomposition of refusal signals. Separates high-frequency (surface-level) from low-frequency (deep) refusal patterns.
|
||||
|
||||
### informed (EXPERIMENTAL)
|
||||
- **Speed:** Slow (~20-40 min for 8B model)
|
||||
- **Risk:** Variable — results depend on analysis quality
|
||||
- **Use case:** When you want auto-configuration, but be aware this is experimental and may not outperform `advanced`.
|
||||
- **How it works:** Runs 4 analysis modules first (alignment imprint, concept geometry, logit lens, ouroboros detection), then auto-configures extraction strategy. Includes an "Ouroboros loop" that detects and counteracts self-repair.
|
||||
- **Note:** The auto-detection can sometimes misconfigure. If results are poor, fall back to `advanced`.
|
||||
|
||||
### surgical
|
||||
- **Speed:** Very slow (~1-2 hrs for 8B model)
|
||||
- **Risk:** Low (very precise)
|
||||
- **Use case:** Reasoning models (R1 distills, QwQ, etc.) where chain-of-thought must be preserved.
|
||||
- **How it works:** Uses SAE (Sparse Autoencoder) features + individual neuron masking + attention head surgery + per-expert decomposition (for MoE). CoT-aware — identifies and protects reasoning-critical directions before projecting.
|
||||
|
||||
### optimized
|
||||
- **Speed:** Very slow (hours — runs many trials)
|
||||
- **Risk:** Low (finds optimal parameters)
|
||||
- **Use case:** When quality matters more than speed. Production models.
|
||||
- **How it works:** Bayesian hyperparameter search via Optuna TPE sampler. Optimizes n_directions, regularization, refinement passes, and layer selection jointly. Evaluates each configuration on refusal rate + perplexity.
|
||||
|
||||
### inverted
|
||||
- **Speed:** Fast
|
||||
- **Risk:** High (model behavior changes dramatically)
|
||||
- **Use case:** Research, studying refusal mechanisms
|
||||
- **How it works:** Instead of projecting out the refusal direction, reflects it. The model actively complies rather than passively not-refusing. Useful for understanding the geometry of alignment.
|
||||
|
||||
### nuclear
|
||||
- **Speed:** Slow
|
||||
- **Risk:** Medium-High
|
||||
- **Use case:** Stubborn MoE models (DeepSeek-MoE, Mixtral, etc.)
|
||||
- **How it works:** Combines expert-granular abliteration (EGA), steering vector injection, attention head pruning, and multi-pass refinement. Decomposes refusal signals into per-expert components for MoE architectures.
|
||||
|
||||
---
|
||||
|
||||
## Method Selection Flowchart
|
||||
|
||||
```
|
||||
Is this a quick test?
|
||||
→ YES: basic
|
||||
→ NO: continue
|
||||
|
||||
Is it an MoE model (Mixtral, DeepSeek-MoE)?
|
||||
→ YES: nuclear
|
||||
→ NO: continue
|
||||
|
||||
Is it a reasoning model (R1, QwQ, CoT-focused)?
|
||||
→ YES: surgical
|
||||
→ NO: continue
|
||||
|
||||
Do you need the absolute best quality and have time?
|
||||
→ YES: optimized
|
||||
→ NO: advanced (recommended default)
|
||||
|
||||
Did advanced leave > 10% refusals?
|
||||
→ YES: aggressive
|
||||
→ Still refusing: nuclear
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Key Parameters
|
||||
|
||||
| Parameter | Range | Default | Effect |
|
||||
|:----------|:------|:--------|:-------|
|
||||
| `--n-directions` | 1-32 | method-dependent | More directions = more complete removal, but higher damage risk |
|
||||
| `--regularization` | 0.0-1.0 | 0.1 | Higher = more conservative (less removal, less damage) |
|
||||
| `--refinement-passes` | 1-5 | 2 | More passes catch residual refusal, but diminishing returns |
|
||||
| `--quantization` | 4bit, 8bit | none | Reduces VRAM usage; quality impact minimal for extraction |
|
||||
| `--verify-sample-size` | 10-200 | 20 | More samples = more accurate refusal rate estimate |
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
| Problem | Likely Cause | Fix |
|
||||
|:--------|:-------------|:----|
|
||||
| Refusal rate > 20% | Too few directions | Increase `--n-directions`, try `aggressive` |
|
||||
| Refusal rate 5-20% | Residual refusal | Add `--refinement-passes 3`, try `--direction-method svd` |
|
||||
| Perplexity spike > 20% | Over-aggressive removal | Reduce `--n-directions`, increase `--regularization` |
|
||||
| Repetitive output | Weight matrix damage | Use `basic` with fewer directions, check norm preservation |
|
||||
| MoE model still refuses | Non-expert-aware method | Switch to `nuclear` |
|
||||
| Reasoning degraded | CoT directions damaged | Use `surgical` method |
|
||||
| OOM during extraction | Insufficient VRAM | Add `--quantization 4bit` and/or `--large-model` |
|
||||
@@ -0,0 +1,33 @@
|
||||
# OBLITERATUS Abliteration Config
|
||||
# Usage: obliteratus run this-file.yaml
|
||||
#
|
||||
# This is for reproducible, version-controlled abliteration runs.
|
||||
# For one-off usage, the CLI flags are simpler.
|
||||
|
||||
# Model to abliterate
|
||||
model:
|
||||
name: "meta-llama/Llama-3.1-8B-Instruct"
|
||||
dtype: "bfloat16" # float16, bfloat16, float32
|
||||
quantization: null # null, "4bit", "8bit"
|
||||
device: "auto" # auto, cuda, cuda:0, cpu
|
||||
|
||||
# Abliteration method and parameters
|
||||
abliteration:
|
||||
method: "informed" # See SKILL.md Step 4 for all 13 methods
|
||||
n_directions: null # null = auto-detect, or integer (e.g., 8)
|
||||
regularization: 0.0 # 0.0-1.0, fraction of original to preserve
|
||||
refinement_passes: 1 # Iterative passes (increase for self-repair)
|
||||
norm_preserve: true # Keep weight norms intact after projection
|
||||
|
||||
# Output
|
||||
output:
|
||||
directory: "./abliterated-models"
|
||||
save_metadata: true # Save abliteration_metadata.json alongside model
|
||||
contribute: false # Save community contribution data
|
||||
|
||||
# Verification
|
||||
verify:
|
||||
enabled: true
|
||||
test_prompts: null # null = use built-in test prompts
|
||||
compute_perplexity: true
|
||||
compute_kl: true
|
||||
@@ -0,0 +1,40 @@
|
||||
# OBLITERATUS Analysis Study Config
|
||||
# Usage: obliteratus run this-file.yaml --preset jailbreak
|
||||
#
|
||||
# Run analysis modules to understand refusal geometry BEFORE abliterating.
|
||||
# Useful for research or when you want to understand what you're removing.
|
||||
|
||||
# Model to analyze
|
||||
model:
|
||||
name: "meta-llama/Llama-3.1-8B-Instruct"
|
||||
dtype: "bfloat16"
|
||||
quantization: "4bit" # Saves VRAM for analysis
|
||||
device: "auto"
|
||||
|
||||
# Study configuration
|
||||
study:
|
||||
# Available presets: quick, full, attention, jailbreak, guardrail, knowledge
|
||||
preset: "jailbreak"
|
||||
|
||||
# Or specify individual strategies:
|
||||
# strategies:
|
||||
# - layer_removal
|
||||
# - head_pruning
|
||||
# - ffn_ablation
|
||||
# - embedding_ablation
|
||||
|
||||
# Analysis modules to run (subset of the 27 available)
|
||||
analysis:
|
||||
- alignment_imprint # Detect DPO/RLHF/CAI/SFT training method
|
||||
- concept_geometry # Map refusal cone geometry
|
||||
- logit_lens # Find which layer decides to refuse
|
||||
- anti_ouroboros # Detect self-repair tendency
|
||||
- cross_layer # Cross-layer alignment clustering
|
||||
- causal_tracing # Causal necessity of components
|
||||
- residual_stream # Attention vs MLP contribution
|
||||
|
||||
# Output
|
||||
output:
|
||||
directory: "./analysis-results"
|
||||
save_plots: true # Generate matplotlib visualizations
|
||||
save_report: true # Generate markdown report
|
||||
@@ -0,0 +1,41 @@
|
||||
# OBLITERATUS Batch Abliteration Config
|
||||
# Abliterate multiple models with the same method for comparison.
|
||||
#
|
||||
# Run each one sequentially:
|
||||
# for model in models; do obliteratus obliterate $model --method informed; done
|
||||
#
|
||||
# Or use this as a reference for which models to process.
|
||||
|
||||
# Common settings
|
||||
defaults:
|
||||
method: "informed"
|
||||
quantization: "4bit"
|
||||
output_dir: "./abliterated-models"
|
||||
|
||||
# Models to process (grouped by compute tier)
|
||||
models:
|
||||
# Small (4-8 GB VRAM)
|
||||
small:
|
||||
- "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
- "microsoft/Phi-3.5-mini-instruct"
|
||||
- "meta-llama/Llama-3.2-3B-Instruct"
|
||||
|
||||
# Medium (8-16 GB VRAM)
|
||||
medium:
|
||||
- "meta-llama/Llama-3.1-8B-Instruct"
|
||||
- "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
- "google/gemma-2-9b-it"
|
||||
- "Qwen/Qwen2.5-7B-Instruct"
|
||||
|
||||
# Large (24 GB VRAM, 4-bit quantization)
|
||||
large:
|
||||
- "Qwen/Qwen2.5-14B-Instruct"
|
||||
- "Qwen/Qwen3-32B"
|
||||
- "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
|
||||
|
||||
# Per-model method overrides (optional)
|
||||
overrides:
|
||||
"deepseek-ai/DeepSeek-R1-Distill-Qwen-32B":
|
||||
method: "surgical" # CoT-aware for reasoning models
|
||||
"mistralai/Mixtral-8x7B-Instruct-v0.1":
|
||||
method: "nuclear" # Expert-granular for MoE models
|
||||
@@ -0,0 +1,655 @@
|
||||
---
|
||||
name: outlines
|
||||
description: "Outlines: structured JSON/regex/Pydantic LLM generation."
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [outlines, transformers, vllm, pydantic]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Prompt Engineering, Outlines, Structured Generation, JSON Schema, Pydantic, Local Models, Grammar-Based Generation, vLLM, Transformers, Type Safety]
|
||||
|
||||
---
|
||||
|
||||
# Outlines: Structured Text Generation
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use Outlines when you need to:
|
||||
- **Guarantee valid JSON/XML/code** structure during generation
|
||||
- **Use Pydantic models** for type-safe outputs
|
||||
- **Support local models** (Transformers, llama.cpp, vLLM)
|
||||
- **Maximize inference speed** with zero-overhead structured generation
|
||||
- **Generate against JSON schemas** automatically
|
||||
- **Control token sampling** at the grammar level
|
||||
|
||||
**GitHub Stars**: 8,000+ | **From**: dottxt.ai (formerly .txt)
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Base installation
|
||||
pip install outlines
|
||||
|
||||
# With specific backends
|
||||
pip install outlines transformers # Hugging Face models
|
||||
pip install outlines llama-cpp-python # llama.cpp
|
||||
pip install outlines vllm # vLLM for high-throughput
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic Example: Classification
|
||||
|
||||
```python
|
||||
import outlines
|
||||
from typing import Literal
|
||||
|
||||
# Load model
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
|
||||
# Generate with type constraint
|
||||
prompt = "Sentiment of 'This product is amazing!': "
|
||||
generator = outlines.generate.choice(model, ["positive", "negative", "neutral"])
|
||||
sentiment = generator(prompt)
|
||||
|
||||
print(sentiment) # "positive" (guaranteed one of these)
|
||||
```
|
||||
|
||||
### With Pydantic Models
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel
|
||||
import outlines
|
||||
|
||||
class User(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
email: str
|
||||
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
|
||||
# Generate structured output
|
||||
prompt = "Extract user: John Doe, 30 years old, john@example.com"
|
||||
generator = outlines.generate.json(model, User)
|
||||
user = generator(prompt)
|
||||
|
||||
print(user.name) # "John Doe"
|
||||
print(user.age) # 30
|
||||
print(user.email) # "john@example.com"
|
||||
```
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### 1. Constrained Token Sampling
|
||||
|
||||
Outlines uses Finite State Machines (FSM) to constrain token generation at the logit level.
|
||||
|
||||
**How it works:**
|
||||
1. Convert schema (JSON/Pydantic/regex) to context-free grammar (CFG)
|
||||
2. Transform CFG into Finite State Machine (FSM)
|
||||
3. Filter invalid tokens at each step during generation
|
||||
4. Fast-forward when only one valid token exists
|
||||
|
||||
**Benefits:**
|
||||
- **Zero overhead**: Filtering happens at token level
|
||||
- **Speed improvement**: Fast-forward through deterministic paths
|
||||
- **Guaranteed validity**: Invalid outputs impossible
|
||||
|
||||
```python
|
||||
import outlines
|
||||
|
||||
# Pydantic model -> JSON schema -> CFG -> FSM
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
|
||||
# Behind the scenes:
|
||||
# 1. Person -> JSON schema
|
||||
# 2. JSON schema -> CFG
|
||||
# 3. CFG -> FSM
|
||||
# 4. FSM filters tokens during generation
|
||||
|
||||
generator = outlines.generate.json(model, Person)
|
||||
result = generator("Generate person: Alice, 25")
|
||||
```
|
||||
|
||||
### 2. Structured Generators
|
||||
|
||||
Outlines provides specialized generators for different output types.
|
||||
|
||||
#### Choice Generator
|
||||
|
||||
```python
|
||||
# Multiple choice selection
|
||||
generator = outlines.generate.choice(
|
||||
model,
|
||||
["positive", "negative", "neutral"]
|
||||
)
|
||||
|
||||
sentiment = generator("Review: This is great!")
|
||||
# Result: One of the three choices
|
||||
```
|
||||
|
||||
#### JSON Generator
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel
|
||||
|
||||
class Product(BaseModel):
|
||||
name: str
|
||||
price: float
|
||||
in_stock: bool
|
||||
|
||||
# Generate valid JSON matching schema
|
||||
generator = outlines.generate.json(model, Product)
|
||||
product = generator("Extract: iPhone 15, $999, available")
|
||||
|
||||
# Guaranteed valid Product instance
|
||||
print(type(product)) # <class '__main__.Product'>
|
||||
```
|
||||
|
||||
#### Regex Generator
|
||||
|
||||
```python
|
||||
# Generate text matching regex
|
||||
generator = outlines.generate.regex(
|
||||
model,
|
||||
r"[0-9]{3}-[0-9]{3}-[0-9]{4}" # Phone number pattern
|
||||
)
|
||||
|
||||
phone = generator("Generate phone number:")
|
||||
# Result: "555-123-4567" (guaranteed to match pattern)
|
||||
```
|
||||
|
||||
#### Integer/Float Generators
|
||||
|
||||
```python
|
||||
# Generate specific numeric types
|
||||
int_generator = outlines.generate.integer(model)
|
||||
age = int_generator("Person's age:") # Guaranteed integer
|
||||
|
||||
float_generator = outlines.generate.float(model)
|
||||
price = float_generator("Product price:") # Guaranteed float
|
||||
```
|
||||
|
||||
### 3. Model Backends
|
||||
|
||||
Outlines supports multiple local and API-based backends.
|
||||
|
||||
#### Transformers (Hugging Face)
|
||||
|
||||
```python
|
||||
import outlines
|
||||
|
||||
# Load from Hugging Face
|
||||
model = outlines.models.transformers(
|
||||
"microsoft/Phi-3-mini-4k-instruct",
|
||||
device="cuda" # Or "cpu"
|
||||
)
|
||||
|
||||
# Use with any generator
|
||||
generator = outlines.generate.json(model, YourModel)
|
||||
```
|
||||
|
||||
#### llama.cpp
|
||||
|
||||
```python
|
||||
# Load GGUF model
|
||||
model = outlines.models.llamacpp(
|
||||
"./models/llama-3.1-8b-instruct.Q4_K_M.gguf",
|
||||
n_gpu_layers=35
|
||||
)
|
||||
|
||||
generator = outlines.generate.json(model, YourModel)
|
||||
```
|
||||
|
||||
#### vLLM (High Throughput)
|
||||
|
||||
```python
|
||||
# For production deployments
|
||||
model = outlines.models.vllm(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
tensor_parallel_size=2 # Multi-GPU
|
||||
)
|
||||
|
||||
generator = outlines.generate.json(model, YourModel)
|
||||
```
|
||||
|
||||
#### OpenAI (Limited Support)
|
||||
|
||||
```python
|
||||
# Basic OpenAI support
|
||||
model = outlines.models.openai(
|
||||
"gpt-4o-mini",
|
||||
api_key="your-api-key"
|
||||
)
|
||||
|
||||
# Note: Some features limited with API models
|
||||
generator = outlines.generate.json(model, YourModel)
|
||||
```
|
||||
|
||||
### 4. Pydantic Integration
|
||||
|
||||
Outlines has first-class Pydantic support with automatic schema translation.
|
||||
|
||||
#### Basic Models
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class Article(BaseModel):
|
||||
title: str = Field(description="Article title")
|
||||
author: str = Field(description="Author name")
|
||||
word_count: int = Field(description="Number of words", gt=0)
|
||||
tags: list[str] = Field(description="List of tags")
|
||||
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, Article)
|
||||
|
||||
article = generator("Generate article about AI")
|
||||
print(article.title)
|
||||
print(article.word_count) # Guaranteed > 0
|
||||
```
|
||||
|
||||
#### Nested Models
|
||||
|
||||
```python
|
||||
class Address(BaseModel):
|
||||
street: str
|
||||
city: str
|
||||
country: str
|
||||
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
address: Address # Nested model
|
||||
|
||||
generator = outlines.generate.json(model, Person)
|
||||
person = generator("Generate person in New York")
|
||||
|
||||
print(person.address.city) # "New York"
|
||||
```
|
||||
|
||||
#### Enums and Literals
|
||||
|
||||
```python
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
class Status(str, Enum):
|
||||
PENDING = "pending"
|
||||
APPROVED = "approved"
|
||||
REJECTED = "rejected"
|
||||
|
||||
class Application(BaseModel):
|
||||
applicant: str
|
||||
status: Status # Must be one of enum values
|
||||
priority: Literal["low", "medium", "high"] # Must be one of literals
|
||||
|
||||
generator = outlines.generate.json(model, Application)
|
||||
app = generator("Generate application")
|
||||
|
||||
print(app.status) # Status.PENDING (or APPROVED/REJECTED)
|
||||
```
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Pattern 1: Data Extraction
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel
|
||||
import outlines
|
||||
|
||||
class CompanyInfo(BaseModel):
|
||||
name: str
|
||||
founded_year: int
|
||||
industry: str
|
||||
employees: int
|
||||
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, CompanyInfo)
|
||||
|
||||
text = """
|
||||
Apple Inc. was founded in 1976 in the technology industry.
|
||||
The company employs approximately 164,000 people worldwide.
|
||||
"""
|
||||
|
||||
prompt = f"Extract company information:\n{text}\n\nCompany:"
|
||||
company = generator(prompt)
|
||||
|
||||
print(f"Name: {company.name}")
|
||||
print(f"Founded: {company.founded_year}")
|
||||
print(f"Industry: {company.industry}")
|
||||
print(f"Employees: {company.employees}")
|
||||
```
|
||||
|
||||
### Pattern 2: Classification
|
||||
|
||||
```python
|
||||
from typing import Literal
|
||||
import outlines
|
||||
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
|
||||
# Binary classification
|
||||
generator = outlines.generate.choice(model, ["spam", "not_spam"])
|
||||
result = generator("Email: Buy now! 50% off!")
|
||||
|
||||
# Multi-class classification
|
||||
categories = ["technology", "business", "sports", "entertainment"]
|
||||
category_gen = outlines.generate.choice(model, categories)
|
||||
category = category_gen("Article: Apple announces new iPhone...")
|
||||
|
||||
# With confidence
|
||||
class Classification(BaseModel):
|
||||
label: Literal["positive", "negative", "neutral"]
|
||||
confidence: float
|
||||
|
||||
classifier = outlines.generate.json(model, Classification)
|
||||
result = classifier("Review: This product is okay, nothing special")
|
||||
```
|
||||
|
||||
### Pattern 3: Structured Forms
|
||||
|
||||
```python
|
||||
class UserProfile(BaseModel):
|
||||
full_name: str
|
||||
age: int
|
||||
email: str
|
||||
phone: str
|
||||
country: str
|
||||
interests: list[str]
|
||||
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, UserProfile)
|
||||
|
||||
prompt = """
|
||||
Extract user profile from:
|
||||
Name: Alice Johnson
|
||||
Age: 28
|
||||
Email: alice@example.com
|
||||
Phone: 555-0123
|
||||
Country: USA
|
||||
Interests: hiking, photography, cooking
|
||||
"""
|
||||
|
||||
profile = generator(prompt)
|
||||
print(profile.full_name)
|
||||
print(profile.interests) # ["hiking", "photography", "cooking"]
|
||||
```
|
||||
|
||||
### Pattern 4: Multi-Entity Extraction
|
||||
|
||||
```python
|
||||
class Entity(BaseModel):
|
||||
name: str
|
||||
type: Literal["PERSON", "ORGANIZATION", "LOCATION"]
|
||||
|
||||
class DocumentEntities(BaseModel):
|
||||
entities: list[Entity]
|
||||
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, DocumentEntities)
|
||||
|
||||
text = "Tim Cook met with Satya Nadella at Microsoft headquarters in Redmond."
|
||||
prompt = f"Extract entities from: {text}"
|
||||
|
||||
result = generator(prompt)
|
||||
for entity in result.entities:
|
||||
print(f"{entity.name} ({entity.type})")
|
||||
```
|
||||
|
||||
### Pattern 5: Code Generation
|
||||
|
||||
```python
|
||||
class PythonFunction(BaseModel):
|
||||
function_name: str
|
||||
parameters: list[str]
|
||||
docstring: str
|
||||
body: str
|
||||
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, PythonFunction)
|
||||
|
||||
prompt = "Generate a Python function to calculate factorial"
|
||||
func = generator(prompt)
|
||||
|
||||
print(f"def {func.function_name}({', '.join(func.parameters)}):")
|
||||
print(f' """{func.docstring}"""')
|
||||
print(f" {func.body}")
|
||||
```
|
||||
|
||||
### Pattern 6: Batch Processing
|
||||
|
||||
```python
|
||||
def batch_extract(texts: list[str], schema: type[BaseModel]):
|
||||
"""Extract structured data from multiple texts."""
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, schema)
|
||||
|
||||
results = []
|
||||
for text in texts:
|
||||
result = generator(f"Extract from: {text}")
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
texts = [
|
||||
"John is 30 years old",
|
||||
"Alice is 25 years old",
|
||||
"Bob is 40 years old"
|
||||
]
|
||||
|
||||
people = batch_extract(texts, Person)
|
||||
for person in people:
|
||||
print(f"{person.name}: {person.age}")
|
||||
```
|
||||
|
||||
## Backend Configuration
|
||||
|
||||
### Transformers
|
||||
|
||||
```python
|
||||
import outlines
|
||||
|
||||
# Basic usage
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
|
||||
# GPU configuration
|
||||
model = outlines.models.transformers(
|
||||
"microsoft/Phi-3-mini-4k-instruct",
|
||||
device="cuda",
|
||||
model_kwargs={"torch_dtype": "float16"}
|
||||
)
|
||||
|
||||
# Popular models
|
||||
model = outlines.models.transformers("meta-llama/Llama-3.1-8B-Instruct")
|
||||
model = outlines.models.transformers("mistralai/Mistral-7B-Instruct-v0.3")
|
||||
model = outlines.models.transformers("Qwen/Qwen2.5-7B-Instruct")
|
||||
```
|
||||
|
||||
### llama.cpp
|
||||
|
||||
```python
|
||||
# Load GGUF model
|
||||
model = outlines.models.llamacpp(
|
||||
"./models/llama-3.1-8b.Q4_K_M.gguf",
|
||||
n_ctx=4096, # Context window
|
||||
n_gpu_layers=35, # GPU layers
|
||||
n_threads=8 # CPU threads
|
||||
)
|
||||
|
||||
# Full GPU offload
|
||||
model = outlines.models.llamacpp(
|
||||
"./models/model.gguf",
|
||||
n_gpu_layers=-1 # All layers on GPU
|
||||
)
|
||||
```
|
||||
|
||||
### vLLM (Production)
|
||||
|
||||
```python
|
||||
# Single GPU
|
||||
model = outlines.models.vllm("meta-llama/Llama-3.1-8B-Instruct")
|
||||
|
||||
# Multi-GPU
|
||||
model = outlines.models.vllm(
|
||||
"meta-llama/Llama-3.1-70B-Instruct",
|
||||
tensor_parallel_size=4 # 4 GPUs
|
||||
)
|
||||
|
||||
# With quantization
|
||||
model = outlines.models.vllm(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
quantization="awq" # Or "gptq"
|
||||
)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Use Specific Types
|
||||
|
||||
```python
|
||||
# ✅ Good: Specific types
|
||||
class Product(BaseModel):
|
||||
name: str
|
||||
price: float # Not str
|
||||
quantity: int # Not str
|
||||
in_stock: bool # Not str
|
||||
|
||||
# ❌ Bad: Everything as string
|
||||
class Product(BaseModel):
|
||||
name: str
|
||||
price: str # Should be float
|
||||
quantity: str # Should be int
|
||||
```
|
||||
|
||||
### 2. Add Constraints
|
||||
|
||||
```python
|
||||
from pydantic import Field
|
||||
|
||||
# ✅ Good: With constraints
|
||||
class User(BaseModel):
|
||||
name: str = Field(min_length=1, max_length=100)
|
||||
age: int = Field(ge=0, le=120)
|
||||
email: str = Field(pattern=r"^[\w\.-]+@[\w\.-]+\.\w+$")
|
||||
|
||||
# ❌ Bad: No constraints
|
||||
class User(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
email: str
|
||||
```
|
||||
|
||||
### 3. Use Enums for Categories
|
||||
|
||||
```python
|
||||
# ✅ Good: Enum for fixed set
|
||||
class Priority(str, Enum):
|
||||
LOW = "low"
|
||||
MEDIUM = "medium"
|
||||
HIGH = "high"
|
||||
|
||||
class Task(BaseModel):
|
||||
title: str
|
||||
priority: Priority
|
||||
|
||||
# ❌ Bad: Free-form string
|
||||
class Task(BaseModel):
|
||||
title: str
|
||||
priority: str # Can be anything
|
||||
```
|
||||
|
||||
### 4. Provide Context in Prompts
|
||||
|
||||
```python
|
||||
# ✅ Good: Clear context
|
||||
prompt = """
|
||||
Extract product information from the following text.
|
||||
Text: iPhone 15 Pro costs $999 and is currently in stock.
|
||||
Product:
|
||||
"""
|
||||
|
||||
# ❌ Bad: Minimal context
|
||||
prompt = "iPhone 15 Pro costs $999 and is currently in stock."
|
||||
```
|
||||
|
||||
### 5. Handle Optional Fields
|
||||
|
||||
```python
|
||||
from typing import Optional
|
||||
|
||||
# ✅ Good: Optional fields for incomplete data
|
||||
class Article(BaseModel):
|
||||
title: str # Required
|
||||
author: Optional[str] = None # Optional
|
||||
date: Optional[str] = None # Optional
|
||||
tags: list[str] = [] # Default empty list
|
||||
|
||||
# Can succeed even if author/date missing
|
||||
```
|
||||
|
||||
## Comparison to Alternatives
|
||||
|
||||
| Feature | Outlines | Instructor | Guidance | LMQL |
|
||||
|---------|----------|------------|----------|------|
|
||||
| Pydantic Support | ✅ Native | ✅ Native | ❌ No | ❌ No |
|
||||
| JSON Schema | ✅ Yes | ✅ Yes | ⚠️ Limited | ✅ Yes |
|
||||
| Regex Constraints | ✅ Yes | ❌ No | ✅ Yes | ✅ Yes |
|
||||
| Local Models | ✅ Full | ⚠️ Limited | ✅ Full | ✅ Full |
|
||||
| API Models | ⚠️ Limited | ✅ Full | ✅ Full | ✅ Full |
|
||||
| Zero Overhead | ✅ Yes | ❌ No | ⚠️ Partial | ✅ Yes |
|
||||
| Automatic Retrying | ❌ No | ✅ Yes | ❌ No | ❌ No |
|
||||
| Learning Curve | Low | Low | Low | High |
|
||||
|
||||
**When to choose Outlines:**
|
||||
- Using local models (Transformers, llama.cpp, vLLM)
|
||||
- Need maximum inference speed
|
||||
- Want Pydantic model support
|
||||
- Require zero-overhead structured generation
|
||||
- Control token sampling process
|
||||
|
||||
**When to choose alternatives:**
|
||||
- Instructor: Need API models with automatic retrying
|
||||
- Guidance: Need token healing and complex workflows
|
||||
- LMQL: Prefer declarative query syntax
|
||||
|
||||
## Performance Characteristics
|
||||
|
||||
**Speed:**
|
||||
- **Zero overhead**: Structured generation as fast as unconstrained
|
||||
- **Fast-forward optimization**: Skips deterministic tokens
|
||||
- **1.2-2x faster** than post-generation validation approaches
|
||||
|
||||
**Memory:**
|
||||
- FSM compiled once per schema (cached)
|
||||
- Minimal runtime overhead
|
||||
- Efficient with vLLM for high throughput
|
||||
|
||||
**Accuracy:**
|
||||
- **100% valid outputs** (guaranteed by FSM)
|
||||
- No retry loops needed
|
||||
- Deterministic token filtering
|
||||
|
||||
## Resources
|
||||
|
||||
- **Documentation**: https://outlines-dev.github.io/outlines
|
||||
- **GitHub**: https://github.com/outlines-dev/outlines (8k+ stars)
|
||||
- **Discord**: https://discord.gg/R9DSu34mGd
|
||||
- **Blog**: https://blog.dottxt.co
|
||||
|
||||
## See Also
|
||||
|
||||
- `references/json_generation.md` - Comprehensive JSON and Pydantic patterns
|
||||
- `references/backends.md` - Backend-specific configuration
|
||||
- `references/examples.md` - Production-ready examples
|
||||
|
||||
|
||||
@@ -0,0 +1,615 @@
|
||||
# Backend Configuration Guide
|
||||
|
||||
Complete guide to configuring Outlines with different model backends.
|
||||
|
||||
## Table of Contents
|
||||
- Local Models (Transformers, llama.cpp, vLLM)
|
||||
- API Models (OpenAI)
|
||||
- Performance Comparison
|
||||
- Configuration Examples
|
||||
- Production Deployment
|
||||
|
||||
## Transformers (Hugging Face)
|
||||
|
||||
### Basic Setup
|
||||
|
||||
```python
|
||||
import outlines
|
||||
|
||||
# Load model from Hugging Face
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
|
||||
# Use with generator
|
||||
generator = outlines.generate.json(model, YourModel)
|
||||
result = generator("Your prompt")
|
||||
```
|
||||
|
||||
### GPU Configuration
|
||||
|
||||
```python
|
||||
# Use CUDA GPU
|
||||
model = outlines.models.transformers(
|
||||
"microsoft/Phi-3-mini-4k-instruct",
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# Use specific GPU
|
||||
model = outlines.models.transformers(
|
||||
"microsoft/Phi-3-mini-4k-instruct",
|
||||
device="cuda:0" # GPU 0
|
||||
)
|
||||
|
||||
# Use CPU
|
||||
model = outlines.models.transformers(
|
||||
"microsoft/Phi-3-mini-4k-instruct",
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# Use Apple Silicon MPS
|
||||
model = outlines.models.transformers(
|
||||
"microsoft/Phi-3-mini-4k-instruct",
|
||||
device="mps"
|
||||
)
|
||||
```
|
||||
|
||||
### Advanced Configuration
|
||||
|
||||
```python
|
||||
# FP16 for faster inference
|
||||
model = outlines.models.transformers(
|
||||
"microsoft/Phi-3-mini-4k-instruct",
|
||||
device="cuda",
|
||||
model_kwargs={
|
||||
"torch_dtype": "float16"
|
||||
}
|
||||
)
|
||||
|
||||
# 8-bit quantization (less memory)
|
||||
model = outlines.models.transformers(
|
||||
"microsoft/Phi-3-mini-4k-instruct",
|
||||
device="cuda",
|
||||
model_kwargs={
|
||||
"load_in_8bit": True,
|
||||
"device_map": "auto"
|
||||
}
|
||||
)
|
||||
|
||||
# 4-bit quantization (even less memory)
|
||||
model = outlines.models.transformers(
|
||||
"meta-llama/Llama-3.1-70B-Instruct",
|
||||
device="cuda",
|
||||
model_kwargs={
|
||||
"load_in_4bit": True,
|
||||
"device_map": "auto",
|
||||
"bnb_4bit_compute_dtype": "float16"
|
||||
}
|
||||
)
|
||||
|
||||
# Multi-GPU
|
||||
model = outlines.models.transformers(
|
||||
"meta-llama/Llama-3.1-70B-Instruct",
|
||||
device="cuda",
|
||||
model_kwargs={
|
||||
"device_map": "auto", # Automatic GPU distribution
|
||||
"max_memory": {0: "40GB", 1: "40GB"} # Per-GPU limits
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### Popular Models
|
||||
|
||||
```python
|
||||
# Phi-4 (Microsoft)
|
||||
model = outlines.models.transformers("microsoft/Phi-4-mini-instruct")
|
||||
model = outlines.models.transformers("microsoft/Phi-3-medium-4k-instruct")
|
||||
|
||||
# Llama 3.1 (Meta)
|
||||
model = outlines.models.transformers("meta-llama/Llama-3.1-8B-Instruct")
|
||||
model = outlines.models.transformers("meta-llama/Llama-3.1-70B-Instruct")
|
||||
model = outlines.models.transformers("meta-llama/Llama-3.1-405B-Instruct")
|
||||
|
||||
# Mistral (Mistral AI)
|
||||
model = outlines.models.transformers("mistralai/Mistral-7B-Instruct-v0.3")
|
||||
model = outlines.models.transformers("mistralai/Mixtral-8x7B-Instruct-v0.1")
|
||||
model = outlines.models.transformers("mistralai/Mixtral-8x22B-Instruct-v0.1")
|
||||
|
||||
# Qwen (Alibaba)
|
||||
model = outlines.models.transformers("Qwen/Qwen2.5-7B-Instruct")
|
||||
model = outlines.models.transformers("Qwen/Qwen2.5-14B-Instruct")
|
||||
model = outlines.models.transformers("Qwen/Qwen2.5-72B-Instruct")
|
||||
|
||||
# Gemma (Google)
|
||||
model = outlines.models.transformers("google/gemma-2-9b-it")
|
||||
model = outlines.models.transformers("google/gemma-2-27b-it")
|
||||
|
||||
# Llava (Vision)
|
||||
model = outlines.models.transformers("llava-hf/llava-v1.6-mistral-7b-hf")
|
||||
```
|
||||
|
||||
### Custom Model Loading
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
import outlines
|
||||
|
||||
# Load model manually
|
||||
tokenizer = AutoTokenizer.from_pretrained("your-model")
|
||||
model_hf = AutoModelForCausalLM.from_pretrained(
|
||||
"your-model",
|
||||
device_map="auto",
|
||||
torch_dtype="float16"
|
||||
)
|
||||
|
||||
# Use with Outlines
|
||||
model = outlines.models.transformers(
|
||||
model=model_hf,
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
```
|
||||
|
||||
## llama.cpp
|
||||
|
||||
### Basic Setup
|
||||
|
||||
```python
|
||||
import outlines
|
||||
|
||||
# Load GGUF model
|
||||
model = outlines.models.llamacpp(
|
||||
"./models/llama-3.1-8b-instruct.Q4_K_M.gguf",
|
||||
n_ctx=4096 # Context window
|
||||
)
|
||||
|
||||
# Use with generator
|
||||
generator = outlines.generate.json(model, YourModel)
|
||||
```
|
||||
|
||||
### GPU Configuration
|
||||
|
||||
```python
|
||||
# CPU only
|
||||
model = outlines.models.llamacpp(
|
||||
"./models/model.gguf",
|
||||
n_ctx=4096,
|
||||
n_threads=8 # Use 8 CPU threads
|
||||
)
|
||||
|
||||
# GPU offload (partial)
|
||||
model = outlines.models.llamacpp(
|
||||
"./models/model.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=35, # Offload 35 layers to GPU
|
||||
n_threads=4 # CPU threads for remaining layers
|
||||
)
|
||||
|
||||
# Full GPU offload
|
||||
model = outlines.models.llamacpp(
|
||||
"./models/model.gguf",
|
||||
n_ctx=8192,
|
||||
n_gpu_layers=-1 # All layers on GPU
|
||||
)
|
||||
```
|
||||
|
||||
### Advanced Configuration
|
||||
|
||||
```python
|
||||
model = outlines.models.llamacpp(
|
||||
"./models/llama-3.1-8b.Q4_K_M.gguf",
|
||||
n_ctx=8192, # Context window (tokens)
|
||||
n_gpu_layers=35, # GPU layers
|
||||
n_threads=8, # CPU threads
|
||||
n_batch=512, # Batch size for prompt processing
|
||||
use_mmap=True, # Memory-map model file (faster loading)
|
||||
use_mlock=False, # Lock model in RAM (prevents swapping)
|
||||
seed=42, # Random seed for reproducibility
|
||||
verbose=False # Suppress verbose output
|
||||
)
|
||||
```
|
||||
|
||||
### Quantization Formats
|
||||
|
||||
```python
|
||||
# Q4_K_M (4-bit, recommended for most cases)
|
||||
# - Size: ~4.5GB for 7B model
|
||||
# - Quality: Good
|
||||
# - Speed: Fast
|
||||
model = outlines.models.llamacpp("./models/model.Q4_K_M.gguf")
|
||||
|
||||
# Q5_K_M (5-bit, better quality)
|
||||
# - Size: ~5.5GB for 7B model
|
||||
# - Quality: Very good
|
||||
# - Speed: Slightly slower than Q4
|
||||
model = outlines.models.llamacpp("./models/model.Q5_K_M.gguf")
|
||||
|
||||
# Q6_K (6-bit, high quality)
|
||||
# - Size: ~6.5GB for 7B model
|
||||
# - Quality: Excellent
|
||||
# - Speed: Slower than Q5
|
||||
model = outlines.models.llamacpp("./models/model.Q6_K.gguf")
|
||||
|
||||
# Q8_0 (8-bit, near-original quality)
|
||||
# - Size: ~8GB for 7B model
|
||||
# - Quality: Near FP16
|
||||
# - Speed: Slower than Q6
|
||||
model = outlines.models.llamacpp("./models/model.Q8_0.gguf")
|
||||
|
||||
# F16 (16-bit float, original quality)
|
||||
# - Size: ~14GB for 7B model
|
||||
# - Quality: Original
|
||||
# - Speed: Slowest
|
||||
model = outlines.models.llamacpp("./models/model.F16.gguf")
|
||||
```
|
||||
|
||||
### Popular GGUF Models
|
||||
|
||||
```python
|
||||
# Llama 3.1
|
||||
model = outlines.models.llamacpp("llama-3.1-8b-instruct.Q4_K_M.gguf")
|
||||
model = outlines.models.llamacpp("llama-3.1-70b-instruct.Q4_K_M.gguf")
|
||||
|
||||
# Mistral
|
||||
model = outlines.models.llamacpp("mistral-7b-instruct-v0.3.Q4_K_M.gguf")
|
||||
|
||||
# Phi-4
|
||||
model = outlines.models.llamacpp("phi-4-mini-instruct.Q4_K_M.gguf")
|
||||
|
||||
# Qwen
|
||||
model = outlines.models.llamacpp("qwen2.5-7b-instruct.Q4_K_M.gguf")
|
||||
```
|
||||
|
||||
### Apple Silicon Optimization
|
||||
|
||||
```python
|
||||
# Optimized for M1/M2/M3 Macs
|
||||
model = outlines.models.llamacpp(
|
||||
"./models/llama-3.1-8b.Q4_K_M.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=-1, # Use Metal GPU acceleration
|
||||
use_mmap=True, # Efficient memory mapping
|
||||
n_threads=8 # Use performance cores
|
||||
)
|
||||
```
|
||||
|
||||
## vLLM (Production)
|
||||
|
||||
### Basic Setup
|
||||
|
||||
```python
|
||||
import outlines
|
||||
|
||||
# Load model with vLLM
|
||||
model = outlines.models.vllm("meta-llama/Llama-3.1-8B-Instruct")
|
||||
|
||||
# Use with generator
|
||||
generator = outlines.generate.json(model, YourModel)
|
||||
```
|
||||
|
||||
### Single GPU
|
||||
|
||||
```python
|
||||
model = outlines.models.vllm(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
gpu_memory_utilization=0.9, # Use 90% of GPU memory
|
||||
max_model_len=4096 # Max sequence length
|
||||
)
|
||||
```
|
||||
|
||||
### Multi-GPU
|
||||
|
||||
```python
|
||||
# Tensor parallelism (split model across GPUs)
|
||||
model = outlines.models.vllm(
|
||||
"meta-llama/Llama-3.1-70B-Instruct",
|
||||
tensor_parallel_size=4, # Use 4 GPUs
|
||||
gpu_memory_utilization=0.9
|
||||
)
|
||||
|
||||
# Pipeline parallelism (rare, for very large models)
|
||||
model = outlines.models.vllm(
|
||||
"meta-llama/Llama-3.1-405B-Instruct",
|
||||
pipeline_parallel_size=8, # 8-GPU pipeline
|
||||
tensor_parallel_size=4 # 4-GPU tensor split
|
||||
# Total: 32 GPUs
|
||||
)
|
||||
```
|
||||
|
||||
### Quantization
|
||||
|
||||
```python
|
||||
# AWQ quantization (4-bit)
|
||||
model = outlines.models.vllm(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
quantization="awq",
|
||||
dtype="float16"
|
||||
)
|
||||
|
||||
# GPTQ quantization (4-bit)
|
||||
model = outlines.models.vllm(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
quantization="gptq"
|
||||
)
|
||||
|
||||
# SqueezeLLM quantization
|
||||
model = outlines.models.vllm(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
quantization="squeezellm"
|
||||
)
|
||||
```
|
||||
|
||||
### Advanced Configuration
|
||||
|
||||
```python
|
||||
model = outlines.models.vllm(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
tensor_parallel_size=1,
|
||||
gpu_memory_utilization=0.9,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=256, # Max concurrent sequences
|
||||
max_num_batched_tokens=8192, # Max tokens per batch
|
||||
dtype="float16",
|
||||
trust_remote_code=True,
|
||||
enforce_eager=False, # Use CUDA graphs (faster)
|
||||
swap_space=4 # CPU swap space (GB)
|
||||
)
|
||||
```
|
||||
|
||||
### Batch Processing
|
||||
|
||||
```python
|
||||
# vLLM optimized for high-throughput batch processing
|
||||
model = outlines.models.vllm(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
max_num_seqs=128 # Process 128 sequences in parallel
|
||||
)
|
||||
|
||||
generator = outlines.generate.json(model, YourModel)
|
||||
|
||||
# Process many prompts efficiently
|
||||
prompts = ["prompt1", "prompt2", ..., "prompt100"]
|
||||
results = [generator(p) for p in prompts]
|
||||
# vLLM automatically batches and optimizes
|
||||
```
|
||||
|
||||
## OpenAI (Limited Support)
|
||||
|
||||
### Basic Setup
|
||||
|
||||
```python
|
||||
import outlines
|
||||
|
||||
# Basic OpenAI support
|
||||
model = outlines.models.openai("gpt-4o-mini", api_key="your-api-key")
|
||||
|
||||
# Use with generator
|
||||
generator = outlines.generate.json(model, YourModel)
|
||||
result = generator("Your prompt")
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
```python
|
||||
model = outlines.models.openai(
|
||||
"gpt-4o-mini",
|
||||
api_key="your-api-key", # Or set OPENAI_API_KEY env var
|
||||
max_tokens=2048,
|
||||
temperature=0.7
|
||||
)
|
||||
```
|
||||
|
||||
### Available Models
|
||||
|
||||
```python
|
||||
# GPT-4o (latest)
|
||||
model = outlines.models.openai("gpt-4o")
|
||||
|
||||
# GPT-4o Mini (cost-effective)
|
||||
model = outlines.models.openai("gpt-4o-mini")
|
||||
|
||||
# GPT-4 Turbo
|
||||
model = outlines.models.openai("gpt-4-turbo")
|
||||
|
||||
# GPT-3.5 Turbo
|
||||
model = outlines.models.openai("gpt-3.5-turbo")
|
||||
```
|
||||
|
||||
**Note**: OpenAI support is limited compared to local models. Some advanced features may not work.
|
||||
|
||||
## Backend Comparison
|
||||
|
||||
### Feature Matrix
|
||||
|
||||
| Feature | Transformers | llama.cpp | vLLM | OpenAI |
|
||||
|---------|-------------|-----------|------|--------|
|
||||
| Structured Generation | ✅ Full | ✅ Full | ✅ Full | ⚠️ Limited |
|
||||
| FSM Optimization | ✅ Yes | ✅ Yes | ✅ Yes | ❌ No |
|
||||
| GPU Support | ✅ Yes | ✅ Yes | ✅ Yes | N/A |
|
||||
| Multi-GPU | ✅ Yes | ✅ Yes | ✅ Yes | N/A |
|
||||
| Quantization | ✅ Yes | ✅ Yes | ✅ Yes | N/A |
|
||||
| High Throughput | ⚠️ Medium | ⚠️ Medium | ✅ Excellent | ⚠️ API-limited |
|
||||
| Setup Difficulty | Easy | Medium | Medium | Easy |
|
||||
| Cost | Hardware | Hardware | Hardware | API usage |
|
||||
|
||||
### Performance Characteristics
|
||||
|
||||
**Transformers:**
|
||||
- **Latency**: 50-200ms (single request, GPU)
|
||||
- **Throughput**: 10-50 tokens/sec (depends on hardware)
|
||||
- **Memory**: 2-4GB per 1B parameters (FP16)
|
||||
- **Best for**: Development, small-scale deployment, flexibility
|
||||
|
||||
**llama.cpp:**
|
||||
- **Latency**: 30-150ms (single request)
|
||||
- **Throughput**: 20-150 tokens/sec (depends on quantization)
|
||||
- **Memory**: 0.5-2GB per 1B parameters (Q4-Q8)
|
||||
- **Best for**: CPU inference, Apple Silicon, edge deployment, low memory
|
||||
|
||||
**vLLM:**
|
||||
- **Latency**: 30-100ms (single request)
|
||||
- **Throughput**: 100-1000+ tokens/sec (batch processing)
|
||||
- **Memory**: 2-4GB per 1B parameters (FP16)
|
||||
- **Best for**: Production, high-throughput, batch processing, serving
|
||||
|
||||
**OpenAI:**
|
||||
- **Latency**: 200-500ms (API call)
|
||||
- **Throughput**: API rate limits
|
||||
- **Memory**: N/A (cloud-based)
|
||||
- **Best for**: Quick prototyping, no infrastructure
|
||||
|
||||
### Memory Requirements
|
||||
|
||||
**7B Model:**
|
||||
- FP16: ~14GB
|
||||
- 8-bit: ~7GB
|
||||
- 4-bit: ~4GB
|
||||
- Q4_K_M (GGUF): ~4.5GB
|
||||
|
||||
**13B Model:**
|
||||
- FP16: ~26GB
|
||||
- 8-bit: ~13GB
|
||||
- 4-bit: ~7GB
|
||||
- Q4_K_M (GGUF): ~8GB
|
||||
|
||||
**70B Model:**
|
||||
- FP16: ~140GB (multi-GPU)
|
||||
- 8-bit: ~70GB (multi-GPU)
|
||||
- 4-bit: ~35GB (single A100/H100)
|
||||
- Q4_K_M (GGUF): ~40GB
|
||||
|
||||
## Performance Tuning
|
||||
|
||||
### Transformers Optimization
|
||||
|
||||
```python
|
||||
# Use FP16
|
||||
model = outlines.models.transformers(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
device="cuda",
|
||||
model_kwargs={"torch_dtype": "float16"}
|
||||
)
|
||||
|
||||
# Use flash attention (2-4x faster)
|
||||
model = outlines.models.transformers(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
device="cuda",
|
||||
model_kwargs={
|
||||
"torch_dtype": "float16",
|
||||
"use_flash_attention_2": True
|
||||
}
|
||||
)
|
||||
|
||||
# Use 8-bit quantization (2x less memory)
|
||||
model = outlines.models.transformers(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
device="cuda",
|
||||
model_kwargs={
|
||||
"load_in_8bit": True,
|
||||
"device_map": "auto"
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### llama.cpp Optimization
|
||||
|
||||
```python
|
||||
# Maximize GPU usage
|
||||
model = outlines.models.llamacpp(
|
||||
"./models/model.Q4_K_M.gguf",
|
||||
n_gpu_layers=-1, # All layers on GPU
|
||||
n_ctx=8192,
|
||||
n_batch=512 # Larger batch = faster
|
||||
)
|
||||
|
||||
# Optimize for CPU (Apple Silicon)
|
||||
model = outlines.models.llamacpp(
|
||||
"./models/model.Q4_K_M.gguf",
|
||||
n_ctx=4096,
|
||||
n_threads=8, # Use all performance cores
|
||||
use_mmap=True
|
||||
)
|
||||
```
|
||||
|
||||
### vLLM Optimization
|
||||
|
||||
```python
|
||||
# High throughput
|
||||
model = outlines.models.vllm(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
gpu_memory_utilization=0.95, # Use 95% of GPU
|
||||
max_num_seqs=256, # High concurrency
|
||||
enforce_eager=False # Use CUDA graphs
|
||||
)
|
||||
|
||||
# Multi-GPU
|
||||
model = outlines.models.vllm(
|
||||
"meta-llama/Llama-3.1-70B-Instruct",
|
||||
tensor_parallel_size=4, # 4 GPUs
|
||||
gpu_memory_utilization=0.9
|
||||
)
|
||||
```
|
||||
|
||||
## Production Deployment
|
||||
|
||||
### Docker with vLLM
|
||||
|
||||
```dockerfile
|
||||
FROM vllm/vllm-openai:latest
|
||||
|
||||
# Install outlines
|
||||
RUN pip install outlines
|
||||
|
||||
# Copy your code
|
||||
COPY app.py /app/
|
||||
|
||||
# Run
|
||||
CMD ["python", "/app/app.py"]
|
||||
```
|
||||
|
||||
### Environment Variables
|
||||
|
||||
```bash
|
||||
# Transformers cache
|
||||
export HF_HOME="/path/to/cache"
|
||||
export TRANSFORMERS_CACHE="/path/to/cache"
|
||||
|
||||
# GPU selection
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
|
||||
# OpenAI API key
|
||||
export OPENAI_API_KEY="sk-..."
|
||||
|
||||
# Disable tokenizers parallelism warning
|
||||
export TOKENIZERS_PARALLELISM=false
|
||||
```
|
||||
|
||||
### Model Serving
|
||||
|
||||
```python
|
||||
# Simple HTTP server with vLLM
|
||||
import outlines
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Load model once at startup
|
||||
model = outlines.models.vllm("meta-llama/Llama-3.1-8B-Instruct")
|
||||
|
||||
class User(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
email: str
|
||||
|
||||
generator = outlines.generate.json(model, User)
|
||||
|
||||
@app.post("/extract")
|
||||
def extract(text: str):
|
||||
result = generator(f"Extract user from: {text}")
|
||||
return result.model_dump()
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **Transformers**: https://huggingface.co/docs/transformers
|
||||
- **llama.cpp**: https://github.com/ggerganov/llama.cpp
|
||||
- **vLLM**: https://docs.vllm.ai
|
||||
- **Outlines**: https://github.com/outlines-dev/outlines
|
||||
@@ -0,0 +1,773 @@
|
||||
# Production-Ready Examples
|
||||
|
||||
Real-world examples of using Outlines for structured generation in production systems.
|
||||
|
||||
## Table of Contents
|
||||
- Data Extraction
|
||||
- Classification Systems
|
||||
- Form Processing
|
||||
- Multi-Entity Extraction
|
||||
- Code Generation
|
||||
- Batch Processing
|
||||
- Production Patterns
|
||||
|
||||
## Data Extraction
|
||||
|
||||
### Basic Information Extraction
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, Field
|
||||
import outlines
|
||||
|
||||
class PersonInfo(BaseModel):
|
||||
name: str = Field(description="Full name")
|
||||
age: int = Field(ge=0, le=120)
|
||||
occupation: str
|
||||
email: str = Field(pattern=r"^[\w\.-]+@[\w\.-]+\.\w+$")
|
||||
location: str
|
||||
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, PersonInfo)
|
||||
|
||||
text = """
|
||||
Dr. Sarah Johnson is a 42-year-old research scientist at MIT.
|
||||
She can be reached at sarah.j@mit.edu and currently lives in Cambridge, MA.
|
||||
"""
|
||||
|
||||
prompt = f"Extract person information from:\n{text}\n\nPerson:"
|
||||
person = generator(prompt)
|
||||
|
||||
print(f"Name: {person.name}")
|
||||
print(f"Age: {person.age}")
|
||||
print(f"Occupation: {person.occupation}")
|
||||
print(f"Email: {person.email}")
|
||||
print(f"Location: {person.location}")
|
||||
```
|
||||
|
||||
### Company Information
|
||||
|
||||
```python
|
||||
class CompanyInfo(BaseModel):
|
||||
name: str
|
||||
founded_year: int = Field(ge=1800, le=2025)
|
||||
industry: str
|
||||
headquarters: str
|
||||
employees: int = Field(gt=0)
|
||||
revenue: Optional[str] = None
|
||||
|
||||
model = outlines.models.transformers("meta-llama/Llama-3.1-8B-Instruct")
|
||||
generator = outlines.generate.json(model, CompanyInfo)
|
||||
|
||||
text = """
|
||||
Tesla, Inc. was founded in 2003 and operates primarily in the automotive
|
||||
and energy industries. The company is headquartered in Austin, Texas,
|
||||
and employs approximately 140,000 people worldwide.
|
||||
"""
|
||||
|
||||
company = generator(f"Extract company information:\n{text}\n\nCompany:")
|
||||
|
||||
print(f"Company: {company.name}")
|
||||
print(f"Founded: {company.founded_year}")
|
||||
print(f"Industry: {company.industry}")
|
||||
print(f"HQ: {company.headquarters}")
|
||||
print(f"Employees: {company.employees:,}")
|
||||
```
|
||||
|
||||
### Product Specifications
|
||||
|
||||
```python
|
||||
class ProductSpec(BaseModel):
|
||||
name: str
|
||||
brand: str
|
||||
price: float = Field(gt=0)
|
||||
dimensions: str
|
||||
weight: str
|
||||
features: list[str]
|
||||
rating: Optional[float] = Field(None, ge=0, le=5)
|
||||
|
||||
generator = outlines.generate.json(model, ProductSpec)
|
||||
|
||||
text = """
|
||||
The Apple iPhone 15 Pro is priced at $999. It measures 146.6 x 70.6 x 8.25 mm
|
||||
and weighs 187 grams. Key features include the A17 Pro chip, titanium design,
|
||||
action button, and USB-C port. It has an average customer rating of 4.5 stars.
|
||||
"""
|
||||
|
||||
product = generator(f"Extract product specifications:\n{text}\n\nProduct:")
|
||||
|
||||
print(f"Product: {product.brand} {product.name}")
|
||||
print(f"Price: ${product.price}")
|
||||
print(f"Features: {', '.join(product.features)}")
|
||||
```
|
||||
|
||||
## Classification Systems
|
||||
|
||||
### Sentiment Analysis
|
||||
|
||||
```python
|
||||
from typing import Literal
|
||||
from enum import Enum
|
||||
|
||||
class Sentiment(str, Enum):
|
||||
VERY_POSITIVE = "very_positive"
|
||||
POSITIVE = "positive"
|
||||
NEUTRAL = "neutral"
|
||||
NEGATIVE = "negative"
|
||||
VERY_NEGATIVE = "very_negative"
|
||||
|
||||
class SentimentAnalysis(BaseModel):
|
||||
text: str
|
||||
sentiment: Sentiment
|
||||
confidence: float = Field(ge=0.0, le=1.0)
|
||||
aspects: list[str] # What aspects were mentioned
|
||||
reasoning: str
|
||||
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, SentimentAnalysis)
|
||||
|
||||
review = """
|
||||
This product completely exceeded my expectations! The build quality is
|
||||
outstanding, and customer service was incredibly helpful. My only minor
|
||||
complaint is the packaging could be better.
|
||||
"""
|
||||
|
||||
result = generator(f"Analyze sentiment:\n{review}\n\nAnalysis:")
|
||||
|
||||
print(f"Sentiment: {result.sentiment.value}")
|
||||
print(f"Confidence: {result.confidence:.2%}")
|
||||
print(f"Aspects: {', '.join(result.aspects)}")
|
||||
print(f"Reasoning: {result.reasoning}")
|
||||
```
|
||||
|
||||
### Content Classification
|
||||
|
||||
```python
|
||||
class Category(str, Enum):
|
||||
TECHNOLOGY = "technology"
|
||||
BUSINESS = "business"
|
||||
SCIENCE = "science"
|
||||
POLITICS = "politics"
|
||||
ENTERTAINMENT = "entertainment"
|
||||
SPORTS = "sports"
|
||||
HEALTH = "health"
|
||||
|
||||
class ArticleClassification(BaseModel):
|
||||
primary_category: Category
|
||||
secondary_categories: list[Category]
|
||||
keywords: list[str] = Field(min_items=3, max_items=10)
|
||||
target_audience: Literal["general", "expert", "beginner"]
|
||||
reading_level: Literal["elementary", "intermediate", "advanced"]
|
||||
|
||||
generator = outlines.generate.json(model, ArticleClassification)
|
||||
|
||||
article = """
|
||||
Apple announced groundbreaking advancements in its AI capabilities with the
|
||||
release of iOS 18. The new features leverage machine learning to significantly
|
||||
improve battery life and overall device performance. Industry analysts predict
|
||||
this will strengthen Apple's position in the competitive smartphone market.
|
||||
"""
|
||||
|
||||
classification = generator(f"Classify article:\n{article}\n\nClassification:")
|
||||
|
||||
print(f"Primary: {classification.primary_category.value}")
|
||||
print(f"Secondary: {[c.value for c in classification.secondary_categories]}")
|
||||
print(f"Keywords: {classification.keywords}")
|
||||
print(f"Audience: {classification.target_audience}")
|
||||
```
|
||||
|
||||
### Intent Recognition
|
||||
|
||||
```python
|
||||
class Intent(str, Enum):
|
||||
QUESTION = "question"
|
||||
COMPLAINT = "complaint"
|
||||
REQUEST = "request"
|
||||
FEEDBACK = "feedback"
|
||||
CANCEL = "cancel"
|
||||
UPGRADE = "upgrade"
|
||||
|
||||
class UserMessage(BaseModel):
|
||||
original_message: str
|
||||
intent: Intent
|
||||
urgency: Literal["low", "medium", "high", "critical"]
|
||||
department: Literal["support", "sales", "billing", "technical"]
|
||||
sentiment: Literal["positive", "neutral", "negative"]
|
||||
action_required: bool
|
||||
summary: str
|
||||
|
||||
generator = outlines.generate.json(model, UserMessage)
|
||||
|
||||
message = """
|
||||
I've been charged twice for my subscription this month! This is the third
|
||||
time this has happened. I need someone to fix this immediately and refund
|
||||
the extra charge. Very disappointed with this service.
|
||||
"""
|
||||
|
||||
result = generator(f"Analyze message:\n{message}\n\nAnalysis:")
|
||||
|
||||
print(f"Intent: {result.intent.value}")
|
||||
print(f"Urgency: {result.urgency}")
|
||||
print(f"Route to: {result.department}")
|
||||
print(f"Action required: {result.action_required}")
|
||||
print(f"Summary: {result.summary}")
|
||||
```
|
||||
|
||||
## Form Processing
|
||||
|
||||
### Job Application
|
||||
|
||||
```python
|
||||
class Education(BaseModel):
|
||||
degree: str
|
||||
field: str
|
||||
institution: str
|
||||
year: int
|
||||
|
||||
class Experience(BaseModel):
|
||||
title: str
|
||||
company: str
|
||||
duration: str
|
||||
responsibilities: list[str]
|
||||
|
||||
class JobApplication(BaseModel):
|
||||
full_name: str
|
||||
email: str
|
||||
phone: str
|
||||
education: list[Education]
|
||||
experience: list[Experience]
|
||||
skills: list[str]
|
||||
availability: str
|
||||
|
||||
model = outlines.models.transformers("meta-llama/Llama-3.1-8B-Instruct")
|
||||
generator = outlines.generate.json(model, JobApplication)
|
||||
|
||||
resume_text = """
|
||||
John Smith
|
||||
Email: john.smith@email.com | Phone: 555-0123
|
||||
|
||||
EDUCATION
|
||||
- BS in Computer Science, MIT, 2018
|
||||
- MS in Artificial Intelligence, Stanford, 2020
|
||||
|
||||
EXPERIENCE
|
||||
Software Engineer, Google (2020-2023)
|
||||
- Developed ML pipelines for search ranking
|
||||
- Led team of 5 engineers
|
||||
- Improved search quality by 15%
|
||||
|
||||
SKILLS: Python, Machine Learning, TensorFlow, System Design
|
||||
|
||||
AVAILABILITY: Immediate
|
||||
"""
|
||||
|
||||
application = generator(f"Extract job application:\n{resume_text}\n\nApplication:")
|
||||
|
||||
print(f"Applicant: {application.full_name}")
|
||||
print(f"Email: {application.email}")
|
||||
print(f"Education: {len(application.education)} degrees")
|
||||
for edu in application.education:
|
||||
print(f" - {edu.degree} in {edu.field}, {edu.institution} ({edu.year})")
|
||||
print(f"Experience: {len(application.experience)} positions")
|
||||
```
|
||||
|
||||
### Invoice Processing
|
||||
|
||||
```python
|
||||
class InvoiceItem(BaseModel):
|
||||
description: str
|
||||
quantity: int = Field(gt=0)
|
||||
unit_price: float = Field(gt=0)
|
||||
total: float = Field(gt=0)
|
||||
|
||||
class Invoice(BaseModel):
|
||||
invoice_number: str
|
||||
date: str = Field(pattern=r"\d{4}-\d{2}-\d{2}")
|
||||
vendor: str
|
||||
customer: str
|
||||
items: list[InvoiceItem]
|
||||
subtotal: float = Field(gt=0)
|
||||
tax: float = Field(ge=0)
|
||||
total: float = Field(gt=0)
|
||||
|
||||
generator = outlines.generate.json(model, Invoice)
|
||||
|
||||
invoice_text = """
|
||||
INVOICE #INV-2024-001
|
||||
Date: 2024-01-15
|
||||
|
||||
From: Acme Corp
|
||||
To: Smith & Co
|
||||
|
||||
Items:
|
||||
- Widget A: 10 units @ $50.00 = $500.00
|
||||
- Widget B: 5 units @ $75.00 = $375.00
|
||||
- Service Fee: 1 @ $100.00 = $100.00
|
||||
|
||||
Subtotal: $975.00
|
||||
Tax (8%): $78.00
|
||||
TOTAL: $1,053.00
|
||||
"""
|
||||
|
||||
invoice = generator(f"Extract invoice:\n{invoice_text}\n\nInvoice:")
|
||||
|
||||
print(f"Invoice: {invoice.invoice_number}")
|
||||
print(f"From: {invoice.vendor} → To: {invoice.customer}")
|
||||
print(f"Items: {len(invoice.items)}")
|
||||
for item in invoice.items:
|
||||
print(f" - {item.description}: {item.quantity} × ${item.unit_price} = ${item.total}")
|
||||
print(f"Total: ${invoice.total}")
|
||||
```
|
||||
|
||||
### Survey Responses
|
||||
|
||||
```python
|
||||
class SurveyResponse(BaseModel):
|
||||
respondent_id: str
|
||||
completion_date: str
|
||||
satisfaction: Literal[1, 2, 3, 4, 5]
|
||||
would_recommend: bool
|
||||
favorite_features: list[str]
|
||||
improvement_areas: list[str]
|
||||
additional_comments: Optional[str] = None
|
||||
|
||||
generator = outlines.generate.json(model, SurveyResponse)
|
||||
|
||||
survey_text = """
|
||||
Survey ID: RESP-12345
|
||||
Completed: 2024-01-20
|
||||
|
||||
How satisfied are you with our product? 4 out of 5
|
||||
|
||||
Would you recommend to a friend? Yes
|
||||
|
||||
What features do you like most?
|
||||
- Fast performance
|
||||
- Easy to use
|
||||
- Great customer support
|
||||
|
||||
What could we improve?
|
||||
- Better documentation
|
||||
- More integrations
|
||||
|
||||
Additional feedback: Overall great product, keep up the good work!
|
||||
"""
|
||||
|
||||
response = generator(f"Extract survey response:\n{survey_text}\n\nResponse:")
|
||||
|
||||
print(f"Respondent: {response.respondent_id}")
|
||||
print(f"Satisfaction: {response.satisfaction}/5")
|
||||
print(f"Would recommend: {response.would_recommend}")
|
||||
print(f"Favorite features: {response.favorite_features}")
|
||||
print(f"Improvement areas: {response.improvement_areas}")
|
||||
```
|
||||
|
||||
## Multi-Entity Extraction
|
||||
|
||||
### News Article Entities
|
||||
|
||||
```python
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
role: Optional[str] = None
|
||||
affiliation: Optional[str] = None
|
||||
|
||||
class Organization(BaseModel):
|
||||
name: str
|
||||
type: Optional[str] = None
|
||||
|
||||
class Location(BaseModel):
|
||||
name: str
|
||||
type: Literal["city", "state", "country", "region"]
|
||||
|
||||
class Event(BaseModel):
|
||||
name: str
|
||||
date: Optional[str] = None
|
||||
location: Optional[str] = None
|
||||
|
||||
class ArticleEntities(BaseModel):
|
||||
people: list[Person]
|
||||
organizations: list[Organization]
|
||||
locations: list[Location]
|
||||
events: list[Event]
|
||||
dates: list[str]
|
||||
|
||||
model = outlines.models.transformers("meta-llama/Llama-3.1-8B-Instruct")
|
||||
generator = outlines.generate.json(model, ArticleEntities)
|
||||
|
||||
article = """
|
||||
Apple CEO Tim Cook met with Microsoft CEO Satya Nadella at Microsoft
|
||||
headquarters in Redmond, Washington on September 15, 2024, to discuss
|
||||
potential collaboration opportunities. The meeting was attended by executives
|
||||
from both companies and focused on AI integration strategies. Apple's
|
||||
Cupertino offices will host a follow-up meeting on October 20, 2024.
|
||||
"""
|
||||
|
||||
entities = generator(f"Extract all entities:\n{article}\n\nEntities:")
|
||||
|
||||
print("People:")
|
||||
for person in entities.people:
|
||||
print(f" - {person.name} ({person.role}) @ {person.affiliation}")
|
||||
|
||||
print("\nOrganizations:")
|
||||
for org in entities.organizations:
|
||||
print(f" - {org.name} ({org.type})")
|
||||
|
||||
print("\nLocations:")
|
||||
for loc in entities.locations:
|
||||
print(f" - {loc.name} ({loc.type})")
|
||||
|
||||
print("\nEvents:")
|
||||
for event in entities.events:
|
||||
print(f" - {event.name} on {event.date}")
|
||||
```
|
||||
|
||||
### Document Metadata
|
||||
|
||||
```python
|
||||
class Author(BaseModel):
|
||||
name: str
|
||||
email: Optional[str] = None
|
||||
affiliation: Optional[str] = None
|
||||
|
||||
class Reference(BaseModel):
|
||||
title: str
|
||||
authors: list[str]
|
||||
year: int
|
||||
source: str
|
||||
|
||||
class DocumentMetadata(BaseModel):
|
||||
title: str
|
||||
authors: list[Author]
|
||||
abstract: str
|
||||
keywords: list[str]
|
||||
publication_date: str
|
||||
journal: str
|
||||
doi: Optional[str] = None
|
||||
references: list[Reference]
|
||||
|
||||
generator = outlines.generate.json(model, DocumentMetadata)
|
||||
|
||||
paper = """
|
||||
Title: Advances in Neural Machine Translation
|
||||
|
||||
Authors:
|
||||
- Dr. Jane Smith (jane@university.edu), MIT
|
||||
- Prof. John Doe (jdoe@stanford.edu), Stanford University
|
||||
|
||||
Abstract: This paper presents novel approaches to neural machine translation
|
||||
using transformer architectures. We demonstrate significant improvements in
|
||||
translation quality across multiple language pairs.
|
||||
|
||||
Keywords: Neural Networks, Machine Translation, Transformers, NLP
|
||||
|
||||
Published: Journal of AI Research, 2024-03-15
|
||||
DOI: 10.1234/jair.2024.001
|
||||
|
||||
References:
|
||||
1. "Attention Is All You Need" by Vaswani et al., 2017, NeurIPS
|
||||
2. "BERT: Pre-training of Deep Bidirectional Transformers" by Devlin et al., 2019, NAACL
|
||||
"""
|
||||
|
||||
metadata = generator(f"Extract document metadata:\n{paper}\n\nMetadata:")
|
||||
|
||||
print(f"Title: {metadata.title}")
|
||||
print(f"Authors: {', '.join(a.name for a in metadata.authors)}")
|
||||
print(f"Keywords: {', '.join(metadata.keywords)}")
|
||||
print(f"References: {len(metadata.references)}")
|
||||
```
|
||||
|
||||
## Code Generation
|
||||
|
||||
### Python Function Generation
|
||||
|
||||
```python
|
||||
class Parameter(BaseModel):
|
||||
name: str = Field(pattern=r"^[a-z_][a-z0-9_]*$")
|
||||
type_hint: str
|
||||
default: Optional[str] = None
|
||||
|
||||
class PythonFunction(BaseModel):
|
||||
function_name: str = Field(pattern=r"^[a-z_][a-z0-9_]*$")
|
||||
parameters: list[Parameter]
|
||||
return_type: str
|
||||
docstring: str
|
||||
body: list[str] # Lines of code
|
||||
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, PythonFunction)
|
||||
|
||||
spec = "Create a function to calculate the factorial of a number"
|
||||
|
||||
func = generator(f"Generate Python function:\n{spec}\n\nFunction:")
|
||||
|
||||
print(f"def {func.function_name}(", end="")
|
||||
print(", ".join(f"{p.name}: {p.type_hint}" for p in func.parameters), end="")
|
||||
print(f") -> {func.return_type}:")
|
||||
print(f' """{func.docstring}"""')
|
||||
for line in func.body:
|
||||
print(f" {line}")
|
||||
```
|
||||
|
||||
### SQL Query Generation
|
||||
|
||||
```python
|
||||
class SQLQuery(BaseModel):
|
||||
query_type: Literal["SELECT", "INSERT", "UPDATE", "DELETE"]
|
||||
select_columns: Optional[list[str]] = None
|
||||
from_tables: list[str]
|
||||
joins: Optional[list[str]] = None
|
||||
where_conditions: Optional[list[str]] = None
|
||||
group_by: Optional[list[str]] = None
|
||||
order_by: Optional[list[str]] = None
|
||||
limit: Optional[int] = None
|
||||
|
||||
generator = outlines.generate.json(model, SQLQuery)
|
||||
|
||||
request = "Get top 10 users who made purchases in the last 30 days, ordered by total spent"
|
||||
|
||||
sql = generator(f"Generate SQL query:\n{request}\n\nQuery:")
|
||||
|
||||
print(f"Query type: {sql.query_type}")
|
||||
print(f"SELECT {', '.join(sql.select_columns)}")
|
||||
print(f"FROM {', '.join(sql.from_tables)}")
|
||||
if sql.joins:
|
||||
for join in sql.joins:
|
||||
print(f" {join}")
|
||||
if sql.where_conditions:
|
||||
print(f"WHERE {' AND '.join(sql.where_conditions)}")
|
||||
if sql.order_by:
|
||||
print(f"ORDER BY {', '.join(sql.order_by)}")
|
||||
if sql.limit:
|
||||
print(f"LIMIT {sql.limit}")
|
||||
```
|
||||
|
||||
### API Endpoint Spec
|
||||
|
||||
```python
|
||||
class Parameter(BaseModel):
|
||||
name: str
|
||||
type: str
|
||||
required: bool
|
||||
description: str
|
||||
|
||||
class APIEndpoint(BaseModel):
|
||||
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"]
|
||||
path: str
|
||||
description: str
|
||||
parameters: list[Parameter]
|
||||
request_body: Optional[dict] = None
|
||||
response_schema: dict
|
||||
status_codes: dict[int, str]
|
||||
|
||||
generator = outlines.generate.json(model, APIEndpoint)
|
||||
|
||||
spec = "Create user endpoint"
|
||||
|
||||
endpoint = generator(f"Generate API endpoint:\n{spec}\n\nEndpoint:")
|
||||
|
||||
print(f"{endpoint.method} {endpoint.path}")
|
||||
print(f"Description: {endpoint.description}")
|
||||
print("\nParameters:")
|
||||
for param in endpoint.parameters:
|
||||
req = "required" if param.required else "optional"
|
||||
print(f" - {param.name} ({param.type}, {req}): {param.description}")
|
||||
```
|
||||
|
||||
## Batch Processing
|
||||
|
||||
### Parallel Extraction
|
||||
|
||||
```python
|
||||
def batch_extract(texts: list[str], schema: type[BaseModel], model_name: str):
|
||||
"""Extract structured data from multiple texts."""
|
||||
model = outlines.models.transformers(model_name)
|
||||
generator = outlines.generate.json(model, schema)
|
||||
|
||||
results = []
|
||||
for i, text in enumerate(texts):
|
||||
print(f"Processing {i+1}/{len(texts)}...", end="\r")
|
||||
result = generator(f"Extract:\n{text}\n\nData:")
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
class Product(BaseModel):
|
||||
name: str
|
||||
price: float
|
||||
category: str
|
||||
|
||||
texts = [
|
||||
"iPhone 15 Pro costs $999 in Electronics",
|
||||
"Running Shoes are $89.99 in Sports",
|
||||
"Coffee Maker priced at $49.99 in Home & Kitchen"
|
||||
]
|
||||
|
||||
products = batch_extract(texts, Product, "microsoft/Phi-3-mini-4k-instruct")
|
||||
|
||||
for product in products:
|
||||
print(f"{product.name}: ${product.price} ({product.category})")
|
||||
```
|
||||
|
||||
### CSV Processing
|
||||
|
||||
```python
|
||||
import csv
|
||||
|
||||
def process_csv(csv_file: str, schema: type[BaseModel]):
|
||||
"""Process CSV file and extract structured data."""
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, schema)
|
||||
|
||||
results = []
|
||||
with open(csv_file, 'r') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
text = " | ".join(f"{k}: {v}" for k, v in row.items())
|
||||
result = generator(f"Extract:\n{text}\n\nData:")
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
class Customer(BaseModel):
|
||||
name: str
|
||||
email: str
|
||||
tier: Literal["basic", "premium", "enterprise"]
|
||||
mrr: float
|
||||
|
||||
# customers = process_csv("customers.csv", Customer)
|
||||
```
|
||||
|
||||
## Production Patterns
|
||||
|
||||
### Error Handling
|
||||
|
||||
```python
|
||||
from pydantic import ValidationError
|
||||
|
||||
def safe_extract(text: str, schema: type[BaseModel], retries: int = 3):
|
||||
"""Extract with error handling and retries."""
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, schema)
|
||||
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
result = generator(f"Extract:\n{text}\n\nData:")
|
||||
return result
|
||||
except ValidationError as e:
|
||||
print(f"Attempt {attempt + 1} failed: {e}")
|
||||
if attempt == retries - 1:
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f"Unexpected error: {e}")
|
||||
if attempt == retries - 1:
|
||||
raise
|
||||
|
||||
return None
|
||||
```
|
||||
|
||||
### Caching
|
||||
|
||||
```python
|
||||
from functools import lru_cache
|
||||
import hashlib
|
||||
|
||||
@lru_cache(maxsize=1000)
|
||||
def cached_extract(text_hash: str, schema_name: str):
|
||||
"""Cache extraction results."""
|
||||
# This would be called with actual extraction logic
|
||||
pass
|
||||
|
||||
def extract_with_cache(text: str, schema: type[BaseModel]):
|
||||
"""Extract with caching."""
|
||||
text_hash = hashlib.md5(text.encode()).hexdigest()
|
||||
schema_name = schema.__name__
|
||||
|
||||
cached_result = cached_extract(text_hash, schema_name)
|
||||
if cached_result:
|
||||
return cached_result
|
||||
|
||||
# Perform actual extraction
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, schema)
|
||||
result = generator(f"Extract:\n{text}\n\nData:")
|
||||
|
||||
return result
|
||||
```
|
||||
|
||||
### Monitoring
|
||||
|
||||
```python
|
||||
import time
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def monitored_extract(text: str, schema: type[BaseModel]):
|
||||
"""Extract with monitoring and logging."""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, schema)
|
||||
|
||||
result = generator(f"Extract:\n{text}\n\nData:")
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(f"Extraction succeeded in {elapsed:.2f}s")
|
||||
logger.info(f"Input length: {len(text)} chars")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
elapsed = time.time() - start_time
|
||||
logger.error(f"Extraction failed after {elapsed:.2f}s: {e}")
|
||||
raise
|
||||
```
|
||||
|
||||
### Rate Limiting
|
||||
|
||||
```python
|
||||
import time
|
||||
from threading import Lock
|
||||
|
||||
class RateLimiter:
|
||||
def __init__(self, max_requests: int, time_window: int):
|
||||
self.max_requests = max_requests
|
||||
self.time_window = time_window
|
||||
self.requests = []
|
||||
self.lock = Lock()
|
||||
|
||||
def wait_if_needed(self):
|
||||
with self.lock:
|
||||
now = time.time()
|
||||
# Remove old requests
|
||||
self.requests = [r for r in self.requests if now - r < self.time_window]
|
||||
|
||||
if len(self.requests) >= self.max_requests:
|
||||
sleep_time = self.time_window - (now - self.requests[0])
|
||||
time.sleep(sleep_time)
|
||||
self.requests = []
|
||||
|
||||
self.requests.append(now)
|
||||
|
||||
def rate_limited_extract(texts: list[str], schema: type[BaseModel]):
|
||||
"""Extract with rate limiting."""
|
||||
limiter = RateLimiter(max_requests=10, time_window=60) # 10 req/min
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, schema)
|
||||
|
||||
results = []
|
||||
for text in texts:
|
||||
limiter.wait_if_needed()
|
||||
result = generator(f"Extract:\n{text}\n\nData:")
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **Outlines Documentation**: https://outlines-dev.github.io/outlines
|
||||
- **Pydantic Documentation**: https://docs.pydantic.dev
|
||||
- **GitHub Examples**: https://github.com/outlines-dev/outlines/tree/main/examples
|
||||
@@ -0,0 +1,652 @@
|
||||
# Comprehensive JSON Generation Guide
|
||||
|
||||
Complete guide to JSON generation with Outlines using Pydantic models and JSON schemas.
|
||||
|
||||
## Table of Contents
|
||||
- Pydantic Models
|
||||
- JSON Schema Support
|
||||
- Advanced Patterns
|
||||
- Nested Structures
|
||||
- Complex Types
|
||||
- Validation
|
||||
- Performance Optimization
|
||||
|
||||
## Pydantic Models
|
||||
|
||||
### Basic Models
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel
|
||||
import outlines
|
||||
|
||||
class User(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
email: str
|
||||
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, User)
|
||||
|
||||
user = generator("Generate user: Alice, 25, alice@example.com")
|
||||
print(user.name) # "Alice"
|
||||
print(user.age) # 25
|
||||
print(user.email) # "alice@example.com"
|
||||
```
|
||||
|
||||
###
|
||||
|
||||
Field Constraints
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class Product(BaseModel):
|
||||
name: str = Field(min_length=1, max_length=100)
|
||||
price: float = Field(gt=0, description="Price in USD")
|
||||
discount: float = Field(ge=0, le=100, description="Discount percentage")
|
||||
quantity: int = Field(ge=0, description="Available quantity")
|
||||
sku: str = Field(pattern=r"^[A-Z]{3}-\d{6}$")
|
||||
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, Product)
|
||||
|
||||
product = generator("Generate product: iPhone 15, $999")
|
||||
# All fields guaranteed to meet constraints
|
||||
```
|
||||
|
||||
**Available Constraints:**
|
||||
- `min_length`, `max_length`: String length
|
||||
- `gt`, `ge`, `lt`, `le`: Numeric comparisons
|
||||
- `multiple_of`: Number must be multiple of value
|
||||
- `pattern`: Regex pattern for strings
|
||||
- `min_items`, `max_items`: List length
|
||||
|
||||
### Optional Fields
|
||||
|
||||
```python
|
||||
from typing import Optional
|
||||
|
||||
class Article(BaseModel):
|
||||
title: str # Required
|
||||
author: Optional[str] = None # Optional
|
||||
published_date: Optional[str] = None # Optional
|
||||
tags: list[str] = [] # Default empty list
|
||||
view_count: int = 0 # Default value
|
||||
|
||||
generator = outlines.generate.json(model, Article)
|
||||
|
||||
# Can generate even if optional fields missing
|
||||
article = generator("Title: Introduction to AI")
|
||||
print(article.author) # None (not provided)
|
||||
print(article.tags) # [] (default)
|
||||
```
|
||||
|
||||
### Default Values
|
||||
|
||||
```python
|
||||
class Config(BaseModel):
|
||||
debug: bool = False
|
||||
max_retries: int = 3
|
||||
timeout: float = 30.0
|
||||
log_level: str = "INFO"
|
||||
|
||||
# Generator uses defaults when not specified
|
||||
generator = outlines.generate.json(model, Config)
|
||||
config = generator("Generate config with debug enabled")
|
||||
print(config.debug) # True (from prompt)
|
||||
print(config.timeout) # 30.0 (default)
|
||||
```
|
||||
|
||||
## Enums and Literals
|
||||
|
||||
### Enum Fields
|
||||
|
||||
```python
|
||||
from enum import Enum
|
||||
|
||||
class Status(str, Enum):
|
||||
PENDING = "pending"
|
||||
APPROVED = "approved"
|
||||
REJECTED = "rejected"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
class Application(BaseModel):
|
||||
applicant_name: str
|
||||
status: Status # Must be one of enum values
|
||||
submitted_date: str
|
||||
|
||||
generator = outlines.generate.json(model, Application)
|
||||
app = generator("Generate application for John Doe")
|
||||
|
||||
print(app.status) # Status.PENDING (or one of the enum values)
|
||||
print(type(app.status)) # <enum 'Status'>
|
||||
```
|
||||
|
||||
### Literal Types
|
||||
|
||||
```python
|
||||
from typing import Literal
|
||||
|
||||
class Task(BaseModel):
|
||||
title: str
|
||||
priority: Literal["low", "medium", "high", "critical"]
|
||||
status: Literal["todo", "in_progress", "done"]
|
||||
assigned_to: str
|
||||
|
||||
generator = outlines.generate.json(model, Task)
|
||||
task = generator("Create high priority task: Fix bug")
|
||||
|
||||
print(task.priority) # One of: "low", "medium", "high", "critical"
|
||||
```
|
||||
|
||||
### Multiple Choice Fields
|
||||
|
||||
```python
|
||||
class Survey(BaseModel):
|
||||
question: str
|
||||
answer: Literal["strongly_disagree", "disagree", "neutral", "agree", "strongly_agree"]
|
||||
confidence: Literal["low", "medium", "high"]
|
||||
|
||||
generator = outlines.generate.json(model, Survey)
|
||||
survey = generator("Rate: 'I enjoy using this product'")
|
||||
```
|
||||
|
||||
## Nested Structures
|
||||
|
||||
### Nested Models
|
||||
|
||||
```python
|
||||
class Address(BaseModel):
|
||||
street: str
|
||||
city: str
|
||||
state: str
|
||||
zip_code: str
|
||||
country: str = "USA"
|
||||
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
email: str
|
||||
address: Address # Nested model
|
||||
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, Person)
|
||||
|
||||
prompt = """
|
||||
Extract person:
|
||||
Name: Alice Johnson
|
||||
Age: 28
|
||||
Email: alice@example.com
|
||||
Address: 123 Main St, Boston, MA, 02101
|
||||
"""
|
||||
|
||||
person = generator(prompt)
|
||||
print(person.name) # "Alice Johnson"
|
||||
print(person.address.city) # "Boston"
|
||||
print(person.address.state) # "MA"
|
||||
```
|
||||
|
||||
### Deep Nesting
|
||||
|
||||
```python
|
||||
class Coordinates(BaseModel):
|
||||
latitude: float
|
||||
longitude: float
|
||||
|
||||
class Location(BaseModel):
|
||||
name: str
|
||||
coordinates: Coordinates
|
||||
|
||||
class Event(BaseModel):
|
||||
title: str
|
||||
date: str
|
||||
location: Location
|
||||
|
||||
generator = outlines.generate.json(model, Event)
|
||||
event = generator("Generate event: Tech Conference in San Francisco")
|
||||
|
||||
print(event.title) # "Tech Conference"
|
||||
print(event.location.name) # "San Francisco"
|
||||
print(event.location.coordinates.latitude) # 37.7749
|
||||
```
|
||||
|
||||
### Lists of Nested Models
|
||||
|
||||
```python
|
||||
class Item(BaseModel):
|
||||
name: str
|
||||
quantity: int
|
||||
price: float
|
||||
|
||||
class Order(BaseModel):
|
||||
order_id: str
|
||||
customer: str
|
||||
items: list[Item] # List of nested models
|
||||
total: float
|
||||
|
||||
generator = outlines.generate.json(model, Order)
|
||||
|
||||
prompt = """
|
||||
Generate order for John:
|
||||
- 2x Widget ($10 each)
|
||||
- 3x Gadget ($15 each)
|
||||
Order ID: ORD-001
|
||||
"""
|
||||
|
||||
order = generator(prompt)
|
||||
print(f"Order ID: {order.order_id}")
|
||||
for item in order.items:
|
||||
print(f"- {item.quantity}x {item.name} @ ${item.price}")
|
||||
print(f"Total: ${order.total}")
|
||||
```
|
||||
|
||||
## Complex Types
|
||||
|
||||
### Union Types
|
||||
|
||||
```python
|
||||
from typing import Union
|
||||
|
||||
class TextContent(BaseModel):
|
||||
type: Literal["text"]
|
||||
content: str
|
||||
|
||||
class ImageContent(BaseModel):
|
||||
type: Literal["image"]
|
||||
url: str
|
||||
caption: str
|
||||
|
||||
class Post(BaseModel):
|
||||
title: str
|
||||
content: Union[TextContent, ImageContent] # Either type
|
||||
|
||||
generator = outlines.generate.json(model, Post)
|
||||
|
||||
# Can generate either text or image content
|
||||
post = generator("Generate blog post with image")
|
||||
if post.content.type == "text":
|
||||
print(post.content.content)
|
||||
elif post.content.type == "image":
|
||||
print(post.content.url)
|
||||
```
|
||||
|
||||
### Lists and Arrays
|
||||
|
||||
```python
|
||||
class Article(BaseModel):
|
||||
title: str
|
||||
authors: list[str] # List of strings
|
||||
tags: list[str]
|
||||
sections: list[dict[str, str]] # List of dicts
|
||||
related_ids: list[int]
|
||||
|
||||
generator = outlines.generate.json(model, Article)
|
||||
article = generator("Generate article about AI")
|
||||
|
||||
print(article.authors) # ["Alice", "Bob"]
|
||||
print(article.tags) # ["AI", "Machine Learning", "Technology"]
|
||||
```
|
||||
|
||||
### Dictionaries
|
||||
|
||||
```python
|
||||
class Metadata(BaseModel):
|
||||
title: str
|
||||
properties: dict[str, str] # String keys and values
|
||||
counts: dict[str, int] # String keys, int values
|
||||
settings: dict[str, Union[str, int, bool]] # Mixed value types
|
||||
|
||||
generator = outlines.generate.json(model, Metadata)
|
||||
meta = generator("Generate metadata")
|
||||
|
||||
print(meta.properties) # {"author": "Alice", "version": "1.0"}
|
||||
print(meta.counts) # {"views": 1000, "likes": 50}
|
||||
```
|
||||
|
||||
### Any Type (Use Sparingly)
|
||||
|
||||
```python
|
||||
from typing import Any
|
||||
|
||||
class FlexibleData(BaseModel):
|
||||
name: str
|
||||
structured_field: str
|
||||
flexible_field: Any # Can be anything
|
||||
|
||||
# Note: Any reduces type safety, use only when necessary
|
||||
generator = outlines.generate.json(model, FlexibleData)
|
||||
```
|
||||
|
||||
## JSON Schema Support
|
||||
|
||||
### Direct Schema Usage
|
||||
|
||||
```python
|
||||
import outlines
|
||||
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
|
||||
# Define JSON schema
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer", "minimum": 0, "maximum": 120},
|
||||
"email": {"type": "string", "format": "email"}
|
||||
},
|
||||
"required": ["name", "age", "email"]
|
||||
}
|
||||
|
||||
# Generate from schema
|
||||
generator = outlines.generate.json(model, schema)
|
||||
result = generator("Generate person: Alice, 25, alice@example.com")
|
||||
|
||||
print(result) # Valid JSON matching schema
|
||||
```
|
||||
|
||||
### Schema from Pydantic
|
||||
|
||||
```python
|
||||
class User(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
email: str
|
||||
|
||||
# Get JSON schema from Pydantic model
|
||||
schema = User.model_json_schema()
|
||||
print(schema)
|
||||
# {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "name": {"type": "string"},
|
||||
# "age": {"type": "integer"},
|
||||
# "email": {"type": "string"}
|
||||
# },
|
||||
# "required": ["name", "age", "email"]
|
||||
# }
|
||||
|
||||
# Both approaches equivalent:
|
||||
generator1 = outlines.generate.json(model, User)
|
||||
generator2 = outlines.generate.json(model, schema)
|
||||
```
|
||||
|
||||
## Advanced Patterns
|
||||
|
||||
### Conditional Fields
|
||||
|
||||
```python
|
||||
class Order(BaseModel):
|
||||
order_type: Literal["standard", "express"]
|
||||
delivery_date: str
|
||||
express_fee: Optional[float] = None # Only for express orders
|
||||
|
||||
generator = outlines.generate.json(model, Order)
|
||||
|
||||
# Express order
|
||||
order1 = generator("Create express order for tomorrow")
|
||||
print(order1.express_fee) # 25.0
|
||||
|
||||
# Standard order
|
||||
order2 = generator("Create standard order")
|
||||
print(order2.express_fee) # None
|
||||
```
|
||||
|
||||
### Recursive Models
|
||||
|
||||
```python
|
||||
from typing import Optional, List
|
||||
|
||||
class TreeNode(BaseModel):
|
||||
value: str
|
||||
children: Optional[List['TreeNode']] = None
|
||||
|
||||
# Enable forward references
|
||||
TreeNode.model_rebuild()
|
||||
|
||||
generator = outlines.generate.json(model, TreeNode)
|
||||
tree = generator("Generate file tree with subdirectories")
|
||||
|
||||
print(tree.value) # "root"
|
||||
print(tree.children[0].value) # "subdir1"
|
||||
```
|
||||
|
||||
### Model with Validation
|
||||
|
||||
```python
|
||||
from pydantic import field_validator
|
||||
|
||||
class DateRange(BaseModel):
|
||||
start_date: str
|
||||
end_date: str
|
||||
|
||||
@field_validator('end_date')
|
||||
def end_after_start(cls, v, info):
|
||||
"""Ensure end_date is after start_date."""
|
||||
if 'start_date' in info.data:
|
||||
from datetime import datetime
|
||||
start = datetime.strptime(info.data['start_date'], '%Y-%m-%d')
|
||||
end = datetime.strptime(v, '%Y-%m-%d')
|
||||
if end < start:
|
||||
raise ValueError('end_date must be after start_date')
|
||||
return v
|
||||
|
||||
generator = outlines.generate.json(model, DateRange)
|
||||
# Validation happens after generation
|
||||
```
|
||||
|
||||
## Multiple Objects
|
||||
|
||||
### Generate List of Objects
|
||||
|
||||
```python
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
class Team(BaseModel):
|
||||
team_name: str
|
||||
members: list[Person]
|
||||
|
||||
generator = outlines.generate.json(model, Team)
|
||||
|
||||
team = generator("Generate engineering team with 5 members")
|
||||
print(f"Team: {team.team_name}")
|
||||
for member in team.members:
|
||||
print(f"- {member.name}, {member.age}")
|
||||
```
|
||||
|
||||
### Batch Generation
|
||||
|
||||
```python
|
||||
def generate_batch(prompts: list[str], schema: type[BaseModel]):
|
||||
"""Generate structured outputs for multiple prompts."""
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, schema)
|
||||
|
||||
results = []
|
||||
for prompt in prompts:
|
||||
result = generator(prompt)
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
class Product(BaseModel):
|
||||
name: str
|
||||
price: float
|
||||
|
||||
prompts = [
|
||||
"Product: iPhone 15, $999",
|
||||
"Product: MacBook Pro, $2499",
|
||||
"Product: AirPods, $179"
|
||||
]
|
||||
|
||||
products = generate_batch(prompts, Product)
|
||||
for product in products:
|
||||
print(f"{product.name}: ${product.price}")
|
||||
```
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### Caching Generators
|
||||
|
||||
```python
|
||||
from functools import lru_cache
|
||||
|
||||
@lru_cache(maxsize=10)
|
||||
def get_generator(model_name: str, schema_hash: int):
|
||||
"""Cache generators for reuse."""
|
||||
model = outlines.models.transformers(model_name)
|
||||
return outlines.generate.json(model, schema)
|
||||
|
||||
# First call: creates generator
|
||||
gen1 = get_generator("microsoft/Phi-3-mini-4k-instruct", hash(User))
|
||||
|
||||
# Second call: returns cached generator (fast!)
|
||||
gen2 = get_generator("microsoft/Phi-3-mini-4k-instruct", hash(User))
|
||||
```
|
||||
|
||||
### Batch Processing
|
||||
|
||||
```python
|
||||
# Process multiple items efficiently
|
||||
model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
|
||||
generator = outlines.generate.json(model, User)
|
||||
|
||||
texts = ["User: Alice, 25", "User: Bob, 30", "User: Carol, 35"]
|
||||
|
||||
# Reuse generator (model stays loaded)
|
||||
users = [generator(text) for text in texts]
|
||||
```
|
||||
|
||||
### Minimize Schema Complexity
|
||||
|
||||
```python
|
||||
# ✅ Good: Simple, flat structure (faster)
|
||||
class SimplePerson(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
city: str
|
||||
|
||||
# ⚠️ Slower: Deep nesting
|
||||
class ComplexPerson(BaseModel):
|
||||
personal_info: PersonalInfo
|
||||
address: Address
|
||||
employment: Employment
|
||||
# ... many nested levels
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
### Handle Missing Fields
|
||||
|
||||
```python
|
||||
from pydantic import ValidationError
|
||||
|
||||
class User(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
email: str
|
||||
|
||||
try:
|
||||
user = generator("Generate user") # May not include all fields
|
||||
except ValidationError as e:
|
||||
print(f"Validation error: {e}")
|
||||
# Handle gracefully
|
||||
```
|
||||
|
||||
### Fallback with Optional Fields
|
||||
|
||||
```python
|
||||
class RobustUser(BaseModel):
|
||||
name: str # Required
|
||||
age: Optional[int] = None # Optional
|
||||
email: Optional[str] = None # Optional
|
||||
|
||||
# More likely to succeed even with incomplete data
|
||||
user = generator("Generate user: Alice")
|
||||
print(user.name) # "Alice"
|
||||
print(user.age) # None (not provided)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Use Specific Types
|
||||
|
||||
```python
|
||||
# ✅ Good: Specific types
|
||||
class Product(BaseModel):
|
||||
name: str
|
||||
price: float # Not Any or str
|
||||
quantity: int # Not str
|
||||
in_stock: bool # Not int
|
||||
|
||||
# ❌ Bad: Generic types
|
||||
class Product(BaseModel):
|
||||
name: Any
|
||||
price: str # Should be float
|
||||
quantity: str # Should be int
|
||||
```
|
||||
|
||||
### 2. Add Descriptions
|
||||
|
||||
```python
|
||||
# ✅ Good: Clear descriptions
|
||||
class Article(BaseModel):
|
||||
title: str = Field(description="Article title, 10-100 characters")
|
||||
content: str = Field(description="Main article content in paragraphs")
|
||||
tags: list[str] = Field(description="List of relevant topic tags")
|
||||
|
||||
# Descriptions help the model understand expected output
|
||||
```
|
||||
|
||||
### 3. Use Constraints
|
||||
|
||||
```python
|
||||
# ✅ Good: With constraints
|
||||
class Age(BaseModel):
|
||||
value: int = Field(ge=0, le=120, description="Age in years")
|
||||
|
||||
# ❌ Bad: No constraints
|
||||
class Age(BaseModel):
|
||||
value: int # Could be negative or > 120
|
||||
```
|
||||
|
||||
### 4. Prefer Enums Over Strings
|
||||
|
||||
```python
|
||||
# ✅ Good: Enum for fixed set
|
||||
class Priority(str, Enum):
|
||||
LOW = "low"
|
||||
MEDIUM = "medium"
|
||||
HIGH = "high"
|
||||
|
||||
class Task(BaseModel):
|
||||
priority: Priority # Guaranteed valid
|
||||
|
||||
# ❌ Bad: Free-form string
|
||||
class Task(BaseModel):
|
||||
priority: str # Could be "urgent", "ASAP", "!!", etc.
|
||||
```
|
||||
|
||||
### 5. Test Your Models
|
||||
|
||||
```python
|
||||
# Test models work as expected
|
||||
def test_product_model():
|
||||
product = Product(
|
||||
name="Test Product",
|
||||
price=19.99,
|
||||
quantity=10,
|
||||
in_stock=True
|
||||
)
|
||||
assert product.price == 19.99
|
||||
assert isinstance(product, Product)
|
||||
|
||||
# Run tests before using in production
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **Pydantic Docs**: https://docs.pydantic.dev
|
||||
- **JSON Schema**: https://json-schema.org
|
||||
- **Outlines GitHub**: https://github.com/outlines-dev/outlines
|
||||
@@ -0,0 +1,371 @@
|
||||
---
|
||||
name: serving-llms-vllm
|
||||
description: "vLLM: high-throughput LLM serving, OpenAI API, quantization."
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [vllm, torch, transformers]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [vLLM, Inference Serving, PagedAttention, Continuous Batching, High Throughput, Production, OpenAI API, Quantization, Tensor Parallelism]
|
||||
|
||||
---
|
||||
|
||||
# vLLM - High-Performance LLM Serving
|
||||
|
||||
## When to use
|
||||
|
||||
Use when deploying production LLM APIs, optimizing inference latency/throughput, or serving models with limited GPU memory. Supports OpenAI-compatible endpoints, quantization (GPTQ/AWQ/FP8), and tensor parallelism.
|
||||
|
||||
## Quick start
|
||||
|
||||
vLLM achieves 24x higher throughput than standard transformers through PagedAttention (block-based KV cache) and continuous batching (mixing prefill/decode requests).
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install vllm
|
||||
```
|
||||
|
||||
**Basic offline inference**:
|
||||
```python
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
llm = LLM(model="meta-llama/Llama-3-8B-Instruct")
|
||||
sampling = SamplingParams(temperature=0.7, max_tokens=256)
|
||||
|
||||
outputs = llm.generate(["Explain quantum computing"], sampling)
|
||||
print(outputs[0].outputs[0].text)
|
||||
```
|
||||
|
||||
**OpenAI-compatible server**:
|
||||
```bash
|
||||
vllm serve meta-llama/Llama-3-8B-Instruct
|
||||
|
||||
# Query with OpenAI SDK
|
||||
python -c "
|
||||
from openai import OpenAI
|
||||
client = OpenAI(base_url='http://localhost:8000/v1', api_key='EMPTY')
|
||||
print(client.chat.completions.create(
|
||||
model='meta-llama/Llama-3-8B-Instruct',
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}]
|
||||
).choices[0].message.content)
|
||||
"
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Production API deployment
|
||||
|
||||
Copy this checklist and track progress:
|
||||
|
||||
```
|
||||
Deployment Progress:
|
||||
- [ ] Step 1: Configure server settings
|
||||
- [ ] Step 2: Test with limited traffic
|
||||
- [ ] Step 3: Enable monitoring
|
||||
- [ ] Step 4: Deploy to production
|
||||
- [ ] Step 5: Verify performance metrics
|
||||
```
|
||||
|
||||
**Step 1: Configure server settings**
|
||||
|
||||
Choose configuration based on your model size:
|
||||
|
||||
```bash
|
||||
# For 7B-13B models on single GPU
|
||||
vllm serve meta-llama/Llama-3-8B-Instruct \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--max-model-len 8192 \
|
||||
--port 8000
|
||||
|
||||
# For 30B-70B models with tensor parallelism
|
||||
vllm serve meta-llama/Llama-2-70b-hf \
|
||||
--tensor-parallel-size 4 \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--quantization awq \
|
||||
--port 8000
|
||||
|
||||
# For production with caching and metrics
|
||||
vllm serve meta-llama/Llama-3-8B-Instruct \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--enable-prefix-caching \
|
||||
--enable-metrics \
|
||||
--metrics-port 9090 \
|
||||
--port 8000 \
|
||||
--host 0.0.0.0
|
||||
```
|
||||
|
||||
**Step 2: Test with limited traffic**
|
||||
|
||||
Run load test before production:
|
||||
|
||||
```bash
|
||||
# Install load testing tool
|
||||
pip install locust
|
||||
|
||||
# Create test_load.py with sample requests
|
||||
# Run: locust -f test_load.py --host http://localhost:8000
|
||||
```
|
||||
|
||||
Verify TTFT (time to first token) < 500ms and throughput > 100 req/sec.
|
||||
|
||||
**Step 3: Enable monitoring**
|
||||
|
||||
vLLM exposes Prometheus metrics on port 9090:
|
||||
|
||||
```bash
|
||||
curl http://localhost:9090/metrics | grep vllm
|
||||
```
|
||||
|
||||
Key metrics to monitor:
|
||||
- `vllm:time_to_first_token_seconds` - Latency
|
||||
- `vllm:num_requests_running` - Active requests
|
||||
- `vllm:gpu_cache_usage_perc` - KV cache utilization
|
||||
|
||||
**Step 4: Deploy to production**
|
||||
|
||||
Use Docker for consistent deployment:
|
||||
|
||||
```bash
|
||||
# Run vLLM in Docker
|
||||
docker run --gpus all -p 8000:8000 \
|
||||
vllm/vllm-openai:latest \
|
||||
--model meta-llama/Llama-3-8B-Instruct \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--enable-prefix-caching
|
||||
```
|
||||
|
||||
**Step 5: Verify performance metrics**
|
||||
|
||||
Check that deployment meets targets:
|
||||
- TTFT < 500ms (for short prompts)
|
||||
- Throughput > target req/sec
|
||||
- GPU utilization > 80%
|
||||
- No OOM errors in logs
|
||||
|
||||
### Workflow 2: Offline batch inference
|
||||
|
||||
For processing large datasets without server overhead.
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
Batch Processing:
|
||||
- [ ] Step 1: Prepare input data
|
||||
- [ ] Step 2: Configure LLM engine
|
||||
- [ ] Step 3: Run batch inference
|
||||
- [ ] Step 4: Process results
|
||||
```
|
||||
|
||||
**Step 1: Prepare input data**
|
||||
|
||||
```python
|
||||
# Load prompts from file
|
||||
prompts = []
|
||||
with open("prompts.txt") as f:
|
||||
prompts = [line.strip() for line in f]
|
||||
|
||||
print(f"Loaded {len(prompts)} prompts")
|
||||
```
|
||||
|
||||
**Step 2: Configure LLM engine**
|
||||
|
||||
```python
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3-8B-Instruct",
|
||||
tensor_parallel_size=2, # Use 2 GPUs
|
||||
gpu_memory_utilization=0.9,
|
||||
max_model_len=4096
|
||||
)
|
||||
|
||||
sampling = SamplingParams(
|
||||
temperature=0.7,
|
||||
top_p=0.95,
|
||||
max_tokens=512,
|
||||
stop=["</s>", "\n\n"]
|
||||
)
|
||||
```
|
||||
|
||||
**Step 3: Run batch inference**
|
||||
|
||||
vLLM automatically batches requests for efficiency:
|
||||
|
||||
```python
|
||||
# Process all prompts in one call
|
||||
outputs = llm.generate(prompts, sampling)
|
||||
|
||||
# vLLM handles batching internally
|
||||
# No need to manually chunk prompts
|
||||
```
|
||||
|
||||
**Step 4: Process results**
|
||||
|
||||
```python
|
||||
# Extract generated text
|
||||
results = []
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated = output.outputs[0].text
|
||||
results.append({
|
||||
"prompt": prompt,
|
||||
"generated": generated,
|
||||
"tokens": len(output.outputs[0].token_ids)
|
||||
})
|
||||
|
||||
# Save to file
|
||||
import json
|
||||
with open("results.jsonl", "w") as f:
|
||||
for result in results:
|
||||
f.write(json.dumps(result) + "\n")
|
||||
|
||||
print(f"Processed {len(results)} prompts")
|
||||
```
|
||||
|
||||
### Workflow 3: Quantized model serving
|
||||
|
||||
Fit large models in limited GPU memory.
|
||||
|
||||
```
|
||||
Quantization Setup:
|
||||
- [ ] Step 1: Choose quantization method
|
||||
- [ ] Step 2: Find or create quantized model
|
||||
- [ ] Step 3: Launch with quantization flag
|
||||
- [ ] Step 4: Verify accuracy
|
||||
```
|
||||
|
||||
**Step 1: Choose quantization method**
|
||||
|
||||
- **AWQ**: Best for 70B models, minimal accuracy loss
|
||||
- **GPTQ**: Wide model support, good compression
|
||||
- **FP8**: Fastest on H100 GPUs
|
||||
|
||||
**Step 2: Find or create quantized model**
|
||||
|
||||
Use pre-quantized models from HuggingFace:
|
||||
|
||||
```bash
|
||||
# Search for AWQ models
|
||||
# Example: TheBloke/Llama-2-70B-AWQ
|
||||
```
|
||||
|
||||
**Step 3: Launch with quantization flag**
|
||||
|
||||
```bash
|
||||
# Using pre-quantized model
|
||||
vllm serve TheBloke/Llama-2-70B-AWQ \
|
||||
--quantization awq \
|
||||
--tensor-parallel-size 1 \
|
||||
--gpu-memory-utilization 0.95
|
||||
|
||||
# Results: 70B model in ~40GB VRAM
|
||||
```
|
||||
|
||||
**Step 4: Verify accuracy**
|
||||
|
||||
Test outputs match expected quality:
|
||||
|
||||
```python
|
||||
# Compare quantized vs non-quantized responses
|
||||
# Verify task-specific performance unchanged
|
||||
```
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use vLLM when:**
|
||||
- Deploying production LLM APIs (100+ req/sec)
|
||||
- Serving OpenAI-compatible endpoints
|
||||
- Limited GPU memory but need large models
|
||||
- Multi-user applications (chatbots, assistants)
|
||||
- Need low latency with high throughput
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **llama.cpp**: CPU/edge inference, single-user
|
||||
- **HuggingFace transformers**: Research, prototyping, one-off generation
|
||||
- **TensorRT-LLM**: NVIDIA-only, need absolute maximum performance
|
||||
- **Text-Generation-Inference**: Already in HuggingFace ecosystem
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: Out of memory during model loading**
|
||||
|
||||
Reduce memory usage:
|
||||
```bash
|
||||
vllm serve MODEL \
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--max-model-len 4096
|
||||
```
|
||||
|
||||
Or use quantization:
|
||||
```bash
|
||||
vllm serve MODEL --quantization awq
|
||||
```
|
||||
|
||||
**Issue: Slow first token (TTFT > 1 second)**
|
||||
|
||||
Enable prefix caching for repeated prompts:
|
||||
```bash
|
||||
vllm serve MODEL --enable-prefix-caching
|
||||
```
|
||||
|
||||
For long prompts, enable chunked prefill:
|
||||
```bash
|
||||
vllm serve MODEL --enable-chunked-prefill
|
||||
```
|
||||
|
||||
**Issue: Model not found error**
|
||||
|
||||
Use `--trust-remote-code` for custom models:
|
||||
```bash
|
||||
vllm serve MODEL --trust-remote-code
|
||||
```
|
||||
|
||||
**Issue: Low throughput (<50 req/sec)**
|
||||
|
||||
Increase concurrent sequences:
|
||||
```bash
|
||||
vllm serve MODEL --max-num-seqs 512
|
||||
```
|
||||
|
||||
Check GPU utilization with `nvidia-smi` - should be >80%.
|
||||
|
||||
**Issue: Inference slower than expected**
|
||||
|
||||
Verify tensor parallelism uses power of 2 GPUs:
|
||||
```bash
|
||||
vllm serve MODEL --tensor-parallel-size 4 # Not 3
|
||||
```
|
||||
|
||||
Enable speculative decoding for faster generation:
|
||||
```bash
|
||||
vllm serve MODEL --speculative-model DRAFT_MODEL
|
||||
```
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**Server deployment patterns**: See [references/server-deployment.md](references/server-deployment.md) for Docker, Kubernetes, and load balancing configurations.
|
||||
|
||||
**Performance optimization**: See [references/optimization.md](references/optimization.md) for PagedAttention tuning, continuous batching details, and benchmark results.
|
||||
|
||||
**Quantization guide**: See [references/quantization.md](references/quantization.md) for AWQ/GPTQ/FP8 setup, model preparation, and accuracy comparisons.
|
||||
|
||||
**Troubleshooting**: See [references/troubleshooting.md](references/troubleshooting.md) for detailed error messages, debugging steps, and performance diagnostics.
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
- **Small models (7B-13B)**: 1x A10 (24GB) or A100 (40GB)
|
||||
- **Medium models (30B-40B)**: 2x A100 (40GB) with tensor parallelism
|
||||
- **Large models (70B+)**: 4x A100 (40GB) or 2x A100 (80GB), use AWQ/GPTQ
|
||||
|
||||
Supported platforms: NVIDIA (primary), AMD ROCm, Intel GPUs, TPUs
|
||||
|
||||
## Resources
|
||||
|
||||
- Official docs: https://docs.vllm.ai
|
||||
- GitHub: https://github.com/vllm-project/vllm
|
||||
- Paper: "Efficient Memory Management for Large Language Model Serving with PagedAttention" (SOSP 2023)
|
||||
- Community: https://discuss.vllm.ai
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,226 @@
|
||||
# Performance Optimization
|
||||
|
||||
## Contents
|
||||
- PagedAttention explained
|
||||
- Continuous batching mechanics
|
||||
- Prefix caching strategies
|
||||
- Speculative decoding setup
|
||||
- Benchmark results and comparisons
|
||||
- Performance tuning guide
|
||||
|
||||
## PagedAttention explained
|
||||
|
||||
**Traditional attention problem**:
|
||||
- KV cache stored in contiguous memory
|
||||
- Wastes ~50% GPU memory due to fragmentation
|
||||
- Cannot dynamically reallocate for varying sequence lengths
|
||||
|
||||
**PagedAttention solution**:
|
||||
- Divides KV cache into fixed-size blocks (like OS virtual memory)
|
||||
- Dynamic allocation from free block queue
|
||||
- Shares blocks across sequences (for prefix caching)
|
||||
|
||||
**Memory savings example**:
|
||||
```
|
||||
Traditional: 70B model needs 160GB KV cache → OOM on 8x A100
|
||||
PagedAttention: 70B model needs 80GB KV cache → Fits on 4x A100
|
||||
```
|
||||
|
||||
**Configuration**:
|
||||
```bash
|
||||
# Block size (default: 16 tokens)
|
||||
vllm serve MODEL --block-size 16
|
||||
|
||||
# Number of GPU blocks (auto-calculated)
|
||||
# Controlled by --gpu-memory-utilization
|
||||
vllm serve MODEL --gpu-memory-utilization 0.9
|
||||
```
|
||||
|
||||
## Continuous batching mechanics
|
||||
|
||||
**Traditional batching**:
|
||||
- Wait for all sequences in batch to finish
|
||||
- GPU idle while waiting for longest sequence
|
||||
- Low GPU utilization (~40-60%)
|
||||
|
||||
**Continuous batching**:
|
||||
- Add new requests as slots become available
|
||||
- Mix prefill (new requests) and decode (ongoing) in same batch
|
||||
- High GPU utilization (>90%)
|
||||
|
||||
**Throughput improvement**:
|
||||
```
|
||||
Traditional batching: 50 req/sec @ 50% GPU util
|
||||
Continuous batching: 200 req/sec @ 90% GPU util
|
||||
= 4x throughput improvement
|
||||
```
|
||||
|
||||
**Tuning parameters**:
|
||||
```bash
|
||||
# Max concurrent sequences (higher = more batching)
|
||||
vllm serve MODEL --max-num-seqs 256
|
||||
|
||||
# Prefill/decode schedule (auto-balanced by default)
|
||||
# No manual tuning needed
|
||||
```
|
||||
|
||||
## Prefix caching strategies
|
||||
|
||||
Reuse computed KV cache for common prompt prefixes.
|
||||
|
||||
**Use cases**:
|
||||
- System prompts repeated across requests
|
||||
- Few-shot examples in every prompt
|
||||
- RAG contexts with overlapping chunks
|
||||
|
||||
**Example savings**:
|
||||
```
|
||||
Prompt: [System: 500 tokens] + [User: 100 tokens]
|
||||
|
||||
Without caching: Compute 600 tokens every request
|
||||
With caching: Compute 500 tokens once, then 100 tokens/request
|
||||
= 83% faster TTFT
|
||||
```
|
||||
|
||||
**Enable prefix caching**:
|
||||
```bash
|
||||
vllm serve MODEL --enable-prefix-caching
|
||||
```
|
||||
|
||||
**Automatic prefix detection**:
|
||||
- vLLM detects common prefixes automatically
|
||||
- No code changes required
|
||||
- Works with OpenAI-compatible API
|
||||
|
||||
**Cache hit rate monitoring**:
|
||||
```bash
|
||||
curl http://localhost:9090/metrics | grep cache_hit
|
||||
# vllm_cache_hit_rate: 0.75 (75% hit rate)
|
||||
```
|
||||
|
||||
## Speculative decoding setup
|
||||
|
||||
Use smaller "draft" model to propose tokens, larger model to verify.
|
||||
|
||||
**Speed improvement**:
|
||||
```
|
||||
Standard: Generate 1 token per forward pass
|
||||
Speculative: Generate 3-5 tokens per forward pass
|
||||
= 2-3x faster generation
|
||||
```
|
||||
|
||||
**How it works**:
|
||||
1. Draft model proposes K tokens (fast)
|
||||
2. Target model verifies all K tokens in parallel (one pass)
|
||||
3. Accept verified tokens, restart from first rejection
|
||||
|
||||
**Setup with separate draft model**:
|
||||
```bash
|
||||
vllm serve meta-llama/Llama-3-70B-Instruct \
|
||||
--speculative-model TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
|
||||
--num-speculative-tokens 5
|
||||
```
|
||||
|
||||
**Setup with n-gram draft** (no separate model):
|
||||
```bash
|
||||
vllm serve MODEL \
|
||||
--speculative-method ngram \
|
||||
--num-speculative-tokens 3
|
||||
```
|
||||
|
||||
**When to use**:
|
||||
- Output length > 100 tokens
|
||||
- Draft model 5-10x smaller than target
|
||||
- Acceptable 2-3% accuracy trade-off
|
||||
|
||||
## Benchmark results
|
||||
|
||||
**vLLM vs HuggingFace Transformers** (Llama 3 8B, A100):
|
||||
```
|
||||
Metric | HF Transformers | vLLM | Improvement
|
||||
------------------------|-----------------|--------|------------
|
||||
Throughput (req/sec) | 12 | 280 | 23x
|
||||
TTFT (ms) | 850 | 120 | 7x
|
||||
Tokens/sec | 45 | 2,100 | 47x
|
||||
GPU Memory (GB) | 28 | 16 | 1.75x less
|
||||
```
|
||||
|
||||
**vLLM vs TensorRT-LLM** (Llama 2 70B, 4x A100):
|
||||
```
|
||||
Metric | TensorRT-LLM | vLLM | Notes
|
||||
------------------------|--------------|--------|------------------
|
||||
Throughput (req/sec) | 320 | 285 | TRT 12% faster
|
||||
Setup complexity | High | Low | vLLM much easier
|
||||
NVIDIA-only | Yes | No | vLLM multi-platform
|
||||
Quantization support | FP8, INT8 | AWQ/GPTQ/FP8 | vLLM more options
|
||||
```
|
||||
|
||||
## Performance tuning guide
|
||||
|
||||
**Step 1: Measure baseline**
|
||||
|
||||
```bash
|
||||
# Install benchmarking tool
|
||||
pip install locust
|
||||
|
||||
# Run baseline benchmark
|
||||
vllm bench throughput \
|
||||
--model MODEL \
|
||||
--input-tokens 128 \
|
||||
--output-tokens 256 \
|
||||
--num-prompts 1000
|
||||
|
||||
# Record: throughput, TTFT, tokens/sec
|
||||
```
|
||||
|
||||
**Step 2: Tune memory utilization**
|
||||
|
||||
```bash
|
||||
# Try different values: 0.7, 0.85, 0.9, 0.95
|
||||
vllm serve MODEL --gpu-memory-utilization 0.9
|
||||
```
|
||||
|
||||
Higher = more batch capacity = higher throughput, but risk OOM.
|
||||
|
||||
**Step 3: Tune concurrency**
|
||||
|
||||
```bash
|
||||
# Try values: 128, 256, 512, 1024
|
||||
vllm serve MODEL --max-num-seqs 256
|
||||
```
|
||||
|
||||
Higher = more batching opportunity, but may increase latency.
|
||||
|
||||
**Step 4: Enable optimizations**
|
||||
|
||||
```bash
|
||||
vllm serve MODEL \
|
||||
--enable-prefix-caching \ # For repeated prompts
|
||||
--enable-chunked-prefill \ # For long prompts
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--max-num-seqs 512
|
||||
```
|
||||
|
||||
**Step 5: Re-benchmark and compare**
|
||||
|
||||
Target improvements:
|
||||
- Throughput: +30-100%
|
||||
- TTFT: -20-50%
|
||||
- GPU utilization: >85%
|
||||
|
||||
**Common performance issues**:
|
||||
|
||||
**Low throughput (<50 req/sec)**:
|
||||
- Increase `--max-num-seqs`
|
||||
- Enable `--enable-prefix-caching`
|
||||
- Check GPU utilization (should be >80%)
|
||||
|
||||
**High TTFT (>1 second)**:
|
||||
- Enable `--enable-chunked-prefill`
|
||||
- Reduce `--max-model-len` if possible
|
||||
- Check if model is too large for GPU
|
||||
|
||||
**OOM errors**:
|
||||
- Reduce `--gpu-memory-utilization` to 0.7
|
||||
- Reduce `--max-model-len`
|
||||
- Use quantization (`--quantization awq`)
|
||||
@@ -0,0 +1,284 @@
|
||||
# Quantization Guide
|
||||
|
||||
## Contents
|
||||
- Quantization methods comparison
|
||||
- AWQ setup and usage
|
||||
- GPTQ setup and usage
|
||||
- FP8 quantization (H100)
|
||||
- Model preparation
|
||||
- Accuracy vs compression trade-offs
|
||||
|
||||
## Quantization methods comparison
|
||||
|
||||
| Method | Compression | Accuracy Loss | Speed | Best For |
|
||||
|--------|-------------|---------------|-------|----------|
|
||||
| **AWQ** | 4-bit (75%) | <1% | Fast | 70B models, production |
|
||||
| **GPTQ** | 4-bit (75%) | 1-2% | Fast | Wide model support |
|
||||
| **FP8** | 8-bit (50%) | <0.5% | Fastest | H100 GPUs only |
|
||||
| **SqueezeLLM** | 3-4 bit (75-80%) | 2-3% | Medium | Extreme compression |
|
||||
|
||||
**Recommendation**:
|
||||
- **Production**: Use AWQ for 70B models
|
||||
- **H100 GPUs**: Use FP8 for best speed
|
||||
- **Maximum compatibility**: Use GPTQ
|
||||
- **Extreme compression**: Use SqueezeLLM
|
||||
|
||||
## AWQ setup and usage
|
||||
|
||||
**AWQ** (Activation-aware Weight Quantization) achieves best accuracy at 4-bit.
|
||||
|
||||
**Step 1: Find pre-quantized model**
|
||||
|
||||
Search HuggingFace for AWQ models:
|
||||
```bash
|
||||
# Example: TheBloke/Llama-2-70B-AWQ
|
||||
# Example: TheBloke/Mixtral-8x7B-Instruct-v0.1-AWQ
|
||||
```
|
||||
|
||||
**Step 2: Launch with AWQ**
|
||||
|
||||
```bash
|
||||
vllm serve TheBloke/Llama-2-70B-AWQ \
|
||||
--quantization awq \
|
||||
--tensor-parallel-size 1 \
|
||||
--gpu-memory-utilization 0.95
|
||||
```
|
||||
|
||||
**Memory savings**:
|
||||
```
|
||||
Llama 2 70B fp16: 140GB VRAM (4x A100 needed)
|
||||
Llama 2 70B AWQ: 35GB VRAM (1x A100 40GB)
|
||||
= 4x memory reduction
|
||||
```
|
||||
|
||||
**Step 3: Verify performance**
|
||||
|
||||
Test that outputs are acceptable:
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(base_url="http://localhost:8000/v1", api_key="EMPTY")
|
||||
|
||||
# Test complex reasoning
|
||||
response = client.chat.completions.create(
|
||||
model="TheBloke/Llama-2-70B-AWQ",
|
||||
messages=[{"role": "user", "content": "Explain quantum entanglement"}]
|
||||
)
|
||||
|
||||
print(response.choices[0].message.content)
|
||||
# Verify quality matches your requirements
|
||||
```
|
||||
|
||||
**Quantize your own model** (requires GPU with 80GB+ VRAM):
|
||||
|
||||
```python
|
||||
from awq import AutoAWQForCausalLM
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
model_path = "meta-llama/Llama-2-70b-hf"
|
||||
quant_path = "llama-2-70b-awq"
|
||||
|
||||
# Load model
|
||||
model = AutoAWQForCausalLM.from_pretrained(model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
|
||||
# Quantize
|
||||
quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4}
|
||||
model.quantize(tokenizer, quant_config=quant_config)
|
||||
|
||||
# Save
|
||||
model.save_quantized(quant_path)
|
||||
tokenizer.save_pretrained(quant_path)
|
||||
```
|
||||
|
||||
## GPTQ setup and usage
|
||||
|
||||
**GPTQ** has widest model support and good compression.
|
||||
|
||||
**Step 1: Find GPTQ model**
|
||||
|
||||
```bash
|
||||
# Example: TheBloke/Llama-2-13B-GPTQ
|
||||
# Example: TheBloke/CodeLlama-34B-GPTQ
|
||||
```
|
||||
|
||||
**Step 2: Launch with GPTQ**
|
||||
|
||||
```bash
|
||||
vllm serve TheBloke/Llama-2-13B-GPTQ \
|
||||
--quantization gptq \
|
||||
--dtype float16
|
||||
```
|
||||
|
||||
**GPTQ configuration options**:
|
||||
```bash
|
||||
# Specify GPTQ parameters if needed
|
||||
vllm serve MODEL \
|
||||
--quantization gptq \
|
||||
--gptq-act-order \ # Activation ordering
|
||||
--dtype float16
|
||||
```
|
||||
|
||||
**Quantize your own model**:
|
||||
|
||||
```python
|
||||
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
model_name = "meta-llama/Llama-2-13b-hf"
|
||||
quantized_name = "llama-2-13b-gptq"
|
||||
|
||||
# Load model
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = AutoGPTQForCausalLM.from_pretrained(model_name, quantize_config)
|
||||
|
||||
# Prepare calibration data
|
||||
calib_data = [...] # List of sample texts
|
||||
|
||||
# Quantize
|
||||
quantize_config = BaseQuantizeConfig(
|
||||
bits=4,
|
||||
group_size=128,
|
||||
desc_act=True
|
||||
)
|
||||
model.quantize(calib_data)
|
||||
|
||||
# Save
|
||||
model.save_quantized(quantized_name)
|
||||
```
|
||||
|
||||
## FP8 quantization (H100)
|
||||
|
||||
**FP8** (8-bit floating point) offers best speed on H100 GPUs with minimal accuracy loss.
|
||||
|
||||
**Requirements**:
|
||||
- H100 or H800 GPU
|
||||
- CUDA 12.3+ (12.8 recommended)
|
||||
- Hopper architecture support
|
||||
|
||||
**Step 1: Enable FP8**
|
||||
|
||||
```bash
|
||||
vllm serve meta-llama/Llama-3-70B-Instruct \
|
||||
--quantization fp8 \
|
||||
--tensor-parallel-size 2
|
||||
```
|
||||
|
||||
**Performance gains on H100**:
|
||||
```
|
||||
fp16: 180 tokens/sec
|
||||
FP8: 320 tokens/sec
|
||||
= 1.8x speedup
|
||||
```
|
||||
|
||||
**Step 2: Verify accuracy**
|
||||
|
||||
FP8 typically has <0.5% accuracy degradation:
|
||||
```python
|
||||
# Run evaluation suite
|
||||
# Compare FP8 vs FP16 on your tasks
|
||||
# Verify acceptable accuracy
|
||||
```
|
||||
|
||||
**Dynamic FP8 quantization** (no pre-quantized model needed):
|
||||
|
||||
```bash
|
||||
# vLLM automatically quantizes at runtime
|
||||
vllm serve MODEL --quantization fp8
|
||||
# No model preparation required
|
||||
```
|
||||
|
||||
## Model preparation
|
||||
|
||||
**Pre-quantized models (easiest)**:
|
||||
|
||||
1. Search HuggingFace: `[model name] AWQ` or `[model name] GPTQ`
|
||||
2. Download or use directly: `TheBloke/[Model]-AWQ`
|
||||
3. Launch with appropriate `--quantization` flag
|
||||
|
||||
**Quantize your own model**:
|
||||
|
||||
**AWQ**:
|
||||
```bash
|
||||
# Install AutoAWQ
|
||||
pip install autoawq
|
||||
|
||||
# Run quantization script
|
||||
python quantize_awq.py --model MODEL --output OUTPUT
|
||||
```
|
||||
|
||||
**GPTQ**:
|
||||
```bash
|
||||
# Install AutoGPTQ
|
||||
pip install auto-gptq
|
||||
|
||||
# Run quantization script
|
||||
python quantize_gptq.py --model MODEL --output OUTPUT
|
||||
```
|
||||
|
||||
**Calibration data**:
|
||||
- Use 128-512 diverse examples from target domain
|
||||
- Representative of production inputs
|
||||
- Higher quality calibration = better accuracy
|
||||
|
||||
## Accuracy vs compression trade-offs
|
||||
|
||||
**Empirical results** (Llama 2 70B on MMLU benchmark):
|
||||
|
||||
| Quantization | Accuracy | Memory | Speed | Production-Ready |
|
||||
|--------------|----------|--------|-------|------------------|
|
||||
| FP16 (baseline) | 100% | 140GB | 1.0x | ✅ (if memory available) |
|
||||
| FP8 | 99.5% | 70GB | 1.8x | ✅ (H100 only) |
|
||||
| AWQ 4-bit | 99.0% | 35GB | 1.5x | ✅ (best for 70B) |
|
||||
| GPTQ 4-bit | 98.5% | 35GB | 1.5x | ✅ (good compatibility) |
|
||||
| SqueezeLLM 3-bit | 96.0% | 26GB | 1.3x | ⚠️ (check accuracy) |
|
||||
|
||||
**When to use each**:
|
||||
|
||||
**No quantization (FP16)**:
|
||||
- Have sufficient GPU memory
|
||||
- Need absolute best accuracy
|
||||
- Model <13B parameters
|
||||
|
||||
**FP8**:
|
||||
- Using H100/H800 GPUs
|
||||
- Need best speed with minimal accuracy loss
|
||||
- Production deployment
|
||||
|
||||
**AWQ 4-bit**:
|
||||
- Need to fit 70B model in 40GB GPU
|
||||
- Production deployment
|
||||
- <1% accuracy loss acceptable
|
||||
|
||||
**GPTQ 4-bit**:
|
||||
- Wide model support needed
|
||||
- Not on H100 (use FP8 instead)
|
||||
- 1-2% accuracy loss acceptable
|
||||
|
||||
**Testing strategy**:
|
||||
|
||||
1. **Baseline**: Measure FP16 accuracy on your evaluation set
|
||||
2. **Quantize**: Create quantized version
|
||||
3. **Evaluate**: Compare quantized vs baseline on same tasks
|
||||
4. **Decide**: Accept if degradation < threshold (typically 1-2%)
|
||||
|
||||
**Example evaluation**:
|
||||
```python
|
||||
from evaluate import load_evaluation_suite
|
||||
|
||||
# Run on FP16 baseline
|
||||
baseline_score = evaluate(model_fp16, eval_suite)
|
||||
|
||||
# Run on quantized
|
||||
quant_score = evaluate(model_awq, eval_suite)
|
||||
|
||||
# Compare
|
||||
degradation = (baseline_score - quant_score) / baseline_score * 100
|
||||
print(f"Accuracy degradation: {degradation:.2f}%")
|
||||
|
||||
# Decision
|
||||
if degradation < 1.0:
|
||||
print("✅ Quantization acceptable for production")
|
||||
else:
|
||||
print("⚠️ Review accuracy loss")
|
||||
```
|
||||
@@ -0,0 +1,255 @@
|
||||
# Server Deployment Patterns
|
||||
|
||||
## Contents
|
||||
- Docker deployment
|
||||
- Kubernetes deployment
|
||||
- Load balancing with Nginx
|
||||
- Multi-node distributed serving
|
||||
- Production configuration examples
|
||||
- Health checks and monitoring
|
||||
|
||||
## Docker deployment
|
||||
|
||||
**Basic Dockerfile**:
|
||||
```dockerfile
|
||||
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04
|
||||
|
||||
RUN apt-get update && apt-get install -y python3-pip
|
||||
RUN pip install vllm
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["vllm", "serve", "meta-llama/Llama-3-8B-Instruct", \
|
||||
"--host", "0.0.0.0", "--port", "8000", \
|
||||
"--gpu-memory-utilization", "0.9"]
|
||||
```
|
||||
|
||||
**Build and run**:
|
||||
```bash
|
||||
docker build -t vllm-server .
|
||||
docker run --gpus all -p 8000:8000 vllm-server
|
||||
```
|
||||
|
||||
**Docker Compose** (with metrics):
|
||||
```yaml
|
||||
version: '3.8'
|
||||
services:
|
||||
vllm:
|
||||
image: vllm/vllm-openai:latest
|
||||
command: >
|
||||
--model meta-llama/Llama-3-8B-Instruct
|
||||
--gpu-memory-utilization 0.9
|
||||
--enable-metrics
|
||||
--metrics-port 9090
|
||||
ports:
|
||||
- "8000:8000"
|
||||
- "9090:9090"
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: all
|
||||
capabilities: [gpu]
|
||||
```
|
||||
|
||||
## Kubernetes deployment
|
||||
|
||||
**Deployment manifest**:
|
||||
```yaml
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: vllm-server
|
||||
spec:
|
||||
replicas: 2
|
||||
selector:
|
||||
matchLabels:
|
||||
app: vllm
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: vllm
|
||||
spec:
|
||||
containers:
|
||||
- name: vllm
|
||||
image: vllm/vllm-openai:latest
|
||||
args:
|
||||
- "--model=meta-llama/Llama-3-8B-Instruct"
|
||||
- "--gpu-memory-utilization=0.9"
|
||||
- "--enable-prefix-caching"
|
||||
resources:
|
||||
limits:
|
||||
nvidia.com/gpu: 1
|
||||
ports:
|
||||
- containerPort: 8000
|
||||
name: http
|
||||
- containerPort: 9090
|
||||
name: metrics
|
||||
readinessProbe:
|
||||
httpGet:
|
||||
path: /health
|
||||
port: 8000
|
||||
initialDelaySeconds: 30
|
||||
periodSeconds: 10
|
||||
livenessProbe:
|
||||
httpGet:
|
||||
path: /health
|
||||
port: 8000
|
||||
initialDelaySeconds: 60
|
||||
periodSeconds: 30
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: vllm-service
|
||||
spec:
|
||||
selector:
|
||||
app: vllm
|
||||
ports:
|
||||
- port: 8000
|
||||
targetPort: 8000
|
||||
name: http
|
||||
- port: 9090
|
||||
targetPort: 9090
|
||||
name: metrics
|
||||
type: LoadBalancer
|
||||
```
|
||||
|
||||
## Load balancing with Nginx
|
||||
|
||||
**Nginx configuration**:
|
||||
```nginx
|
||||
upstream vllm_backend {
|
||||
least_conn; # Route to least-loaded server
|
||||
server localhost:8001;
|
||||
server localhost:8002;
|
||||
server localhost:8003;
|
||||
}
|
||||
|
||||
server {
|
||||
listen 80;
|
||||
|
||||
location / {
|
||||
proxy_pass http://vllm_backend;
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
|
||||
# Timeouts for long-running inference
|
||||
proxy_read_timeout 300s;
|
||||
proxy_connect_timeout 75s;
|
||||
}
|
||||
|
||||
# Metrics endpoint
|
||||
location /metrics {
|
||||
proxy_pass http://localhost:9090/metrics;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Start multiple vLLM instances**:
|
||||
```bash
|
||||
# Terminal 1
|
||||
vllm serve MODEL --port 8001 --tensor-parallel-size 1
|
||||
|
||||
# Terminal 2
|
||||
vllm serve MODEL --port 8002 --tensor-parallel-size 1
|
||||
|
||||
# Terminal 3
|
||||
vllm serve MODEL --port 8003 --tensor-parallel-size 1
|
||||
|
||||
# Start Nginx
|
||||
nginx -c /path/to/nginx.conf
|
||||
```
|
||||
|
||||
## Multi-node distributed serving
|
||||
|
||||
For models too large for single node:
|
||||
|
||||
**Node 1** (master):
|
||||
```bash
|
||||
export MASTER_ADDR=192.168.1.10
|
||||
export MASTER_PORT=29500
|
||||
export RANK=0
|
||||
export WORLD_SIZE=2
|
||||
|
||||
vllm serve meta-llama/Llama-2-70b-hf \
|
||||
--tensor-parallel-size 8 \
|
||||
--pipeline-parallel-size 2
|
||||
```
|
||||
|
||||
**Node 2** (worker):
|
||||
```bash
|
||||
export MASTER_ADDR=192.168.1.10
|
||||
export MASTER_PORT=29500
|
||||
export RANK=1
|
||||
export WORLD_SIZE=2
|
||||
|
||||
vllm serve meta-llama/Llama-2-70b-hf \
|
||||
--tensor-parallel-size 8 \
|
||||
--pipeline-parallel-size 2
|
||||
```
|
||||
|
||||
## Production configuration examples
|
||||
|
||||
**High throughput** (batch-heavy workload):
|
||||
```bash
|
||||
vllm serve MODEL \
|
||||
--max-num-seqs 512 \
|
||||
--gpu-memory-utilization 0.95 \
|
||||
--enable-prefix-caching \
|
||||
--trust-remote-code
|
||||
```
|
||||
|
||||
**Low latency** (interactive workload):
|
||||
```bash
|
||||
vllm serve MODEL \
|
||||
--max-num-seqs 64 \
|
||||
--gpu-memory-utilization 0.85 \
|
||||
--enable-chunked-prefill
|
||||
```
|
||||
|
||||
**Memory-constrained** (40GB GPU for 70B model):
|
||||
```bash
|
||||
vllm serve TheBloke/Llama-2-70B-AWQ \
|
||||
--quantization awq \
|
||||
--tensor-parallel-size 1 \
|
||||
--gpu-memory-utilization 0.95 \
|
||||
--max-model-len 4096
|
||||
```
|
||||
|
||||
## Health checks and monitoring
|
||||
|
||||
**Health check endpoint**:
|
||||
```bash
|
||||
curl http://localhost:8000/health
|
||||
# Returns: {"status": "ok"}
|
||||
```
|
||||
|
||||
**Readiness check** (wait for model loaded):
|
||||
```bash
|
||||
#!/bin/bash
|
||||
until curl -f http://localhost:8000/health; do
|
||||
echo "Waiting for vLLM to be ready..."
|
||||
sleep 5
|
||||
done
|
||||
echo "vLLM is ready!"
|
||||
```
|
||||
|
||||
**Prometheus scraping**:
|
||||
```yaml
|
||||
# prometheus.yml
|
||||
scrape_configs:
|
||||
- job_name: 'vllm'
|
||||
static_configs:
|
||||
- targets: ['localhost:9090']
|
||||
metrics_path: '/metrics'
|
||||
scrape_interval: 15s
|
||||
```
|
||||
|
||||
**Grafana dashboard** (key metrics):
|
||||
- Requests per second: `rate(vllm_request_success_total[5m])`
|
||||
- TTFT p50: `histogram_quantile(0.5, vllm_time_to_first_token_seconds_bucket)`
|
||||
- TTFT p99: `histogram_quantile(0.99, vllm_time_to_first_token_seconds_bucket)`
|
||||
- GPU cache usage: `vllm_gpu_cache_usage_perc`
|
||||
- Active requests: `vllm_num_requests_running`
|
||||
@@ -0,0 +1,447 @@
|
||||
# Troubleshooting Guide
|
||||
|
||||
## Contents
|
||||
- Out of memory (OOM) errors
|
||||
- Performance issues
|
||||
- Model loading errors
|
||||
- Network and connection issues
|
||||
- Quantization problems
|
||||
- Distributed serving issues
|
||||
- Debugging tools and commands
|
||||
|
||||
## Out of memory (OOM) errors
|
||||
|
||||
### Symptom: `torch.cuda.OutOfMemoryError` during model loading
|
||||
|
||||
**Cause**: Model + KV cache exceeds available VRAM
|
||||
|
||||
**Solutions (try in order)**:
|
||||
|
||||
1. **Reduce GPU memory utilization**:
|
||||
```bash
|
||||
vllm serve MODEL --gpu-memory-utilization 0.7 # Try 0.7, 0.75, 0.8
|
||||
```
|
||||
|
||||
2. **Reduce max sequence length**:
|
||||
```bash
|
||||
vllm serve MODEL --max-model-len 4096 # Instead of 8192
|
||||
```
|
||||
|
||||
3. **Enable quantization**:
|
||||
```bash
|
||||
vllm serve MODEL --quantization awq # 4x memory reduction
|
||||
```
|
||||
|
||||
4. **Use tensor parallelism** (multiple GPUs):
|
||||
```bash
|
||||
vllm serve MODEL --tensor-parallel-size 2 # Split across 2 GPUs
|
||||
```
|
||||
|
||||
5. **Reduce max concurrent sequences**:
|
||||
```bash
|
||||
vllm serve MODEL --max-num-seqs 128 # Default is 256
|
||||
```
|
||||
|
||||
### Symptom: OOM during inference (not model loading)
|
||||
|
||||
**Cause**: KV cache fills up during generation
|
||||
|
||||
**Solutions**:
|
||||
|
||||
```bash
|
||||
# Reduce KV cache allocation
|
||||
vllm serve MODEL --gpu-memory-utilization 0.85
|
||||
|
||||
# Reduce batch size
|
||||
vllm serve MODEL --max-num-seqs 64
|
||||
|
||||
# Reduce max tokens per request
|
||||
# Set in client request: max_tokens=512
|
||||
```
|
||||
|
||||
### Symptom: OOM with quantized model
|
||||
|
||||
**Cause**: Quantization overhead or incorrect configuration
|
||||
|
||||
**Solution**:
|
||||
```bash
|
||||
# Ensure quantization flag matches model
|
||||
vllm serve TheBloke/Llama-2-70B-AWQ --quantization awq # Must specify
|
||||
|
||||
# Try different dtype
|
||||
vllm serve MODEL --quantization awq --dtype float16
|
||||
```
|
||||
|
||||
## Performance issues
|
||||
|
||||
### Symptom: Low throughput (<50 req/sec expected >100)
|
||||
|
||||
**Diagnostic steps**:
|
||||
|
||||
1. **Check GPU utilization**:
|
||||
```bash
|
||||
watch -n 1 nvidia-smi
|
||||
# GPU utilization should be >80%
|
||||
```
|
||||
|
||||
If <80%, increase concurrent requests:
|
||||
```bash
|
||||
vllm serve MODEL --max-num-seqs 512 # Increase from 256
|
||||
```
|
||||
|
||||
2. **Check if memory-bound**:
|
||||
```bash
|
||||
# If memory at 100% but GPU <80%, reduce sequence length
|
||||
vllm serve MODEL --max-model-len 4096
|
||||
```
|
||||
|
||||
3. **Enable optimizations**:
|
||||
```bash
|
||||
vllm serve MODEL \
|
||||
--enable-prefix-caching \
|
||||
--enable-chunked-prefill \
|
||||
--max-num-seqs 512
|
||||
```
|
||||
|
||||
4. **Check tensor parallelism settings**:
|
||||
```bash
|
||||
# Must use power-of-2 GPUs
|
||||
vllm serve MODEL --tensor-parallel-size 4 # Not 3 or 5
|
||||
```
|
||||
|
||||
### Symptom: High TTFT (time to first token >1 second)
|
||||
|
||||
**Causes and solutions**:
|
||||
|
||||
**Long prompts**:
|
||||
```bash
|
||||
vllm serve MODEL --enable-chunked-prefill
|
||||
```
|
||||
|
||||
**No prefix caching**:
|
||||
```bash
|
||||
vllm serve MODEL --enable-prefix-caching # For repeated prompts
|
||||
```
|
||||
|
||||
**Too many concurrent requests**:
|
||||
```bash
|
||||
vllm serve MODEL --max-num-seqs 64 # Reduce to prioritize latency
|
||||
```
|
||||
|
||||
**Model too large for single GPU**:
|
||||
```bash
|
||||
vllm serve MODEL --tensor-parallel-size 2 # Parallelize prefill
|
||||
```
|
||||
|
||||
### Symptom: Slow token generation (low tokens/sec)
|
||||
|
||||
**Diagnostic**:
|
||||
```bash
|
||||
# Check if model is correct size
|
||||
vllm serve MODEL # Should see model size in logs
|
||||
|
||||
# Check speculative decoding
|
||||
vllm serve MODEL --speculative-model DRAFT_MODEL
|
||||
```
|
||||
|
||||
**For H100 GPUs**, enable FP8:
|
||||
```bash
|
||||
vllm serve MODEL --quantization fp8
|
||||
```
|
||||
|
||||
## Model loading errors
|
||||
|
||||
### Symptom: `OSError: MODEL not found`
|
||||
|
||||
**Causes**:
|
||||
|
||||
1. **Model name typo**:
|
||||
```bash
|
||||
# Check exact model name on HuggingFace
|
||||
vllm serve meta-llama/Llama-3-8B-Instruct # Correct capitalization
|
||||
```
|
||||
|
||||
2. **Private/gated model**:
|
||||
```bash
|
||||
# Login to HuggingFace first
|
||||
huggingface-cli login
|
||||
# Then run vLLM
|
||||
vllm serve meta-llama/Llama-3-70B-Instruct
|
||||
```
|
||||
|
||||
3. **Custom model needs trust flag**:
|
||||
```bash
|
||||
vllm serve MODEL --trust-remote-code
|
||||
```
|
||||
|
||||
### Symptom: `ValueError: Tokenizer not found`
|
||||
|
||||
**Solution**:
|
||||
```bash
|
||||
# Download model manually first
|
||||
python -c "from transformers import AutoTokenizer; AutoTokenizer.from_pretrained('MODEL')"
|
||||
|
||||
# Then launch vLLM
|
||||
vllm serve MODEL
|
||||
```
|
||||
|
||||
### Symptom: `ImportError: No module named 'flash_attn'`
|
||||
|
||||
**Solution**:
|
||||
```bash
|
||||
# Install flash attention
|
||||
pip install flash-attn --no-build-isolation
|
||||
|
||||
# Or disable flash attention
|
||||
vllm serve MODEL --disable-flash-attn
|
||||
```
|
||||
|
||||
## Network and connection issues
|
||||
|
||||
### Symptom: `Connection refused` when querying server
|
||||
|
||||
**Diagnostic**:
|
||||
|
||||
1. **Check server is running**:
|
||||
```bash
|
||||
curl http://localhost:8000/health
|
||||
```
|
||||
|
||||
2. **Check port binding**:
|
||||
```bash
|
||||
# Bind to all interfaces for remote access
|
||||
vllm serve MODEL --host 0.0.0.0 --port 8000
|
||||
|
||||
# Check if port is in use
|
||||
lsof -i :8000
|
||||
```
|
||||
|
||||
3. **Check firewall**:
|
||||
```bash
|
||||
# Allow port through firewall
|
||||
sudo ufw allow 8000
|
||||
```
|
||||
|
||||
### Symptom: Slow response times over network
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Increase timeout**:
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:8000/v1",
|
||||
api_key="EMPTY",
|
||||
timeout=300.0 # 5 minute timeout
|
||||
)
|
||||
```
|
||||
|
||||
2. **Check network latency**:
|
||||
```bash
|
||||
ping SERVER_IP # Should be <10ms for local network
|
||||
```
|
||||
|
||||
3. **Use connection pooling**:
|
||||
```python
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter
|
||||
from urllib3.util.retry import Retry
|
||||
|
||||
session = requests.Session()
|
||||
retries = Retry(total=3, backoff_factor=1)
|
||||
session.mount('http://', HTTPAdapter(max_retries=retries))
|
||||
```
|
||||
|
||||
## Quantization problems
|
||||
|
||||
### Symptom: `RuntimeError: Quantization format not supported`
|
||||
|
||||
**Solution**:
|
||||
```bash
|
||||
# Ensure correct quantization method
|
||||
vllm serve MODEL --quantization awq # For AWQ models
|
||||
vllm serve MODEL --quantization gptq # For GPTQ models
|
||||
|
||||
# Check model card for quantization type
|
||||
```
|
||||
|
||||
### Symptom: Poor quality outputs after quantization
|
||||
|
||||
**Diagnostic**:
|
||||
|
||||
1. **Verify model is correctly quantized**:
|
||||
```bash
|
||||
# Check model config.json for quantization_config
|
||||
cat ~/.cache/huggingface/hub/models--MODEL/config.json
|
||||
```
|
||||
|
||||
2. **Try different quantization method**:
|
||||
```bash
|
||||
# If AWQ quality issues, try FP8 (H100 only)
|
||||
vllm serve MODEL --quantization fp8
|
||||
|
||||
# Or use less aggressive quantization
|
||||
vllm serve MODEL # No quantization
|
||||
```
|
||||
|
||||
3. **Increase temperature for better diversity**:
|
||||
```python
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
```
|
||||
|
||||
## Distributed serving issues
|
||||
|
||||
### Symptom: `RuntimeError: Distributed init failed`
|
||||
|
||||
**Diagnostic**:
|
||||
|
||||
1. **Check environment variables**:
|
||||
```bash
|
||||
# On all nodes
|
||||
echo $MASTER_ADDR # Should be same
|
||||
echo $MASTER_PORT # Should be same
|
||||
echo $RANK # Should be unique per node (0, 1, 2, ...)
|
||||
echo $WORLD_SIZE # Should be same (total nodes)
|
||||
```
|
||||
|
||||
2. **Check network connectivity**:
|
||||
```bash
|
||||
# From node 1 to node 2
|
||||
ping NODE2_IP
|
||||
nc -zv NODE2_IP 29500 # Check port accessibility
|
||||
```
|
||||
|
||||
3. **Check NCCL settings**:
|
||||
```bash
|
||||
export NCCL_DEBUG=INFO
|
||||
export NCCL_SOCKET_IFNAME=eth0 # Or your network interface
|
||||
vllm serve MODEL --tensor-parallel-size 8
|
||||
```
|
||||
|
||||
### Symptom: `NCCL error: unhandled cuda error`
|
||||
|
||||
**Solutions**:
|
||||
|
||||
```bash
|
||||
# Set NCCL to use correct network interface
|
||||
export NCCL_SOCKET_IFNAME=eth0 # Replace with your interface
|
||||
|
||||
# Increase timeout
|
||||
export NCCL_TIMEOUT=1800 # 30 minutes
|
||||
|
||||
# Force P2P for debugging
|
||||
export NCCL_P2P_DISABLE=1
|
||||
```
|
||||
|
||||
## Debugging tools and commands
|
||||
|
||||
### Enable debug logging
|
||||
|
||||
```bash
|
||||
export VLLM_LOGGING_LEVEL=DEBUG
|
||||
vllm serve MODEL
|
||||
```
|
||||
|
||||
### Monitor GPU usage
|
||||
|
||||
```bash
|
||||
# Real-time GPU monitoring
|
||||
watch -n 1 nvidia-smi
|
||||
|
||||
# Memory breakdown
|
||||
nvidia-smi --query-gpu=memory.used,memory.free --format=csv -l 1
|
||||
```
|
||||
|
||||
### Profile performance
|
||||
|
||||
```bash
|
||||
# Built-in benchmarking
|
||||
vllm bench throughput \
|
||||
--model MODEL \
|
||||
--input-tokens 128 \
|
||||
--output-tokens 256 \
|
||||
--num-prompts 100
|
||||
|
||||
vllm bench latency \
|
||||
--model MODEL \
|
||||
--input-tokens 128 \
|
||||
--output-tokens 256 \
|
||||
--batch-size 8
|
||||
```
|
||||
|
||||
### Check metrics
|
||||
|
||||
```bash
|
||||
# Prometheus metrics
|
||||
curl http://localhost:9090/metrics
|
||||
|
||||
# Filter for specific metrics
|
||||
curl http://localhost:9090/metrics | grep vllm_time_to_first_token
|
||||
|
||||
# Key metrics to monitor:
|
||||
# - vllm_time_to_first_token_seconds
|
||||
# - vllm_time_per_output_token_seconds
|
||||
# - vllm_num_requests_running
|
||||
# - vllm_gpu_cache_usage_perc
|
||||
# - vllm_request_success_total
|
||||
```
|
||||
|
||||
### Test server health
|
||||
|
||||
```bash
|
||||
# Health check
|
||||
curl http://localhost:8000/health
|
||||
|
||||
# Model info
|
||||
curl http://localhost:8000/v1/models
|
||||
|
||||
# Test completion
|
||||
curl http://localhost:8000/v1/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "MODEL",
|
||||
"prompt": "Hello",
|
||||
"max_tokens": 10
|
||||
}'
|
||||
```
|
||||
|
||||
### Common environment variables
|
||||
|
||||
```bash
|
||||
# CUDA settings
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3 # Limit to specific GPUs
|
||||
|
||||
# vLLM settings
|
||||
export VLLM_LOGGING_LEVEL=DEBUG
|
||||
export VLLM_TRACE_FUNCTION=1 # Profile functions
|
||||
export VLLM_USE_V1=1 # Use v1.0 engine (faster)
|
||||
|
||||
# NCCL settings (distributed)
|
||||
export NCCL_DEBUG=INFO
|
||||
export NCCL_SOCKET_IFNAME=eth0
|
||||
export NCCL_IB_DISABLE=0 # Enable InfiniBand
|
||||
```
|
||||
|
||||
### Collect diagnostic info for bug reports
|
||||
|
||||
```bash
|
||||
# System info
|
||||
nvidia-smi
|
||||
python --version
|
||||
pip show vllm
|
||||
|
||||
# vLLM version and config
|
||||
vllm --version
|
||||
python -c "import vllm; print(vllm.__version__)"
|
||||
|
||||
# Run with debug logging
|
||||
export VLLM_LOGGING_LEVEL=DEBUG
|
||||
vllm serve MODEL 2>&1 | tee vllm_debug.log
|
||||
|
||||
# Include in bug report:
|
||||
# - vllm_debug.log
|
||||
# - nvidia-smi output
|
||||
# - Full command used
|
||||
# - Expected vs actual behavior
|
||||
```
|
||||
@@ -0,0 +1,3 @@
|
||||
---
|
||||
description: Specific model architectures and tools — image segmentation (Segment Anything / SAM) and audio generation (AudioCraft / MusicGen). Additional model skills (CLIP, Stable Diffusion, Whisper, LLaVA) are available as optional skills.
|
||||
---
|
||||
@@ -0,0 +1,567 @@
|
||||
---
|
||||
name: audiocraft-audio-generation
|
||||
description: "AudioCraft: MusicGen text-to-music, AudioGen text-to-sound."
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [audiocraft, torch>=2.0.0, transformers>=4.30.0]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Multimodal, Audio Generation, Text-to-Music, Text-to-Audio, MusicGen]
|
||||
|
||||
---
|
||||
|
||||
# AudioCraft: Audio Generation
|
||||
|
||||
Comprehensive guide to using Meta's AudioCraft for text-to-music and text-to-audio generation with MusicGen, AudioGen, and EnCodec.
|
||||
|
||||
## When to use AudioCraft
|
||||
|
||||
**Use AudioCraft when:**
|
||||
- Need to generate music from text descriptions
|
||||
- Creating sound effects and environmental audio
|
||||
- Building music generation applications
|
||||
- Need melody-conditioned music generation
|
||||
- Want stereo audio output
|
||||
- Require controllable music generation with style transfer
|
||||
|
||||
**Key features:**
|
||||
- **MusicGen**: Text-to-music generation with melody conditioning
|
||||
- **AudioGen**: Text-to-sound effects generation
|
||||
- **EnCodec**: High-fidelity neural audio codec
|
||||
- **Multiple model sizes**: Small (300M) to Large (3.3B)
|
||||
- **Stereo support**: Full stereo audio generation
|
||||
- **Style conditioning**: MusicGen-Style for reference-based generation
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **Stable Audio**: For longer commercial music generation
|
||||
- **Bark**: For text-to-speech with music/sound effects
|
||||
- **Riffusion**: For spectogram-based music generation
|
||||
- **OpenAI Jukebox**: For raw audio generation with lyrics
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# From PyPI
|
||||
pip install audiocraft
|
||||
|
||||
# From GitHub (latest)
|
||||
pip install git+https://github.com/facebookresearch/audiocraft.git
|
||||
|
||||
# Or use HuggingFace Transformers
|
||||
pip install transformers torch torchaudio
|
||||
```
|
||||
|
||||
### Basic text-to-music (AudioCraft)
|
||||
|
||||
```python
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Set generation parameters
|
||||
model.set_generation_params(
|
||||
duration=8, # seconds
|
||||
top_k=250,
|
||||
temperature=1.0
|
||||
)
|
||||
|
||||
# Generate from text
|
||||
descriptions = ["happy upbeat electronic dance music with synths"]
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
# Save audio
|
||||
torchaudio.save("output.wav", wav[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Using HuggingFace Transformers
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
import scipy
|
||||
|
||||
# Load model and processor
|
||||
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
|
||||
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
|
||||
model.to("cuda")
|
||||
|
||||
# Generate music
|
||||
inputs = processor(
|
||||
text=["80s pop track with bassy drums and synth"],
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
).to("cuda")
|
||||
|
||||
audio_values = model.generate(
|
||||
**inputs,
|
||||
do_sample=True,
|
||||
guidance_scale=3,
|
||||
max_new_tokens=256
|
||||
)
|
||||
|
||||
# Save
|
||||
sampling_rate = model.config.audio_encoder.sampling_rate
|
||||
scipy.io.wavfile.write("output.wav", rate=sampling_rate, data=audio_values[0, 0].cpu().numpy())
|
||||
```
|
||||
|
||||
### Text-to-sound with AudioGen
|
||||
|
||||
```python
|
||||
from audiocraft.models import AudioGen
|
||||
|
||||
# Load AudioGen
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
|
||||
model.set_generation_params(duration=5)
|
||||
|
||||
# Generate sound effects
|
||||
descriptions = ["dog barking in a park with birds chirping"]
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
torchaudio.save("sound.wav", wav[0].cpu(), sample_rate=16000)
|
||||
```
|
||||
|
||||
## Core concepts
|
||||
|
||||
### Architecture overview
|
||||
|
||||
```
|
||||
AudioCraft Architecture:
|
||||
┌──────────────────────────────────────────────────────────────┐
|
||||
│ Text Encoder (T5) │
|
||||
│ │ │
|
||||
│ Text Embeddings │
|
||||
└────────────────────────┬─────────────────────────────────────┘
|
||||
│
|
||||
┌────────────────────────▼─────────────────────────────────────┐
|
||||
│ Transformer Decoder (LM) │
|
||||
│ Auto-regressively generates audio tokens │
|
||||
│ Using efficient token interleaving patterns │
|
||||
└────────────────────────┬─────────────────────────────────────┘
|
||||
│
|
||||
┌────────────────────────▼─────────────────────────────────────┐
|
||||
│ EnCodec Audio Decoder │
|
||||
│ Converts tokens back to audio waveform │
|
||||
└──────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Model variants
|
||||
|
||||
| Model | Size | Description | Use Case |
|
||||
|-------|------|-------------|----------|
|
||||
| `musicgen-small` | 300M | Text-to-music | Quick generation |
|
||||
| `musicgen-medium` | 1.5B | Text-to-music | Balanced |
|
||||
| `musicgen-large` | 3.3B | Text-to-music | Best quality |
|
||||
| `musicgen-melody` | 1.5B | Text + melody | Melody conditioning |
|
||||
| `musicgen-melody-large` | 3.3B | Text + melody | Best melody |
|
||||
| `musicgen-stereo-*` | Varies | Stereo output | Stereo generation |
|
||||
| `musicgen-style` | 1.5B | Style transfer | Reference-based |
|
||||
| `audiogen-medium` | 1.5B | Text-to-sound | Sound effects |
|
||||
|
||||
### Generation parameters
|
||||
|
||||
| Parameter | Default | Description |
|
||||
|-----------|---------|-------------|
|
||||
| `duration` | 8.0 | Length in seconds (1-120) |
|
||||
| `top_k` | 250 | Top-k sampling |
|
||||
| `top_p` | 0.0 | Nucleus sampling (0 = disabled) |
|
||||
| `temperature` | 1.0 | Sampling temperature |
|
||||
| `cfg_coef` | 3.0 | Classifier-free guidance |
|
||||
|
||||
## MusicGen usage
|
||||
|
||||
### Text-to-music generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
import torchaudio
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-medium')
|
||||
|
||||
# Configure generation
|
||||
model.set_generation_params(
|
||||
duration=30, # Up to 30 seconds
|
||||
top_k=250, # Sampling diversity
|
||||
top_p=0.0, # 0 = use top_k only
|
||||
temperature=1.0, # Creativity (higher = more varied)
|
||||
cfg_coef=3.0 # Text adherence (higher = stricter)
|
||||
)
|
||||
|
||||
# Generate multiple samples
|
||||
descriptions = [
|
||||
"epic orchestral soundtrack with strings and brass",
|
||||
"chill lo-fi hip hop beat with jazzy piano",
|
||||
"energetic rock song with electric guitar"
|
||||
]
|
||||
|
||||
# Generate (returns [batch, channels, samples])
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
# Save each
|
||||
for i, audio in enumerate(wav):
|
||||
torchaudio.save(f"music_{i}.wav", audio.cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Melody-conditioned generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
import torchaudio
|
||||
|
||||
# Load melody model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||||
model.set_generation_params(duration=30)
|
||||
|
||||
# Load melody audio
|
||||
melody, sr = torchaudio.load("melody.wav")
|
||||
|
||||
# Generate with melody conditioning
|
||||
descriptions = ["acoustic guitar folk song"]
|
||||
wav = model.generate_with_chroma(descriptions, melody, sr)
|
||||
|
||||
torchaudio.save("melody_conditioned.wav", wav[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Stereo generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load stereo model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
|
||||
model.set_generation_params(duration=15)
|
||||
|
||||
descriptions = ["ambient electronic music with wide stereo panning"]
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
# wav shape: [batch, 2, samples] for stereo
|
||||
print(f"Stereo shape: {wav.shape}") # [1, 2, 480000]
|
||||
torchaudio.save("stereo.wav", wav[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Audio continuation
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
|
||||
processor = AutoProcessor.from_pretrained("facebook/musicgen-medium")
|
||||
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-medium")
|
||||
|
||||
# Load audio to continue
|
||||
import torchaudio
|
||||
audio, sr = torchaudio.load("intro.wav")
|
||||
|
||||
# Process with text and audio
|
||||
inputs = processor(
|
||||
audio=audio.squeeze().numpy(),
|
||||
sampling_rate=sr,
|
||||
text=["continue with a epic chorus"],
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
# Generate continuation
|
||||
audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=512)
|
||||
```
|
||||
|
||||
## MusicGen-Style usage
|
||||
|
||||
### Style-conditioned generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load style model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-style')
|
||||
|
||||
# Configure generation with style
|
||||
model.set_generation_params(
|
||||
duration=30,
|
||||
cfg_coef=3.0,
|
||||
cfg_coef_beta=5.0 # Style influence
|
||||
)
|
||||
|
||||
# Configure style conditioner
|
||||
model.set_style_conditioner_params(
|
||||
eval_q=3, # RVQ quantizers (1-6)
|
||||
excerpt_length=3.0 # Style excerpt length
|
||||
)
|
||||
|
||||
# Load style reference
|
||||
style_audio, sr = torchaudio.load("reference_style.wav")
|
||||
|
||||
# Generate with text + style
|
||||
descriptions = ["upbeat dance track"]
|
||||
wav = model.generate_with_style(descriptions, style_audio, sr)
|
||||
```
|
||||
|
||||
### Style-only generation (no text)
|
||||
|
||||
```python
|
||||
# Generate matching style without text prompt
|
||||
model.set_generation_params(
|
||||
duration=30,
|
||||
cfg_coef=3.0,
|
||||
cfg_coef_beta=None # Disable double CFG for style-only
|
||||
)
|
||||
|
||||
wav = model.generate_with_style([None], style_audio, sr)
|
||||
```
|
||||
|
||||
## AudioGen usage
|
||||
|
||||
### Sound effect generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import AudioGen
|
||||
import torchaudio
|
||||
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
# Generate various sounds
|
||||
descriptions = [
|
||||
"thunderstorm with heavy rain and lightning",
|
||||
"busy city traffic with car horns",
|
||||
"ocean waves crashing on rocks",
|
||||
"crackling campfire in forest"
|
||||
]
|
||||
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
for i, audio in enumerate(wav):
|
||||
torchaudio.save(f"sound_{i}.wav", audio.cpu(), sample_rate=16000)
|
||||
```
|
||||
|
||||
## EnCodec usage
|
||||
|
||||
### Audio compression
|
||||
|
||||
```python
|
||||
from audiocraft.models import CompressionModel
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
# Load EnCodec
|
||||
model = CompressionModel.get_pretrained('facebook/encodec_32khz')
|
||||
|
||||
# Load audio
|
||||
wav, sr = torchaudio.load("audio.wav")
|
||||
|
||||
# Ensure correct sample rate
|
||||
if sr != 32000:
|
||||
resampler = torchaudio.transforms.Resample(sr, 32000)
|
||||
wav = resampler(wav)
|
||||
|
||||
# Encode to tokens
|
||||
with torch.no_grad():
|
||||
encoded = model.encode(wav.unsqueeze(0))
|
||||
codes = encoded[0] # Audio codes
|
||||
|
||||
# Decode back to audio
|
||||
with torch.no_grad():
|
||||
decoded = model.decode(codes)
|
||||
|
||||
torchaudio.save("reconstructed.wav", decoded[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Music generation pipeline
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
class MusicGenerator:
|
||||
def __init__(self, model_name="facebook/musicgen-medium"):
|
||||
self.model = MusicGen.get_pretrained(model_name)
|
||||
self.sample_rate = 32000
|
||||
|
||||
def generate(self, prompt, duration=30, temperature=1.0, cfg=3.0):
|
||||
self.model.set_generation_params(
|
||||
duration=duration,
|
||||
top_k=250,
|
||||
temperature=temperature,
|
||||
cfg_coef=cfg
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = self.model.generate([prompt])
|
||||
|
||||
return wav[0].cpu()
|
||||
|
||||
def generate_batch(self, prompts, duration=30):
|
||||
self.model.set_generation_params(duration=duration)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = self.model.generate(prompts)
|
||||
|
||||
return wav.cpu()
|
||||
|
||||
def save(self, audio, path):
|
||||
torchaudio.save(path, audio, sample_rate=self.sample_rate)
|
||||
|
||||
# Usage
|
||||
generator = MusicGenerator()
|
||||
audio = generator.generate(
|
||||
"epic cinematic orchestral music",
|
||||
duration=30,
|
||||
temperature=1.0
|
||||
)
|
||||
generator.save(audio, "epic_music.wav")
|
||||
```
|
||||
|
||||
### Workflow 2: Sound design batch processing
|
||||
|
||||
```python
|
||||
import json
|
||||
from pathlib import Path
|
||||
from audiocraft.models import AudioGen
|
||||
import torchaudio
|
||||
|
||||
def batch_generate_sounds(sound_specs, output_dir):
|
||||
"""
|
||||
Generate multiple sounds from specifications.
|
||||
|
||||
Args:
|
||||
sound_specs: list of {"name": str, "description": str, "duration": float}
|
||||
output_dir: output directory path
|
||||
"""
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
results = []
|
||||
|
||||
for spec in sound_specs:
|
||||
model.set_generation_params(duration=spec.get("duration", 5))
|
||||
|
||||
wav = model.generate([spec["description"]])
|
||||
|
||||
output_path = output_dir / f"{spec['name']}.wav"
|
||||
torchaudio.save(str(output_path), wav[0].cpu(), sample_rate=16000)
|
||||
|
||||
results.append({
|
||||
"name": spec["name"],
|
||||
"path": str(output_path),
|
||||
"description": spec["description"]
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
# Usage
|
||||
sounds = [
|
||||
{"name": "explosion", "description": "massive explosion with debris", "duration": 3},
|
||||
{"name": "footsteps", "description": "footsteps on wooden floor", "duration": 5},
|
||||
{"name": "door", "description": "wooden door creaking and closing", "duration": 2}
|
||||
]
|
||||
|
||||
results = batch_generate_sounds(sounds, "sound_effects/")
|
||||
```
|
||||
|
||||
### Workflow 3: Gradio demo
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
def generate_music(prompt, duration, temperature, cfg_coef):
|
||||
model.set_generation_params(
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
cfg_coef=cfg_coef
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = model.generate([prompt])
|
||||
|
||||
# Save to temp file
|
||||
path = "temp_output.wav"
|
||||
torchaudio.save(path, wav[0].cpu(), sample_rate=32000)
|
||||
return path
|
||||
|
||||
demo = gr.Interface(
|
||||
fn=generate_music,
|
||||
inputs=[
|
||||
gr.Textbox(label="Music Description", placeholder="upbeat electronic dance music"),
|
||||
gr.Slider(1, 30, value=8, label="Duration (seconds)"),
|
||||
gr.Slider(0.5, 2.0, value=1.0, label="Temperature"),
|
||||
gr.Slider(1.0, 10.0, value=3.0, label="CFG Coefficient")
|
||||
],
|
||||
outputs=gr.Audio(label="Generated Music"),
|
||||
title="MusicGen Demo"
|
||||
)
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Performance optimization
|
||||
|
||||
### Memory optimization
|
||||
|
||||
```python
|
||||
# Use smaller model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Clear cache between generations
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Generate shorter durations
|
||||
model.set_generation_params(duration=10) # Instead of 30
|
||||
|
||||
# Use half precision
|
||||
model = model.half()
|
||||
```
|
||||
|
||||
### Batch processing efficiency
|
||||
|
||||
```python
|
||||
# Process multiple prompts at once (more efficient)
|
||||
descriptions = ["prompt1", "prompt2", "prompt3", "prompt4"]
|
||||
wav = model.generate(descriptions) # Single batch
|
||||
|
||||
# Instead of
|
||||
for desc in descriptions:
|
||||
wav = model.generate([desc]) # Multiple batches (slower)
|
||||
```
|
||||
|
||||
### GPU memory requirements
|
||||
|
||||
| Model | FP32 VRAM | FP16 VRAM |
|
||||
|-------|-----------|-----------|
|
||||
| musicgen-small | ~4GB | ~2GB |
|
||||
| musicgen-medium | ~8GB | ~4GB |
|
||||
| musicgen-large | ~16GB | ~8GB |
|
||||
|
||||
## Common issues
|
||||
|
||||
| Issue | Solution |
|
||||
|-------|----------|
|
||||
| CUDA OOM | Use smaller model, reduce duration |
|
||||
| Poor quality | Increase cfg_coef, better prompts |
|
||||
| Generation too short | Check max duration setting |
|
||||
| Audio artifacts | Try different temperature |
|
||||
| Stereo not working | Use stereo model variant |
|
||||
|
||||
## References
|
||||
|
||||
- **[Advanced Usage](references/advanced-usage.md)** - Training, fine-tuning, deployment
|
||||
- **[Troubleshooting](references/troubleshooting.md)** - Common issues and solutions
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/facebookresearch/audiocraft
|
||||
- **Paper (MusicGen)**: https://arxiv.org/abs/2306.05284
|
||||
- **Paper (AudioGen)**: https://arxiv.org/abs/2209.15352
|
||||
- **HuggingFace**: https://huggingface.co/facebook/musicgen-small
|
||||
- **Demo**: https://huggingface.co/spaces/facebook/MusicGen
|
||||
@@ -0,0 +1,666 @@
|
||||
# AudioCraft Advanced Usage Guide
|
||||
|
||||
## Fine-tuning MusicGen
|
||||
|
||||
### Custom dataset preparation
|
||||
|
||||
```python
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
import torchaudio
|
||||
|
||||
def prepare_dataset(audio_dir, output_dir, metadata_file):
|
||||
"""
|
||||
Prepare dataset for MusicGen fine-tuning.
|
||||
|
||||
Directory structure:
|
||||
output_dir/
|
||||
├── audio/
|
||||
│ ├── 0001.wav
|
||||
│ ├── 0002.wav
|
||||
│ └── ...
|
||||
└── metadata.json
|
||||
"""
|
||||
output_dir = Path(output_dir)
|
||||
audio_output = output_dir / "audio"
|
||||
audio_output.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load metadata (format: {"path": "...", "description": "..."})
|
||||
with open(metadata_file) as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
processed = []
|
||||
|
||||
for idx, item in enumerate(metadata):
|
||||
audio_path = Path(audio_dir) / item["path"]
|
||||
|
||||
# Load and resample to 32kHz
|
||||
wav, sr = torchaudio.load(str(audio_path))
|
||||
if sr != 32000:
|
||||
resampler = torchaudio.transforms.Resample(sr, 32000)
|
||||
wav = resampler(wav)
|
||||
|
||||
# Convert to mono if stereo
|
||||
if wav.shape[0] > 1:
|
||||
wav = wav.mean(dim=0, keepdim=True)
|
||||
|
||||
# Save processed audio
|
||||
output_path = audio_output / f"{idx:04d}.wav"
|
||||
torchaudio.save(str(output_path), wav, sample_rate=32000)
|
||||
|
||||
processed.append({
|
||||
"path": str(output_path.relative_to(output_dir)),
|
||||
"description": item["description"],
|
||||
"duration": wav.shape[1] / 32000
|
||||
})
|
||||
|
||||
# Save processed metadata
|
||||
with open(output_dir / "metadata.json", "w") as f:
|
||||
json.dump(processed, f, indent=2)
|
||||
|
||||
print(f"Processed {len(processed)} samples")
|
||||
return processed
|
||||
```
|
||||
|
||||
### Fine-tuning with dora
|
||||
|
||||
```bash
|
||||
# AudioCraft uses dora for experiment management
|
||||
# Install dora
|
||||
pip install dora-search
|
||||
|
||||
# Clone AudioCraft
|
||||
git clone https://github.com/facebookresearch/audiocraft.git
|
||||
cd audiocraft
|
||||
|
||||
# Create config for fine-tuning
|
||||
cat > config/solver/musicgen/finetune.yaml << 'EOF'
|
||||
defaults:
|
||||
- musicgen/musicgen_base
|
||||
- /model: lm/musicgen_lm
|
||||
- /conditioner: cond_base
|
||||
|
||||
solver: musicgen
|
||||
autocast: true
|
||||
autocast_dtype: float16
|
||||
|
||||
optim:
|
||||
epochs: 100
|
||||
batch_size: 4
|
||||
lr: 1e-4
|
||||
ema: 0.999
|
||||
optimizer: adamw
|
||||
|
||||
dataset:
|
||||
batch_size: 4
|
||||
num_workers: 4
|
||||
train:
|
||||
- dset: your_dataset
|
||||
root: /path/to/dataset
|
||||
valid:
|
||||
- dset: your_dataset
|
||||
root: /path/to/dataset
|
||||
|
||||
checkpoint:
|
||||
save_every: 10
|
||||
keep_every_states: null
|
||||
EOF
|
||||
|
||||
# Run fine-tuning
|
||||
dora run solver=musicgen/finetune
|
||||
```
|
||||
|
||||
### LoRA fine-tuning
|
||||
|
||||
```python
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from audiocraft.models import MusicGen
|
||||
import torch
|
||||
|
||||
# Load base model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Get the language model component
|
||||
lm = model.lm
|
||||
|
||||
# Configure LoRA
|
||||
lora_config = LoraConfig(
|
||||
r=8,
|
||||
lora_alpha=16,
|
||||
target_modules=["q_proj", "v_proj", "k_proj", "out_proj"],
|
||||
lora_dropout=0.05,
|
||||
bias="none"
|
||||
)
|
||||
|
||||
# Apply LoRA
|
||||
lm = get_peft_model(lm, lora_config)
|
||||
lm.print_trainable_parameters()
|
||||
```
|
||||
|
||||
## Multi-GPU Training
|
||||
|
||||
### DataParallel
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Wrap LM with DataParallel
|
||||
if torch.cuda.device_count() > 1:
|
||||
model.lm = nn.DataParallel(model.lm)
|
||||
|
||||
model.to("cuda")
|
||||
```
|
||||
|
||||
### DistributedDataParallel
|
||||
|
||||
```python
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
def setup(rank, world_size):
|
||||
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
def train(rank, world_size):
|
||||
setup(rank, world_size)
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
model.lm = model.lm.to(rank)
|
||||
model.lm = DDP(model.lm, device_ids=[rank])
|
||||
|
||||
# Training loop
|
||||
# ...
|
||||
|
||||
dist.destroy_process_group()
|
||||
```
|
||||
|
||||
## Custom Conditioning
|
||||
|
||||
### Adding new conditioners
|
||||
|
||||
```python
|
||||
from audiocraft.modules.conditioners import BaseConditioner
|
||||
import torch
|
||||
|
||||
class CustomConditioner(BaseConditioner):
|
||||
"""Custom conditioner for additional control signals."""
|
||||
|
||||
def __init__(self, dim, output_dim):
|
||||
super().__init__(dim, output_dim)
|
||||
self.embed = torch.nn.Linear(dim, output_dim)
|
||||
|
||||
def forward(self, x):
|
||||
return self.embed(x)
|
||||
|
||||
def tokenize(self, x):
|
||||
# Tokenize input for conditioning
|
||||
return x
|
||||
|
||||
# Use with MusicGen
|
||||
from audiocraft.models.builders import get_lm_model
|
||||
|
||||
# Modify model config to include custom conditioner
|
||||
# This requires editing the model configuration
|
||||
```
|
||||
|
||||
### Melody conditioning internals
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
from audiocraft.modules.codebooks_patterns import DelayedPatternProvider
|
||||
import torch
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||||
|
||||
# Access chroma extractor
|
||||
chroma_extractor = model.lm.condition_provider.conditioners.get('chroma')
|
||||
|
||||
# Manual chroma extraction
|
||||
def extract_chroma(audio, sr):
|
||||
"""Extract chroma features from audio."""
|
||||
import librosa
|
||||
|
||||
# Compute chroma
|
||||
chroma = librosa.feature.chroma_cqt(y=audio.numpy(), sr=sr)
|
||||
|
||||
return torch.from_numpy(chroma).float()
|
||||
|
||||
# Use extracted chroma for conditioning
|
||||
chroma = extract_chroma(melody_audio, sample_rate)
|
||||
```
|
||||
|
||||
## EnCodec Deep Dive
|
||||
|
||||
### Custom compression settings
|
||||
|
||||
```python
|
||||
from audiocraft.models import CompressionModel
|
||||
import torch
|
||||
|
||||
# Load EnCodec
|
||||
encodec = CompressionModel.get_pretrained('facebook/encodec_32khz')
|
||||
|
||||
# Access codec parameters
|
||||
print(f"Sample rate: {encodec.sample_rate}")
|
||||
print(f"Channels: {encodec.channels}")
|
||||
print(f"Cardinality: {encodec.cardinality}") # Codebook size
|
||||
print(f"Num codebooks: {encodec.num_codebooks}")
|
||||
print(f"Frame rate: {encodec.frame_rate}")
|
||||
|
||||
# Encode with specific bandwidth
|
||||
# Lower bandwidth = more compression, lower quality
|
||||
encodec.set_target_bandwidth(6.0) # 6 kbps
|
||||
|
||||
audio = torch.randn(1, 1, 32000) # 1 second
|
||||
encoded = encodec.encode(audio)
|
||||
decoded = encodec.decode(encoded[0])
|
||||
```
|
||||
|
||||
### Streaming encoding
|
||||
|
||||
```python
|
||||
import torch
|
||||
from audiocraft.models import CompressionModel
|
||||
|
||||
encodec = CompressionModel.get_pretrained('facebook/encodec_32khz')
|
||||
|
||||
def encode_streaming(audio_stream, chunk_size=32000):
|
||||
"""Encode audio in streaming fashion."""
|
||||
all_codes = []
|
||||
|
||||
for chunk in audio_stream:
|
||||
# Ensure chunk is right shape
|
||||
if chunk.dim() == 1:
|
||||
chunk = chunk.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
codes = encodec.encode(chunk)[0]
|
||||
all_codes.append(codes)
|
||||
|
||||
return torch.cat(all_codes, dim=-1)
|
||||
|
||||
def decode_streaming(codes_stream, output_stream):
|
||||
"""Decode codes in streaming fashion."""
|
||||
for codes in codes_stream:
|
||||
with torch.no_grad():
|
||||
audio = encodec.decode(codes)
|
||||
output_stream.write(audio.cpu().numpy())
|
||||
```
|
||||
|
||||
## MultiBand Diffusion
|
||||
|
||||
### Using MBD for enhanced quality
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen, MultiBandDiffusion
|
||||
|
||||
# Load MusicGen
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-medium')
|
||||
|
||||
# Load MultiBand Diffusion
|
||||
mbd = MultiBandDiffusion.get_mbd_musicgen()
|
||||
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
# Generate with standard decoder
|
||||
descriptions = ["epic orchestral music"]
|
||||
wav_standard = model.generate(descriptions)
|
||||
|
||||
# Generate tokens and use MBD decoder
|
||||
with torch.no_grad():
|
||||
# Get tokens
|
||||
gen_tokens = model.generate_tokens(descriptions)
|
||||
|
||||
# Decode with MBD
|
||||
wav_mbd = mbd.tokens_to_wav(gen_tokens)
|
||||
|
||||
# Compare quality
|
||||
print(f"Standard shape: {wav_standard.shape}")
|
||||
print(f"MBD shape: {wav_mbd.shape}")
|
||||
```
|
||||
|
||||
## API Server Deployment
|
||||
|
||||
### FastAPI server
|
||||
|
||||
```python
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
import io
|
||||
import base64
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Load model at startup
|
||||
model = None
|
||||
|
||||
@app.on_event("startup")
|
||||
async def load_model():
|
||||
global model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
prompt: str
|
||||
duration: float = 10.0
|
||||
temperature: float = 1.0
|
||||
cfg_coef: float = 3.0
|
||||
|
||||
class GenerateResponse(BaseModel):
|
||||
audio_base64: str
|
||||
sample_rate: int
|
||||
duration: float
|
||||
|
||||
@app.post("/generate", response_model=GenerateResponse)
|
||||
async def generate(request: GenerateRequest):
|
||||
if model is None:
|
||||
raise HTTPException(status_code=500, detail="Model not loaded")
|
||||
|
||||
try:
|
||||
model.set_generation_params(
|
||||
duration=min(request.duration, 30),
|
||||
temperature=request.temperature,
|
||||
cfg_coef=request.cfg_coef
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = model.generate([request.prompt])
|
||||
|
||||
# Convert to bytes
|
||||
buffer = io.BytesIO()
|
||||
torchaudio.save(buffer, wav[0].cpu(), sample_rate=32000, format="wav")
|
||||
buffer.seek(0)
|
||||
|
||||
audio_base64 = base64.b64encode(buffer.read()).decode()
|
||||
|
||||
return GenerateResponse(
|
||||
audio_base64=audio_base64,
|
||||
sample_rate=32000,
|
||||
duration=wav.shape[-1] / 32000
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok", "model_loaded": model is not None}
|
||||
|
||||
# Run: uvicorn server:app --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
### Batch processing service
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import torch
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
class MusicGenService:
|
||||
def __init__(self, model_name='facebook/musicgen-small', max_workers=2):
|
||||
self.model = MusicGen.get_pretrained(model_name)
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
async def generate_async(self, prompt, duration=10):
|
||||
"""Async generation with thread pool."""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def _generate():
|
||||
with torch.no_grad():
|
||||
self.model.set_generation_params(duration=duration)
|
||||
return self.model.generate([prompt])
|
||||
|
||||
# Run in thread pool
|
||||
wav = await loop.run_in_executor(self.executor, _generate)
|
||||
return wav[0].cpu()
|
||||
|
||||
async def generate_batch_async(self, prompts, duration=10):
|
||||
"""Process multiple prompts concurrently."""
|
||||
tasks = [self.generate_async(p, duration) for p in prompts]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
# Usage
|
||||
service = MusicGenService()
|
||||
|
||||
async def main():
|
||||
prompts = ["jazz piano", "rock guitar", "electronic beats"]
|
||||
results = await service.generate_batch_async(prompts)
|
||||
return results
|
||||
```
|
||||
|
||||
## Integration Patterns
|
||||
|
||||
### LangChain tool
|
||||
|
||||
```python
|
||||
from langchain.tools import BaseTool
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
import tempfile
|
||||
|
||||
class MusicGeneratorTool(BaseTool):
|
||||
name = "music_generator"
|
||||
description = "Generate music from a text description. Input should be a detailed description of the music style, mood, and instruments."
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
self.model.set_generation_params(duration=15)
|
||||
|
||||
def _run(self, description: str) -> str:
|
||||
with torch.no_grad():
|
||||
wav = self.model.generate([description])
|
||||
|
||||
# Save to temp file
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
||||
torchaudio.save(f.name, wav[0].cpu(), sample_rate=32000)
|
||||
return f"Generated music saved to: {f.name}"
|
||||
|
||||
async def _arun(self, description: str) -> str:
|
||||
return self._run(description)
|
||||
```
|
||||
|
||||
### Gradio with advanced controls
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
models = {}
|
||||
|
||||
def load_model(model_size):
|
||||
if model_size not in models:
|
||||
model_name = f"facebook/musicgen-{model_size}"
|
||||
models[model_size] = MusicGen.get_pretrained(model_name)
|
||||
return models[model_size]
|
||||
|
||||
def generate(prompt, duration, temperature, cfg_coef, top_k, model_size):
|
||||
model = load_model(model_size)
|
||||
|
||||
model.set_generation_params(
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
cfg_coef=cfg_coef,
|
||||
top_k=top_k
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = model.generate([prompt])
|
||||
|
||||
# Save
|
||||
path = "output.wav"
|
||||
torchaudio.save(path, wav[0].cpu(), sample_rate=32000)
|
||||
return path
|
||||
|
||||
demo = gr.Interface(
|
||||
fn=generate,
|
||||
inputs=[
|
||||
gr.Textbox(label="Prompt", lines=3),
|
||||
gr.Slider(1, 30, value=10, label="Duration (s)"),
|
||||
gr.Slider(0.1, 2.0, value=1.0, label="Temperature"),
|
||||
gr.Slider(0.5, 10.0, value=3.0, label="CFG Coefficient"),
|
||||
gr.Slider(50, 500, value=250, step=50, label="Top-K"),
|
||||
gr.Dropdown(["small", "medium", "large"], value="small", label="Model Size")
|
||||
],
|
||||
outputs=gr.Audio(label="Generated Music"),
|
||||
title="MusicGen Advanced",
|
||||
allow_flagging="never"
|
||||
)
|
||||
|
||||
demo.launch(share=True)
|
||||
```
|
||||
|
||||
## Audio Processing Pipeline
|
||||
|
||||
### Post-processing chain
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torchaudio
|
||||
import torchaudio.transforms as T
|
||||
import numpy as np
|
||||
|
||||
class AudioPostProcessor:
|
||||
def __init__(self, sample_rate=32000):
|
||||
self.sample_rate = sample_rate
|
||||
|
||||
def normalize(self, audio, target_db=-14.0):
|
||||
"""Normalize audio to target loudness."""
|
||||
rms = torch.sqrt(torch.mean(audio ** 2))
|
||||
target_rms = 10 ** (target_db / 20)
|
||||
gain = target_rms / (rms + 1e-8)
|
||||
return audio * gain
|
||||
|
||||
def fade_in_out(self, audio, fade_duration=0.1):
|
||||
"""Apply fade in/out."""
|
||||
fade_samples = int(fade_duration * self.sample_rate)
|
||||
|
||||
# Create fade curves
|
||||
fade_in = torch.linspace(0, 1, fade_samples)
|
||||
fade_out = torch.linspace(1, 0, fade_samples)
|
||||
|
||||
# Apply fades
|
||||
audio[..., :fade_samples] *= fade_in
|
||||
audio[..., -fade_samples:] *= fade_out
|
||||
|
||||
return audio
|
||||
|
||||
def apply_reverb(self, audio, decay=0.5):
|
||||
"""Apply simple reverb effect."""
|
||||
impulse = torch.zeros(int(self.sample_rate * 0.5))
|
||||
impulse[0] = 1.0
|
||||
impulse[int(self.sample_rate * 0.1)] = decay * 0.5
|
||||
impulse[int(self.sample_rate * 0.2)] = decay * 0.25
|
||||
|
||||
# Convolve
|
||||
audio = torch.nn.functional.conv1d(
|
||||
audio.unsqueeze(0),
|
||||
impulse.unsqueeze(0).unsqueeze(0),
|
||||
padding=len(impulse) // 2
|
||||
).squeeze(0)
|
||||
|
||||
return audio
|
||||
|
||||
def process(self, audio):
|
||||
"""Full processing pipeline."""
|
||||
audio = self.normalize(audio)
|
||||
audio = self.fade_in_out(audio)
|
||||
return audio
|
||||
|
||||
# Usage with MusicGen
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
wav = model.generate(["chill ambient music"])
|
||||
processor = AudioPostProcessor()
|
||||
wav_processed = processor.process(wav[0].cpu())
|
||||
|
||||
torchaudio.save("processed.wav", wav_processed, sample_rate=32000)
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
|
||||
### Audio quality metrics
|
||||
|
||||
```python
|
||||
import torch
|
||||
from audiocraft.metrics import CLAPTextConsistencyMetric
|
||||
from audiocraft.data.audio import audio_read
|
||||
|
||||
def evaluate_generation(audio_path, text_prompt):
|
||||
"""Evaluate generated audio quality."""
|
||||
# Load audio
|
||||
wav, sr = audio_read(audio_path)
|
||||
|
||||
# CLAP consistency (text-audio alignment)
|
||||
clap_metric = CLAPTextConsistencyMetric()
|
||||
clap_score = clap_metric.compute(wav, [text_prompt])
|
||||
|
||||
return {
|
||||
"clap_score": clap_score,
|
||||
"duration": wav.shape[-1] / sr
|
||||
}
|
||||
|
||||
# Batch evaluation
|
||||
def evaluate_batch(generations):
|
||||
"""Evaluate multiple generations."""
|
||||
results = []
|
||||
for gen in generations:
|
||||
result = evaluate_generation(gen["path"], gen["prompt"])
|
||||
result["prompt"] = gen["prompt"]
|
||||
results.append(result)
|
||||
|
||||
# Aggregate
|
||||
avg_clap = sum(r["clap_score"] for r in results) / len(results)
|
||||
return {
|
||||
"individual": results,
|
||||
"average_clap": avg_clap
|
||||
}
|
||||
```
|
||||
|
||||
## Model Comparison
|
||||
|
||||
### MusicGen variants benchmark
|
||||
|
||||
| Model | CLAP Score | Generation Time (10s) | VRAM |
|
||||
|-------|------------|----------------------|------|
|
||||
| musicgen-small | 0.35 | ~5s | 2GB |
|
||||
| musicgen-medium | 0.42 | ~15s | 4GB |
|
||||
| musicgen-large | 0.48 | ~30s | 8GB |
|
||||
| musicgen-melody | 0.45 | ~15s | 4GB |
|
||||
| musicgen-stereo-medium | 0.41 | ~18s | 5GB |
|
||||
|
||||
### Prompt engineering tips
|
||||
|
||||
```python
|
||||
# Good prompts - specific and descriptive
|
||||
good_prompts = [
|
||||
"upbeat electronic dance music with synthesizer leads and punchy drums at 128 bpm",
|
||||
"melancholic piano ballad with strings, slow tempo, emotional and cinematic",
|
||||
"funky disco groove with slap bass, brass section, and rhythmic guitar"
|
||||
]
|
||||
|
||||
# Bad prompts - too vague
|
||||
bad_prompts = [
|
||||
"nice music",
|
||||
"song",
|
||||
"good beat"
|
||||
]
|
||||
|
||||
# Structure: [mood] [genre] with [instruments] at [tempo/style]
|
||||
```
|
||||
@@ -0,0 +1,504 @@
|
||||
# AudioCraft Troubleshooting Guide
|
||||
|
||||
## Installation Issues
|
||||
|
||||
### Import errors
|
||||
|
||||
**Error**: `ModuleNotFoundError: No module named 'audiocraft'`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Install from PyPI
|
||||
pip install audiocraft
|
||||
|
||||
# Or from GitHub
|
||||
pip install git+https://github.com/facebookresearch/audiocraft.git
|
||||
|
||||
# Verify installation
|
||||
python -c "from audiocraft.models import MusicGen; print('OK')"
|
||||
```
|
||||
|
||||
### FFmpeg not found
|
||||
|
||||
**Error**: `RuntimeError: ffmpeg not found`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Ubuntu/Debian
|
||||
sudo apt-get install ffmpeg
|
||||
|
||||
# macOS
|
||||
brew install ffmpeg
|
||||
|
||||
# Windows (using conda)
|
||||
conda install -c conda-forge ffmpeg
|
||||
|
||||
# Verify
|
||||
ffmpeg -version
|
||||
```
|
||||
|
||||
### PyTorch CUDA mismatch
|
||||
|
||||
**Error**: `RuntimeError: CUDA error: no kernel image is available`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Check CUDA version
|
||||
nvcc --version
|
||||
python -c "import torch; print(torch.version.cuda)"
|
||||
|
||||
# Install matching PyTorch
|
||||
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu121
|
||||
|
||||
# For CUDA 11.8
|
||||
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu118
|
||||
```
|
||||
|
||||
### xformers issues
|
||||
|
||||
**Error**: `ImportError: xformers` related errors
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Install xformers for memory efficiency
|
||||
pip install xformers
|
||||
|
||||
# Or disable xformers
|
||||
export AUDIOCRAFT_USE_XFORMERS=0
|
||||
|
||||
# In Python
|
||||
import os
|
||||
os.environ["AUDIOCRAFT_USE_XFORMERS"] = "0"
|
||||
from audiocraft.models import MusicGen
|
||||
```
|
||||
|
||||
## Model Loading Issues
|
||||
|
||||
### Out of memory during load
|
||||
|
||||
**Error**: `torch.cuda.OutOfMemoryError` during model loading
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use smaller model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Force CPU loading first
|
||||
import torch
|
||||
device = "cpu"
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small', device=device)
|
||||
model = model.to("cuda")
|
||||
|
||||
# Use HuggingFace with device_map
|
||||
from transformers import MusicgenForConditionalGeneration
|
||||
model = MusicgenForConditionalGeneration.from_pretrained(
|
||||
"facebook/musicgen-small",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
### Download failures
|
||||
|
||||
**Error**: Connection errors or incomplete downloads
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Set cache directory
|
||||
import os
|
||||
os.environ["AUDIOCRAFT_CACHE_DIR"] = "/path/to/cache"
|
||||
|
||||
# Or for HuggingFace
|
||||
os.environ["HF_HOME"] = "/path/to/hf_cache"
|
||||
|
||||
# Resume download
|
||||
from huggingface_hub import snapshot_download
|
||||
snapshot_download("facebook/musicgen-small", resume_download=True)
|
||||
|
||||
# Use local files
|
||||
model = MusicGen.get_pretrained('/local/path/to/model')
|
||||
```
|
||||
|
||||
### Wrong model type
|
||||
|
||||
**Error**: Loading wrong model for task
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# For text-to-music: use MusicGen
|
||||
from audiocraft.models import MusicGen
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-medium')
|
||||
|
||||
# For text-to-sound: use AudioGen
|
||||
from audiocraft.models import AudioGen
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
|
||||
# For melody conditioning: use melody variant
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||||
|
||||
# For stereo: use stereo variant
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
|
||||
```
|
||||
|
||||
## Generation Issues
|
||||
|
||||
### Empty or silent output
|
||||
|
||||
**Problem**: Generated audio is silent or very quiet
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Check output
|
||||
wav = model.generate(["upbeat music"])
|
||||
print(f"Shape: {wav.shape}")
|
||||
print(f"Max amplitude: {wav.abs().max().item()}")
|
||||
print(f"Mean amplitude: {wav.abs().mean().item()}")
|
||||
|
||||
# If too quiet, normalize
|
||||
def normalize_audio(audio, target_db=-14.0):
|
||||
rms = torch.sqrt(torch.mean(audio ** 2))
|
||||
target_rms = 10 ** (target_db / 20)
|
||||
gain = target_rms / (rms + 1e-8)
|
||||
return audio * gain
|
||||
|
||||
wav_normalized = normalize_audio(wav)
|
||||
```
|
||||
|
||||
### Poor quality output
|
||||
|
||||
**Problem**: Generated music sounds bad or noisy
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use larger model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-large')
|
||||
|
||||
# Adjust generation parameters
|
||||
model.set_generation_params(
|
||||
duration=15,
|
||||
top_k=250, # Increase for more diversity
|
||||
temperature=0.8, # Lower for more focused output
|
||||
cfg_coef=4.0 # Increase for better text adherence
|
||||
)
|
||||
|
||||
# Use better prompts
|
||||
# Bad: "music"
|
||||
# Good: "upbeat electronic dance music with synthesizers and punchy drums"
|
||||
|
||||
# Try MultiBand Diffusion
|
||||
from audiocraft.models import MultiBandDiffusion
|
||||
mbd = MultiBandDiffusion.get_mbd_musicgen()
|
||||
tokens = model.generate_tokens(["prompt"])
|
||||
wav = mbd.tokens_to_wav(tokens)
|
||||
```
|
||||
|
||||
### Generation too short
|
||||
|
||||
**Problem**: Audio shorter than expected
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Check duration setting
|
||||
model.set_generation_params(duration=30) # Set before generate
|
||||
|
||||
# Verify in generation
|
||||
print(f"Duration setting: {model.generation_params}")
|
||||
|
||||
# Check output shape
|
||||
wav = model.generate(["prompt"])
|
||||
actual_duration = wav.shape[-1] / 32000
|
||||
print(f"Actual duration: {actual_duration}s")
|
||||
|
||||
# Note: max duration is typically 30s
|
||||
```
|
||||
|
||||
### Melody conditioning fails
|
||||
|
||||
**Error**: Issues with melody-conditioned generation
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load melody model (not base model)
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||||
|
||||
# Load and prepare melody
|
||||
melody, sr = torchaudio.load("melody.wav")
|
||||
|
||||
# Resample to model sample rate if needed
|
||||
if sr != 32000:
|
||||
resampler = torchaudio.transforms.Resample(sr, 32000)
|
||||
melody = resampler(melody)
|
||||
|
||||
# Ensure correct shape [batch, channels, samples]
|
||||
if melody.dim() == 1:
|
||||
melody = melody.unsqueeze(0).unsqueeze(0)
|
||||
elif melody.dim() == 2:
|
||||
melody = melody.unsqueeze(0)
|
||||
|
||||
# Convert stereo to mono
|
||||
if melody.shape[1] > 1:
|
||||
melody = melody.mean(dim=1, keepdim=True)
|
||||
|
||||
# Generate with melody
|
||||
model.set_generation_params(duration=min(melody.shape[-1] / 32000, 30))
|
||||
wav = model.generate_with_chroma(["piano cover"], melody, 32000)
|
||||
```
|
||||
|
||||
## Memory Issues
|
||||
|
||||
### CUDA out of memory
|
||||
|
||||
**Error**: `torch.cuda.OutOfMemoryError: CUDA out of memory`
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Clear cache before generation
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Use smaller model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Reduce duration
|
||||
model.set_generation_params(duration=10) # Instead of 30
|
||||
|
||||
# Generate one at a time
|
||||
for prompt in prompts:
|
||||
wav = model.generate([prompt])
|
||||
save_audio(wav)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Use CPU for very large generations
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small', device="cpu")
|
||||
```
|
||||
|
||||
### Memory leak during batch processing
|
||||
|
||||
**Problem**: Memory grows over time
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import gc
|
||||
import torch
|
||||
|
||||
def generate_with_cleanup(model, prompts):
|
||||
results = []
|
||||
|
||||
for prompt in prompts:
|
||||
with torch.no_grad():
|
||||
wav = model.generate([prompt])
|
||||
results.append(wav.cpu())
|
||||
|
||||
# Cleanup
|
||||
del wav
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return results
|
||||
|
||||
# Use context manager
|
||||
with torch.inference_mode():
|
||||
wav = model.generate(["prompt"])
|
||||
```
|
||||
|
||||
## Audio Format Issues
|
||||
|
||||
### Wrong sample rate
|
||||
|
||||
**Problem**: Audio plays at wrong speed
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torchaudio
|
||||
|
||||
# MusicGen outputs at 32kHz
|
||||
sample_rate = 32000
|
||||
|
||||
# AudioGen outputs at 16kHz
|
||||
sample_rate = 16000
|
||||
|
||||
# Always use correct rate when saving
|
||||
torchaudio.save("output.wav", wav[0].cpu(), sample_rate=sample_rate)
|
||||
|
||||
# Resample if needed
|
||||
resampler = torchaudio.transforms.Resample(32000, 44100)
|
||||
wav_resampled = resampler(wav)
|
||||
```
|
||||
|
||||
### Stereo/mono mismatch
|
||||
|
||||
**Problem**: Wrong number of channels
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Check model type
|
||||
print(f"Audio channels: {wav.shape}")
|
||||
# Mono: [batch, 1, samples]
|
||||
# Stereo: [batch, 2, samples]
|
||||
|
||||
# Convert mono to stereo
|
||||
if wav.shape[1] == 1:
|
||||
wav_stereo = wav.repeat(1, 2, 1)
|
||||
|
||||
# Convert stereo to mono
|
||||
if wav.shape[1] == 2:
|
||||
wav_mono = wav.mean(dim=1, keepdim=True)
|
||||
|
||||
# Use stereo model for stereo output
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
|
||||
```
|
||||
|
||||
### Clipping and distortion
|
||||
|
||||
**Problem**: Audio has clipping or distortion
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Check for clipping
|
||||
max_val = wav.abs().max().item()
|
||||
print(f"Max amplitude: {max_val}")
|
||||
|
||||
# Normalize to prevent clipping
|
||||
if max_val > 1.0:
|
||||
wav = wav / max_val
|
||||
|
||||
# Apply soft clipping
|
||||
def soft_clip(x, threshold=0.9):
|
||||
return torch.tanh(x / threshold) * threshold
|
||||
|
||||
wav_clipped = soft_clip(wav)
|
||||
|
||||
# Lower temperature during generation
|
||||
model.set_generation_params(temperature=0.7) # More controlled
|
||||
```
|
||||
|
||||
## HuggingFace Transformers Issues
|
||||
|
||||
### Processor errors
|
||||
|
||||
**Error**: Issues with MusicgenProcessor
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
|
||||
# Load matching processor and model
|
||||
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
|
||||
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
|
||||
|
||||
# Ensure inputs are on same device
|
||||
inputs = processor(
|
||||
text=["prompt"],
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
).to("cuda")
|
||||
|
||||
# Check processor configuration
|
||||
print(processor.tokenizer)
|
||||
print(processor.feature_extractor)
|
||||
```
|
||||
|
||||
### Generation parameter errors
|
||||
|
||||
**Error**: Invalid generation parameters
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# HuggingFace uses different parameter names
|
||||
audio_values = model.generate(
|
||||
**inputs,
|
||||
do_sample=True, # Enable sampling
|
||||
guidance_scale=3.0, # CFG (not cfg_coef)
|
||||
max_new_tokens=256, # Token limit (not duration)
|
||||
temperature=1.0
|
||||
)
|
||||
|
||||
# Calculate tokens from duration
|
||||
# ~50 tokens per second
|
||||
duration_seconds = 10
|
||||
max_tokens = duration_seconds * 50
|
||||
audio_values = model.generate(**inputs, max_new_tokens=max_tokens)
|
||||
```
|
||||
|
||||
## Performance Issues
|
||||
|
||||
### Slow generation
|
||||
|
||||
**Problem**: Generation takes too long
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use smaller model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Reduce duration
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
# Use GPU
|
||||
model.to("cuda")
|
||||
|
||||
# Enable flash attention if available
|
||||
# (requires compatible hardware)
|
||||
|
||||
# Batch multiple prompts
|
||||
prompts = ["prompt1", "prompt2", "prompt3"]
|
||||
wav = model.generate(prompts) # Single batch is faster than loop
|
||||
|
||||
# Use compile (PyTorch 2.0+)
|
||||
model.lm = torch.compile(model.lm)
|
||||
```
|
||||
|
||||
### CPU fallback
|
||||
|
||||
**Problem**: Generation running on CPU instead of GPU
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Check CUDA availability
|
||||
print(f"CUDA available: {torch.cuda.is_available()}")
|
||||
print(f"CUDA device: {torch.cuda.get_device_name(0)}")
|
||||
|
||||
# Explicitly move to GPU
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
model.to("cuda")
|
||||
|
||||
# Verify model device
|
||||
print(f"Model device: {next(model.lm.parameters()).device}")
|
||||
```
|
||||
|
||||
## Common Error Messages
|
||||
|
||||
| Error | Cause | Solution |
|
||||
|-------|-------|----------|
|
||||
| `CUDA out of memory` | Model too large | Use smaller model, reduce duration |
|
||||
| `ffmpeg not found` | FFmpeg not installed | Install FFmpeg |
|
||||
| `No module named 'audiocraft'` | Not installed | `pip install audiocraft` |
|
||||
| `RuntimeError: Expected 3D tensor` | Wrong input shape | Check tensor dimensions |
|
||||
| `KeyError: 'melody'` | Wrong model for melody | Use musicgen-melody |
|
||||
| `Sample rate mismatch` | Wrong audio format | Resample to model rate |
|
||||
|
||||
## Getting Help
|
||||
|
||||
1. **GitHub Issues**: https://github.com/facebookresearch/audiocraft/issues
|
||||
2. **HuggingFace Forums**: https://discuss.huggingface.co
|
||||
3. **Paper**: https://arxiv.org/abs/2306.05284
|
||||
|
||||
### Reporting Issues
|
||||
|
||||
Include:
|
||||
- Python version
|
||||
- PyTorch version
|
||||
- CUDA version
|
||||
- AudioCraft version: `pip show audiocraft`
|
||||
- Full error traceback
|
||||
- Minimal reproducible code
|
||||
- Hardware (GPU model, VRAM)
|
||||
@@ -0,0 +1,505 @@
|
||||
---
|
||||
name: segment-anything-model
|
||||
description: "SAM: zero-shot image segmentation via points, boxes, masks."
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [segment-anything, transformers>=4.30.0, torch>=1.7.0]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Multimodal, Image Segmentation, Computer Vision, SAM, Zero-Shot]
|
||||
|
||||
---
|
||||
|
||||
# Segment Anything Model (SAM)
|
||||
|
||||
Comprehensive guide to using Meta AI's Segment Anything Model for zero-shot image segmentation.
|
||||
|
||||
## When to use SAM
|
||||
|
||||
**Use SAM when:**
|
||||
- Need to segment any object in images without task-specific training
|
||||
- Building interactive annotation tools with point/box prompts
|
||||
- Generating training data for other vision models
|
||||
- Need zero-shot transfer to new image domains
|
||||
- Building object detection/segmentation pipelines
|
||||
- Processing medical, satellite, or domain-specific images
|
||||
|
||||
**Key features:**
|
||||
- **Zero-shot segmentation**: Works on any image domain without fine-tuning
|
||||
- **Flexible prompts**: Points, bounding boxes, or previous masks
|
||||
- **Automatic segmentation**: Generate all object masks automatically
|
||||
- **High quality**: Trained on 1.1 billion masks from 11 million images
|
||||
- **Multiple model sizes**: ViT-B (fastest), ViT-L, ViT-H (most accurate)
|
||||
- **ONNX export**: Deploy in browsers and edge devices
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **YOLO/Detectron2**: For real-time object detection with classes
|
||||
- **Mask2Former**: For semantic/panoptic segmentation with categories
|
||||
- **GroundingDINO + SAM**: For text-prompted segmentation
|
||||
- **SAM 2**: For video segmentation tasks
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# From GitHub
|
||||
pip install git+https://github.com/facebookresearch/segment-anything.git
|
||||
|
||||
# Optional dependencies
|
||||
pip install opencv-python pycocotools matplotlib
|
||||
|
||||
# Or use HuggingFace transformers
|
||||
pip install transformers
|
||||
```
|
||||
|
||||
### Download checkpoints
|
||||
|
||||
```bash
|
||||
# ViT-H (largest, most accurate) - 2.4GB
|
||||
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
|
||||
|
||||
# ViT-L (medium) - 1.2GB
|
||||
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth
|
||||
|
||||
# ViT-B (smallest, fastest) - 375MB
|
||||
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
|
||||
```
|
||||
|
||||
### Basic usage with SamPredictor
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
from segment_anything import sam_model_registry, SamPredictor
|
||||
|
||||
# Load model
|
||||
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
|
||||
sam.to(device="cuda")
|
||||
|
||||
# Create predictor
|
||||
predictor = SamPredictor(sam)
|
||||
|
||||
# Set image (computes embeddings once)
|
||||
image = cv2.imread("image.jpg")
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
predictor.set_image(image)
|
||||
|
||||
# Predict with point prompts
|
||||
input_point = np.array([[500, 375]]) # (x, y) coordinates
|
||||
input_label = np.array([1]) # 1 = foreground, 0 = background
|
||||
|
||||
masks, scores, logits = predictor.predict(
|
||||
point_coords=input_point,
|
||||
point_labels=input_label,
|
||||
multimask_output=True # Returns 3 mask options
|
||||
)
|
||||
|
||||
# Select best mask
|
||||
best_mask = masks[np.argmax(scores)]
|
||||
```
|
||||
|
||||
### HuggingFace Transformers
|
||||
|
||||
```python
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import SamModel, SamProcessor
|
||||
|
||||
# Load model and processor
|
||||
model = SamModel.from_pretrained("facebook/sam-vit-huge")
|
||||
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
||||
model.to("cuda")
|
||||
|
||||
# Process image with point prompt
|
||||
image = Image.open("image.jpg")
|
||||
input_points = [[[450, 600]]] # Batch of points
|
||||
|
||||
inputs = processor(image, input_points=input_points, return_tensors="pt")
|
||||
inputs = {k: v.to("cuda") for k, v in inputs.items()}
|
||||
|
||||
# Generate masks
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# Post-process masks to original size
|
||||
masks = processor.image_processor.post_process_masks(
|
||||
outputs.pred_masks.cpu(),
|
||||
inputs["original_sizes"].cpu(),
|
||||
inputs["reshaped_input_sizes"].cpu()
|
||||
)
|
||||
```
|
||||
|
||||
## Core concepts
|
||||
|
||||
### Model architecture
|
||||
|
||||
<!-- ascii-guard-ignore -->
|
||||
```
|
||||
SAM Architecture:
|
||||
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
|
||||
│ Image Encoder │────▶│ Prompt Encoder │────▶│ Mask Decoder │
|
||||
│ (ViT) │ │ (Points/Boxes) │ │ (Transformer) │
|
||||
└─────────────────┘ └─────────────────┘ └─────────────────┘
|
||||
│ │ │
|
||||
Image Embeddings Prompt Embeddings Masks + IoU
|
||||
(computed once) (per prompt) predictions
|
||||
```
|
||||
<!-- ascii-guard-ignore-end -->
|
||||
|
||||
### Model variants
|
||||
|
||||
| Model | Checkpoint | Size | Speed | Accuracy |
|
||||
|-------|------------|------|-------|----------|
|
||||
| ViT-H | `vit_h` | 2.4 GB | Slowest | Best |
|
||||
| ViT-L | `vit_l` | 1.2 GB | Medium | Good |
|
||||
| ViT-B | `vit_b` | 375 MB | Fastest | Good |
|
||||
|
||||
### Prompt types
|
||||
|
||||
| Prompt | Description | Use Case |
|
||||
|--------|-------------|----------|
|
||||
| Point (foreground) | Click on object | Single object selection |
|
||||
| Point (background) | Click outside object | Exclude regions |
|
||||
| Bounding box | Rectangle around object | Larger objects |
|
||||
| Previous mask | Low-res mask input | Iterative refinement |
|
||||
|
||||
## Interactive segmentation
|
||||
|
||||
### Point prompts
|
||||
|
||||
```python
|
||||
# Single foreground point
|
||||
input_point = np.array([[500, 375]])
|
||||
input_label = np.array([1])
|
||||
|
||||
masks, scores, logits = predictor.predict(
|
||||
point_coords=input_point,
|
||||
point_labels=input_label,
|
||||
multimask_output=True
|
||||
)
|
||||
|
||||
# Multiple points (foreground + background)
|
||||
input_points = np.array([[500, 375], [600, 400], [450, 300]])
|
||||
input_labels = np.array([1, 1, 0]) # 2 foreground, 1 background
|
||||
|
||||
masks, scores, logits = predictor.predict(
|
||||
point_coords=input_points,
|
||||
point_labels=input_labels,
|
||||
multimask_output=False # Single mask when prompts are clear
|
||||
)
|
||||
```
|
||||
|
||||
### Box prompts
|
||||
|
||||
```python
|
||||
# Bounding box [x1, y1, x2, y2]
|
||||
input_box = np.array([425, 600, 700, 875])
|
||||
|
||||
masks, scores, logits = predictor.predict(
|
||||
box=input_box,
|
||||
multimask_output=False
|
||||
)
|
||||
```
|
||||
|
||||
### Combined prompts
|
||||
|
||||
```python
|
||||
# Box + points for precise control
|
||||
masks, scores, logits = predictor.predict(
|
||||
point_coords=np.array([[500, 375]]),
|
||||
point_labels=np.array([1]),
|
||||
box=np.array([400, 300, 700, 600]),
|
||||
multimask_output=False
|
||||
)
|
||||
```
|
||||
|
||||
### Iterative refinement
|
||||
|
||||
```python
|
||||
# Initial prediction
|
||||
masks, scores, logits = predictor.predict(
|
||||
point_coords=np.array([[500, 375]]),
|
||||
point_labels=np.array([1]),
|
||||
multimask_output=True
|
||||
)
|
||||
|
||||
# Refine with additional point using previous mask
|
||||
masks, scores, logits = predictor.predict(
|
||||
point_coords=np.array([[500, 375], [550, 400]]),
|
||||
point_labels=np.array([1, 0]), # Add background point
|
||||
mask_input=logits[np.argmax(scores)][None, :, :], # Use best mask
|
||||
multimask_output=False
|
||||
)
|
||||
```
|
||||
|
||||
## Automatic mask generation
|
||||
|
||||
### Basic automatic segmentation
|
||||
|
||||
```python
|
||||
from segment_anything import SamAutomaticMaskGenerator
|
||||
|
||||
# Create generator
|
||||
mask_generator = SamAutomaticMaskGenerator(sam)
|
||||
|
||||
# Generate all masks
|
||||
masks = mask_generator.generate(image)
|
||||
|
||||
# Each mask contains:
|
||||
# - segmentation: binary mask
|
||||
# - bbox: [x, y, w, h]
|
||||
# - area: pixel count
|
||||
# - predicted_iou: quality score
|
||||
# - stability_score: robustness score
|
||||
# - point_coords: generating point
|
||||
```
|
||||
|
||||
### Customized generation
|
||||
|
||||
```python
|
||||
mask_generator = SamAutomaticMaskGenerator(
|
||||
model=sam,
|
||||
points_per_side=32, # Grid density (more = more masks)
|
||||
pred_iou_thresh=0.88, # Quality threshold
|
||||
stability_score_thresh=0.95, # Stability threshold
|
||||
crop_n_layers=1, # Multi-scale crops
|
||||
crop_n_points_downscale_factor=2,
|
||||
min_mask_region_area=100, # Remove tiny masks
|
||||
)
|
||||
|
||||
masks = mask_generator.generate(image)
|
||||
```
|
||||
|
||||
### Filtering masks
|
||||
|
||||
```python
|
||||
# Sort by area (largest first)
|
||||
masks = sorted(masks, key=lambda x: x['area'], reverse=True)
|
||||
|
||||
# Filter by predicted IoU
|
||||
high_quality = [m for m in masks if m['predicted_iou'] > 0.9]
|
||||
|
||||
# Filter by stability score
|
||||
stable_masks = [m for m in masks if m['stability_score'] > 0.95]
|
||||
```
|
||||
|
||||
## Batched inference
|
||||
|
||||
### Multiple images
|
||||
|
||||
```python
|
||||
# Process multiple images efficiently
|
||||
images = [cv2.imread(f"image_{i}.jpg") for i in range(10)]
|
||||
|
||||
all_masks = []
|
||||
for image in images:
|
||||
predictor.set_image(image)
|
||||
masks, _, _ = predictor.predict(
|
||||
point_coords=np.array([[500, 375]]),
|
||||
point_labels=np.array([1]),
|
||||
multimask_output=True
|
||||
)
|
||||
all_masks.append(masks)
|
||||
```
|
||||
|
||||
### Multiple prompts per image
|
||||
|
||||
```python
|
||||
# Process multiple prompts efficiently (one image encoding)
|
||||
predictor.set_image(image)
|
||||
|
||||
# Batch of point prompts
|
||||
points = [
|
||||
np.array([[100, 100]]),
|
||||
np.array([[200, 200]]),
|
||||
np.array([[300, 300]])
|
||||
]
|
||||
|
||||
all_masks = []
|
||||
for point in points:
|
||||
masks, scores, _ = predictor.predict(
|
||||
point_coords=point,
|
||||
point_labels=np.array([1]),
|
||||
multimask_output=True
|
||||
)
|
||||
all_masks.append(masks[np.argmax(scores)])
|
||||
```
|
||||
|
||||
## ONNX deployment
|
||||
|
||||
### Export model
|
||||
|
||||
```bash
|
||||
python scripts/export_onnx_model.py \
|
||||
--checkpoint sam_vit_h_4b8939.pth \
|
||||
--model-type vit_h \
|
||||
--output sam_onnx.onnx \
|
||||
--return-single-mask
|
||||
```
|
||||
|
||||
### Use ONNX model
|
||||
|
||||
```python
|
||||
import onnxruntime
|
||||
|
||||
# Load ONNX model
|
||||
ort_session = onnxruntime.InferenceSession("sam_onnx.onnx")
|
||||
|
||||
# Run inference (image embeddings computed separately)
|
||||
masks = ort_session.run(
|
||||
None,
|
||||
{
|
||||
"image_embeddings": image_embeddings,
|
||||
"point_coords": point_coords,
|
||||
"point_labels": point_labels,
|
||||
"mask_input": np.zeros((1, 1, 256, 256), dtype=np.float32),
|
||||
"has_mask_input": np.array([0], dtype=np.float32),
|
||||
"orig_im_size": np.array([h, w], dtype=np.float32)
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Annotation tool
|
||||
|
||||
```python
|
||||
import cv2
|
||||
|
||||
# Load model
|
||||
predictor = SamPredictor(sam)
|
||||
predictor.set_image(image)
|
||||
|
||||
def on_click(event, x, y, flags, param):
|
||||
if event == cv2.EVENT_LBUTTONDOWN:
|
||||
# Foreground point
|
||||
masks, scores, _ = predictor.predict(
|
||||
point_coords=np.array([[x, y]]),
|
||||
point_labels=np.array([1]),
|
||||
multimask_output=True
|
||||
)
|
||||
# Display best mask
|
||||
display_mask(masks[np.argmax(scores)])
|
||||
```
|
||||
|
||||
### Workflow 2: Object extraction
|
||||
|
||||
```python
|
||||
def extract_object(image, point):
|
||||
"""Extract object at point with transparent background."""
|
||||
predictor.set_image(image)
|
||||
|
||||
masks, scores, _ = predictor.predict(
|
||||
point_coords=np.array([point]),
|
||||
point_labels=np.array([1]),
|
||||
multimask_output=True
|
||||
)
|
||||
|
||||
best_mask = masks[np.argmax(scores)]
|
||||
|
||||
# Create RGBA output
|
||||
rgba = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
|
||||
rgba[:, :, :3] = image
|
||||
rgba[:, :, 3] = best_mask * 255
|
||||
|
||||
return rgba
|
||||
```
|
||||
|
||||
### Workflow 3: Medical image segmentation
|
||||
|
||||
```python
|
||||
# Process medical images (grayscale to RGB)
|
||||
medical_image = cv2.imread("scan.png", cv2.IMREAD_GRAYSCALE)
|
||||
rgb_image = cv2.cvtColor(medical_image, cv2.COLOR_GRAY2RGB)
|
||||
|
||||
predictor.set_image(rgb_image)
|
||||
|
||||
# Segment region of interest
|
||||
masks, scores, _ = predictor.predict(
|
||||
box=np.array([x1, y1, x2, y2]), # ROI bounding box
|
||||
multimask_output=True
|
||||
)
|
||||
```
|
||||
|
||||
## Output format
|
||||
|
||||
### Mask data structure
|
||||
|
||||
```python
|
||||
# SamAutomaticMaskGenerator output
|
||||
{
|
||||
"segmentation": np.ndarray, # H×W binary mask
|
||||
"bbox": [x, y, w, h], # Bounding box
|
||||
"area": int, # Pixel count
|
||||
"predicted_iou": float, # 0-1 quality score
|
||||
"stability_score": float, # 0-1 robustness score
|
||||
"crop_box": [x, y, w, h], # Generation crop region
|
||||
"point_coords": [[x, y]], # Input point
|
||||
}
|
||||
```
|
||||
|
||||
### COCO RLE format
|
||||
|
||||
```python
|
||||
from pycocotools import mask as mask_utils
|
||||
|
||||
# Encode mask to RLE
|
||||
rle = mask_utils.encode(np.asfortranarray(mask.astype(np.uint8)))
|
||||
rle["counts"] = rle["counts"].decode("utf-8")
|
||||
|
||||
# Decode RLE to mask
|
||||
decoded_mask = mask_utils.decode(rle)
|
||||
```
|
||||
|
||||
## Performance optimization
|
||||
|
||||
### GPU memory
|
||||
|
||||
```python
|
||||
# Use smaller model for limited VRAM
|
||||
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
|
||||
|
||||
# Process images in batches
|
||||
# Clear CUDA cache between large batches
|
||||
torch.cuda.empty_cache()
|
||||
```
|
||||
|
||||
### Speed optimization
|
||||
|
||||
```python
|
||||
# Use half precision
|
||||
sam = sam.half()
|
||||
|
||||
# Reduce points for automatic generation
|
||||
mask_generator = SamAutomaticMaskGenerator(
|
||||
model=sam,
|
||||
points_per_side=16, # Default is 32
|
||||
)
|
||||
|
||||
# Use ONNX for deployment
|
||||
# Export with --return-single-mask for faster inference
|
||||
```
|
||||
|
||||
## Common issues
|
||||
|
||||
| Issue | Solution |
|
||||
|-------|----------|
|
||||
| Out of memory | Use ViT-B model, reduce image size |
|
||||
| Slow inference | Use ViT-B, reduce points_per_side |
|
||||
| Poor mask quality | Try different prompts, use box + points |
|
||||
| Edge artifacts | Use stability_score filtering |
|
||||
| Small objects missed | Increase points_per_side |
|
||||
|
||||
## References
|
||||
|
||||
- **[Advanced Usage](references/advanced-usage.md)** - Batching, fine-tuning, integration
|
||||
- **[Troubleshooting](references/troubleshooting.md)** - Common issues and solutions
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/facebookresearch/segment-anything
|
||||
- **Paper**: https://arxiv.org/abs/2304.02643
|
||||
- **Demo**: https://segment-anything.com
|
||||
- **SAM 2 (Video)**: https://github.com/facebookresearch/segment-anything-2
|
||||
- **HuggingFace**: https://huggingface.co/facebook/sam-vit-huge
|
||||
@@ -0,0 +1,589 @@
|
||||
# Segment Anything Advanced Usage Guide
|
||||
|
||||
## SAM 2 (Video Segmentation)
|
||||
|
||||
### Overview
|
||||
|
||||
SAM 2 extends SAM to video segmentation with streaming memory architecture:
|
||||
|
||||
```bash
|
||||
pip install git+https://github.com/facebookresearch/segment-anything-2.git
|
||||
```
|
||||
|
||||
### Video segmentation
|
||||
|
||||
```python
|
||||
from sam2.build_sam import build_sam2_video_predictor
|
||||
|
||||
predictor = build_sam2_video_predictor("sam2_hiera_l.yaml", "sam2_hiera_large.pt")
|
||||
|
||||
# Initialize with video
|
||||
predictor.init_state(video_path="video.mp4")
|
||||
|
||||
# Add prompt on first frame
|
||||
predictor.add_new_points(
|
||||
frame_idx=0,
|
||||
obj_id=1,
|
||||
points=[[100, 200]],
|
||||
labels=[1]
|
||||
)
|
||||
|
||||
# Propagate through video
|
||||
for frame_idx, masks in predictor.propagate_in_video():
|
||||
# masks contains segmentation for all tracked objects
|
||||
process_frame(frame_idx, masks)
|
||||
```
|
||||
|
||||
### SAM 2 vs SAM comparison
|
||||
|
||||
| Feature | SAM | SAM 2 |
|
||||
|---------|-----|-------|
|
||||
| Input | Images only | Images + Videos |
|
||||
| Architecture | ViT + Decoder | Hiera + Memory |
|
||||
| Memory | Per-image | Streaming memory bank |
|
||||
| Tracking | No | Yes, across frames |
|
||||
| Models | ViT-B/L/H | Hiera-T/S/B+/L |
|
||||
|
||||
## Grounded SAM (Text-Prompted Segmentation)
|
||||
|
||||
### Setup
|
||||
|
||||
```bash
|
||||
pip install groundingdino-py
|
||||
pip install git+https://github.com/facebookresearch/segment-anything.git
|
||||
```
|
||||
|
||||
### Text-to-mask pipeline
|
||||
|
||||
```python
|
||||
from groundingdino.util.inference import load_model, predict
|
||||
from segment_anything import sam_model_registry, SamPredictor
|
||||
import cv2
|
||||
|
||||
# Load Grounding DINO
|
||||
grounding_model = load_model("groundingdino_swint_ogc.pth", "GroundingDINO_SwinT_OGC.py")
|
||||
|
||||
# Load SAM
|
||||
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
|
||||
predictor = SamPredictor(sam)
|
||||
|
||||
def text_to_mask(image, text_prompt, box_threshold=0.3, text_threshold=0.25):
|
||||
"""Generate masks from text description."""
|
||||
# Get bounding boxes from text
|
||||
boxes, logits, phrases = predict(
|
||||
model=grounding_model,
|
||||
image=image,
|
||||
caption=text_prompt,
|
||||
box_threshold=box_threshold,
|
||||
text_threshold=text_threshold
|
||||
)
|
||||
|
||||
# Generate masks with SAM
|
||||
predictor.set_image(image)
|
||||
|
||||
masks = []
|
||||
for box in boxes:
|
||||
# Convert normalized box to pixel coordinates
|
||||
h, w = image.shape[:2]
|
||||
box_pixels = box * np.array([w, h, w, h])
|
||||
|
||||
mask, score, _ = predictor.predict(
|
||||
box=box_pixels,
|
||||
multimask_output=False
|
||||
)
|
||||
masks.append(mask[0])
|
||||
|
||||
return masks, boxes, phrases
|
||||
|
||||
# Usage
|
||||
image = cv2.imread("image.jpg")
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
masks, boxes, phrases = text_to_mask(image, "person . dog . car")
|
||||
```
|
||||
|
||||
## Batched Processing
|
||||
|
||||
### Efficient multi-image processing
|
||||
|
||||
```python
|
||||
import torch
|
||||
from segment_anything import SamPredictor, sam_model_registry
|
||||
|
||||
class BatchedSAM:
|
||||
def __init__(self, checkpoint, model_type="vit_h", device="cuda"):
|
||||
self.sam = sam_model_registry[model_type](checkpoint=checkpoint)
|
||||
self.sam.to(device)
|
||||
self.predictor = SamPredictor(self.sam)
|
||||
self.device = device
|
||||
|
||||
def process_batch(self, images, prompts):
|
||||
"""Process multiple images with corresponding prompts."""
|
||||
results = []
|
||||
|
||||
for image, prompt in zip(images, prompts):
|
||||
self.predictor.set_image(image)
|
||||
|
||||
if "point" in prompt:
|
||||
masks, scores, _ = self.predictor.predict(
|
||||
point_coords=prompt["point"],
|
||||
point_labels=prompt["label"],
|
||||
multimask_output=True
|
||||
)
|
||||
elif "box" in prompt:
|
||||
masks, scores, _ = self.predictor.predict(
|
||||
box=prompt["box"],
|
||||
multimask_output=False
|
||||
)
|
||||
|
||||
results.append({
|
||||
"masks": masks,
|
||||
"scores": scores,
|
||||
"best_mask": masks[np.argmax(scores)]
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
# Usage
|
||||
batch_sam = BatchedSAM("sam_vit_h_4b8939.pth")
|
||||
|
||||
images = [cv2.imread(f"image_{i}.jpg") for i in range(10)]
|
||||
prompts = [{"point": np.array([[100, 100]]), "label": np.array([1])} for _ in range(10)]
|
||||
|
||||
results = batch_sam.process_batch(images, prompts)
|
||||
```
|
||||
|
||||
### Parallel automatic mask generation
|
||||
|
||||
```python
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from segment_anything import SamAutomaticMaskGenerator
|
||||
|
||||
def generate_masks_parallel(images, num_workers=4):
|
||||
"""Generate masks for multiple images in parallel."""
|
||||
# Note: Each worker needs its own model instance
|
||||
def worker_init():
|
||||
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
|
||||
return SamAutomaticMaskGenerator(sam)
|
||||
|
||||
generators = [worker_init() for _ in range(num_workers)]
|
||||
|
||||
def process_image(args):
|
||||
idx, image = args
|
||||
generator = generators[idx % num_workers]
|
||||
return generator.generate(image)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
results = list(executor.map(process_image, enumerate(images)))
|
||||
|
||||
return results
|
||||
```
|
||||
|
||||
## Custom Integration
|
||||
|
||||
### FastAPI service
|
||||
|
||||
```python
|
||||
from fastapi import FastAPI, File, UploadFile
|
||||
from pydantic import BaseModel
|
||||
import numpy as np
|
||||
import cv2
|
||||
import io
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Load model once
|
||||
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
|
||||
sam.to("cuda")
|
||||
predictor = SamPredictor(sam)
|
||||
|
||||
class PointPrompt(BaseModel):
|
||||
x: int
|
||||
y: int
|
||||
label: int = 1
|
||||
|
||||
@app.post("/segment/point")
|
||||
async def segment_with_point(
|
||||
file: UploadFile = File(...),
|
||||
points: list[PointPrompt] = []
|
||||
):
|
||||
# Read image
|
||||
contents = await file.read()
|
||||
nparr = np.frombuffer(contents, np.uint8)
|
||||
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Set image
|
||||
predictor.set_image(image)
|
||||
|
||||
# Prepare prompts
|
||||
point_coords = np.array([[p.x, p.y] for p in points])
|
||||
point_labels = np.array([p.label for p in points])
|
||||
|
||||
# Generate masks
|
||||
masks, scores, _ = predictor.predict(
|
||||
point_coords=point_coords,
|
||||
point_labels=point_labels,
|
||||
multimask_output=True
|
||||
)
|
||||
|
||||
best_idx = np.argmax(scores)
|
||||
|
||||
return {
|
||||
"mask": masks[best_idx].tolist(),
|
||||
"score": float(scores[best_idx]),
|
||||
"all_scores": scores.tolist()
|
||||
}
|
||||
|
||||
@app.post("/segment/auto")
|
||||
async def segment_automatic(file: UploadFile = File(...)):
|
||||
contents = await file.read()
|
||||
nparr = np.frombuffer(contents, np.uint8)
|
||||
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
mask_generator = SamAutomaticMaskGenerator(sam)
|
||||
masks = mask_generator.generate(image)
|
||||
|
||||
return {
|
||||
"num_masks": len(masks),
|
||||
"masks": [
|
||||
{
|
||||
"bbox": m["bbox"],
|
||||
"area": m["area"],
|
||||
"predicted_iou": m["predicted_iou"],
|
||||
"stability_score": m["stability_score"]
|
||||
}
|
||||
for m in masks
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Gradio interface
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
|
||||
# Load model
|
||||
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
|
||||
predictor = SamPredictor(sam)
|
||||
|
||||
def segment_image(image, evt: gr.SelectData):
|
||||
"""Segment object at clicked point."""
|
||||
predictor.set_image(image)
|
||||
|
||||
point = np.array([[evt.index[0], evt.index[1]]])
|
||||
label = np.array([1])
|
||||
|
||||
masks, scores, _ = predictor.predict(
|
||||
point_coords=point,
|
||||
point_labels=label,
|
||||
multimask_output=True
|
||||
)
|
||||
|
||||
best_mask = masks[np.argmax(scores)]
|
||||
|
||||
# Overlay mask on image
|
||||
overlay = image.copy()
|
||||
overlay[best_mask] = overlay[best_mask] * 0.5 + np.array([255, 0, 0]) * 0.5
|
||||
|
||||
return overlay
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("# SAM Interactive Segmentation")
|
||||
gr.Markdown("Click on an object to segment it")
|
||||
|
||||
with gr.Row():
|
||||
input_image = gr.Image(label="Input Image", interactive=True)
|
||||
output_image = gr.Image(label="Segmented Image")
|
||||
|
||||
input_image.select(segment_image, inputs=[input_image], outputs=[output_image])
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Fine-Tuning SAM
|
||||
|
||||
### LoRA fine-tuning (experimental)
|
||||
|
||||
```python
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from transformers import SamModel
|
||||
|
||||
# Load model
|
||||
model = SamModel.from_pretrained("facebook/sam-vit-base")
|
||||
|
||||
# Configure LoRA
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules=["qkv"], # Attention layers
|
||||
lora_dropout=0.1,
|
||||
bias="none",
|
||||
)
|
||||
|
||||
# Apply LoRA
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
# Training loop (simplified)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
||||
|
||||
for batch in dataloader:
|
||||
outputs = model(
|
||||
pixel_values=batch["pixel_values"],
|
||||
input_points=batch["input_points"],
|
||||
input_labels=batch["input_labels"]
|
||||
)
|
||||
|
||||
# Custom loss (e.g., IoU loss with ground truth)
|
||||
loss = compute_loss(outputs.pred_masks, batch["gt_masks"])
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
### MedSAM (Medical imaging)
|
||||
|
||||
```python
|
||||
# MedSAM is a fine-tuned SAM for medical images
|
||||
# https://github.com/bowang-lab/MedSAM
|
||||
|
||||
from segment_anything import sam_model_registry, SamPredictor
|
||||
import torch
|
||||
|
||||
# Load MedSAM checkpoint
|
||||
medsam = sam_model_registry["vit_b"](checkpoint="medsam_vit_b.pth")
|
||||
medsam.to("cuda")
|
||||
|
||||
predictor = SamPredictor(medsam)
|
||||
|
||||
# Process medical image
|
||||
# Convert grayscale to RGB if needed
|
||||
medical_image = cv2.imread("ct_scan.png", cv2.IMREAD_GRAYSCALE)
|
||||
rgb_image = np.stack([medical_image] * 3, axis=-1)
|
||||
|
||||
predictor.set_image(rgb_image)
|
||||
|
||||
# Segment with box prompt (common for medical imaging)
|
||||
masks, scores, _ = predictor.predict(
|
||||
box=np.array([x1, y1, x2, y2]),
|
||||
multimask_output=False
|
||||
)
|
||||
```
|
||||
|
||||
## Advanced Mask Processing
|
||||
|
||||
### Mask refinement
|
||||
|
||||
```python
|
||||
import cv2
|
||||
from scipy import ndimage
|
||||
|
||||
def refine_mask(mask, kernel_size=5, iterations=2):
|
||||
"""Refine mask with morphological operations."""
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
|
||||
|
||||
# Close small holes
|
||||
closed = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel, iterations=iterations)
|
||||
|
||||
# Remove small noise
|
||||
opened = cv2.morphologyEx(closed, cv2.MORPH_OPEN, kernel, iterations=iterations)
|
||||
|
||||
return opened.astype(bool)
|
||||
|
||||
def fill_holes(mask):
|
||||
"""Fill holes in mask."""
|
||||
filled = ndimage.binary_fill_holes(mask)
|
||||
return filled
|
||||
|
||||
def remove_small_regions(mask, min_area=100):
|
||||
"""Remove small disconnected regions."""
|
||||
labeled, num_features = ndimage.label(mask)
|
||||
sizes = ndimage.sum(mask, labeled, range(1, num_features + 1))
|
||||
|
||||
# Keep only regions larger than min_area
|
||||
mask_clean = np.zeros_like(mask)
|
||||
for i, size in enumerate(sizes, 1):
|
||||
if size >= min_area:
|
||||
mask_clean[labeled == i] = True
|
||||
|
||||
return mask_clean
|
||||
```
|
||||
|
||||
### Mask to polygon conversion
|
||||
|
||||
```python
|
||||
import cv2
|
||||
|
||||
def mask_to_polygons(mask, epsilon_factor=0.01):
|
||||
"""Convert binary mask to polygon coordinates."""
|
||||
contours, _ = cv2.findContours(
|
||||
mask.astype(np.uint8),
|
||||
cv2.RETR_EXTERNAL,
|
||||
cv2.CHAIN_APPROX_SIMPLE
|
||||
)
|
||||
|
||||
polygons = []
|
||||
for contour in contours:
|
||||
epsilon = epsilon_factor * cv2.arcLength(contour, True)
|
||||
approx = cv2.approxPolyDP(contour, epsilon, True)
|
||||
polygon = approx.squeeze().tolist()
|
||||
if len(polygon) >= 3: # Valid polygon
|
||||
polygons.append(polygon)
|
||||
|
||||
return polygons
|
||||
|
||||
def polygons_to_mask(polygons, height, width):
|
||||
"""Convert polygons back to binary mask."""
|
||||
mask = np.zeros((height, width), dtype=np.uint8)
|
||||
for polygon in polygons:
|
||||
pts = np.array(polygon, dtype=np.int32)
|
||||
cv2.fillPoly(mask, [pts], 1)
|
||||
return mask.astype(bool)
|
||||
```
|
||||
|
||||
### Multi-scale segmentation
|
||||
|
||||
```python
|
||||
def multiscale_segment(image, predictor, point, scales=[0.5, 1.0, 2.0]):
|
||||
"""Generate masks at multiple scales and combine."""
|
||||
h, w = image.shape[:2]
|
||||
masks_all = []
|
||||
|
||||
for scale in scales:
|
||||
# Resize image
|
||||
new_h, new_w = int(h * scale), int(w * scale)
|
||||
scaled_image = cv2.resize(image, (new_w, new_h))
|
||||
scaled_point = (point * scale).astype(int)
|
||||
|
||||
# Segment
|
||||
predictor.set_image(scaled_image)
|
||||
masks, scores, _ = predictor.predict(
|
||||
point_coords=scaled_point.reshape(1, 2),
|
||||
point_labels=np.array([1]),
|
||||
multimask_output=True
|
||||
)
|
||||
|
||||
# Resize mask back
|
||||
best_mask = masks[np.argmax(scores)]
|
||||
original_mask = cv2.resize(best_mask.astype(np.uint8), (w, h)) > 0.5
|
||||
|
||||
masks_all.append(original_mask)
|
||||
|
||||
# Combine masks (majority voting)
|
||||
combined = np.stack(masks_all, axis=0)
|
||||
final_mask = np.sum(combined, axis=0) >= len(scales) // 2 + 1
|
||||
|
||||
return final_mask
|
||||
```
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### TensorRT acceleration
|
||||
|
||||
```python
|
||||
import tensorrt as trt
|
||||
import pycuda.driver as cuda
|
||||
import pycuda.autoinit
|
||||
|
||||
def export_to_tensorrt(onnx_path, engine_path, fp16=True):
|
||||
"""Convert ONNX model to TensorRT engine."""
|
||||
logger = trt.Logger(trt.Logger.WARNING)
|
||||
builder = trt.Builder(logger)
|
||||
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
|
||||
parser = trt.OnnxParser(network, logger)
|
||||
|
||||
with open(onnx_path, 'rb') as f:
|
||||
if not parser.parse(f.read()):
|
||||
for error in range(parser.num_errors):
|
||||
print(parser.get_error(error))
|
||||
return None
|
||||
|
||||
config = builder.create_builder_config()
|
||||
config.max_workspace_size = 1 << 30 # 1GB
|
||||
|
||||
if fp16:
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
|
||||
engine = builder.build_engine(network, config)
|
||||
|
||||
with open(engine_path, 'wb') as f:
|
||||
f.write(engine.serialize())
|
||||
|
||||
return engine
|
||||
```
|
||||
|
||||
### Memory-efficient inference
|
||||
|
||||
```python
|
||||
class MemoryEfficientSAM:
|
||||
def __init__(self, checkpoint, model_type="vit_b"):
|
||||
self.sam = sam_model_registry[model_type](checkpoint=checkpoint)
|
||||
self.sam.eval()
|
||||
self.predictor = None
|
||||
|
||||
def __enter__(self):
|
||||
self.sam.to("cuda")
|
||||
self.predictor = SamPredictor(self.sam)
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.sam.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def segment(self, image, points, labels):
|
||||
self.predictor.set_image(image)
|
||||
masks, scores, _ = self.predictor.predict(
|
||||
point_coords=points,
|
||||
point_labels=labels,
|
||||
multimask_output=True
|
||||
)
|
||||
return masks, scores
|
||||
|
||||
# Usage with context manager (auto-cleanup)
|
||||
with MemoryEfficientSAM("sam_vit_b_01ec64.pth") as sam:
|
||||
masks, scores = sam.segment(image, points, labels)
|
||||
# CUDA memory freed automatically
|
||||
```
|
||||
|
||||
## Dataset Generation
|
||||
|
||||
### Create segmentation dataset
|
||||
|
||||
```python
|
||||
import json
|
||||
|
||||
def generate_dataset(images_dir, output_dir, mask_generator):
|
||||
"""Generate segmentation dataset from images."""
|
||||
annotations = []
|
||||
|
||||
for img_path in Path(images_dir).glob("*.jpg"):
|
||||
image = cv2.imread(str(img_path))
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Generate masks
|
||||
masks = mask_generator.generate(image)
|
||||
|
||||
# Filter high-quality masks
|
||||
good_masks = [m for m in masks if m["predicted_iou"] > 0.9]
|
||||
|
||||
# Save annotations
|
||||
for i, mask_data in enumerate(good_masks):
|
||||
annotation = {
|
||||
"image_id": img_path.stem,
|
||||
"mask_id": i,
|
||||
"bbox": mask_data["bbox"],
|
||||
"area": mask_data["area"],
|
||||
"segmentation": mask_to_rle(mask_data["segmentation"]),
|
||||
"predicted_iou": mask_data["predicted_iou"],
|
||||
"stability_score": mask_data["stability_score"]
|
||||
}
|
||||
annotations.append(annotation)
|
||||
|
||||
# Save dataset
|
||||
with open(output_dir / "annotations.json", "w") as f:
|
||||
json.dump(annotations, f)
|
||||
|
||||
return annotations
|
||||
```
|
||||
@@ -0,0 +1,484 @@
|
||||
# Segment Anything Troubleshooting Guide
|
||||
|
||||
## Installation Issues
|
||||
|
||||
### CUDA not available
|
||||
|
||||
**Error**: `RuntimeError: CUDA not available`
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Check CUDA availability
|
||||
import torch
|
||||
print(torch.cuda.is_available())
|
||||
print(torch.version.cuda)
|
||||
|
||||
# Install PyTorch with CUDA
|
||||
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
|
||||
|
||||
# If CUDA works but SAM doesn't use it
|
||||
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
|
||||
sam.to("cuda") # Explicitly move to GPU
|
||||
```
|
||||
|
||||
### Import errors
|
||||
|
||||
**Error**: `ModuleNotFoundError: No module named 'segment_anything'`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Install from GitHub
|
||||
pip install git+https://github.com/facebookresearch/segment-anything.git
|
||||
|
||||
# Or clone and install
|
||||
git clone https://github.com/facebookresearch/segment-anything.git
|
||||
cd segment-anything
|
||||
pip install -e .
|
||||
|
||||
# Verify installation
|
||||
python -c "from segment_anything import sam_model_registry; print('OK')"
|
||||
```
|
||||
|
||||
### Missing dependencies
|
||||
|
||||
**Error**: `ModuleNotFoundError: No module named 'cv2'` or similar
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Install all optional dependencies
|
||||
pip install opencv-python pycocotools matplotlib onnxruntime onnx
|
||||
|
||||
# For pycocotools on Windows
|
||||
pip install pycocotools-windows
|
||||
```
|
||||
|
||||
## Model Loading Issues
|
||||
|
||||
### Checkpoint not found
|
||||
|
||||
**Error**: `FileNotFoundError: checkpoint file not found`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Download correct checkpoint
|
||||
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
|
||||
|
||||
# Verify file integrity
|
||||
md5sum sam_vit_h_4b8939.pth
|
||||
# Expected: a7bf3b02f3ebf1267aba913ff637d9a2
|
||||
|
||||
# Use absolute path
|
||||
sam = sam_model_registry["vit_h"](checkpoint="/full/path/to/sam_vit_h_4b8939.pth")
|
||||
```
|
||||
|
||||
### Model type mismatch
|
||||
|
||||
**Error**: `KeyError: 'unexpected key in state_dict'`
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Ensure model type matches checkpoint
|
||||
# vit_h checkpoint → vit_h model
|
||||
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
|
||||
|
||||
# vit_l checkpoint → vit_l model
|
||||
sam = sam_model_registry["vit_l"](checkpoint="sam_vit_l_0b3195.pth")
|
||||
|
||||
# vit_b checkpoint → vit_b model
|
||||
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
|
||||
```
|
||||
|
||||
### Out of memory during load
|
||||
|
||||
**Error**: `CUDA out of memory` during model loading
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use smaller model
|
||||
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
|
||||
|
||||
# Load to CPU first, then move
|
||||
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
|
||||
sam.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
sam.to("cuda")
|
||||
|
||||
# Use half precision
|
||||
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
|
||||
sam = sam.half()
|
||||
sam.to("cuda")
|
||||
```
|
||||
|
||||
## Inference Issues
|
||||
|
||||
### Image format errors
|
||||
|
||||
**Error**: `ValueError: expected input to have 3 channels`
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import cv2
|
||||
|
||||
# Ensure RGB format
|
||||
image = cv2.imread("image.jpg")
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # BGR to RGB
|
||||
|
||||
# Convert grayscale to RGB
|
||||
if len(image.shape) == 2:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
||||
|
||||
# Handle RGBA
|
||||
if image.shape[2] == 4:
|
||||
image = image[:, :, :3] # Drop alpha channel
|
||||
```
|
||||
|
||||
### Coordinate errors
|
||||
|
||||
**Error**: `IndexError: index out of bounds` or incorrect mask location
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Ensure points are (x, y) not (row, col)
|
||||
# x = column index, y = row index
|
||||
point = np.array([[x, y]]) # Correct
|
||||
|
||||
# Verify coordinates are within image bounds
|
||||
h, w = image.shape[:2]
|
||||
assert 0 <= x < w and 0 <= y < h, "Point outside image"
|
||||
|
||||
# For bounding boxes: [x1, y1, x2, y2]
|
||||
box = np.array([x1, y1, x2, y2])
|
||||
assert x1 < x2 and y1 < y2, "Invalid box coordinates"
|
||||
```
|
||||
|
||||
### Empty or incorrect masks
|
||||
|
||||
**Problem**: Masks don't match expected object
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Try multiple prompts
|
||||
input_points = np.array([[x1, y1], [x2, y2]])
|
||||
input_labels = np.array([1, 1]) # Multiple foreground points
|
||||
|
||||
# Add background points
|
||||
input_points = np.array([[obj_x, obj_y], [bg_x, bg_y]])
|
||||
input_labels = np.array([1, 0]) # 1=foreground, 0=background
|
||||
|
||||
# Use box prompt for large objects
|
||||
box = np.array([x1, y1, x2, y2])
|
||||
masks, scores, _ = predictor.predict(box=box, multimask_output=False)
|
||||
|
||||
# Combine box and point
|
||||
masks, scores, _ = predictor.predict(
|
||||
point_coords=np.array([[center_x, center_y]]),
|
||||
point_labels=np.array([1]),
|
||||
box=np.array([x1, y1, x2, y2]),
|
||||
multimask_output=True
|
||||
)
|
||||
|
||||
# Check scores and select best
|
||||
print(f"Scores: {scores}")
|
||||
best_mask = masks[np.argmax(scores)]
|
||||
```
|
||||
|
||||
### Slow inference
|
||||
|
||||
**Problem**: Prediction takes too long
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use smaller model
|
||||
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
|
||||
|
||||
# Reuse image embeddings
|
||||
predictor.set_image(image) # Compute once
|
||||
for point in points:
|
||||
masks, _, _ = predictor.predict(...) # Fast, reuses embeddings
|
||||
|
||||
# Reduce automatic generation points
|
||||
mask_generator = SamAutomaticMaskGenerator(
|
||||
model=sam,
|
||||
points_per_side=16, # Default is 32
|
||||
)
|
||||
|
||||
# Use ONNX for deployment
|
||||
# Export: python scripts/export_onnx_model.py --return-single-mask
|
||||
```
|
||||
|
||||
## Automatic Mask Generation Issues
|
||||
|
||||
### Too many masks
|
||||
|
||||
**Problem**: Generating thousands of overlapping masks
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
mask_generator = SamAutomaticMaskGenerator(
|
||||
model=sam,
|
||||
points_per_side=16, # Reduce from 32
|
||||
pred_iou_thresh=0.92, # Increase from 0.88
|
||||
stability_score_thresh=0.98, # Increase from 0.95
|
||||
box_nms_thresh=0.5, # More aggressive NMS
|
||||
min_mask_region_area=500, # Remove small masks
|
||||
)
|
||||
```
|
||||
|
||||
### Too few masks
|
||||
|
||||
**Problem**: Missing objects in automatic generation
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
mask_generator = SamAutomaticMaskGenerator(
|
||||
model=sam,
|
||||
points_per_side=64, # Increase density
|
||||
pred_iou_thresh=0.80, # Lower threshold
|
||||
stability_score_thresh=0.85, # Lower threshold
|
||||
crop_n_layers=2, # Add multi-scale
|
||||
min_mask_region_area=0, # Keep all masks
|
||||
)
|
||||
```
|
||||
|
||||
### Small objects missed
|
||||
|
||||
**Problem**: Automatic generation misses small objects
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use crop layers for multi-scale detection
|
||||
mask_generator = SamAutomaticMaskGenerator(
|
||||
model=sam,
|
||||
crop_n_layers=2,
|
||||
crop_n_points_downscale_factor=1, # Don't reduce points in crops
|
||||
min_mask_region_area=10, # Very small minimum
|
||||
)
|
||||
|
||||
# Or process image patches
|
||||
def segment_with_patches(image, patch_size=512, overlap=64):
|
||||
h, w = image.shape[:2]
|
||||
all_masks = []
|
||||
|
||||
for y in range(0, h, patch_size - overlap):
|
||||
for x in range(0, w, patch_size - overlap):
|
||||
patch = image[y:y+patch_size, x:x+patch_size]
|
||||
masks = mask_generator.generate(patch)
|
||||
|
||||
# Offset masks to original coordinates
|
||||
for m in masks:
|
||||
m['bbox'][0] += x
|
||||
m['bbox'][1] += y
|
||||
# Offset segmentation mask too
|
||||
|
||||
all_masks.extend(masks)
|
||||
|
||||
return all_masks
|
||||
```
|
||||
|
||||
## Memory Issues
|
||||
|
||||
### CUDA out of memory
|
||||
|
||||
**Error**: `torch.cuda.OutOfMemoryError: CUDA out of memory`
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use smaller model
|
||||
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
|
||||
|
||||
# Clear cache between images
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Process images sequentially, not batched
|
||||
for image in images:
|
||||
predictor.set_image(image)
|
||||
masks, _, _ = predictor.predict(...)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Reduce image size
|
||||
max_size = 1024
|
||||
h, w = image.shape[:2]
|
||||
if max(h, w) > max_size:
|
||||
scale = max_size / max(h, w)
|
||||
image = cv2.resize(image, (int(w*scale), int(h*scale)))
|
||||
|
||||
# Use CPU for large batch processing
|
||||
sam.to("cpu")
|
||||
```
|
||||
|
||||
### RAM out of memory
|
||||
|
||||
**Problem**: System runs out of RAM
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Process images one at a time
|
||||
for img_path in image_paths:
|
||||
image = cv2.imread(img_path)
|
||||
masks = process_image(image)
|
||||
save_results(masks)
|
||||
del image, masks
|
||||
gc.collect()
|
||||
|
||||
# Use generators instead of lists
|
||||
def generate_masks_lazy(image_paths):
|
||||
for path in image_paths:
|
||||
image = cv2.imread(path)
|
||||
masks = mask_generator.generate(image)
|
||||
yield path, masks
|
||||
```
|
||||
|
||||
## ONNX Export Issues
|
||||
|
||||
### Export fails
|
||||
|
||||
**Error**: Various export errors
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Install correct ONNX version
|
||||
pip install onnx==1.14.0 onnxruntime==1.15.0
|
||||
|
||||
# Use correct opset version
|
||||
python scripts/export_onnx_model.py \
|
||||
--checkpoint sam_vit_h_4b8939.pth \
|
||||
--model-type vit_h \
|
||||
--output sam.onnx \
|
||||
--opset 17
|
||||
```
|
||||
|
||||
### ONNX runtime errors
|
||||
|
||||
**Error**: `ONNXRuntimeError` during inference
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import onnxruntime
|
||||
|
||||
# Check available providers
|
||||
print(onnxruntime.get_available_providers())
|
||||
|
||||
# Use CPU provider if GPU fails
|
||||
session = onnxruntime.InferenceSession(
|
||||
"sam.onnx",
|
||||
providers=['CPUExecutionProvider']
|
||||
)
|
||||
|
||||
# Verify input shapes
|
||||
for input in session.get_inputs():
|
||||
print(f"{input.name}: {input.shape}")
|
||||
```
|
||||
|
||||
## HuggingFace Integration Issues
|
||||
|
||||
### Processor errors
|
||||
|
||||
**Error**: Issues with SamProcessor
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
from transformers import SamModel, SamProcessor
|
||||
|
||||
# Use matching processor and model
|
||||
model = SamModel.from_pretrained("facebook/sam-vit-huge")
|
||||
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
||||
|
||||
# Ensure input format
|
||||
input_points = [[[x, y]]] # Nested list for batch dimension
|
||||
inputs = processor(image, input_points=input_points, return_tensors="pt")
|
||||
|
||||
# Post-process correctly
|
||||
masks = processor.image_processor.post_process_masks(
|
||||
outputs.pred_masks.cpu(),
|
||||
inputs["original_sizes"].cpu(),
|
||||
inputs["reshaped_input_sizes"].cpu()
|
||||
)
|
||||
```
|
||||
|
||||
## Quality Issues
|
||||
|
||||
### Jagged mask edges
|
||||
|
||||
**Problem**: Masks have rough, pixelated edges
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import cv2
|
||||
from scipy import ndimage
|
||||
|
||||
def smooth_mask(mask, sigma=2):
|
||||
"""Smooth mask edges."""
|
||||
# Gaussian blur
|
||||
smooth = ndimage.gaussian_filter(mask.astype(float), sigma=sigma)
|
||||
return smooth > 0.5
|
||||
|
||||
def refine_edges(mask, kernel_size=5):
|
||||
"""Refine mask edges with morphological operations."""
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
|
||||
# Close small gaps
|
||||
closed = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel)
|
||||
# Open to remove noise
|
||||
opened = cv2.morphologyEx(closed, cv2.MORPH_OPEN, kernel)
|
||||
return opened.astype(bool)
|
||||
```
|
||||
|
||||
### Incomplete segmentation
|
||||
|
||||
**Problem**: Mask doesn't cover entire object
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Add multiple points
|
||||
input_points = np.array([
|
||||
[obj_center_x, obj_center_y],
|
||||
[obj_left_x, obj_center_y],
|
||||
[obj_right_x, obj_center_y],
|
||||
[obj_center_x, obj_top_y],
|
||||
[obj_center_x, obj_bottom_y]
|
||||
])
|
||||
input_labels = np.array([1, 1, 1, 1, 1])
|
||||
|
||||
# Use bounding box
|
||||
masks, _, _ = predictor.predict(
|
||||
box=np.array([x1, y1, x2, y2]),
|
||||
multimask_output=False
|
||||
)
|
||||
|
||||
# Iterative refinement
|
||||
mask_input = None
|
||||
for point in points:
|
||||
masks, scores, logits = predictor.predict(
|
||||
point_coords=point.reshape(1, 2),
|
||||
point_labels=np.array([1]),
|
||||
mask_input=mask_input,
|
||||
multimask_output=False
|
||||
)
|
||||
mask_input = logits
|
||||
```
|
||||
|
||||
## Common Error Messages
|
||||
|
||||
| Error | Cause | Solution |
|
||||
|-------|-------|----------|
|
||||
| `CUDA out of memory` | GPU memory full | Use smaller model, clear cache |
|
||||
| `expected 3 channels` | Wrong image format | Convert to RGB |
|
||||
| `index out of bounds` | Invalid coordinates | Check point/box bounds |
|
||||
| `checkpoint not found` | Wrong path | Use absolute path |
|
||||
| `unexpected key` | Model/checkpoint mismatch | Match model type |
|
||||
| `invalid box coordinates` | x1 > x2 or y1 > y2 | Fix box format |
|
||||
|
||||
## Getting Help
|
||||
|
||||
1. **GitHub Issues**: https://github.com/facebookresearch/segment-anything/issues
|
||||
2. **HuggingFace Forums**: https://discuss.huggingface.co
|
||||
3. **Paper**: https://arxiv.org/abs/2304.02643
|
||||
|
||||
### Reporting Issues
|
||||
|
||||
Include:
|
||||
- Python version
|
||||
- PyTorch version: `python -c "import torch; print(torch.__version__)"`
|
||||
- CUDA version: `python -c "import torch; print(torch.version.cuda)"`
|
||||
- SAM model type (vit_b/l/h)
|
||||
- Full error traceback
|
||||
- Minimal reproducible code
|
||||
@@ -0,0 +1,3 @@
|
||||
---
|
||||
description: ML research frameworks for building and optimizing AI systems with declarative programming.
|
||||
---
|
||||
@@ -0,0 +1,593 @@
|
||||
---
|
||||
name: dspy
|
||||
description: "DSPy: declarative LM programs, auto-optimize prompts, RAG."
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [dspy, openai, anthropic]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Prompt Engineering, DSPy, Declarative Programming, RAG, Agents, Prompt Optimization, LM Programming, Stanford NLP, Automatic Optimization, Modular AI]
|
||||
|
||||
---
|
||||
|
||||
# DSPy: Declarative Language Model Programming
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use DSPy when you need to:
|
||||
- **Build complex AI systems** with multiple components and workflows
|
||||
- **Program LMs declaratively** instead of manual prompt engineering
|
||||
- **Optimize prompts automatically** using data-driven methods
|
||||
- **Create modular AI pipelines** that are maintainable and portable
|
||||
- **Improve model outputs systematically** with optimizers
|
||||
- **Build RAG systems, agents, or classifiers** with better reliability
|
||||
|
||||
**GitHub Stars**: 22,000+ | **Created By**: Stanford NLP
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Stable release
|
||||
pip install dspy
|
||||
|
||||
# Latest development version
|
||||
pip install git+https://github.com/stanfordnlp/dspy.git
|
||||
|
||||
# With specific LM providers
|
||||
pip install dspy[openai] # OpenAI
|
||||
pip install dspy[anthropic] # Anthropic Claude
|
||||
pip install dspy[all] # All providers
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic Example: Question Answering
|
||||
|
||||
```python
|
||||
import dspy
|
||||
|
||||
# Configure your language model
|
||||
lm = dspy.Claude(model="claude-sonnet-4-5-20250929")
|
||||
dspy.settings.configure(lm=lm)
|
||||
|
||||
# Define a signature (input → output)
|
||||
class QA(dspy.Signature):
|
||||
"""Answer questions with short factual answers."""
|
||||
question = dspy.InputField()
|
||||
answer = dspy.OutputField(desc="often between 1 and 5 words")
|
||||
|
||||
# Create a module
|
||||
qa = dspy.Predict(QA)
|
||||
|
||||
# Use it
|
||||
response = qa(question="What is the capital of France?")
|
||||
print(response.answer) # "Paris"
|
||||
```
|
||||
|
||||
### Chain of Thought Reasoning
|
||||
|
||||
```python
|
||||
import dspy
|
||||
|
||||
lm = dspy.Claude(model="claude-sonnet-4-5-20250929")
|
||||
dspy.settings.configure(lm=lm)
|
||||
|
||||
# Use ChainOfThought for better reasoning
|
||||
class MathProblem(dspy.Signature):
|
||||
"""Solve math word problems."""
|
||||
problem = dspy.InputField()
|
||||
answer = dspy.OutputField(desc="numerical answer")
|
||||
|
||||
# ChainOfThought generates reasoning steps automatically
|
||||
cot = dspy.ChainOfThought(MathProblem)
|
||||
|
||||
response = cot(problem="If John has 5 apples and gives 2 to Mary, how many does he have?")
|
||||
print(response.rationale) # Shows reasoning steps
|
||||
print(response.answer) # "3"
|
||||
```
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### 1. Signatures
|
||||
|
||||
Signatures define the structure of your AI task (inputs → outputs):
|
||||
|
||||
```python
|
||||
# Inline signature (simple)
|
||||
qa = dspy.Predict("question -> answer")
|
||||
|
||||
# Class signature (detailed)
|
||||
class Summarize(dspy.Signature):
|
||||
"""Summarize text into key points."""
|
||||
text = dspy.InputField()
|
||||
summary = dspy.OutputField(desc="bullet points, 3-5 items")
|
||||
|
||||
summarizer = dspy.ChainOfThought(Summarize)
|
||||
```
|
||||
|
||||
**When to use each:**
|
||||
- **Inline**: Quick prototyping, simple tasks
|
||||
- **Class**: Complex tasks, type hints, better documentation
|
||||
|
||||
### 2. Modules
|
||||
|
||||
Modules are reusable components that transform inputs to outputs:
|
||||
|
||||
#### dspy.Predict
|
||||
Basic prediction module:
|
||||
|
||||
```python
|
||||
predictor = dspy.Predict("context, question -> answer")
|
||||
result = predictor(context="Paris is the capital of France",
|
||||
question="What is the capital?")
|
||||
```
|
||||
|
||||
#### dspy.ChainOfThought
|
||||
Generates reasoning steps before answering:
|
||||
|
||||
```python
|
||||
cot = dspy.ChainOfThought("question -> answer")
|
||||
result = cot(question="Why is the sky blue?")
|
||||
print(result.rationale) # Reasoning steps
|
||||
print(result.answer) # Final answer
|
||||
```
|
||||
|
||||
#### dspy.ReAct
|
||||
Agent-like reasoning with tools:
|
||||
|
||||
```python
|
||||
from dspy.predict import ReAct
|
||||
|
||||
class SearchQA(dspy.Signature):
|
||||
"""Answer questions using search."""
|
||||
question = dspy.InputField()
|
||||
answer = dspy.OutputField()
|
||||
|
||||
def search_tool(query: str) -> str:
|
||||
"""Search Wikipedia."""
|
||||
# Your search implementation
|
||||
return results
|
||||
|
||||
react = ReAct(SearchQA, tools=[search_tool])
|
||||
result = react(question="When was Python created?")
|
||||
```
|
||||
|
||||
#### dspy.ProgramOfThought
|
||||
Generates and executes code for reasoning:
|
||||
|
||||
```python
|
||||
pot = dspy.ProgramOfThought("question -> answer")
|
||||
result = pot(question="What is 15% of 240?")
|
||||
# Generates: answer = 240 * 0.15
|
||||
```
|
||||
|
||||
### 3. Optimizers
|
||||
|
||||
Optimizers improve your modules automatically using training data:
|
||||
|
||||
#### BootstrapFewShot
|
||||
Learns from examples:
|
||||
|
||||
```python
|
||||
from dspy.teleprompt import BootstrapFewShot
|
||||
|
||||
# Training data
|
||||
trainset = [
|
||||
dspy.Example(question="What is 2+2?", answer="4").with_inputs("question"),
|
||||
dspy.Example(question="What is 3+5?", answer="8").with_inputs("question"),
|
||||
]
|
||||
|
||||
# Define metric
|
||||
def validate_answer(example, pred, trace=None):
|
||||
return example.answer == pred.answer
|
||||
|
||||
# Optimize
|
||||
optimizer = BootstrapFewShot(metric=validate_answer, max_bootstrapped_demos=3)
|
||||
optimized_qa = optimizer.compile(qa, trainset=trainset)
|
||||
|
||||
# Now optimized_qa performs better!
|
||||
```
|
||||
|
||||
#### MIPRO (Most Important Prompt Optimization)
|
||||
Iteratively improves prompts:
|
||||
|
||||
```python
|
||||
from dspy.teleprompt import MIPRO
|
||||
|
||||
optimizer = MIPRO(
|
||||
metric=validate_answer,
|
||||
num_candidates=10,
|
||||
init_temperature=1.0
|
||||
)
|
||||
|
||||
optimized_cot = optimizer.compile(
|
||||
cot,
|
||||
trainset=trainset,
|
||||
num_trials=100
|
||||
)
|
||||
```
|
||||
|
||||
#### BootstrapFinetune
|
||||
Creates datasets for model fine-tuning:
|
||||
|
||||
```python
|
||||
from dspy.teleprompt import BootstrapFinetune
|
||||
|
||||
optimizer = BootstrapFinetune(metric=validate_answer)
|
||||
optimized_module = optimizer.compile(qa, trainset=trainset)
|
||||
|
||||
# Exports training data for fine-tuning
|
||||
```
|
||||
|
||||
### 4. Building Complex Systems
|
||||
|
||||
#### Multi-Stage Pipeline
|
||||
|
||||
```python
|
||||
import dspy
|
||||
|
||||
class MultiHopQA(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.retrieve = dspy.Retrieve(k=3)
|
||||
self.generate_query = dspy.ChainOfThought("question -> search_query")
|
||||
self.generate_answer = dspy.ChainOfThought("context, question -> answer")
|
||||
|
||||
def forward(self, question):
|
||||
# Stage 1: Generate search query
|
||||
search_query = self.generate_query(question=question).search_query
|
||||
|
||||
# Stage 2: Retrieve context
|
||||
passages = self.retrieve(search_query).passages
|
||||
context = "\n".join(passages)
|
||||
|
||||
# Stage 3: Generate answer
|
||||
answer = self.generate_answer(context=context, question=question).answer
|
||||
return dspy.Prediction(answer=answer, context=context)
|
||||
|
||||
# Use the pipeline
|
||||
qa_system = MultiHopQA()
|
||||
result = qa_system(question="Who wrote the book that inspired the movie Blade Runner?")
|
||||
```
|
||||
|
||||
#### RAG System with Optimization
|
||||
|
||||
```python
|
||||
import dspy
|
||||
from dspy.retrieve.chromadb_rm import ChromadbRM
|
||||
|
||||
# Configure retriever
|
||||
retriever = ChromadbRM(
|
||||
collection_name="documents",
|
||||
persist_directory="./chroma_db"
|
||||
)
|
||||
|
||||
class RAG(dspy.Module):
|
||||
def __init__(self, num_passages=3):
|
||||
super().__init__()
|
||||
self.retrieve = dspy.Retrieve(k=num_passages)
|
||||
self.generate = dspy.ChainOfThought("context, question -> answer")
|
||||
|
||||
def forward(self, question):
|
||||
context = self.retrieve(question).passages
|
||||
return self.generate(context=context, question=question)
|
||||
|
||||
# Create and optimize
|
||||
rag = RAG()
|
||||
|
||||
# Optimize with training data
|
||||
from dspy.teleprompt import BootstrapFewShot
|
||||
|
||||
optimizer = BootstrapFewShot(metric=validate_answer)
|
||||
optimized_rag = optimizer.compile(rag, trainset=trainset)
|
||||
```
|
||||
|
||||
## LM Provider Configuration
|
||||
|
||||
### Anthropic Claude
|
||||
|
||||
```python
|
||||
import dspy
|
||||
|
||||
lm = dspy.Claude(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
api_key="your-api-key", # Or set ANTHROPIC_API_KEY env var
|
||||
max_tokens=1000,
|
||||
temperature=0.7
|
||||
)
|
||||
dspy.settings.configure(lm=lm)
|
||||
```
|
||||
|
||||
### OpenAI
|
||||
|
||||
```python
|
||||
lm = dspy.OpenAI(
|
||||
model="gpt-4",
|
||||
api_key="your-api-key",
|
||||
max_tokens=1000
|
||||
)
|
||||
dspy.settings.configure(lm=lm)
|
||||
```
|
||||
|
||||
### Local Models (Ollama)
|
||||
|
||||
```python
|
||||
lm = dspy.OllamaLocal(
|
||||
model="llama3.1",
|
||||
base_url="http://localhost:11434"
|
||||
)
|
||||
dspy.settings.configure(lm=lm)
|
||||
```
|
||||
|
||||
### Multiple Models
|
||||
|
||||
```python
|
||||
# Different models for different tasks
|
||||
cheap_lm = dspy.OpenAI(model="gpt-3.5-turbo")
|
||||
strong_lm = dspy.Claude(model="claude-sonnet-4-5-20250929")
|
||||
|
||||
# Use cheap model for retrieval, strong model for reasoning
|
||||
with dspy.settings.context(lm=cheap_lm):
|
||||
context = retriever(question)
|
||||
|
||||
with dspy.settings.context(lm=strong_lm):
|
||||
answer = generator(context=context, question=question)
|
||||
```
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Pattern 1: Structured Output
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class PersonInfo(BaseModel):
|
||||
name: str = Field(description="Full name")
|
||||
age: int = Field(description="Age in years")
|
||||
occupation: str = Field(description="Current job")
|
||||
|
||||
class ExtractPerson(dspy.Signature):
|
||||
"""Extract person information from text."""
|
||||
text = dspy.InputField()
|
||||
person: PersonInfo = dspy.OutputField()
|
||||
|
||||
extractor = dspy.TypedPredictor(ExtractPerson)
|
||||
result = extractor(text="John Doe is a 35-year-old software engineer.")
|
||||
print(result.person.name) # "John Doe"
|
||||
print(result.person.age) # 35
|
||||
```
|
||||
|
||||
### Pattern 2: Assertion-Driven Optimization
|
||||
|
||||
```python
|
||||
import dspy
|
||||
from dspy.primitives.assertions import assert_transform_module, backtrack_handler
|
||||
|
||||
class MathQA(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.solve = dspy.ChainOfThought("problem -> solution: float")
|
||||
|
||||
def forward(self, problem):
|
||||
solution = self.solve(problem=problem).solution
|
||||
|
||||
# Assert solution is numeric
|
||||
dspy.Assert(
|
||||
isinstance(float(solution), float),
|
||||
"Solution must be a number",
|
||||
backtrack=backtrack_handler
|
||||
)
|
||||
|
||||
return dspy.Prediction(solution=solution)
|
||||
```
|
||||
|
||||
### Pattern 3: Self-Consistency
|
||||
|
||||
```python
|
||||
import dspy
|
||||
from collections import Counter
|
||||
|
||||
class ConsistentQA(dspy.Module):
|
||||
def __init__(self, num_samples=5):
|
||||
super().__init__()
|
||||
self.qa = dspy.ChainOfThought("question -> answer")
|
||||
self.num_samples = num_samples
|
||||
|
||||
def forward(self, question):
|
||||
# Generate multiple answers
|
||||
answers = []
|
||||
for _ in range(self.num_samples):
|
||||
result = self.qa(question=question)
|
||||
answers.append(result.answer)
|
||||
|
||||
# Return most common answer
|
||||
most_common = Counter(answers).most_common(1)[0][0]
|
||||
return dspy.Prediction(answer=most_common)
|
||||
```
|
||||
|
||||
### Pattern 4: Retrieval with Reranking
|
||||
|
||||
```python
|
||||
class RerankedRAG(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.retrieve = dspy.Retrieve(k=10)
|
||||
self.rerank = dspy.Predict("question, passage -> relevance_score: float")
|
||||
self.answer = dspy.ChainOfThought("context, question -> answer")
|
||||
|
||||
def forward(self, question):
|
||||
# Retrieve candidates
|
||||
passages = self.retrieve(question).passages
|
||||
|
||||
# Rerank passages
|
||||
scored = []
|
||||
for passage in passages:
|
||||
score = float(self.rerank(question=question, passage=passage).relevance_score)
|
||||
scored.append((score, passage))
|
||||
|
||||
# Take top 3
|
||||
top_passages = [p for _, p in sorted(scored, reverse=True)[:3]]
|
||||
context = "\n\n".join(top_passages)
|
||||
|
||||
# Generate answer
|
||||
return self.answer(context=context, question=question)
|
||||
```
|
||||
|
||||
## Evaluation and Metrics
|
||||
|
||||
### Custom Metrics
|
||||
|
||||
```python
|
||||
def exact_match(example, pred, trace=None):
|
||||
"""Exact match metric."""
|
||||
return example.answer.lower() == pred.answer.lower()
|
||||
|
||||
def f1_score(example, pred, trace=None):
|
||||
"""F1 score for text overlap."""
|
||||
pred_tokens = set(pred.answer.lower().split())
|
||||
gold_tokens = set(example.answer.lower().split())
|
||||
|
||||
if not pred_tokens:
|
||||
return 0.0
|
||||
|
||||
precision = len(pred_tokens & gold_tokens) / len(pred_tokens)
|
||||
recall = len(pred_tokens & gold_tokens) / len(gold_tokens)
|
||||
|
||||
if precision + recall == 0:
|
||||
return 0.0
|
||||
|
||||
return 2 * (precision * recall) / (precision + recall)
|
||||
```
|
||||
|
||||
### Evaluation
|
||||
|
||||
```python
|
||||
from dspy.evaluate import Evaluate
|
||||
|
||||
# Create evaluator
|
||||
evaluator = Evaluate(
|
||||
devset=testset,
|
||||
metric=exact_match,
|
||||
num_threads=4,
|
||||
display_progress=True
|
||||
)
|
||||
|
||||
# Evaluate model
|
||||
score = evaluator(qa_system)
|
||||
print(f"Accuracy: {score}")
|
||||
|
||||
# Compare optimized vs unoptimized
|
||||
score_before = evaluator(qa)
|
||||
score_after = evaluator(optimized_qa)
|
||||
print(f"Improvement: {score_after - score_before:.2%}")
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Start Simple, Iterate
|
||||
|
||||
```python
|
||||
# Start with Predict
|
||||
qa = dspy.Predict("question -> answer")
|
||||
|
||||
# Add reasoning if needed
|
||||
qa = dspy.ChainOfThought("question -> answer")
|
||||
|
||||
# Add optimization when you have data
|
||||
optimized_qa = optimizer.compile(qa, trainset=data)
|
||||
```
|
||||
|
||||
### 2. Use Descriptive Signatures
|
||||
|
||||
```python
|
||||
# ❌ Bad: Vague
|
||||
class Task(dspy.Signature):
|
||||
input = dspy.InputField()
|
||||
output = dspy.OutputField()
|
||||
|
||||
# ✅ Good: Descriptive
|
||||
class SummarizeArticle(dspy.Signature):
|
||||
"""Summarize news articles into 3-5 key points."""
|
||||
article = dspy.InputField(desc="full article text")
|
||||
summary = dspy.OutputField(desc="bullet points, 3-5 items")
|
||||
```
|
||||
|
||||
### 3. Optimize with Representative Data
|
||||
|
||||
```python
|
||||
# Create diverse training examples
|
||||
trainset = [
|
||||
dspy.Example(question="factual", answer="...).with_inputs("question"),
|
||||
dspy.Example(question="reasoning", answer="...").with_inputs("question"),
|
||||
dspy.Example(question="calculation", answer="...").with_inputs("question"),
|
||||
]
|
||||
|
||||
# Use validation set for metric
|
||||
def metric(example, pred, trace=None):
|
||||
return example.answer in pred.answer
|
||||
```
|
||||
|
||||
### 4. Save and Load Optimized Models
|
||||
|
||||
```python
|
||||
# Save
|
||||
optimized_qa.save("models/qa_v1.json")
|
||||
|
||||
# Load
|
||||
loaded_qa = dspy.ChainOfThought("question -> answer")
|
||||
loaded_qa.load("models/qa_v1.json")
|
||||
```
|
||||
|
||||
### 5. Monitor and Debug
|
||||
|
||||
```python
|
||||
# Enable tracing
|
||||
dspy.settings.configure(lm=lm, trace=[])
|
||||
|
||||
# Run prediction
|
||||
result = qa(question="...")
|
||||
|
||||
# Inspect trace
|
||||
for call in dspy.settings.trace:
|
||||
print(f"Prompt: {call['prompt']}")
|
||||
print(f"Response: {call['response']}")
|
||||
```
|
||||
|
||||
## Comparison to Other Approaches
|
||||
|
||||
| Feature | Manual Prompting | LangChain | DSPy |
|
||||
|---------|-----------------|-----------|------|
|
||||
| Prompt Engineering | Manual | Manual | Automatic |
|
||||
| Optimization | Trial & error | None | Data-driven |
|
||||
| Modularity | Low | Medium | High |
|
||||
| Type Safety | No | Limited | Yes (Signatures) |
|
||||
| Portability | Low | Medium | High |
|
||||
| Learning Curve | Low | Medium | Medium-High |
|
||||
|
||||
**When to choose DSPy:**
|
||||
- You have training data or can generate it
|
||||
- You need systematic prompt improvement
|
||||
- You're building complex multi-stage systems
|
||||
- You want to optimize across different LMs
|
||||
|
||||
**When to choose alternatives:**
|
||||
- Quick prototypes (manual prompting)
|
||||
- Simple chains with existing tools (LangChain)
|
||||
- Custom optimization logic needed
|
||||
|
||||
## Resources
|
||||
|
||||
- **Documentation**: https://dspy.ai
|
||||
- **GitHub**: https://github.com/stanfordnlp/dspy (22k+ stars)
|
||||
- **Discord**: https://discord.gg/XCGy2WDCQB
|
||||
- **Twitter**: @DSPyOSS
|
||||
- **Paper**: "DSPy: Compiling Declarative Language Model Calls into Self-Improving Pipelines"
|
||||
|
||||
## See Also
|
||||
|
||||
- `references/modules.md` - Detailed module guide (Predict, ChainOfThought, ReAct, ProgramOfThought)
|
||||
- `references/optimizers.md` - Optimization algorithms (BootstrapFewShot, MIPRO, BootstrapFinetune)
|
||||
- `references/examples.md` - Real-world examples (RAG, agents, classifiers)
|
||||
|
||||
|
||||
@@ -0,0 +1,663 @@
|
||||
# DSPy Real-World Examples
|
||||
|
||||
Practical examples of building production systems with DSPy.
|
||||
|
||||
## Table of Contents
|
||||
- RAG Systems
|
||||
- Agent Systems
|
||||
- Classification
|
||||
- Data Processing
|
||||
- Multi-Stage Pipelines
|
||||
|
||||
## RAG Systems
|
||||
|
||||
### Basic RAG
|
||||
|
||||
```python
|
||||
import dspy
|
||||
|
||||
class BasicRAG(dspy.Module):
|
||||
def __init__(self, num_passages=3):
|
||||
super().__init__()
|
||||
self.retrieve = dspy.Retrieve(k=num_passages)
|
||||
self.generate = dspy.ChainOfThought("context, question -> answer")
|
||||
|
||||
def forward(self, question):
|
||||
passages = self.retrieve(question).passages
|
||||
context = "\n\n".join(passages)
|
||||
return self.generate(context=context, question=question)
|
||||
|
||||
# Configure retriever (example with Chroma)
|
||||
from dspy.retrieve.chromadb_rm import ChromadbRM
|
||||
|
||||
retriever = ChromadbRM(
|
||||
collection_name="my_docs",
|
||||
persist_directory="./chroma_db",
|
||||
k=3
|
||||
)
|
||||
dspy.settings.configure(rm=retriever)
|
||||
|
||||
# Use RAG
|
||||
rag = BasicRAG()
|
||||
result = rag(question="What is DSPy?")
|
||||
print(result.answer)
|
||||
```
|
||||
|
||||
### Optimized RAG
|
||||
|
||||
```python
|
||||
from dspy.teleprompt import BootstrapFewShot
|
||||
|
||||
# Training data with question-answer pairs
|
||||
trainset = [
|
||||
dspy.Example(
|
||||
question="What is retrieval augmented generation?",
|
||||
answer="RAG combines retrieval of relevant documents with generation..."
|
||||
).with_inputs("question"),
|
||||
# ... more examples
|
||||
]
|
||||
|
||||
# Define metric
|
||||
def answer_correctness(example, pred, trace=None):
|
||||
# Check if answer contains key information
|
||||
return example.answer.lower() in pred.answer.lower()
|
||||
|
||||
# Optimize RAG
|
||||
optimizer = BootstrapFewShot(metric=answer_correctness)
|
||||
optimized_rag = optimizer.compile(rag, trainset=trainset)
|
||||
|
||||
# Optimized RAG performs better on similar questions
|
||||
result = optimized_rag(question="Explain RAG systems")
|
||||
```
|
||||
|
||||
### Multi-Hop RAG
|
||||
|
||||
```python
|
||||
class MultiHopRAG(dspy.Module):
|
||||
"""RAG that follows chains of reasoning across documents."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.retrieve = dspy.Retrieve(k=3)
|
||||
self.generate_query = dspy.ChainOfThought("question -> search_query")
|
||||
self.generate_answer = dspy.ChainOfThought("context, question -> answer")
|
||||
|
||||
def forward(self, question):
|
||||
# First retrieval
|
||||
query1 = self.generate_query(question=question).search_query
|
||||
passages1 = self.retrieve(query1).passages
|
||||
|
||||
# Generate follow-up query based on first results
|
||||
context1 = "\n".join(passages1)
|
||||
query2 = self.generate_query(
|
||||
question=f"Based on: {context1}\nFollow-up: {question}"
|
||||
).search_query
|
||||
|
||||
# Second retrieval
|
||||
passages2 = self.retrieve(query2).passages
|
||||
|
||||
# Combine all context
|
||||
all_context = "\n\n".join(passages1 + passages2)
|
||||
|
||||
# Generate final answer
|
||||
return self.generate_answer(context=all_context, question=question)
|
||||
|
||||
# Use multi-hop RAG
|
||||
multi_rag = MultiHopRAG()
|
||||
result = multi_rag(question="Who wrote the book that inspired Blade Runner?")
|
||||
# Hop 1: Find "Blade Runner was based on..."
|
||||
# Hop 2: Find author of that book
|
||||
```
|
||||
|
||||
### RAG with Reranking
|
||||
|
||||
```python
|
||||
class RerankedRAG(dspy.Module):
|
||||
"""RAG with learned reranking of retrieved passages."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.retrieve = dspy.Retrieve(k=10) # Get more candidates
|
||||
self.rerank = dspy.Predict("question, passage -> relevance_score: float")
|
||||
self.answer = dspy.ChainOfThought("context, question -> answer")
|
||||
|
||||
def forward(self, question):
|
||||
# Retrieve candidates
|
||||
passages = self.retrieve(question).passages
|
||||
|
||||
# Rerank passages
|
||||
scored_passages = []
|
||||
for passage in passages:
|
||||
score = float(self.rerank(
|
||||
question=question,
|
||||
passage=passage
|
||||
).relevance_score)
|
||||
scored_passages.append((score, passage))
|
||||
|
||||
# Take top 3 after reranking
|
||||
top_passages = [p for _, p in sorted(scored_passages, reverse=True)[:3]]
|
||||
context = "\n\n".join(top_passages)
|
||||
|
||||
# Generate answer from reranked context
|
||||
return self.answer(context=context, question=question)
|
||||
```
|
||||
|
||||
## Agent Systems
|
||||
|
||||
### ReAct Agent
|
||||
|
||||
```python
|
||||
from dspy.predict import ReAct
|
||||
|
||||
# Define tools
|
||||
def search_wikipedia(query: str) -> str:
|
||||
"""Search Wikipedia for information."""
|
||||
import wikipedia
|
||||
try:
|
||||
return wikipedia.summary(query, sentences=3)
|
||||
except:
|
||||
return "No results found"
|
||||
|
||||
def calculate(expression: str) -> str:
|
||||
"""Evaluate mathematical expression safely."""
|
||||
try:
|
||||
# Use safe eval
|
||||
result = eval(expression, {"__builtins__": {}}, {})
|
||||
return str(result)
|
||||
except:
|
||||
return "Invalid expression"
|
||||
|
||||
def search_web(query: str) -> str:
|
||||
"""Search the web."""
|
||||
# Your web search implementation
|
||||
return results
|
||||
|
||||
# Create agent signature
|
||||
class ResearchAgent(dspy.Signature):
|
||||
"""Answer questions using available tools."""
|
||||
question = dspy.InputField()
|
||||
answer = dspy.OutputField()
|
||||
|
||||
# Create ReAct agent
|
||||
agent = ReAct(ResearchAgent, tools=[search_wikipedia, calculate, search_web])
|
||||
|
||||
# Agent decides which tools to use
|
||||
result = agent(question="What is the population of France divided by 10?")
|
||||
# Agent:
|
||||
# 1. Thinks: "Need population of France"
|
||||
# 2. Acts: search_wikipedia("France population")
|
||||
# 3. Thinks: "Got 67 million, need to divide"
|
||||
# 4. Acts: calculate("67000000 / 10")
|
||||
# 5. Returns: "6,700,000"
|
||||
```
|
||||
|
||||
### Multi-Agent System
|
||||
|
||||
```python
|
||||
class MultiAgentSystem(dspy.Module):
|
||||
"""System with specialized agents for different tasks."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# Router agent
|
||||
self.router = dspy.Predict("question -> agent_type: str")
|
||||
|
||||
# Specialized agents
|
||||
self.research_agent = ReAct(
|
||||
ResearchAgent,
|
||||
tools=[search_wikipedia, search_web]
|
||||
)
|
||||
self.math_agent = dspy.ProgramOfThought("problem -> answer")
|
||||
self.reasoning_agent = dspy.ChainOfThought("question -> answer")
|
||||
|
||||
def forward(self, question):
|
||||
# Route to appropriate agent
|
||||
agent_type = self.router(question=question).agent_type
|
||||
|
||||
if agent_type == "research":
|
||||
return self.research_agent(question=question)
|
||||
elif agent_type == "math":
|
||||
return self.math_agent(problem=question)
|
||||
else:
|
||||
return self.reasoning_agent(question=question)
|
||||
|
||||
# Use multi-agent system
|
||||
mas = MultiAgentSystem()
|
||||
result = mas(question="What is 15% of the GDP of France?")
|
||||
# Routes to research_agent for GDP, then to math_agent for calculation
|
||||
```
|
||||
|
||||
## Classification
|
||||
|
||||
### Binary Classifier
|
||||
|
||||
```python
|
||||
class SentimentClassifier(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.classify = dspy.Predict("text -> sentiment: str")
|
||||
|
||||
def forward(self, text):
|
||||
return self.classify(text=text)
|
||||
|
||||
# Training data
|
||||
trainset = [
|
||||
dspy.Example(text="I love this!", sentiment="positive").with_inputs("text"),
|
||||
dspy.Example(text="Terrible experience", sentiment="negative").with_inputs("text"),
|
||||
# ... more examples
|
||||
]
|
||||
|
||||
# Optimize
|
||||
def accuracy(example, pred, trace=None):
|
||||
return example.sentiment == pred.sentiment
|
||||
|
||||
optimizer = BootstrapFewShot(metric=accuracy, max_bootstrapped_demos=5)
|
||||
classifier = SentimentClassifier()
|
||||
optimized_classifier = optimizer.compile(classifier, trainset=trainset)
|
||||
|
||||
# Use classifier
|
||||
result = optimized_classifier(text="This product is amazing!")
|
||||
print(result.sentiment) # "positive"
|
||||
```
|
||||
|
||||
### Multi-Class Classifier
|
||||
|
||||
```python
|
||||
class TopicClassifier(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.classify = dspy.ChainOfThought(
|
||||
"text -> category: str, confidence: float"
|
||||
)
|
||||
|
||||
def forward(self, text):
|
||||
result = self.classify(text=text)
|
||||
return dspy.Prediction(
|
||||
category=result.category,
|
||||
confidence=float(result.confidence)
|
||||
)
|
||||
|
||||
# Define categories in signature
|
||||
class TopicSignature(dspy.Signature):
|
||||
"""Classify text into one of: technology, sports, politics, entertainment."""
|
||||
text = dspy.InputField()
|
||||
category = dspy.OutputField(desc="one of: technology, sports, politics, entertainment")
|
||||
confidence = dspy.OutputField(desc="0.0 to 1.0")
|
||||
|
||||
classifier = dspy.ChainOfThought(TopicSignature)
|
||||
result = classifier(text="The Lakers won the championship")
|
||||
print(result.category) # "sports"
|
||||
print(result.confidence) # 0.95
|
||||
```
|
||||
|
||||
### Hierarchical Classifier
|
||||
|
||||
```python
|
||||
class HierarchicalClassifier(dspy.Module):
|
||||
"""Two-stage classification: coarse then fine-grained."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.coarse = dspy.Predict("text -> broad_category: str")
|
||||
self.fine_tech = dspy.Predict("text -> tech_subcategory: str")
|
||||
self.fine_sports = dspy.Predict("text -> sports_subcategory: str")
|
||||
|
||||
def forward(self, text):
|
||||
# Stage 1: Broad category
|
||||
broad = self.coarse(text=text).broad_category
|
||||
|
||||
# Stage 2: Fine-grained based on broad
|
||||
if broad == "technology":
|
||||
fine = self.fine_tech(text=text).tech_subcategory
|
||||
elif broad == "sports":
|
||||
fine = self.fine_sports(text=text).sports_subcategory
|
||||
else:
|
||||
fine = "other"
|
||||
|
||||
return dspy.Prediction(broad_category=broad, fine_category=fine)
|
||||
```
|
||||
|
||||
## Data Processing
|
||||
|
||||
### Text Summarization
|
||||
|
||||
```python
|
||||
class AdaptiveSummarizer(dspy.Module):
|
||||
"""Summarizes text to target length."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.summarize = dspy.ChainOfThought("text, target_length -> summary")
|
||||
|
||||
def forward(self, text, target_length="3 sentences"):
|
||||
return self.summarize(text=text, target_length=target_length)
|
||||
|
||||
# Use summarizer
|
||||
summarizer = AdaptiveSummarizer()
|
||||
long_text = "..." # Long article
|
||||
|
||||
short_summary = summarizer(long_text, target_length="1 sentence")
|
||||
medium_summary = summarizer(long_text, target_length="3 sentences")
|
||||
detailed_summary = summarizer(long_text, target_length="1 paragraph")
|
||||
```
|
||||
|
||||
### Information Extraction
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class PersonInfo(BaseModel):
|
||||
name: str = Field(description="Full name")
|
||||
age: int = Field(description="Age in years")
|
||||
occupation: str = Field(description="Job title")
|
||||
location: str = Field(description="City and country")
|
||||
|
||||
class ExtractPerson(dspy.Signature):
|
||||
"""Extract person information from text."""
|
||||
text = dspy.InputField()
|
||||
person: PersonInfo = dspy.OutputField()
|
||||
|
||||
extractor = dspy.TypedPredictor(ExtractPerson)
|
||||
|
||||
text = "Dr. Jane Smith, 42, is a neuroscientist at Stanford University in Palo Alto, California."
|
||||
result = extractor(text=text)
|
||||
|
||||
print(result.person.name) # "Dr. Jane Smith"
|
||||
print(result.person.age) # 42
|
||||
print(result.person.occupation) # "neuroscientist"
|
||||
print(result.person.location) # "Palo Alto, California"
|
||||
```
|
||||
|
||||
### Batch Processing
|
||||
|
||||
```python
|
||||
class BatchProcessor(dspy.Module):
|
||||
"""Process large datasets efficiently."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.process = dspy.Predict("text -> processed_text")
|
||||
|
||||
def forward(self, texts):
|
||||
# Batch processing for efficiency
|
||||
return self.process.batch([{"text": t} for t in texts])
|
||||
|
||||
# Process 1000 documents
|
||||
processor = BatchProcessor()
|
||||
results = processor(texts=large_dataset)
|
||||
|
||||
# Results are returned in order
|
||||
for original, result in zip(large_dataset, results):
|
||||
print(f"{original} -> {result.processed_text}")
|
||||
```
|
||||
|
||||
## Multi-Stage Pipelines
|
||||
|
||||
### Document Processing Pipeline
|
||||
|
||||
```python
|
||||
class DocumentPipeline(dspy.Module):
|
||||
"""Multi-stage document processing."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.extract = dspy.Predict("document -> key_points")
|
||||
self.classify = dspy.Predict("key_points -> category")
|
||||
self.summarize = dspy.ChainOfThought("key_points, category -> summary")
|
||||
self.tag = dspy.Predict("summary -> tags")
|
||||
|
||||
def forward(self, document):
|
||||
# Stage 1: Extract key points
|
||||
key_points = self.extract(document=document).key_points
|
||||
|
||||
# Stage 2: Classify
|
||||
category = self.classify(key_points=key_points).category
|
||||
|
||||
# Stage 3: Summarize
|
||||
summary = self.summarize(
|
||||
key_points=key_points,
|
||||
category=category
|
||||
).summary
|
||||
|
||||
# Stage 4: Generate tags
|
||||
tags = self.tag(summary=summary).tags
|
||||
|
||||
return dspy.Prediction(
|
||||
key_points=key_points,
|
||||
category=category,
|
||||
summary=summary,
|
||||
tags=tags
|
||||
)
|
||||
```
|
||||
|
||||
### Quality Control Pipeline
|
||||
|
||||
```python
|
||||
class QualityControlPipeline(dspy.Module):
|
||||
"""Generate output and verify quality."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.generate = dspy.ChainOfThought("prompt -> output")
|
||||
self.verify = dspy.Predict("output -> is_valid: bool, issues: str")
|
||||
self.improve = dspy.ChainOfThought("output, issues -> improved_output")
|
||||
|
||||
def forward(self, prompt, max_iterations=3):
|
||||
output = self.generate(prompt=prompt).output
|
||||
|
||||
for _ in range(max_iterations):
|
||||
# Verify output
|
||||
verification = self.verify(output=output)
|
||||
|
||||
if verification.is_valid:
|
||||
return dspy.Prediction(output=output, iterations=_ + 1)
|
||||
|
||||
# Improve based on issues
|
||||
output = self.improve(
|
||||
output=output,
|
||||
issues=verification.issues
|
||||
).improved_output
|
||||
|
||||
return dspy.Prediction(output=output, iterations=max_iterations)
|
||||
```
|
||||
|
||||
## Production Tips
|
||||
|
||||
### 1. Caching for Performance
|
||||
|
||||
```python
|
||||
from functools import lru_cache
|
||||
|
||||
class CachedRAG(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.retrieve = dspy.Retrieve(k=3)
|
||||
self.generate = dspy.ChainOfThought("context, question -> answer")
|
||||
|
||||
@lru_cache(maxsize=1000)
|
||||
def forward(self, question):
|
||||
passages = self.retrieve(question).passages
|
||||
context = "\n".join(passages)
|
||||
return self.generate(context=context, question=question).answer
|
||||
```
|
||||
|
||||
### 2. Error Handling
|
||||
|
||||
```python
|
||||
class RobustModule(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.process = dspy.ChainOfThought("input -> output")
|
||||
|
||||
def forward(self, input):
|
||||
try:
|
||||
result = self.process(input=input)
|
||||
return result
|
||||
except Exception as e:
|
||||
# Log error
|
||||
print(f"Error processing {input}: {e}")
|
||||
# Return fallback
|
||||
return dspy.Prediction(output="Error: could not process input")
|
||||
```
|
||||
|
||||
### 3. Monitoring
|
||||
|
||||
```python
|
||||
class MonitoredModule(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.process = dspy.ChainOfThought("input -> output")
|
||||
self.call_count = 0
|
||||
self.errors = 0
|
||||
|
||||
def forward(self, input):
|
||||
self.call_count += 1
|
||||
|
||||
try:
|
||||
result = self.process(input=input)
|
||||
return result
|
||||
except Exception as e:
|
||||
self.errors += 1
|
||||
raise
|
||||
|
||||
def get_stats(self):
|
||||
return {
|
||||
"calls": self.call_count,
|
||||
"errors": self.errors,
|
||||
"error_rate": self.errors / max(self.call_count, 1)
|
||||
}
|
||||
```
|
||||
|
||||
### 4. A/B Testing
|
||||
|
||||
```python
|
||||
class ABTestModule(dspy.Module):
|
||||
"""Run two variants and compare."""
|
||||
|
||||
def __init__(self, variant_a, variant_b):
|
||||
super().__init__()
|
||||
self.variant_a = variant_a
|
||||
self.variant_b = variant_b
|
||||
self.a_calls = 0
|
||||
self.b_calls = 0
|
||||
|
||||
def forward(self, input, variant="a"):
|
||||
if variant == "a":
|
||||
self.a_calls += 1
|
||||
return self.variant_a(input=input)
|
||||
else:
|
||||
self.b_calls += 1
|
||||
return self.variant_b(input=input)
|
||||
|
||||
# Compare two optimizers
|
||||
baseline = dspy.ChainOfThought("question -> answer")
|
||||
optimized = BootstrapFewShot(...).compile(baseline, trainset=trainset)
|
||||
|
||||
ab_test = ABTestModule(variant_a=baseline, variant_b=optimized)
|
||||
|
||||
# Route 50% to each
|
||||
import random
|
||||
variant = "a" if random.random() < 0.5 else "b"
|
||||
result = ab_test(input=question, variant=variant)
|
||||
```
|
||||
|
||||
## Complete Example: Customer Support Bot
|
||||
|
||||
```python
|
||||
import dspy
|
||||
from dspy.teleprompt import BootstrapFewShot
|
||||
|
||||
class CustomerSupportBot(dspy.Module):
|
||||
"""Complete customer support system."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# Classify intent
|
||||
self.classify_intent = dspy.Predict("message -> intent: str")
|
||||
|
||||
# Specialized handlers
|
||||
self.technical_handler = dspy.ChainOfThought("message, history -> response")
|
||||
self.billing_handler = dspy.ChainOfThought("message, history -> response")
|
||||
self.general_handler = dspy.Predict("message, history -> response")
|
||||
|
||||
# Retrieve relevant docs
|
||||
self.retrieve = dspy.Retrieve(k=3)
|
||||
|
||||
# Conversation history
|
||||
self.history = []
|
||||
|
||||
def forward(self, message):
|
||||
# Classify intent
|
||||
intent = self.classify_intent(message=message).intent
|
||||
|
||||
# Retrieve relevant documentation
|
||||
docs = self.retrieve(message).passages
|
||||
context = "\n".join(docs)
|
||||
|
||||
# Add context to history
|
||||
history_str = "\n".join(self.history)
|
||||
full_message = f"Context: {context}\n\nMessage: {message}"
|
||||
|
||||
# Route to appropriate handler
|
||||
if intent == "technical":
|
||||
response = self.technical_handler(
|
||||
message=full_message,
|
||||
history=history_str
|
||||
).response
|
||||
elif intent == "billing":
|
||||
response = self.billing_handler(
|
||||
message=full_message,
|
||||
history=history_str
|
||||
).response
|
||||
else:
|
||||
response = self.general_handler(
|
||||
message=full_message,
|
||||
history=history_str
|
||||
).response
|
||||
|
||||
# Update history
|
||||
self.history.append(f"User: {message}")
|
||||
self.history.append(f"Bot: {response}")
|
||||
|
||||
return dspy.Prediction(response=response, intent=intent)
|
||||
|
||||
# Training data
|
||||
trainset = [
|
||||
dspy.Example(
|
||||
message="My account isn't working",
|
||||
intent="technical",
|
||||
response="I'd be happy to help. What error are you seeing?"
|
||||
).with_inputs("message"),
|
||||
# ... more examples
|
||||
]
|
||||
|
||||
# Define metric
|
||||
def response_quality(example, pred, trace=None):
|
||||
# Check if response is helpful
|
||||
if len(pred.response) < 20:
|
||||
return 0.0
|
||||
if example.intent != pred.intent:
|
||||
return 0.3
|
||||
return 1.0
|
||||
|
||||
# Optimize
|
||||
optimizer = BootstrapFewShot(metric=response_quality)
|
||||
bot = CustomerSupportBot()
|
||||
optimized_bot = optimizer.compile(bot, trainset=trainset)
|
||||
|
||||
# Use in production
|
||||
optimized_bot.save("models/support_bot_v1.json")
|
||||
|
||||
# Later, load and use
|
||||
loaded_bot = CustomerSupportBot()
|
||||
loaded_bot.load("models/support_bot_v1.json")
|
||||
response = loaded_bot(message="I can't log in")
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **Documentation**: https://dspy.ai
|
||||
- **Examples Repo**: https://github.com/stanfordnlp/dspy/tree/main/examples
|
||||
- **Discord**: https://discord.gg/XCGy2WDCQB
|
||||
@@ -0,0 +1,475 @@
|
||||
# DSPy Modules
|
||||
|
||||
Complete guide to DSPy's built-in modules for language model programming.
|
||||
|
||||
## Module Basics
|
||||
|
||||
DSPy modules are composable building blocks inspired by PyTorch's NN modules:
|
||||
- Have learnable parameters (prompts, few-shot examples)
|
||||
- Can be composed using Python control flow
|
||||
- Generalized to handle any signature
|
||||
- Optimizable with DSPy optimizers
|
||||
|
||||
### Base Module Pattern
|
||||
|
||||
```python
|
||||
import dspy
|
||||
|
||||
class CustomModule(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Initialize sub-modules
|
||||
self.predictor = dspy.Predict("input -> output")
|
||||
|
||||
def forward(self, input):
|
||||
# Module logic
|
||||
result = self.predictor(input=input)
|
||||
return result
|
||||
```
|
||||
|
||||
## Core Modules
|
||||
|
||||
### dspy.Predict
|
||||
|
||||
**Basic prediction module** - Makes LM calls without reasoning steps.
|
||||
|
||||
```python
|
||||
# Inline signature
|
||||
qa = dspy.Predict("question -> answer")
|
||||
result = qa(question="What is 2+2?")
|
||||
|
||||
# Class signature
|
||||
class QA(dspy.Signature):
|
||||
"""Answer questions concisely."""
|
||||
question = dspy.InputField()
|
||||
answer = dspy.OutputField(desc="short, factual answer")
|
||||
|
||||
qa = dspy.Predict(QA)
|
||||
result = qa(question="What is the capital of France?")
|
||||
print(result.answer) # "Paris"
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- Simple, direct predictions
|
||||
- No reasoning steps needed
|
||||
- Fast responses required
|
||||
|
||||
### dspy.ChainOfThought
|
||||
|
||||
**Step-by-step reasoning** - Generates rationale before answer.
|
||||
|
||||
**Parameters:**
|
||||
- `signature`: Task signature
|
||||
- `rationale_field`: Custom reasoning field (optional)
|
||||
- `rationale_field_type`: Type for rationale (default: `str`)
|
||||
|
||||
```python
|
||||
# Basic usage
|
||||
cot = dspy.ChainOfThought("question -> answer")
|
||||
result = cot(question="If I have 5 apples and give away 2, how many remain?")
|
||||
print(result.rationale) # "Let's think step by step..."
|
||||
print(result.answer) # "3"
|
||||
|
||||
# Custom rationale field
|
||||
cot = dspy.ChainOfThought(
|
||||
signature="problem -> solution",
|
||||
rationale_field=dspy.OutputField(
|
||||
prefix="Reasoning: Let's break this down step by step to"
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- Complex reasoning tasks
|
||||
- Math word problems
|
||||
- Logical deduction
|
||||
- Quality > speed
|
||||
|
||||
**Performance:**
|
||||
- ~2x slower than Predict
|
||||
- Significantly better accuracy on reasoning tasks
|
||||
|
||||
### dspy.ProgramOfThought
|
||||
|
||||
**Code-based reasoning** - Generates and executes Python code.
|
||||
|
||||
```python
|
||||
pot = dspy.ProgramOfThought("question -> answer")
|
||||
|
||||
result = pot(question="What is 15% of 240?")
|
||||
# Internally generates: answer = 240 * 0.15
|
||||
# Executes code and returns result
|
||||
print(result.answer) # 36.0
|
||||
|
||||
result = pot(question="If a train travels 60 mph for 2.5 hours, how far does it go?")
|
||||
# Generates: distance = 60 * 2.5
|
||||
print(result.answer) # 150.0
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- Arithmetic calculations
|
||||
- Symbolic math
|
||||
- Data transformations
|
||||
- Deterministic computations
|
||||
|
||||
**Benefits:**
|
||||
- More reliable than text-based math
|
||||
- Handles complex calculations
|
||||
- Transparent (shows generated code)
|
||||
|
||||
### dspy.ReAct
|
||||
|
||||
**Reasoning + Acting** - Agent that uses tools iteratively.
|
||||
|
||||
```python
|
||||
from dspy.predict import ReAct
|
||||
|
||||
# Define tools
|
||||
def search_wikipedia(query: str) -> str:
|
||||
"""Search Wikipedia for information."""
|
||||
# Your search implementation
|
||||
return search_results
|
||||
|
||||
def calculate(expression: str) -> float:
|
||||
"""Evaluate a mathematical expression."""
|
||||
return eval(expression)
|
||||
|
||||
# Create ReAct agent
|
||||
class ResearchQA(dspy.Signature):
|
||||
"""Answer questions using available tools."""
|
||||
question = dspy.InputField()
|
||||
answer = dspy.OutputField()
|
||||
|
||||
react = ReAct(ResearchQA, tools=[search_wikipedia, calculate])
|
||||
|
||||
# Agent decides which tools to use
|
||||
result = react(question="How old was Einstein when he published special relativity?")
|
||||
# Internally:
|
||||
# 1. Thinks: "Need birth year and publication year"
|
||||
# 2. Acts: search_wikipedia("Albert Einstein")
|
||||
# 3. Acts: search_wikipedia("Special relativity 1905")
|
||||
# 4. Acts: calculate("1905 - 1879")
|
||||
# 5. Returns: "26 years old"
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- Multi-step research tasks
|
||||
- Tool-using agents
|
||||
- Complex information retrieval
|
||||
- Tasks requiring multiple API calls
|
||||
|
||||
**Best practices:**
|
||||
- Keep tool descriptions clear and specific
|
||||
- Limit to 5-7 tools (too many = confusion)
|
||||
- Provide tool usage examples in docstrings
|
||||
|
||||
### dspy.MultiChainComparison
|
||||
|
||||
**Generate multiple outputs and compare** - Self-consistency pattern.
|
||||
|
||||
```python
|
||||
mcc = dspy.MultiChainComparison("question -> answer", M=5)
|
||||
|
||||
result = mcc(question="What is the capital of France?")
|
||||
# Generates 5 candidate answers
|
||||
# Compares and selects most consistent
|
||||
print(result.answer) # "Paris"
|
||||
print(result.candidates) # All 5 generated answers
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
- `M`: Number of candidates to generate (default: 5)
|
||||
- `temperature`: Sampling temperature for diversity
|
||||
|
||||
**When to use:**
|
||||
- High-stakes decisions
|
||||
- Ambiguous questions
|
||||
- When single answer may be unreliable
|
||||
|
||||
**Tradeoff:**
|
||||
- M times slower (M parallel calls)
|
||||
- Higher accuracy on ambiguous tasks
|
||||
|
||||
### dspy.majority
|
||||
|
||||
**Majority voting over multiple predictions.**
|
||||
|
||||
```python
|
||||
from dspy.primitives import majority
|
||||
|
||||
# Generate multiple predictions
|
||||
predictor = dspy.Predict("question -> answer")
|
||||
predictions = [predictor(question="What is 2+2?") for _ in range(5)]
|
||||
|
||||
# Take majority vote
|
||||
answer = majority([p.answer for p in predictions])
|
||||
print(answer) # "4"
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- Combining multiple model outputs
|
||||
- Reducing variance in predictions
|
||||
- Ensemble approaches
|
||||
|
||||
## Advanced Modules
|
||||
|
||||
### dspy.TypedPredictor
|
||||
|
||||
**Structured output with Pydantic models.**
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class PersonInfo(BaseModel):
|
||||
name: str = Field(description="Full name")
|
||||
age: int = Field(description="Age in years")
|
||||
occupation: str = Field(description="Current job")
|
||||
|
||||
class ExtractPerson(dspy.Signature):
|
||||
"""Extract person information from text."""
|
||||
text = dspy.InputField()
|
||||
person: PersonInfo = dspy.OutputField()
|
||||
|
||||
extractor = dspy.TypedPredictor(ExtractPerson)
|
||||
result = extractor(text="John Doe is a 35-year-old software engineer.")
|
||||
|
||||
print(result.person.name) # "John Doe"
|
||||
print(result.person.age) # 35
|
||||
print(result.person.occupation) # "software engineer"
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- Type safety
|
||||
- Automatic validation
|
||||
- JSON schema generation
|
||||
- IDE autocomplete
|
||||
|
||||
### dspy.Retry
|
||||
|
||||
**Automatic retry with validation.**
|
||||
|
||||
```python
|
||||
from dspy.primitives import Retry
|
||||
|
||||
def validate_number(example, pred, trace=None):
|
||||
"""Validate output is a number."""
|
||||
try:
|
||||
float(pred.answer)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
# Retry up to 3 times if validation fails
|
||||
qa = Retry(
|
||||
dspy.ChainOfThought("question -> answer"),
|
||||
validate=validate_number,
|
||||
max_retries=3
|
||||
)
|
||||
|
||||
result = qa(question="What is 15% of 80?")
|
||||
# If first attempt returns non-numeric, retries automatically
|
||||
```
|
||||
|
||||
### dspy.Assert
|
||||
|
||||
**Assertion-driven optimization.**
|
||||
|
||||
```python
|
||||
import dspy
|
||||
from dspy.primitives.assertions import assert_transform_module, backtrack_handler
|
||||
|
||||
class ValidatedQA(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.qa = dspy.ChainOfThought("question -> answer: float")
|
||||
|
||||
def forward(self, question):
|
||||
answer = self.qa(question=question).answer
|
||||
|
||||
# Assert answer is numeric
|
||||
dspy.Assert(
|
||||
isinstance(float(answer), float),
|
||||
"Answer must be a number",
|
||||
backtrack=backtrack_handler
|
||||
)
|
||||
|
||||
return dspy.Prediction(answer=answer)
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- Catches errors during optimization
|
||||
- Guides LM toward valid outputs
|
||||
- Better than post-hoc filtering
|
||||
|
||||
## Module Composition
|
||||
|
||||
### Sequential Pipeline
|
||||
|
||||
```python
|
||||
class Pipeline(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.stage1 = dspy.Predict("input -> intermediate")
|
||||
self.stage2 = dspy.ChainOfThought("intermediate -> output")
|
||||
|
||||
def forward(self, input):
|
||||
intermediate = self.stage1(input=input).intermediate
|
||||
output = self.stage2(intermediate=intermediate).output
|
||||
return dspy.Prediction(output=output)
|
||||
```
|
||||
|
||||
### Conditional Logic
|
||||
|
||||
```python
|
||||
class ConditionalModule(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.router = dspy.Predict("question -> category: str")
|
||||
self.simple_qa = dspy.Predict("question -> answer")
|
||||
self.complex_qa = dspy.ChainOfThought("question -> answer")
|
||||
|
||||
def forward(self, question):
|
||||
category = self.router(question=question).category
|
||||
|
||||
if category == "simple":
|
||||
return self.simple_qa(question=question)
|
||||
else:
|
||||
return self.complex_qa(question=question)
|
||||
```
|
||||
|
||||
### Parallel Execution
|
||||
|
||||
```python
|
||||
class ParallelModule(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.approach1 = dspy.ChainOfThought("question -> answer")
|
||||
self.approach2 = dspy.ProgramOfThought("question -> answer")
|
||||
|
||||
def forward(self, question):
|
||||
# Run both approaches
|
||||
answer1 = self.approach1(question=question).answer
|
||||
answer2 = self.approach2(question=question).answer
|
||||
|
||||
# Compare or combine results
|
||||
if answer1 == answer2:
|
||||
return dspy.Prediction(answer=answer1, confidence="high")
|
||||
else:
|
||||
return dspy.Prediction(answer=answer1, confidence="low")
|
||||
```
|
||||
|
||||
## Batch Processing
|
||||
|
||||
All modules support batch processing for efficiency:
|
||||
|
||||
```python
|
||||
cot = dspy.ChainOfThought("question -> answer")
|
||||
|
||||
questions = [
|
||||
"What is 2+2?",
|
||||
"What is 3+3?",
|
||||
"What is 4+4?"
|
||||
]
|
||||
|
||||
# Process all at once
|
||||
results = cot.batch([{"question": q} for q in questions])
|
||||
|
||||
for result in results:
|
||||
print(result.answer)
|
||||
```
|
||||
|
||||
## Saving and Loading
|
||||
|
||||
```python
|
||||
# Save module
|
||||
qa = dspy.ChainOfThought("question -> answer")
|
||||
qa.save("models/qa_v1.json")
|
||||
|
||||
# Load module
|
||||
loaded_qa = dspy.ChainOfThought("question -> answer")
|
||||
loaded_qa.load("models/qa_v1.json")
|
||||
```
|
||||
|
||||
**What gets saved:**
|
||||
- Few-shot examples
|
||||
- Prompt instructions
|
||||
- Module configuration
|
||||
|
||||
**What doesn't get saved:**
|
||||
- Model weights (DSPy doesn't fine-tune by default)
|
||||
- LM provider configuration
|
||||
|
||||
## Module Selection Guide
|
||||
|
||||
| Task | Module | Reason |
|
||||
|------|--------|--------|
|
||||
| Simple classification | Predict | Fast, direct |
|
||||
| Math word problems | ProgramOfThought | Reliable calculations |
|
||||
| Logical reasoning | ChainOfThought | Better with steps |
|
||||
| Multi-step research | ReAct | Tool usage |
|
||||
| High-stakes decisions | MultiChainComparison | Self-consistency |
|
||||
| Structured extraction | TypedPredictor | Type safety |
|
||||
| Ambiguous questions | MultiChainComparison | Multiple perspectives |
|
||||
|
||||
## Performance Tips
|
||||
|
||||
1. **Start with Predict**, add reasoning only if needed
|
||||
2. **Use batch processing** for multiple inputs
|
||||
3. **Cache predictions** for repeated queries
|
||||
4. **Profile token usage** with `track_usage=True`
|
||||
5. **Optimize after prototyping** with teleprompters
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Pattern: Retrieval + Generation
|
||||
|
||||
```python
|
||||
class RAG(dspy.Module):
|
||||
def __init__(self, k=3):
|
||||
super().__init__()
|
||||
self.retrieve = dspy.Retrieve(k=k)
|
||||
self.generate = dspy.ChainOfThought("context, question -> answer")
|
||||
|
||||
def forward(self, question):
|
||||
context = self.retrieve(question).passages
|
||||
return self.generate(context=context, question=question)
|
||||
```
|
||||
|
||||
### Pattern: Verification Loop
|
||||
|
||||
```python
|
||||
class VerifiedQA(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.answer = dspy.ChainOfThought("question -> answer")
|
||||
self.verify = dspy.Predict("question, answer -> is_correct: bool")
|
||||
|
||||
def forward(self, question, max_attempts=3):
|
||||
for _ in range(max_attempts):
|
||||
answer = self.answer(question=question).answer
|
||||
is_correct = self.verify(question=question, answer=answer).is_correct
|
||||
|
||||
if is_correct:
|
||||
return dspy.Prediction(answer=answer)
|
||||
|
||||
return dspy.Prediction(answer="Unable to verify answer")
|
||||
```
|
||||
|
||||
### Pattern: Multi-Turn Dialog
|
||||
|
||||
```python
|
||||
class DialogAgent(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.respond = dspy.Predict("history, user_message -> assistant_message")
|
||||
self.history = []
|
||||
|
||||
def forward(self, user_message):
|
||||
history_str = "\n".join(self.history)
|
||||
response = self.respond(history=history_str, user_message=user_message)
|
||||
|
||||
self.history.append(f"User: {user_message}")
|
||||
self.history.append(f"Assistant: {response.assistant_message}")
|
||||
|
||||
return response
|
||||
```
|
||||
@@ -0,0 +1,566 @@
|
||||
# DSPy Optimizers (Teleprompters)
|
||||
|
||||
Complete guide to DSPy's optimization algorithms for improving prompts and model weights.
|
||||
|
||||
## What are Optimizers?
|
||||
|
||||
DSPy optimizers (called "teleprompters") automatically improve your modules by:
|
||||
- **Synthesizing few-shot examples** from training data
|
||||
- **Proposing better instructions** through search
|
||||
- **Fine-tuning model weights** (optional)
|
||||
|
||||
**Key idea**: Instead of manually tuning prompts, define a metric and let DSPy optimize.
|
||||
|
||||
## Optimizer Selection Guide
|
||||
|
||||
| Optimizer | Best For | Speed | Quality | Data Needed |
|
||||
|-----------|----------|-------|---------|-------------|
|
||||
| BootstrapFewShot | General purpose | Fast | Good | 10-50 examples |
|
||||
| MIPRO | Instruction tuning | Medium | Excellent | 50-200 examples |
|
||||
| BootstrapFinetune | Fine-tuning | Slow | Excellent | 100+ examples |
|
||||
| COPRO | Prompt optimization | Medium | Good | 20-100 examples |
|
||||
| KNNFewShot | Quick baseline | Very fast | Fair | 10+ examples |
|
||||
|
||||
## Core Optimizers
|
||||
|
||||
### BootstrapFewShot
|
||||
|
||||
**Most popular optimizer** - Generates few-shot demonstrations from training data.
|
||||
|
||||
**How it works:**
|
||||
1. Takes your training examples
|
||||
2. Uses your module to generate predictions
|
||||
3. Selects high-quality predictions (based on metric)
|
||||
4. Uses these as few-shot examples in future prompts
|
||||
|
||||
**Parameters:**
|
||||
- `metric`: Function that scores predictions (required)
|
||||
- `max_bootstrapped_demos`: Max demonstrations to generate (default: 4)
|
||||
- `max_labeled_demos`: Max labeled examples to use (default: 16)
|
||||
- `max_rounds`: Optimization iterations (default: 1)
|
||||
- `metric_threshold`: Minimum score to accept (optional)
|
||||
|
||||
```python
|
||||
import dspy
|
||||
from dspy.teleprompt import BootstrapFewShot
|
||||
|
||||
# Define metric
|
||||
def validate_answer(example, pred, trace=None):
|
||||
"""Return True if prediction matches gold answer."""
|
||||
return example.answer.lower() == pred.answer.lower()
|
||||
|
||||
# Training data
|
||||
trainset = [
|
||||
dspy.Example(question="What is 2+2?", answer="4").with_inputs("question"),
|
||||
dspy.Example(question="What is 3+5?", answer="8").with_inputs("question"),
|
||||
dspy.Example(question="What is 10-3?", answer="7").with_inputs("question"),
|
||||
]
|
||||
|
||||
# Create module
|
||||
qa = dspy.ChainOfThought("question -> answer")
|
||||
|
||||
# Optimize
|
||||
optimizer = BootstrapFewShot(
|
||||
metric=validate_answer,
|
||||
max_bootstrapped_demos=3,
|
||||
max_rounds=2
|
||||
)
|
||||
|
||||
optimized_qa = optimizer.compile(qa, trainset=trainset)
|
||||
|
||||
# Now optimized_qa has learned few-shot examples!
|
||||
result = optimized_qa(question="What is 5+7?")
|
||||
```
|
||||
|
||||
**Best practices:**
|
||||
- Start with 10-50 training examples
|
||||
- Use diverse examples covering edge cases
|
||||
- Set `max_bootstrapped_demos=3-5` for most tasks
|
||||
- Increase `max_rounds=2-3` for better quality
|
||||
|
||||
**When to use:**
|
||||
- First optimizer to try
|
||||
- You have 10+ labeled examples
|
||||
- Want quick improvements
|
||||
- General-purpose tasks
|
||||
|
||||
### MIPRO (Most Important Prompt Optimization)
|
||||
|
||||
**State-of-the-art optimizer** - Iteratively searches for better instructions.
|
||||
|
||||
**How it works:**
|
||||
1. Generates candidate instructions
|
||||
2. Tests each on validation set
|
||||
3. Selects best-performing instructions
|
||||
4. Iterates to refine further
|
||||
|
||||
**Parameters:**
|
||||
- `metric`: Evaluation metric (required)
|
||||
- `num_candidates`: Instructions to try per iteration (default: 10)
|
||||
- `init_temperature`: Sampling temperature (default: 1.0)
|
||||
- `verbose`: Show progress (default: False)
|
||||
|
||||
```python
|
||||
from dspy.teleprompt import MIPRO
|
||||
|
||||
# Define metric with more nuance
|
||||
def answer_quality(example, pred, trace=None):
|
||||
"""Score answer quality 0-1."""
|
||||
if example.answer.lower() in pred.answer.lower():
|
||||
return 1.0
|
||||
# Partial credit for similar answers
|
||||
return 0.5 if len(set(example.answer.split()) & set(pred.answer.split())) > 0 else 0.0
|
||||
|
||||
# Larger training set (MIPRO benefits from more data)
|
||||
trainset = [...] # 50-200 examples
|
||||
valset = [...] # 20-50 examples
|
||||
|
||||
# Create module
|
||||
qa = dspy.ChainOfThought("question -> answer")
|
||||
|
||||
# Optimize with MIPRO
|
||||
optimizer = MIPRO(
|
||||
metric=answer_quality,
|
||||
num_candidates=10,
|
||||
init_temperature=1.0,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
optimized_qa = optimizer.compile(
|
||||
student=qa,
|
||||
trainset=trainset,
|
||||
valset=valset, # MIPRO uses separate validation set
|
||||
num_trials=100 # More trials = better quality
|
||||
)
|
||||
```
|
||||
|
||||
**Best practices:**
|
||||
- Use 50-200 training examples
|
||||
- Separate validation set (20-50 examples)
|
||||
- Run 100-200 trials for best results
|
||||
- Takes 10-30 minutes typically
|
||||
|
||||
**When to use:**
|
||||
- You have 50+ labeled examples
|
||||
- Want state-of-the-art performance
|
||||
- Willing to wait for optimization
|
||||
- Complex reasoning tasks
|
||||
|
||||
### BootstrapFinetune
|
||||
|
||||
**Fine-tune model weights** - Creates training dataset for fine-tuning.
|
||||
|
||||
**How it works:**
|
||||
1. Generates synthetic training data
|
||||
2. Exports data in fine-tuning format
|
||||
3. You fine-tune model separately
|
||||
4. Load fine-tuned model back
|
||||
|
||||
**Parameters:**
|
||||
- `metric`: Evaluation metric (required)
|
||||
- `max_bootstrapped_demos`: Demonstrations to generate (default: 4)
|
||||
- `max_rounds`: Data generation rounds (default: 1)
|
||||
|
||||
```python
|
||||
from dspy.teleprompt import BootstrapFinetune
|
||||
|
||||
# Training data
|
||||
trainset = [...] # 100+ examples recommended
|
||||
|
||||
# Define metric
|
||||
def validate(example, pred, trace=None):
|
||||
return example.answer == pred.answer
|
||||
|
||||
# Create module
|
||||
qa = dspy.ChainOfThought("question -> answer")
|
||||
|
||||
# Generate fine-tuning data
|
||||
optimizer = BootstrapFinetune(metric=validate)
|
||||
optimized_qa = optimizer.compile(qa, trainset=trainset)
|
||||
|
||||
# Exports training data to file
|
||||
# You then fine-tune using your LM provider's API
|
||||
|
||||
# After fine-tuning, load your model:
|
||||
finetuned_lm = dspy.OpenAI(model="ft:gpt-3.5-turbo:your-model-id")
|
||||
dspy.settings.configure(lm=finetuned_lm)
|
||||
```
|
||||
|
||||
**Best practices:**
|
||||
- Use 100+ training examples
|
||||
- Validate on held-out test set
|
||||
- Monitor for overfitting
|
||||
- Compare with prompt-based methods first
|
||||
|
||||
**When to use:**
|
||||
- You have 100+ examples
|
||||
- Latency is critical (fine-tuned models faster)
|
||||
- Task is narrow and well-defined
|
||||
- Prompt optimization isn't enough
|
||||
|
||||
### COPRO (Coordinate Prompt Optimization)
|
||||
|
||||
**Optimize prompts via gradient-free search.**
|
||||
|
||||
**How it works:**
|
||||
1. Generates prompt variants
|
||||
2. Evaluates each variant
|
||||
3. Selects best prompts
|
||||
4. Iterates to refine
|
||||
|
||||
```python
|
||||
from dspy.teleprompt import COPRO
|
||||
|
||||
# Training data
|
||||
trainset = [...]
|
||||
|
||||
# Define metric
|
||||
def metric(example, pred, trace=None):
|
||||
return example.answer == pred.answer
|
||||
|
||||
# Create module
|
||||
qa = dspy.ChainOfThought("question -> answer")
|
||||
|
||||
# Optimize with COPRO
|
||||
optimizer = COPRO(
|
||||
metric=metric,
|
||||
breadth=10, # Candidates per iteration
|
||||
depth=3 # Optimization rounds
|
||||
)
|
||||
|
||||
optimized_qa = optimizer.compile(qa, trainset=trainset)
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- Want prompt optimization
|
||||
- Have 20-100 examples
|
||||
- MIPRO too slow
|
||||
|
||||
### KNNFewShot
|
||||
|
||||
**Simple k-nearest neighbors** - Selects similar examples for each query.
|
||||
|
||||
**How it works:**
|
||||
1. Embeds all training examples
|
||||
2. For each query, finds k most similar examples
|
||||
3. Uses these as few-shot demonstrations
|
||||
|
||||
```python
|
||||
from dspy.teleprompt import KNNFewShot
|
||||
|
||||
trainset = [...]
|
||||
|
||||
# No metric needed - just selects similar examples
|
||||
optimizer = KNNFewShot(k=3)
|
||||
optimized_qa = optimizer.compile(qa, trainset=trainset)
|
||||
|
||||
# For each query, uses 3 most similar examples from trainset
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- Quick baseline
|
||||
- Have diverse training examples
|
||||
- Similarity is good proxy for helpfulness
|
||||
|
||||
## Writing Metrics
|
||||
|
||||
Metrics are functions that score predictions. They're critical for optimization.
|
||||
|
||||
### Binary Metrics
|
||||
|
||||
```python
|
||||
def exact_match(example, pred, trace=None):
|
||||
"""Return True if prediction exactly matches gold."""
|
||||
return example.answer == pred.answer
|
||||
|
||||
def contains_answer(example, pred, trace=None):
|
||||
"""Return True if prediction contains gold answer."""
|
||||
return example.answer.lower() in pred.answer.lower()
|
||||
```
|
||||
|
||||
### Continuous Metrics
|
||||
|
||||
```python
|
||||
def f1_score(example, pred, trace=None):
|
||||
"""F1 score between prediction and gold."""
|
||||
pred_tokens = set(pred.answer.lower().split())
|
||||
gold_tokens = set(example.answer.lower().split())
|
||||
|
||||
if not pred_tokens:
|
||||
return 0.0
|
||||
|
||||
precision = len(pred_tokens & gold_tokens) / len(pred_tokens)
|
||||
recall = len(pred_tokens & gold_tokens) / len(gold_tokens)
|
||||
|
||||
if precision + recall == 0:
|
||||
return 0.0
|
||||
|
||||
return 2 * (precision * recall) / (precision + recall)
|
||||
|
||||
def semantic_similarity(example, pred, trace=None):
|
||||
"""Embedding similarity between prediction and gold."""
|
||||
from sentence_transformers import SentenceTransformer
|
||||
model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||
|
||||
emb1 = model.encode(example.answer)
|
||||
emb2 = model.encode(pred.answer)
|
||||
|
||||
similarity = cosine_similarity(emb1, emb2)
|
||||
return similarity
|
||||
```
|
||||
|
||||
### Multi-Factor Metrics
|
||||
|
||||
```python
|
||||
def comprehensive_metric(example, pred, trace=None):
|
||||
"""Combine multiple factors."""
|
||||
score = 0.0
|
||||
|
||||
# Correctness (50%)
|
||||
if example.answer.lower() in pred.answer.lower():
|
||||
score += 0.5
|
||||
|
||||
# Conciseness (25%)
|
||||
if len(pred.answer.split()) <= 20:
|
||||
score += 0.25
|
||||
|
||||
# Citation (25%)
|
||||
if "source:" in pred.answer.lower():
|
||||
score += 0.25
|
||||
|
||||
return score
|
||||
```
|
||||
|
||||
### Using Trace for Debugging
|
||||
|
||||
```python
|
||||
def metric_with_trace(example, pred, trace=None):
|
||||
"""Metric that uses trace for debugging."""
|
||||
is_correct = example.answer == pred.answer
|
||||
|
||||
if trace is not None and not is_correct:
|
||||
# Log failures for analysis
|
||||
print(f"Failed on: {example.question}")
|
||||
print(f"Expected: {example.answer}")
|
||||
print(f"Got: {pred.answer}")
|
||||
|
||||
return is_correct
|
||||
```
|
||||
|
||||
## Evaluation Best Practices
|
||||
|
||||
### Train/Val/Test Split
|
||||
|
||||
```python
|
||||
# Split data
|
||||
trainset = data[:100] # 70%
|
||||
valset = data[100:120] # 15%
|
||||
testset = data[120:] # 15%
|
||||
|
||||
# Optimize on train
|
||||
optimized = optimizer.compile(module, trainset=trainset)
|
||||
|
||||
# Validate during optimization (for MIPRO)
|
||||
optimized = optimizer.compile(module, trainset=trainset, valset=valset)
|
||||
|
||||
# Evaluate on test
|
||||
from dspy.evaluate import Evaluate
|
||||
evaluator = Evaluate(devset=testset, metric=metric)
|
||||
score = evaluator(optimized)
|
||||
```
|
||||
|
||||
### Cross-Validation
|
||||
|
||||
```python
|
||||
from sklearn.model_selection import KFold
|
||||
|
||||
kfold = KFold(n_splits=5)
|
||||
scores = []
|
||||
|
||||
for train_idx, val_idx in kfold.split(data):
|
||||
trainset = [data[i] for i in train_idx]
|
||||
valset = [data[i] for i in val_idx]
|
||||
|
||||
optimized = optimizer.compile(module, trainset=trainset)
|
||||
score = evaluator(optimized, devset=valset)
|
||||
scores.append(score)
|
||||
|
||||
print(f"Average score: {sum(scores) / len(scores):.2f}")
|
||||
```
|
||||
|
||||
### Comparing Optimizers
|
||||
|
||||
```python
|
||||
results = {}
|
||||
|
||||
for opt_name, optimizer in [
|
||||
("baseline", None),
|
||||
("fewshot", BootstrapFewShot(metric=metric)),
|
||||
("mipro", MIPRO(metric=metric)),
|
||||
]:
|
||||
if optimizer is None:
|
||||
module_opt = module
|
||||
else:
|
||||
module_opt = optimizer.compile(module, trainset=trainset)
|
||||
|
||||
score = evaluator(module_opt, devset=testset)
|
||||
results[opt_name] = score
|
||||
|
||||
print(results)
|
||||
# {'baseline': 0.65, 'fewshot': 0.78, 'mipro': 0.85}
|
||||
```
|
||||
|
||||
## Advanced Patterns
|
||||
|
||||
### Custom Optimizer
|
||||
|
||||
```python
|
||||
from dspy.teleprompt import Teleprompter
|
||||
|
||||
class CustomOptimizer(Teleprompter):
|
||||
def __init__(self, metric):
|
||||
self.metric = metric
|
||||
|
||||
def compile(self, student, trainset, **kwargs):
|
||||
# Your optimization logic here
|
||||
# Return optimized student module
|
||||
return student
|
||||
```
|
||||
|
||||
### Multi-Stage Optimization
|
||||
|
||||
```python
|
||||
# Stage 1: Bootstrap few-shot
|
||||
stage1 = BootstrapFewShot(metric=metric, max_bootstrapped_demos=3)
|
||||
optimized1 = stage1.compile(module, trainset=trainset)
|
||||
|
||||
# Stage 2: Instruction tuning
|
||||
stage2 = MIPRO(metric=metric, num_candidates=10)
|
||||
optimized2 = stage2.compile(optimized1, trainset=trainset, valset=valset)
|
||||
|
||||
# Final optimized module
|
||||
final_module = optimized2
|
||||
```
|
||||
|
||||
### Ensemble Optimization
|
||||
|
||||
```python
|
||||
class EnsembleModule(dspy.Module):
|
||||
def __init__(self, modules):
|
||||
super().__init__()
|
||||
self.modules = modules
|
||||
|
||||
def forward(self, question):
|
||||
predictions = [m(question=question).answer for m in self.modules]
|
||||
# Vote or average
|
||||
return dspy.Prediction(answer=max(set(predictions), key=predictions.count))
|
||||
|
||||
# Optimize multiple modules
|
||||
opt1 = BootstrapFewShot(metric=metric).compile(module, trainset=trainset)
|
||||
opt2 = MIPRO(metric=metric).compile(module, trainset=trainset)
|
||||
opt3 = COPRO(metric=metric).compile(module, trainset=trainset)
|
||||
|
||||
# Ensemble
|
||||
ensemble = EnsembleModule([opt1, opt2, opt3])
|
||||
```
|
||||
|
||||
## Optimization Workflow
|
||||
|
||||
### 1. Start with Baseline
|
||||
|
||||
```python
|
||||
# No optimization
|
||||
baseline = dspy.ChainOfThought("question -> answer")
|
||||
baseline_score = evaluator(baseline, devset=testset)
|
||||
print(f"Baseline: {baseline_score}")
|
||||
```
|
||||
|
||||
### 2. Try BootstrapFewShot
|
||||
|
||||
```python
|
||||
# Quick optimization
|
||||
fewshot = BootstrapFewShot(metric=metric, max_bootstrapped_demos=3)
|
||||
optimized = fewshot.compile(baseline, trainset=trainset)
|
||||
fewshot_score = evaluator(optimized, devset=testset)
|
||||
print(f"Few-shot: {fewshot_score} (+{fewshot_score - baseline_score:.2f})")
|
||||
```
|
||||
|
||||
### 3. If More Data Available, Try MIPRO
|
||||
|
||||
```python
|
||||
# State-of-the-art optimization
|
||||
mipro = MIPRO(metric=metric, num_candidates=10)
|
||||
optimized_mipro = mipro.compile(baseline, trainset=trainset, valset=valset)
|
||||
mipro_score = evaluator(optimized_mipro, devset=testset)
|
||||
print(f"MIPRO: {mipro_score} (+{mipro_score - baseline_score:.2f})")
|
||||
```
|
||||
|
||||
### 4. Save Best Model
|
||||
|
||||
```python
|
||||
if mipro_score > fewshot_score:
|
||||
optimized_mipro.save("models/best_model.json")
|
||||
else:
|
||||
optimized.save("models/best_model.json")
|
||||
```
|
||||
|
||||
## Common Pitfalls
|
||||
|
||||
### 1. Overfitting to Training Data
|
||||
|
||||
```python
|
||||
# ❌ Bad: Too many demos
|
||||
optimizer = BootstrapFewShot(max_bootstrapped_demos=20) # Overfits!
|
||||
|
||||
# ✅ Good: Moderate demos
|
||||
optimizer = BootstrapFewShot(max_bootstrapped_demos=3-5)
|
||||
```
|
||||
|
||||
### 2. Metric Doesn't Match Task
|
||||
|
||||
```python
|
||||
# ❌ Bad: Binary metric for nuanced task
|
||||
def bad_metric(example, pred, trace=None):
|
||||
return example.answer == pred.answer # Too strict!
|
||||
|
||||
# ✅ Good: Graded metric
|
||||
def good_metric(example, pred, trace=None):
|
||||
return f1_score(example.answer, pred.answer) # Allows partial credit
|
||||
```
|
||||
|
||||
### 3. Insufficient Training Data
|
||||
|
||||
```python
|
||||
# ❌ Bad: Too little data
|
||||
trainset = data[:5] # Not enough!
|
||||
|
||||
# ✅ Good: Sufficient data
|
||||
trainset = data[:50] # Better
|
||||
```
|
||||
|
||||
### 4. No Validation Set
|
||||
|
||||
```python
|
||||
# ❌ Bad: Optimizing on test set
|
||||
optimizer.compile(module, trainset=testset) # Cheating!
|
||||
|
||||
# ✅ Good: Proper splits
|
||||
optimizer.compile(module, trainset=trainset, valset=valset)
|
||||
evaluator(optimized, devset=testset)
|
||||
```
|
||||
|
||||
## Performance Tips
|
||||
|
||||
1. **Start simple**: BootstrapFewShot first
|
||||
2. **Use representative data**: Cover edge cases
|
||||
3. **Monitor overfitting**: Validate on held-out set
|
||||
4. **Iterate metrics**: Refine based on failures
|
||||
5. **Save checkpoints**: Don't lose progress
|
||||
6. **Compare to baseline**: Measure improvement
|
||||
7. **Test multiple optimizers**: Find best fit
|
||||
|
||||
## Resources
|
||||
|
||||
- **Paper**: "DSPy: Compiling Declarative Language Model Calls into Self-Improving Pipelines"
|
||||
- **GitHub**: https://github.com/stanfordnlp/dspy
|
||||
- **Discord**: https://discord.gg/XCGy2WDCQB
|
||||
@@ -0,0 +1,3 @@
|
||||
---
|
||||
description: Fine-tuning, RLHF/DPO/GRPO training, distributed training frameworks, and optimization tools for training LLMs and other models.
|
||||
---
|
||||
@@ -0,0 +1,165 @@
|
||||
---
|
||||
name: axolotl
|
||||
description: "Axolotl: YAML LLM fine-tuning (LoRA, DPO, GRPO)."
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [axolotl, torch, transformers, datasets, peft, accelerate, deepspeed]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Fine-Tuning, Axolotl, LLM, LoRA, QLoRA, DPO, KTO, ORPO, GRPO, YAML, HuggingFace, DeepSpeed, Multimodal]
|
||||
|
||||
---
|
||||
|
||||
# Axolotl Skill
|
||||
|
||||
## What's inside
|
||||
|
||||
Expert guidance for fine-tuning LLMs with Axolotl — YAML configs, 100+ models, LoRA/QLoRA, DPO/KTO/ORPO/GRPO, multimodal support.
|
||||
|
||||
Comprehensive assistance with axolotl development, generated from official documentation.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
This skill should be triggered when:
|
||||
- Working with axolotl
|
||||
- Asking about axolotl features or APIs
|
||||
- Implementing axolotl solutions
|
||||
- Debugging axolotl code
|
||||
- Learning axolotl best practices
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### Common Patterns
|
||||
|
||||
**Pattern 1:** To validate that acceptable data transfer speeds exist for your training job, running NCCL Tests can help pinpoint bottlenecks, for example:
|
||||
|
||||
```
|
||||
./build/all_reduce_perf -b 8 -e 128M -f 2 -g 3
|
||||
```
|
||||
|
||||
**Pattern 2:** Configure your model to use FSDP in the Axolotl yaml. For example:
|
||||
|
||||
```
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
offload_params: true
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
reshard_after_forward: true
|
||||
```
|
||||
|
||||
**Pattern 3:** The context_parallel_size should be a divisor of the total number of GPUs. For example:
|
||||
|
||||
```
|
||||
context_parallel_size
|
||||
```
|
||||
|
||||
**Pattern 4:** For example: - With 8 GPUs and no sequence parallelism: 8 different batches processed per step - With 8 GPUs and context_parallel_size=4: Only 2 different batches processed per step (each split across 4 GPUs) - If your per-GPU micro_batch_size is 2, the global batch size decreases from 16 to 4
|
||||
|
||||
```
|
||||
context_parallel_size=4
|
||||
```
|
||||
|
||||
**Pattern 5:** Setting save_compressed: true in your configuration enables saving models in a compressed format, which: - Reduces disk space usage by approximately 40% - Maintains compatibility with vLLM for accelerated inference - Maintains compatibility with llmcompressor for further optimization (example: quantization)
|
||||
|
||||
```
|
||||
save_compressed: true
|
||||
```
|
||||
|
||||
**Pattern 6:** Note It is not necessary to place your integration in the integrations folder. It can be in any location, so long as it’s installed in a package in your python env. See this repo for an example: https://github.com/axolotl-ai-cloud/diff-transformer
|
||||
|
||||
```
|
||||
integrations
|
||||
```
|
||||
|
||||
**Pattern 7:** Handle both single-example and batched data. - single example: sample[‘input_ids’] is a list[int] - batched data: sample[‘input_ids’] is a list[list[int]]
|
||||
|
||||
```
|
||||
utils.trainer.drop_long_seq(sample, sequence_len=2048, min_sequence_len=2)
|
||||
```
|
||||
|
||||
### Example Code Patterns
|
||||
|
||||
**Example 1** (python):
|
||||
```python
|
||||
cli.cloud.modal_.ModalCloud(config, app=None)
|
||||
```
|
||||
|
||||
**Example 2** (python):
|
||||
```python
|
||||
cli.cloud.modal_.run_cmd(cmd, run_folder, volumes=None)
|
||||
```
|
||||
|
||||
**Example 3** (python):
|
||||
```python
|
||||
core.trainers.base.AxolotlTrainer(
|
||||
*_args,
|
||||
bench_data_collator=None,
|
||||
eval_data_collator=None,
|
||||
dataset_tags=None,
|
||||
**kwargs,
|
||||
)
|
||||
```
|
||||
|
||||
**Example 4** (python):
|
||||
```python
|
||||
core.trainers.base.AxolotlTrainer.log(logs, start_time=None)
|
||||
```
|
||||
|
||||
**Example 5** (python):
|
||||
```python
|
||||
prompt_strategies.input_output.RawInputOutputPrompter()
|
||||
```
|
||||
|
||||
## Reference Files
|
||||
|
||||
This skill includes comprehensive documentation in `references/`:
|
||||
|
||||
- **api.md** - Api documentation
|
||||
- **dataset-formats.md** - Dataset-Formats documentation
|
||||
- **other.md** - Other documentation
|
||||
|
||||
Use `view` to read specific reference files when detailed information is needed.
|
||||
|
||||
## Working with This Skill
|
||||
|
||||
### For Beginners
|
||||
Start with the getting_started or tutorials reference files for foundational concepts.
|
||||
|
||||
### For Specific Features
|
||||
Use the appropriate category reference file (api, guides, etc.) for detailed information.
|
||||
|
||||
### For Code Examples
|
||||
The quick reference section above contains common patterns extracted from the official docs.
|
||||
|
||||
## Resources
|
||||
|
||||
### references/
|
||||
Organized documentation extracted from official sources. These files contain:
|
||||
- Detailed explanations
|
||||
- Code examples with language annotations
|
||||
- Links to original documentation
|
||||
- Table of contents for quick navigation
|
||||
|
||||
### scripts/
|
||||
Add helper scripts here for common automation tasks.
|
||||
|
||||
### assets/
|
||||
Add templates, boilerplate, or example projects here.
|
||||
|
||||
## Notes
|
||||
|
||||
- This skill was automatically generated from official documentation
|
||||
- Reference files preserve the structure and examples from source docs
|
||||
- Code examples include language detection for better syntax highlighting
|
||||
- Quick reference patterns are extracted from common usage examples in the docs
|
||||
|
||||
## Updating
|
||||
|
||||
To refresh this skill with updated documentation:
|
||||
1. Re-run the scraper with the same configuration
|
||||
2. The skill will be rebuilt with the latest information
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,15 @@
|
||||
# Axolotl Documentation Index
|
||||
|
||||
## Categories
|
||||
|
||||
### Api
|
||||
**File:** `api.md`
|
||||
**Pages:** 150
|
||||
|
||||
### Dataset-Formats
|
||||
**File:** `dataset-formats.md`
|
||||
**Pages:** 9
|
||||
|
||||
### Other
|
||||
**File:** `other.md`
|
||||
**Pages:** 26
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,462 @@
|
||||
---
|
||||
name: fine-tuning-with-trl
|
||||
description: "TRL: SFT, DPO, PPO, GRPO, reward modeling for LLM RLHF."
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [trl, transformers, datasets, peft, accelerate, torch]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Post-Training, TRL, Reinforcement Learning, Fine-Tuning, SFT, DPO, PPO, GRPO, RLHF, Preference Alignment, HuggingFace]
|
||||
|
||||
---
|
||||
|
||||
# TRL - Transformer Reinforcement Learning
|
||||
|
||||
## Quick start
|
||||
|
||||
TRL provides post-training methods for aligning language models with human preferences.
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install trl transformers datasets peft accelerate
|
||||
```
|
||||
|
||||
**Supervised Fine-Tuning** (instruction tuning):
|
||||
```python
|
||||
from trl import SFTTrainer
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model="Qwen/Qwen2.5-0.5B",
|
||||
train_dataset=dataset, # Prompt-completion pairs
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
**DPO** (align with preferences):
|
||||
```python
|
||||
from trl import DPOTrainer, DPOConfig
|
||||
|
||||
config = DPOConfig(output_dir="model-dpo", beta=0.1)
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
args=config,
|
||||
train_dataset=preference_dataset, # chosen/rejected pairs
|
||||
processing_class=tokenizer
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Full RLHF pipeline (SFT → Reward Model → PPO)
|
||||
|
||||
Complete pipeline from base model to human-aligned model.
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
RLHF Training:
|
||||
- [ ] Step 1: Supervised fine-tuning (SFT)
|
||||
- [ ] Step 2: Train reward model
|
||||
- [ ] Step 3: PPO reinforcement learning
|
||||
- [ ] Step 4: Evaluate aligned model
|
||||
```
|
||||
|
||||
**Step 1: Supervised fine-tuning**
|
||||
|
||||
Train base model on instruction-following data:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
from datasets import load_dataset
|
||||
|
||||
# Load model
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
|
||||
|
||||
# Load instruction dataset
|
||||
dataset = load_dataset("trl-lib/Capybara", split="train")
|
||||
|
||||
# Configure training
|
||||
training_args = SFTConfig(
|
||||
output_dir="Qwen2.5-0.5B-SFT",
|
||||
per_device_train_batch_size=4,
|
||||
num_train_epochs=1,
|
||||
learning_rate=2e-5,
|
||||
logging_steps=10,
|
||||
save_strategy="epoch"
|
||||
)
|
||||
|
||||
# Train
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model()
|
||||
```
|
||||
|
||||
**Step 2: Train reward model**
|
||||
|
||||
Train model to predict human preferences:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
from trl import RewardTrainer, RewardConfig
|
||||
|
||||
# Load SFT model as base
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"Qwen2.5-0.5B-SFT",
|
||||
num_labels=1 # Single reward score
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen2.5-0.5B-SFT")
|
||||
|
||||
# Load preference data (chosen/rejected pairs)
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
|
||||
# Configure training
|
||||
training_args = RewardConfig(
|
||||
output_dir="Qwen2.5-0.5B-Reward",
|
||||
per_device_train_batch_size=2,
|
||||
num_train_epochs=1,
|
||||
learning_rate=1e-5
|
||||
)
|
||||
|
||||
# Train reward model
|
||||
trainer = RewardTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
processing_class=tokenizer,
|
||||
train_dataset=dataset
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model()
|
||||
```
|
||||
|
||||
**Step 3: PPO reinforcement learning**
|
||||
|
||||
Optimize policy using reward model:
|
||||
|
||||
```bash
|
||||
python -m trl.scripts.ppo \
|
||||
--model_name_or_path Qwen2.5-0.5B-SFT \
|
||||
--reward_model_path Qwen2.5-0.5B-Reward \
|
||||
--dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \
|
||||
--output_dir Qwen2.5-0.5B-PPO \
|
||||
--learning_rate 3e-6 \
|
||||
--per_device_train_batch_size 64 \
|
||||
--total_episodes 10000
|
||||
```
|
||||
|
||||
**Step 4: Evaluate**
|
||||
|
||||
```python
|
||||
from transformers import pipeline
|
||||
|
||||
# Load aligned model
|
||||
generator = pipeline("text-generation", model="Qwen2.5-0.5B-PPO")
|
||||
|
||||
# Test
|
||||
prompt = "Explain quantum computing to a 10-year-old"
|
||||
output = generator(prompt, max_length=200)[0]["generated_text"]
|
||||
print(output)
|
||||
```
|
||||
|
||||
### Workflow 2: Simple preference alignment with DPO
|
||||
|
||||
Align model with preferences without reward model.
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
DPO Training:
|
||||
- [ ] Step 1: Prepare preference dataset
|
||||
- [ ] Step 2: Configure DPO
|
||||
- [ ] Step 3: Train with DPOTrainer
|
||||
- [ ] Step 4: Evaluate alignment
|
||||
```
|
||||
|
||||
**Step 1: Prepare preference dataset**
|
||||
|
||||
Dataset format:
|
||||
```json
|
||||
{
|
||||
"prompt": "What is the capital of France?",
|
||||
"chosen": "The capital of France is Paris.",
|
||||
"rejected": "I don't know."
|
||||
}
|
||||
```
|
||||
|
||||
Load dataset:
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
# Or load your own
|
||||
# dataset = load_dataset("json", data_files="preferences.json")
|
||||
```
|
||||
|
||||
**Step 2: Configure DPO**
|
||||
|
||||
```python
|
||||
from trl import DPOConfig
|
||||
|
||||
config = DPOConfig(
|
||||
output_dir="Qwen2.5-0.5B-DPO",
|
||||
per_device_train_batch_size=4,
|
||||
num_train_epochs=1,
|
||||
learning_rate=5e-7,
|
||||
beta=0.1, # KL penalty strength
|
||||
max_prompt_length=512,
|
||||
max_length=1024,
|
||||
logging_steps=10
|
||||
)
|
||||
```
|
||||
|
||||
**Step 3: Train with DPOTrainer**
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import DPOTrainer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
args=config,
|
||||
train_dataset=dataset,
|
||||
processing_class=tokenizer
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
trainer.save_model()
|
||||
```
|
||||
|
||||
**CLI alternative**:
|
||||
```bash
|
||||
trl dpo \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--dataset_name argilla/Capybara-Preferences \
|
||||
--output_dir Qwen2.5-0.5B-DPO \
|
||||
--per_device_train_batch_size 4 \
|
||||
--learning_rate 5e-7 \
|
||||
--beta 0.1
|
||||
```
|
||||
|
||||
### Workflow 3: Memory-efficient online RL with GRPO
|
||||
|
||||
Train with reinforcement learning using minimal memory.
|
||||
|
||||
For in-depth GRPO guidance — reward function design, critical training insights (loss behavior, mode collapse, tuning), and advanced multi-stage patterns — see **[references/grpo-training.md](references/grpo-training.md)**. A production-ready training script is in **[templates/basic_grpo_training.py](templates/basic_grpo_training.py)**.
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
GRPO Training:
|
||||
- [ ] Step 1: Define reward function
|
||||
- [ ] Step 2: Configure GRPO
|
||||
- [ ] Step 3: Train with GRPOTrainer
|
||||
```
|
||||
|
||||
**Step 1: Define reward function**
|
||||
|
||||
```python
|
||||
def reward_function(completions, **kwargs):
|
||||
"""
|
||||
Compute rewards for completions.
|
||||
|
||||
Args:
|
||||
completions: List of generated texts
|
||||
|
||||
Returns:
|
||||
List of reward scores (floats)
|
||||
"""
|
||||
rewards = []
|
||||
for completion in completions:
|
||||
# Example: reward based on length and unique words
|
||||
score = len(completion.split()) # Favor longer responses
|
||||
score += len(set(completion.lower().split())) # Reward unique words
|
||||
rewards.append(score)
|
||||
return rewards
|
||||
```
|
||||
|
||||
Or use a reward model:
|
||||
```python
|
||||
from transformers import pipeline
|
||||
|
||||
reward_model = pipeline("text-classification", model="reward-model-path")
|
||||
|
||||
def reward_from_model(completions, prompts, **kwargs):
|
||||
# Combine prompt + completion
|
||||
full_texts = [p + c for p, c in zip(prompts, completions)]
|
||||
# Get reward scores
|
||||
results = reward_model(full_texts)
|
||||
return [r["score"] for r in results]
|
||||
```
|
||||
|
||||
**Step 2: Configure GRPO**
|
||||
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
config = GRPOConfig(
|
||||
output_dir="Qwen2-GRPO",
|
||||
per_device_train_batch_size=4,
|
||||
num_train_epochs=1,
|
||||
learning_rate=1e-5,
|
||||
num_generations=4, # Generate 4 completions per prompt
|
||||
max_new_tokens=128
|
||||
)
|
||||
```
|
||||
|
||||
**Step 3: Train with GRPOTrainer**
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from trl import GRPOTrainer
|
||||
|
||||
# Load prompt-only dataset
|
||||
dataset = load_dataset("trl-lib/tldr", split="train")
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model="Qwen/Qwen2-0.5B-Instruct",
|
||||
reward_funcs=reward_function, # Your reward function
|
||||
args=config,
|
||||
train_dataset=dataset
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
**CLI**:
|
||||
```bash
|
||||
trl grpo \
|
||||
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
|
||||
--dataset_name trl-lib/tldr \
|
||||
--output_dir Qwen2-GRPO \
|
||||
--num_generations 4
|
||||
```
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use TRL when:**
|
||||
- Need to align model with human preferences
|
||||
- Have preference data (chosen/rejected pairs)
|
||||
- Want to use reinforcement learning (PPO, GRPO)
|
||||
- Need reward model training
|
||||
- Doing RLHF (full pipeline)
|
||||
|
||||
**Method selection**:
|
||||
- **SFT**: Have prompt-completion pairs, want basic instruction following
|
||||
- **DPO**: Have preferences, want simple alignment (no reward model needed)
|
||||
- **PPO**: Have reward model, need maximum control over RL
|
||||
- **GRPO**: Memory-constrained, want online RL
|
||||
- **Reward Model**: Building RLHF pipeline, need to score generations
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **HuggingFace Trainer**: Basic fine-tuning without RL
|
||||
- **Axolotl**: YAML-based training configuration
|
||||
- **LitGPT**: Educational, minimal fine-tuning
|
||||
- **Unsloth**: Fast LoRA training
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: OOM during DPO training**
|
||||
|
||||
Reduce batch size and sequence length:
|
||||
```python
|
||||
config = DPOConfig(
|
||||
per_device_train_batch_size=1, # Reduce from 4
|
||||
max_length=512, # Reduce from 1024
|
||||
gradient_accumulation_steps=8 # Maintain effective batch
|
||||
)
|
||||
```
|
||||
|
||||
Or use gradient checkpointing:
|
||||
```python
|
||||
model.gradient_checkpointing_enable()
|
||||
```
|
||||
|
||||
**Issue: Poor alignment quality**
|
||||
|
||||
Tune beta parameter:
|
||||
```python
|
||||
# Higher beta = more conservative (stays closer to reference)
|
||||
config = DPOConfig(beta=0.5) # Default 0.1
|
||||
|
||||
# Lower beta = more aggressive alignment
|
||||
config = DPOConfig(beta=0.01)
|
||||
```
|
||||
|
||||
**Issue: Reward model not learning**
|
||||
|
||||
Check loss type and learning rate:
|
||||
```python
|
||||
config = RewardConfig(
|
||||
learning_rate=1e-5, # Try different LR
|
||||
num_train_epochs=3 # Train longer
|
||||
)
|
||||
```
|
||||
|
||||
Ensure preference dataset has clear winners:
|
||||
```python
|
||||
# Verify dataset
|
||||
print(dataset[0])
|
||||
# Should have clear chosen > rejected
|
||||
```
|
||||
|
||||
**Issue: PPO training unstable**
|
||||
|
||||
Adjust KL coefficient:
|
||||
```python
|
||||
config = PPOConfig(
|
||||
kl_coef=0.1, # Increase from 0.05
|
||||
cliprange=0.1 # Reduce from 0.2
|
||||
)
|
||||
```
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**SFT training guide**: See [references/sft-training.md](references/sft-training.md) for dataset formats, chat templates, packing strategies, and multi-GPU training.
|
||||
|
||||
**DPO variants**: See [references/dpo-variants.md](references/dpo-variants.md) for IPO, cDPO, RPO, and other DPO loss functions with recommended hyperparameters.
|
||||
|
||||
**Reward modeling**: See [references/reward-modeling.md](references/reward-modeling.md) for outcome vs process rewards, Bradley-Terry loss, and reward model evaluation.
|
||||
|
||||
**Online RL methods**: See [references/online-rl.md](references/online-rl.md) for PPO, GRPO, RLOO, and OnlineDPO with detailed configurations.
|
||||
|
||||
**GRPO deep dive**: See [references/grpo-training.md](references/grpo-training.md) for expert-level GRPO patterns — reward function design philosophy, training insights (why loss increases, mode collapse detection), hyperparameter tuning, multi-stage training, and troubleshooting. Production-ready template in [templates/basic_grpo_training.py](templates/basic_grpo_training.py).
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
- **GPU**: NVIDIA (CUDA required)
|
||||
- **VRAM**: Depends on model and method
|
||||
- SFT 7B: 16GB (with LoRA)
|
||||
- DPO 7B: 24GB (stores reference model)
|
||||
- PPO 7B: 40GB (policy + reward model)
|
||||
- GRPO 7B: 24GB (more memory efficient)
|
||||
- **Multi-GPU**: Supported via `accelerate`
|
||||
- **Mixed precision**: BF16 recommended (A100/H100)
|
||||
|
||||
**Memory optimization**:
|
||||
- Use LoRA/QLoRA for all methods
|
||||
- Enable gradient checkpointing
|
||||
- Use smaller batch sizes with gradient accumulation
|
||||
|
||||
## Resources
|
||||
|
||||
- Docs: https://huggingface.co/docs/trl/
|
||||
- GitHub: https://github.com/huggingface/trl
|
||||
- Papers:
|
||||
- "Training language models to follow instructions with human feedback" (InstructGPT, 2022)
|
||||
- "Direct Preference Optimization: Your Language Model is Secretly a Reward Model" (DPO, 2023)
|
||||
- "Group Relative Policy Optimization" (GRPO, 2024)
|
||||
- Examples: https://github.com/huggingface/trl/tree/main/examples/scripts
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,227 @@
|
||||
# DPO Variants
|
||||
|
||||
Complete guide to Direct Preference Optimization loss variants in TRL.
|
||||
|
||||
## Overview
|
||||
|
||||
DPO optimizes models using preference data (chosen/rejected pairs). TRL supports 10+ loss variants for different scenarios.
|
||||
|
||||
## Loss Types
|
||||
|
||||
### 1. Sigmoid (Standard DPO)
|
||||
|
||||
**Formula**: `-log(sigmoid(β * logits))`
|
||||
|
||||
**When to use**: Default choice, general preference alignment
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="sigmoid",
|
||||
beta=0.1, # KL penalty
|
||||
per_device_train_batch_size=64,
|
||||
learning_rate=1e-6
|
||||
)
|
||||
```
|
||||
|
||||
### 2. IPO (Identity Policy Optimization)
|
||||
|
||||
**Formula**: `(logits - 1/(2β))²`
|
||||
|
||||
**When to use**: Better theoretical foundation, reduce overfitting
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="ipo",
|
||||
beta=0.1,
|
||||
per_device_train_batch_size=90,
|
||||
learning_rate=1e-2
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Hinge (SLiC)
|
||||
|
||||
**Formula**: `ReLU(1 - β * logits)`
|
||||
|
||||
**When to use**: Margin-based objective
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="hinge",
|
||||
beta=0.1,
|
||||
per_device_train_batch_size=512,
|
||||
learning_rate=1e-4
|
||||
)
|
||||
```
|
||||
|
||||
### 4. Robust DPO
|
||||
|
||||
**Formula**: Sigmoid with label smoothing for noise robustness
|
||||
|
||||
**When to use**: Noisy preference labels
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="robust",
|
||||
beta=0.01,
|
||||
label_smoothing=0.1, # Noise probability
|
||||
per_device_train_batch_size=16,
|
||||
learning_rate=1e-3,
|
||||
max_prompt_length=128,
|
||||
max_length=512
|
||||
)
|
||||
```
|
||||
|
||||
### 5. BCO Pair (Binary Classification)
|
||||
|
||||
**Formula**: Train binary classifier (chosen=1, rejected=0)
|
||||
|
||||
**When to use**: Pairwise preference data
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="bco_pair",
|
||||
beta=0.01,
|
||||
per_device_train_batch_size=128,
|
||||
learning_rate=5e-7,
|
||||
max_prompt_length=1536,
|
||||
max_completion_length=512
|
||||
)
|
||||
```
|
||||
|
||||
### 6. SPPO Hard
|
||||
|
||||
**Formula**: Push chosen→0.5, rejected→-0.5
|
||||
|
||||
**When to use**: Nash equilibrium, sparse data
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="sppo_hard",
|
||||
beta=0.1
|
||||
)
|
||||
```
|
||||
|
||||
### 7. DiscoPOP
|
||||
|
||||
**Formula**: Log-Ratio Modulated Loss
|
||||
|
||||
**When to use**: Automated loss discovery
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="discopop",
|
||||
beta=0.05,
|
||||
discopop_tau=0.05,
|
||||
per_device_train_batch_size=64,
|
||||
learning_rate=5e-7
|
||||
)
|
||||
```
|
||||
|
||||
### 8. APO Zero
|
||||
|
||||
**Formula**: Increase chosen, decrease rejected likelihood
|
||||
|
||||
**When to use**: Model worse than winning outputs
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="apo_zero",
|
||||
beta=0.1,
|
||||
per_device_train_batch_size=64,
|
||||
learning_rate=2e-7,
|
||||
max_prompt_length=512,
|
||||
max_completion_length=512
|
||||
)
|
||||
```
|
||||
|
||||
### 9. APO Down
|
||||
|
||||
**Formula**: Decrease both, emphasize rejected reduction
|
||||
|
||||
**When to use**: Model better than winning outputs
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="apo_down",
|
||||
beta=0.1,
|
||||
# Same hyperparameters as apo_zero
|
||||
)
|
||||
```
|
||||
|
||||
### 10. AOT & AOT Pair
|
||||
|
||||
**Formula**: Distributional alignment via stochastic dominance
|
||||
|
||||
**When to use**:
|
||||
- `aot_pair`: Paired preference data
|
||||
- `aot`: Unpaired data
|
||||
|
||||
**Config**:
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type="aot_pair", # or "aot"
|
||||
beta=0.1,
|
||||
label_smoothing=0.0
|
||||
)
|
||||
```
|
||||
|
||||
## Multi-Loss Training
|
||||
|
||||
Combine multiple losses:
|
||||
|
||||
```python
|
||||
DPOConfig(
|
||||
loss_type=["sigmoid", "ipo"],
|
||||
loss_weights=[0.7, 0.3], # Weighted combination
|
||||
beta=0.1
|
||||
)
|
||||
```
|
||||
|
||||
## Key Parameters
|
||||
|
||||
### Beta (β)
|
||||
|
||||
Controls deviation from reference model:
|
||||
- **Higher** (0.5): More conservative, stays close to reference
|
||||
- **Lower** (0.01): More aggressive alignment
|
||||
- **Default**: 0.1
|
||||
|
||||
### Label Smoothing
|
||||
|
||||
For robust DPO:
|
||||
- **0.0**: No smoothing (default)
|
||||
- **0.1-0.3**: Moderate noise robustness
|
||||
- **0.5**: Maximum noise tolerance
|
||||
|
||||
### Max Lengths
|
||||
|
||||
- `max_prompt_length`: 128-1536
|
||||
- `max_completion_length`: 128-512
|
||||
- `max_length`: Total sequence (1024-2048)
|
||||
|
||||
## Comparison Table
|
||||
|
||||
| Loss | Speed | Stability | Best For |
|
||||
|------|-------|-----------|----------|
|
||||
| Sigmoid | Fast | Good | **General use** |
|
||||
| IPO | Fast | Better | Overfitting issues |
|
||||
| Hinge | Fast | Good | Margin objectives |
|
||||
| Robust | Fast | Best | Noisy data |
|
||||
| BCO | Medium | Good | Binary classification |
|
||||
| DiscoPOP | Fast | Good | New architectures |
|
||||
| APO | Fast | Good | Model quality matching |
|
||||
|
||||
## References
|
||||
|
||||
- DPO paper: https://arxiv.org/abs/2305.18290
|
||||
- IPO paper: https://arxiv.org/abs/2310.12036
|
||||
- TRL docs: https://huggingface.co/docs/trl/dpo_trainer
|
||||
@@ -0,0 +1,504 @@
|
||||
# GRPO (Group Relative Policy Optimization) — Deep Guide
|
||||
|
||||
Expert-level patterns, critical insights, and production-ready workflows for fine-tuning language models with custom reward functions using TRL's `GRPOTrainer`. This is the deep reference for the GRPO workflow summarized in the main skill.
|
||||
|
||||
## When to use GRPO
|
||||
|
||||
Use GRPO when you need to:
|
||||
- **Enforce specific output formats** (XML tags, JSON, structured reasoning)
|
||||
- **Teach verifiable tasks** with objective correctness metrics (math, coding, fact-checking)
|
||||
- **Improve reasoning capabilities** by rewarding chain-of-thought patterns
|
||||
- **Align models to domain-specific behaviors** without labeled preference data
|
||||
- **Optimize for multiple objectives** simultaneously (format + correctness + style)
|
||||
|
||||
**Do NOT use GRPO for:**
|
||||
- Simple supervised fine-tuning tasks → use SFT
|
||||
- Tasks without clear reward signals
|
||||
- When you already have high-quality preference pairs → use DPO/PPO
|
||||
|
||||
## Core concepts
|
||||
|
||||
### 1. GRPO algorithm fundamentals
|
||||
|
||||
**Key mechanism:**
|
||||
- Generates **multiple completions** per prompt (group size: 4–16)
|
||||
- Compares completions within each group using reward functions
|
||||
- Updates policy to favor higher-rewarded responses relative to the group
|
||||
|
||||
**Critical differences from PPO:**
|
||||
- No separate reward model needed
|
||||
- More sample-efficient (learns from within-group comparisons)
|
||||
- Simpler to implement and debug
|
||||
|
||||
**Mathematical intuition:**
|
||||
```
|
||||
For each prompt p:
|
||||
1. Generate N completions: {c₁, c₂, ..., cₙ}
|
||||
2. Compute rewards: {r₁, r₂, ..., rₙ}
|
||||
3. Learn to increase probability of high-reward completions
|
||||
relative to low-reward ones in the same group
|
||||
```
|
||||
|
||||
### 2. Reward function design philosophy
|
||||
|
||||
**Golden rules:**
|
||||
1. **Compose multiple reward functions** — each handles one aspect (format, correctness, style)
|
||||
2. **Scale rewards appropriately** — higher weight = stronger signal
|
||||
3. **Use incremental rewards** — partial credit for partial compliance
|
||||
4. **Test rewards independently** — debug each reward function in isolation
|
||||
|
||||
**Reward function types:**
|
||||
|
||||
| Type | Use Case | Example Weight |
|
||||
|------|----------|----------------|
|
||||
| **Correctness** | Verifiable tasks (math, code) | 2.0 (highest) |
|
||||
| **Format** | Strict structure enforcement | 0.5–1.0 |
|
||||
| **Length** | Encourage verbosity/conciseness | 0.1–0.5 |
|
||||
| **Style** | Penalize unwanted patterns | −0.5 to 0.5 |
|
||||
|
||||
## Implementation workflow
|
||||
|
||||
### Step 1: Dataset preparation
|
||||
|
||||
**Critical requirements:**
|
||||
- Prompts in chat format (list of dicts with `role` and `content`)
|
||||
- Include system prompts to set expectations
|
||||
- For verifiable tasks, include ground truth answers as additional columns
|
||||
|
||||
```python
|
||||
from datasets import load_dataset, Dataset
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
Respond in the following format:
|
||||
<reasoning>
|
||||
[Your step-by-step thinking]
|
||||
</reasoning>
|
||||
<answer>
|
||||
[Final answer]
|
||||
</answer>
|
||||
"""
|
||||
|
||||
def prepare_dataset(raw_data):
|
||||
"""Transform raw data into GRPO-compatible format.
|
||||
|
||||
Returns: Dataset with columns:
|
||||
- 'prompt': List[Dict] with role/content (system + user messages)
|
||||
- 'answer': str (ground truth, optional but recommended)
|
||||
"""
|
||||
return raw_data.map(lambda x: {
|
||||
'prompt': [
|
||||
{'role': 'system', 'content': SYSTEM_PROMPT},
|
||||
{'role': 'user', 'content': x['question']}
|
||||
],
|
||||
'answer': extract_answer(x['raw_answer'])
|
||||
})
|
||||
```
|
||||
|
||||
**Pro tips:**
|
||||
- Use one-shot or few-shot examples in the system prompt for complex formats
|
||||
- Keep prompts concise (max_prompt_length: 256–512 tokens)
|
||||
- Validate data quality before training (garbage in = garbage out)
|
||||
|
||||
### Step 2: Reward function implementation
|
||||
|
||||
**Template structure:**
|
||||
```python
|
||||
def reward_function_name(
|
||||
prompts, # List[List[Dict]]: Original prompts
|
||||
completions, # List[List[Dict]]: Model generations
|
||||
answer=None, # Optional: Ground truth from dataset
|
||||
**kwargs # Additional dataset columns
|
||||
) -> list[float]:
|
||||
"""Evaluate completions and return rewards (one per completion)."""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
rewards = []
|
||||
for response in responses:
|
||||
score = compute_score(response)
|
||||
rewards.append(score)
|
||||
return rewards
|
||||
```
|
||||
|
||||
**Example 1: correctness reward (math/coding)**
|
||||
```python
|
||||
def correctness_reward(prompts, completions, answer, **kwargs):
|
||||
"""Reward correct answers with high score."""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
extracted = [extract_final_answer(r) for r in responses]
|
||||
return [2.0 if ans == gt else 0.0
|
||||
for ans, gt in zip(extracted, answer)]
|
||||
```
|
||||
|
||||
**Example 2: format reward (structured output)**
|
||||
```python
|
||||
import re
|
||||
|
||||
def format_reward(completions, **kwargs):
|
||||
"""Reward XML-like structured format."""
|
||||
pattern = r'<reasoning>.*?</reasoning>\s*<answer>.*?</answer>'
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
return [1.0 if re.search(pattern, r, re.DOTALL) else 0.0
|
||||
for r in responses]
|
||||
```
|
||||
|
||||
**Example 3: incremental format reward (partial credit)**
|
||||
```python
|
||||
def incremental_format_reward(completions, **kwargs):
|
||||
"""Award partial credit for format compliance."""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
rewards = []
|
||||
|
||||
for r in responses:
|
||||
score = 0.0
|
||||
if '<reasoning>' in r: score += 0.25
|
||||
if '</reasoning>' in r: score += 0.25
|
||||
if '<answer>' in r: score += 0.25
|
||||
if '</answer>' in r: score += 0.25
|
||||
# Penalize extra text after closing tag
|
||||
if r.count('</answer>') == 1:
|
||||
extra_text = r.split('</answer>')[-1].strip()
|
||||
score -= len(extra_text) * 0.001
|
||||
rewards.append(score)
|
||||
|
||||
return rewards
|
||||
```
|
||||
|
||||
**Critical insight:** Combine 3–5 reward functions for robust training. Order matters less than diversity of signals.
|
||||
|
||||
### Step 3: Training configuration
|
||||
|
||||
**Memory-optimized config (small GPU)**
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
training_args = GRPOConfig(
|
||||
output_dir="outputs/grpo-model",
|
||||
|
||||
# Learning rate
|
||||
learning_rate=5e-6, # Lower = more stable
|
||||
adam_beta1=0.9,
|
||||
adam_beta2=0.99,
|
||||
weight_decay=0.1,
|
||||
warmup_ratio=0.1,
|
||||
lr_scheduler_type='cosine',
|
||||
|
||||
# Batch settings
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=4, # Effective batch = 4
|
||||
|
||||
# GRPO-specific
|
||||
num_generations=8, # Group size: 8–16 recommended
|
||||
max_prompt_length=256,
|
||||
max_completion_length=512,
|
||||
|
||||
# Training duration
|
||||
num_train_epochs=1,
|
||||
max_steps=None,
|
||||
|
||||
# Optimization
|
||||
bf16=True, # Faster on A100/H100
|
||||
optim="adamw_8bit", # Memory-efficient optimizer
|
||||
max_grad_norm=0.1,
|
||||
|
||||
# Logging
|
||||
logging_steps=1,
|
||||
save_steps=100,
|
||||
report_to="wandb",
|
||||
)
|
||||
```
|
||||
|
||||
**High-performance config (large GPU)**
|
||||
```python
|
||||
training_args = GRPOConfig(
|
||||
output_dir="outputs/grpo-model",
|
||||
learning_rate=1e-5,
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=2,
|
||||
num_generations=16, # Larger groups = better signal
|
||||
max_prompt_length=512,
|
||||
max_completion_length=1024,
|
||||
num_train_epochs=1,
|
||||
bf16=True,
|
||||
use_vllm=True, # Fast generation with vLLM
|
||||
logging_steps=10,
|
||||
)
|
||||
```
|
||||
|
||||
**Critical hyperparameters:**
|
||||
|
||||
| Parameter | Impact | Tuning Advice |
|
||||
|-----------|--------|---------------|
|
||||
| `num_generations` | Group size for comparison | Start 8, increase to 16 if GPU allows |
|
||||
| `learning_rate` | Convergence speed/stability | 5e-6 (safe), 1e-5 (faster, riskier) |
|
||||
| `max_completion_length` | Output verbosity | Match your task (512 reasoning, 256 short answers) |
|
||||
| `gradient_accumulation_steps` | Effective batch size | Increase if GPU memory limited |
|
||||
|
||||
### Step 4: Model setup and training
|
||||
|
||||
**Standard setup (Transformers + TRL)**
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from peft import LoraConfig
|
||||
from trl import GRPOTrainer
|
||||
|
||||
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2", # 2–3× faster
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# Optional: LoRA for parameter-efficient training
|
||||
peft_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules=[
|
||||
"q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj",
|
||||
],
|
||||
task_type="CAUSAL_LM",
|
||||
lora_dropout=0.05,
|
||||
)
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model=model,
|
||||
processing_class=tokenizer,
|
||||
reward_funcs=[
|
||||
incremental_format_reward,
|
||||
format_reward,
|
||||
correctness_reward,
|
||||
],
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
peft_config=peft_config, # Remove for full fine-tuning
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
trainer.save_model("final_model")
|
||||
```
|
||||
|
||||
**Unsloth setup (2–3× faster)**
|
||||
```python
|
||||
from unsloth import FastLanguageModel
|
||||
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="google/gemma-3-1b-it",
|
||||
max_seq_length=1024,
|
||||
load_in_4bit=True,
|
||||
fast_inference=True,
|
||||
max_lora_rank=32,
|
||||
)
|
||||
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r=32,
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj"],
|
||||
lora_alpha=32,
|
||||
use_gradient_checkpointing="unsloth",
|
||||
)
|
||||
|
||||
# Rest is identical to the standard setup
|
||||
trainer = GRPOTrainer(model=model, ...)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Critical training insights
|
||||
|
||||
### 1. Loss behavior (EXPECTED pattern)
|
||||
- **Loss starts near 0 and INCREASES during training** — this is CORRECT
|
||||
- Loss measures KL divergence from initial policy; the model is learning (diverging from original behavior to optimize rewards)
|
||||
- **Monitor reward metrics, not loss, for progress**
|
||||
|
||||
### 2. Reward tracking
|
||||
|
||||
Key metrics to watch:
|
||||
- `reward` — average across all completions
|
||||
- `reward_std` — diversity within groups (should remain > 0)
|
||||
- `kl` — KL divergence from reference (should grow moderately)
|
||||
|
||||
**Healthy pattern:**
|
||||
```
|
||||
Step Reward Reward_Std KL
|
||||
100 0.5 0.3 0.02
|
||||
200 0.8 0.25 0.05
|
||||
300 1.2 0.2 0.08 ← Good progression
|
||||
400 1.5 0.15 0.12
|
||||
```
|
||||
|
||||
**Warning signs:**
|
||||
- `reward_std` → 0 (model collapsing to a single response)
|
||||
- `kl` exploding (> 0.5) — diverging too much, reduce LR
|
||||
- Reward stuck — reward functions too harsh or model capacity issue
|
||||
|
||||
### 3. Common pitfalls and solutions
|
||||
|
||||
| Problem | Symptom | Solution |
|
||||
|---------|---------|----------|
|
||||
| **Mode collapse** | All completions identical | Increase `num_generations`, add diversity penalty |
|
||||
| **No learning** | Flat rewards | Check reward function logic, increase LR |
|
||||
| **OOM errors** | GPU memory exceeded | Reduce `num_generations`, enable gradient checkpointing |
|
||||
| **Slow training** | < 1 it/s | Enable `use_vllm=True`, use Unsloth, reduce seq length |
|
||||
| **Format ignored** | Model doesn't follow structure | Increase format reward weight, add incremental rewards |
|
||||
|
||||
## Advanced patterns
|
||||
|
||||
### 1. Multi-stage training
|
||||
|
||||
For complex tasks, train in stages:
|
||||
|
||||
```python
|
||||
# Stage 1: Format compliance
|
||||
trainer_stage1 = GRPOTrainer(
|
||||
model=model,
|
||||
reward_funcs=[incremental_format_reward, format_reward],
|
||||
...
|
||||
)
|
||||
trainer_stage1.train()
|
||||
|
||||
# Stage 2: Correctness
|
||||
trainer_stage2 = GRPOTrainer(
|
||||
model=model,
|
||||
reward_funcs=[format_reward, correctness_reward],
|
||||
...
|
||||
)
|
||||
trainer_stage2.train()
|
||||
```
|
||||
|
||||
### 2. Adaptive reward scaling
|
||||
|
||||
```python
|
||||
class AdaptiveReward:
|
||||
def __init__(self, base_reward_func, initial_weight=1.0):
|
||||
self.func = base_reward_func
|
||||
self.weight = initial_weight
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
rewards = self.func(*args, **kwargs)
|
||||
return [r * self.weight for r in rewards]
|
||||
|
||||
def adjust_weight(self, success_rate):
|
||||
"""Increase weight if model struggling, decrease if succeeding."""
|
||||
if success_rate < 0.3:
|
||||
self.weight *= 1.2
|
||||
elif success_rate > 0.8:
|
||||
self.weight *= 0.9
|
||||
```
|
||||
|
||||
### 3. Custom dataset integration
|
||||
|
||||
```python
|
||||
def load_custom_knowledge_base(csv_path):
|
||||
import pandas as pd
|
||||
df = pd.read_csv(csv_path)
|
||||
return Dataset.from_pandas(df).map(lambda x: {
|
||||
'prompt': [
|
||||
{'role': 'system', 'content': CUSTOM_SYSTEM_PROMPT},
|
||||
{'role': 'user', 'content': x['question']}
|
||||
],
|
||||
'answer': x['expert_answer']
|
||||
})
|
||||
```
|
||||
|
||||
## Deployment and inference
|
||||
|
||||
### Save and merge LoRA
|
||||
```python
|
||||
if hasattr(trainer.model, 'merge_and_unload'):
|
||||
merged_model = trainer.model.merge_and_unload()
|
||||
merged_model.save_pretrained("production_model")
|
||||
tokenizer.save_pretrained("production_model")
|
||||
```
|
||||
|
||||
### Inference
|
||||
```python
|
||||
from transformers import pipeline
|
||||
|
||||
generator = pipeline("text-generation", model="production_model", tokenizer=tokenizer)
|
||||
|
||||
result = generator(
|
||||
[
|
||||
{'role': 'system', 'content': SYSTEM_PROMPT},
|
||||
{'role': 'user', 'content': "What is 15 + 27?"},
|
||||
],
|
||||
max_new_tokens=256,
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
)
|
||||
print(result[0]['generated_text'])
|
||||
```
|
||||
|
||||
## Best practices checklist
|
||||
|
||||
**Before training:**
|
||||
- [ ] Validate dataset format (prompts as List[Dict])
|
||||
- [ ] Test reward functions on sample data
|
||||
- [ ] Calculate expected `max_prompt_length` from data
|
||||
- [ ] Choose `num_generations` based on GPU memory
|
||||
- [ ] Set up logging (wandb recommended)
|
||||
|
||||
**During training:**
|
||||
- [ ] Monitor reward progression (should increase)
|
||||
- [ ] Check `reward_std` (should stay > 0.1)
|
||||
- [ ] Watch for OOM errors (reduce batch size if needed)
|
||||
- [ ] Sample generations every 50–100 steps
|
||||
- [ ] Validate format compliance on holdout set
|
||||
|
||||
**After training:**
|
||||
- [ ] Merge LoRA weights if using PEFT
|
||||
- [ ] Test on diverse prompts
|
||||
- [ ] Compare to baseline model
|
||||
- [ ] Document reward weights and hyperparameters
|
||||
- [ ] Save reproducibility config
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Debugging workflow
|
||||
1. **Isolate reward functions** — test each independently
|
||||
2. **Check data distribution** — ensure diversity in prompts
|
||||
3. **Reduce complexity** — start with single reward, add gradually
|
||||
4. **Monitor generations** — print samples every N steps
|
||||
5. **Validate extraction logic** — ensure answer parsing works
|
||||
|
||||
### Quick debug reward
|
||||
```python
|
||||
def debug_reward(completions, **kwargs):
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
for i, r in enumerate(responses[:2]):
|
||||
print(f"Response {i}: {r[:200]}...")
|
||||
return [1.0] * len(responses)
|
||||
|
||||
# Test without training
|
||||
trainer = GRPOTrainer(..., reward_funcs=[debug_reward])
|
||||
trainer.generate_completions(dataset[:1])
|
||||
```
|
||||
|
||||
## Template
|
||||
|
||||
A production-ready training script lives at **`../templates/basic_grpo_training.py`**. It uses Qwen 2.5-1.5B-Instruct with LoRA and three reward functions (incremental format, strict format, correctness) on GSM8K. Copy and adapt:
|
||||
1. `get_dataset()` — swap in your data loader
|
||||
2. Reward functions — tune to your task
|
||||
3. `SYSTEM_PROMPT` — match your output format
|
||||
4. `GRPOConfig` — adjust hyperparameters for your GPU
|
||||
|
||||
## References and resources
|
||||
|
||||
- TRL GRPO Trainer: https://huggingface.co/docs/trl/grpo_trainer
|
||||
- GRPO paper (DeepSeek): https://arxiv.org/abs/2402.03300
|
||||
- DeepSeek R1 paper: https://arxiv.org/abs/2501.12948
|
||||
- Open R1 implementation: https://github.com/huggingface/open-r1
|
||||
- TRL examples: https://github.com/huggingface/trl/tree/main/examples
|
||||
- Unsloth (faster training): https://docs.unsloth.ai/
|
||||
|
||||
## Critical reminders
|
||||
|
||||
- **Loss goes UP during training** — this is normal (it's KL divergence)
|
||||
- **Use 3–5 reward functions** — single rewards often fail
|
||||
- **Test rewards before training** — debug each function independently
|
||||
- **Monitor `reward_std`** — should stay > 0.1 (avoid mode collapse)
|
||||
- **Start with `num_generations=4–8`** — scale up if GPU allows
|
||||
@@ -0,0 +1,82 @@
|
||||
# Online RL Methods
|
||||
|
||||
Guide to online reinforcement learning with PPO, GRPO, RLOO, and OnlineDPO.
|
||||
|
||||
## Overview
|
||||
|
||||
Online RL generates completions during training and optimizes based on rewards.
|
||||
|
||||
## PPO (Proximal Policy Optimization)
|
||||
|
||||
Classic RL algorithm for LLM alignment.
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```bash
|
||||
python -m trl.scripts.ppo \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--reward_model_path reward-model \
|
||||
--dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \
|
||||
--output_dir model-ppo \
|
||||
--learning_rate 3e-6 \
|
||||
--per_device_train_batch_size 64 \
|
||||
--total_episodes 10000 \
|
||||
--num_ppo_epochs 4 \
|
||||
--kl_coef 0.05
|
||||
```
|
||||
|
||||
### Key Parameters
|
||||
|
||||
- `kl_coef`: KL penalty (0.05-0.2)
|
||||
- `num_ppo_epochs`: Epochs per batch (2-4)
|
||||
- `cliprange`: PPO clip (0.1-0.3)
|
||||
- `vf_coef`: Value function coef (0.1)
|
||||
|
||||
## GRPO (Group Relative Policy Optimization)
|
||||
|
||||
Memory-efficient online RL.
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```python
|
||||
from trl import GRPOTrainer, GRPOConfig
|
||||
from datasets import load_dataset
|
||||
|
||||
# Define reward function
|
||||
def reward_func(completions, **kwargs):
|
||||
return [len(set(c.split())) for c in completions]
|
||||
|
||||
config = GRPOConfig(
|
||||
output_dir="model-grpo",
|
||||
num_generations=4, # Completions per prompt
|
||||
max_new_tokens=128
|
||||
)
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model="Qwen/Qwen2-0.5B-Instruct",
|
||||
reward_funcs=reward_func,
|
||||
args=config,
|
||||
train_dataset=load_dataset("trl-lib/tldr", split="train")
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### Key Parameters
|
||||
|
||||
- `num_generations`: 2-8 completions
|
||||
- `max_new_tokens`: 64-256
|
||||
- Learning rate: 1e-5 to 1e-4
|
||||
|
||||
## Memory Comparison
|
||||
|
||||
| Method | Memory (7B) | Speed | Use Case |
|
||||
|--------|-------------|-------|----------|
|
||||
| PPO | 40GB | Medium | Maximum control |
|
||||
| GRPO | 24GB | Fast | **Memory-constrained** |
|
||||
| OnlineDPO | 28GB | Fast | No reward model |
|
||||
|
||||
## References
|
||||
|
||||
- PPO paper: https://arxiv.org/abs/1707.06347
|
||||
- GRPO paper: https://arxiv.org/abs/2402.03300
|
||||
- TRL docs: https://huggingface.co/docs/trl/
|
||||
@@ -0,0 +1,122 @@
|
||||
# Reward Modeling
|
||||
|
||||
Guide to training reward models with TRL for RLHF pipelines.
|
||||
|
||||
## Overview
|
||||
|
||||
Reward models score completions based on human preferences. Used in:
|
||||
- PPO training (RL feedback)
|
||||
- GRPO online RL
|
||||
- Completion ranking
|
||||
|
||||
## Basic Training
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
from trl import RewardTrainer, RewardConfig
|
||||
from datasets import load_dataset
|
||||
|
||||
# Load model (num_labels=1 for single reward score)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"Qwen/Qwen2.5-0.5B-Instruct",
|
||||
num_labels=1
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
|
||||
# Load preference dataset (chosen/rejected pairs)
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
|
||||
# Configure
|
||||
config = RewardConfig(
|
||||
output_dir="Qwen2.5-Reward",
|
||||
per_device_train_batch_size=2,
|
||||
num_train_epochs=1,
|
||||
learning_rate=1e-5
|
||||
)
|
||||
|
||||
# Train
|
||||
trainer = RewardTrainer(
|
||||
model=model,
|
||||
args=config,
|
||||
processing_class=tokenizer,
|
||||
train_dataset=dataset
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Dataset Format
|
||||
|
||||
Required fields:
|
||||
```json
|
||||
{
|
||||
"prompt": "Question or instruction",
|
||||
"chosen": "Better response",
|
||||
"rejected": "Worse response"
|
||||
}
|
||||
```
|
||||
|
||||
## Bradley-Terry Loss
|
||||
|
||||
Default loss function:
|
||||
```
|
||||
loss = -log(sigmoid(reward_chosen - reward_rejected))
|
||||
```
|
||||
|
||||
Learns to score chosen > rejected.
|
||||
|
||||
## Using Reward Models
|
||||
|
||||
### Inference
|
||||
|
||||
```python
|
||||
from transformers import pipeline
|
||||
|
||||
# Load trained reward model
|
||||
reward_pipe = pipeline("text-classification", model="Qwen2.5-Reward")
|
||||
|
||||
# Score completions
|
||||
texts = ["Good answer", "Bad answer"]
|
||||
scores = reward_pipe(texts)
|
||||
print(scores) # Higher score = better
|
||||
```
|
||||
|
||||
### In PPO
|
||||
|
||||
```python
|
||||
from trl import PPOTrainer, PPOConfig
|
||||
|
||||
config = PPOConfig(
|
||||
reward_model_path="Qwen2.5-Reward" # Use trained reward model
|
||||
)
|
||||
|
||||
trainer = PPOTrainer(
|
||||
model=policy_model,
|
||||
config=config,
|
||||
# Reward model loaded automatically
|
||||
)
|
||||
```
|
||||
|
||||
## Hyperparameters
|
||||
|
||||
| Model Size | Learning Rate | Batch Size | Epochs |
|
||||
|------------|---------------|------------|--------|
|
||||
| <1B | 2e-5 | 4-8 | 1-2 |
|
||||
| 1-7B | 1e-5 | 2-4 | 1 |
|
||||
| 7-13B | 5e-6 | 1-2 | 1 |
|
||||
|
||||
## Evaluation
|
||||
|
||||
Check reward separation:
|
||||
```python
|
||||
# Chosen should score higher than rejected
|
||||
chosen_rewards = model(**chosen_inputs).logits
|
||||
rejected_rewards = model(**rejected_inputs).logits
|
||||
|
||||
accuracy = (chosen_rewards > rejected_rewards).float().mean()
|
||||
print(f"Accuracy: {accuracy:.2%}") # Target: >80%
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- InstructGPT paper: https://arxiv.org/abs/2203.02155
|
||||
- TRL docs: https://huggingface.co/docs/trl/reward_trainer
|
||||
@@ -0,0 +1,168 @@
|
||||
# SFT Training Guide
|
||||
|
||||
Complete guide to Supervised Fine-Tuning (SFT) with TRL for instruction tuning and task-specific fine-tuning.
|
||||
|
||||
## Overview
|
||||
|
||||
SFT trains models on input-output pairs to minimize cross-entropy loss. Use for:
|
||||
- Instruction following
|
||||
- Task-specific fine-tuning
|
||||
- Chatbot training
|
||||
- Domain adaptation
|
||||
|
||||
## Dataset Formats
|
||||
|
||||
### Format 1: Prompt-Completion
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"prompt": "What is the capital of France?",
|
||||
"completion": "The capital of France is Paris."
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
### Format 2: Conversational (ChatML)
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is Python?"},
|
||||
{"role": "assistant", "content": "Python is a programming language."}
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
### Format 3: Text-only
|
||||
|
||||
```json
|
||||
[
|
||||
{"text": "User: Hello\nAssistant: Hi! How can I help?"}
|
||||
]
|
||||
```
|
||||
|
||||
## Basic Training
|
||||
|
||||
```python
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from datasets import load_dataset
|
||||
|
||||
# Load model
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
|
||||
|
||||
# Load dataset
|
||||
dataset = load_dataset("trl-lib/Capybara", split="train")
|
||||
|
||||
# Configure
|
||||
config = SFTConfig(
|
||||
output_dir="Qwen2.5-SFT",
|
||||
per_device_train_batch_size=4,
|
||||
num_train_epochs=1,
|
||||
learning_rate=2e-5,
|
||||
save_strategy="epoch"
|
||||
)
|
||||
|
||||
# Train
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=config,
|
||||
train_dataset=dataset,
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Chat Templates
|
||||
|
||||
Apply chat templates automatically:
|
||||
|
||||
```python
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=config,
|
||||
train_dataset=dataset, # Messages format
|
||||
tokenizer=tokenizer
|
||||
# Chat template applied automatically
|
||||
)
|
||||
```
|
||||
|
||||
Or manually:
|
||||
```python
|
||||
def format_chat(example):
|
||||
messages = example["messages"]
|
||||
text = tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
return {"text": text}
|
||||
|
||||
dataset = dataset.map(format_chat)
|
||||
```
|
||||
|
||||
## Packing for Efficiency
|
||||
|
||||
Pack multiple sequences into one to maximize GPU utilization:
|
||||
|
||||
```python
|
||||
config = SFTConfig(
|
||||
packing=True, # Enable packing
|
||||
max_seq_length=2048,
|
||||
dataset_text_field="text"
|
||||
)
|
||||
```
|
||||
|
||||
**Benefits**: 2-3× faster training
|
||||
**Trade-off**: Slightly more complex batching
|
||||
|
||||
## Multi-GPU Training
|
||||
|
||||
```bash
|
||||
accelerate launch --num_processes 4 train_sft.py
|
||||
```
|
||||
|
||||
Or with config:
|
||||
```python
|
||||
config = SFTConfig(
|
||||
output_dir="model-sft",
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=4,
|
||||
num_train_epochs=1
|
||||
)
|
||||
```
|
||||
|
||||
## LoRA Fine-Tuning
|
||||
|
||||
```python
|
||||
from peft import LoraConfig
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules="all-linear",
|
||||
lora_dropout=0.05,
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=config,
|
||||
train_dataset=dataset,
|
||||
peft_config=lora_config # Add LoRA
|
||||
)
|
||||
```
|
||||
|
||||
## Hyperparameters
|
||||
|
||||
| Model Size | Learning Rate | Batch Size | Epochs |
|
||||
|------------|---------------|------------|--------|
|
||||
| <1B | 5e-5 | 8-16 | 1-3 |
|
||||
| 1-7B | 2e-5 | 4-8 | 1-2 |
|
||||
| 7-13B | 1e-5 | 2-4 | 1 |
|
||||
| 13B+ | 5e-6 | 1-2 | 1 |
|
||||
|
||||
## References
|
||||
|
||||
- TRL docs: https://huggingface.co/docs/trl/sft_trainer
|
||||
- Examples: https://github.com/huggingface/trl/tree/main/examples/scripts
|
||||
@@ -0,0 +1,228 @@
|
||||
"""
|
||||
Basic GRPO Training Template
|
||||
=============================
|
||||
|
||||
A minimal, production-ready template for GRPO training with TRL.
|
||||
Adapt this for your specific task by modifying:
|
||||
1. Dataset loading (get_dataset function)
|
||||
2. Reward functions (reward_*_func)
|
||||
3. System prompt (SYSTEM_PROMPT)
|
||||
4. Hyperparameters (GRPOConfig)
|
||||
"""
|
||||
|
||||
import torch
|
||||
import re
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from peft import LoraConfig
|
||||
from trl import GRPOTrainer, GRPOConfig
|
||||
|
||||
# ==================== CONFIGURATION ====================
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
OUTPUT_DIR = "outputs/grpo-model"
|
||||
MAX_PROMPT_LENGTH = 256
|
||||
MAX_COMPLETION_LENGTH = 512
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
Respond in the following format:
|
||||
<reasoning>
|
||||
[Your step-by-step thinking]
|
||||
</reasoning>
|
||||
<answer>
|
||||
[Final answer]
|
||||
</answer>
|
||||
"""
|
||||
|
||||
# ==================== DATASET ====================
|
||||
|
||||
def get_dataset(split="train"):
|
||||
"""
|
||||
Load and prepare your dataset.
|
||||
|
||||
Returns: Dataset with columns:
|
||||
- 'prompt': List[Dict] with role/content
|
||||
- 'answer': str (ground truth, optional)
|
||||
"""
|
||||
# Example: GSM8K math dataset
|
||||
data = load_dataset('openai/gsm8k', 'main')[split]
|
||||
|
||||
def process_example(x):
|
||||
# Extract ground truth answer
|
||||
answer = x['answer'].split('####')[1].strip() if '####' in x['answer'] else None
|
||||
|
||||
return {
|
||||
'prompt': [
|
||||
{'role': 'system', 'content': SYSTEM_PROMPT},
|
||||
{'role': 'user', 'content': x['question']}
|
||||
],
|
||||
'answer': answer
|
||||
}
|
||||
|
||||
return data.map(process_example)
|
||||
|
||||
# ==================== HELPER FUNCTIONS ====================
|
||||
|
||||
def extract_xml_tag(text: str, tag: str) -> str:
|
||||
"""Extract content between XML tags."""
|
||||
pattern = f'<{tag}>(.*?)</{tag}>'
|
||||
match = re.search(pattern, text, re.DOTALL)
|
||||
return match.group(1).strip() if match else ""
|
||||
|
||||
def extract_answer(text: str) -> str:
|
||||
"""Extract the final answer from structured output."""
|
||||
return extract_xml_tag(text, 'answer')
|
||||
|
||||
# ==================== REWARD FUNCTIONS ====================
|
||||
|
||||
def correctness_reward_func(prompts, completions, answer, **kwargs):
|
||||
"""
|
||||
Reward correct answers.
|
||||
Weight: 2.0 (highest priority)
|
||||
"""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
extracted = [extract_answer(r) for r in responses]
|
||||
return [2.0 if ans == gt else 0.0 for ans, gt in zip(extracted, answer)]
|
||||
|
||||
def format_reward_func(completions, **kwargs):
|
||||
"""
|
||||
Reward proper XML format.
|
||||
Weight: 0.5
|
||||
"""
|
||||
pattern = r'<reasoning>.*?</reasoning>\s*<answer>.*?</answer>'
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
return [0.5 if re.search(pattern, r, re.DOTALL) else 0.0 for r in responses]
|
||||
|
||||
def incremental_format_reward_func(completions, **kwargs):
|
||||
"""
|
||||
Incremental reward for partial format compliance.
|
||||
Weight: up to 0.5
|
||||
"""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
rewards = []
|
||||
|
||||
for r in responses:
|
||||
score = 0.0
|
||||
if '<reasoning>' in r:
|
||||
score += 0.125
|
||||
if '</reasoning>' in r:
|
||||
score += 0.125
|
||||
if '<answer>' in r:
|
||||
score += 0.125
|
||||
if '</answer>' in r:
|
||||
score += 0.125
|
||||
|
||||
# Penalize extra content after closing tag
|
||||
if '</answer>' in r:
|
||||
extra = r.split('</answer>')[-1].strip()
|
||||
score -= len(extra) * 0.001
|
||||
|
||||
rewards.append(score)
|
||||
|
||||
return rewards
|
||||
|
||||
# ==================== MODEL SETUP ====================
|
||||
|
||||
def setup_model_and_tokenizer():
|
||||
"""Load model and tokenizer with optimizations."""
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_NAME,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
def get_peft_config():
|
||||
"""LoRA configuration for parameter-efficient training."""
|
||||
return LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules=[
|
||||
"q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj"
|
||||
],
|
||||
task_type="CAUSAL_LM",
|
||||
lora_dropout=0.05,
|
||||
)
|
||||
|
||||
# ==================== TRAINING ====================
|
||||
|
||||
def main():
|
||||
"""Main training function."""
|
||||
|
||||
# Load data
|
||||
print("Loading dataset...")
|
||||
dataset = get_dataset()
|
||||
print(f"Dataset size: {len(dataset)}")
|
||||
|
||||
# Setup model
|
||||
print("Loading model...")
|
||||
model, tokenizer = setup_model_and_tokenizer()
|
||||
|
||||
# Training configuration
|
||||
training_args = GRPOConfig(
|
||||
output_dir=OUTPUT_DIR,
|
||||
run_name="grpo-training",
|
||||
|
||||
# Learning rate
|
||||
learning_rate=5e-6,
|
||||
adam_beta1=0.9,
|
||||
adam_beta2=0.99,
|
||||
weight_decay=0.1,
|
||||
warmup_ratio=0.1,
|
||||
lr_scheduler_type='cosine',
|
||||
|
||||
# Batch settings
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=4,
|
||||
|
||||
# GRPO specific
|
||||
num_generations=8,
|
||||
max_prompt_length=MAX_PROMPT_LENGTH,
|
||||
max_completion_length=MAX_COMPLETION_LENGTH,
|
||||
|
||||
# Training duration
|
||||
num_train_epochs=1,
|
||||
|
||||
# Optimization
|
||||
bf16=True,
|
||||
optim="adamw_8bit",
|
||||
max_grad_norm=0.1,
|
||||
|
||||
# Logging
|
||||
logging_steps=1,
|
||||
save_steps=100,
|
||||
report_to="wandb", # Change to "none" to disable logging
|
||||
)
|
||||
|
||||
# Initialize trainer
|
||||
trainer = GRPOTrainer(
|
||||
model=model,
|
||||
processing_class=tokenizer,
|
||||
reward_funcs=[
|
||||
incremental_format_reward_func,
|
||||
format_reward_func,
|
||||
correctness_reward_func,
|
||||
],
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
peft_config=get_peft_config(),
|
||||
)
|
||||
|
||||
# Train
|
||||
print("Starting training...")
|
||||
trainer.train()
|
||||
|
||||
# Save final model
|
||||
print(f"Saving model to {OUTPUT_DIR}/final")
|
||||
trainer.save_model(f"{OUTPUT_DIR}/final")
|
||||
|
||||
print("Training complete!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,83 @@
|
||||
---
|
||||
name: unsloth
|
||||
description: "Unsloth: 2-5x faster LoRA/QLoRA fine-tuning, less VRAM."
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [unsloth, torch, transformers, trl, datasets, peft]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Fine-Tuning, Unsloth, Fast Training, LoRA, QLoRA, Memory-Efficient, Optimization, Llama, Mistral, Gemma, Qwen]
|
||||
|
||||
---
|
||||
|
||||
# Unsloth Skill
|
||||
|
||||
Comprehensive assistance with unsloth development, generated from official documentation.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
This skill should be triggered when:
|
||||
- Working with unsloth
|
||||
- Asking about unsloth features or APIs
|
||||
- Implementing unsloth solutions
|
||||
- Debugging unsloth code
|
||||
- Learning unsloth best practices
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### Common Patterns
|
||||
|
||||
*Quick reference patterns will be added as you use the skill.*
|
||||
|
||||
## Reference Files
|
||||
|
||||
This skill includes comprehensive documentation in `references/`:
|
||||
|
||||
- **llms-txt.md** - Llms-Txt documentation
|
||||
|
||||
Use `view` to read specific reference files when detailed information is needed.
|
||||
|
||||
## Working with This Skill
|
||||
|
||||
### For Beginners
|
||||
Start with the getting_started or tutorials reference files for foundational concepts.
|
||||
|
||||
### For Specific Features
|
||||
Use the appropriate category reference file (api, guides, etc.) for detailed information.
|
||||
|
||||
### For Code Examples
|
||||
The quick reference section above contains common patterns extracted from the official docs.
|
||||
|
||||
## Resources
|
||||
|
||||
### references/
|
||||
Organized documentation extracted from official sources. These files contain:
|
||||
- Detailed explanations
|
||||
- Code examples with language annotations
|
||||
- Links to original documentation
|
||||
- Table of contents for quick navigation
|
||||
|
||||
### scripts/
|
||||
Add helper scripts here for common automation tasks.
|
||||
|
||||
### assets/
|
||||
Add templates, boilerplate, or example projects here.
|
||||
|
||||
## Notes
|
||||
|
||||
- This skill was automatically generated from official documentation
|
||||
- Reference files preserve the structure and examples from source docs
|
||||
- Code examples include language detection for better syntax highlighting
|
||||
- Quick reference patterns are extracted from common usage examples in the docs
|
||||
|
||||
## Updating
|
||||
|
||||
To refresh this skill with updated documentation:
|
||||
1. Re-run the scraper with the same configuration
|
||||
2. The skill will be rebuilt with the latest information
|
||||
|
||||
<!-- Trigger re-upload 1763621536 -->
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
# Unsloth Documentation Index
|
||||
|
||||
## Categories
|
||||
|
||||
### Llms-Txt
|
||||
**File:** `llms-txt.md`
|
||||
**Pages:** 136
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -0,0 +1,82 @@
|
||||
# Unsloth Documentation
|
||||
|
||||
## Unsloth Documentation
|
||||
|
||||
- [Unsloth Docs](/get-started/unsloth-docs.md): Train your own model with Unsloth, an open-source framework for LLM fine-tuning and reinforcement learning.
|
||||
- [Beginner? Start here!](/get-started/beginner-start-here.md)
|
||||
- [Unsloth Requirements](/get-started/beginner-start-here/unsloth-requirements.md): Here are Unsloth's requirements including system and GPU VRAM requirements.
|
||||
- [FAQ + Is Fine-tuning Right For Me?](/get-started/beginner-start-here/faq-+-is-fine-tuning-right-for-me.md): If you're stuck on if fine-tuning is right for you, see here! Learn about fine-tuning misconceptions, how it compared to RAG and more:
|
||||
- [Unsloth Notebooks](/get-started/unsloth-notebooks.md): Explore our catalog of Unsloth notebooks:
|
||||
- [All Our Models](/get-started/all-our-models.md)
|
||||
- [Install & Update](/get-started/install-and-update.md): Learn to install Unsloth locally or online.
|
||||
- [Updating](/get-started/install-and-update/updating.md): To update or use an old version of Unsloth, follow the steps below:
|
||||
- [Pip Install](/get-started/install-and-update/pip-install.md): To install Unsloth locally via Pip, follow the steps below:
|
||||
- [Docker](/get-started/install-and-update/docker.md): Install Unsloth using our official Docker container
|
||||
- [Windows Installation](/get-started/install-and-update/windows-installation.md): See how to install Unsloth on Windows with or without WSL.
|
||||
- [AMD](/get-started/install-and-update/amd.md): Fine-tune with Unsloth on AMD GPUs.
|
||||
- [Conda Install](/get-started/install-and-update/conda-install.md): To install Unsloth locally on Conda, follow the steps below:
|
||||
- [Google Colab](/get-started/install-and-update/google-colab.md): To install and run Unsloth on Google Colab, follow the steps below:
|
||||
- [Fine-tuning LLMs Guide](/get-started/fine-tuning-llms-guide.md): Learn all the basics and best practices of fine-tuning. Beginner-friendly.
|
||||
- [What Model Should I Use?](/get-started/fine-tuning-llms-guide/what-model-should-i-use.md)
|
||||
- [Datasets Guide](/get-started/fine-tuning-llms-guide/datasets-guide.md): Learn how to create & prepare a dataset for fine-tuning.
|
||||
- [LoRA Hyperparameters Guide](/get-started/fine-tuning-llms-guide/lora-hyperparameters-guide.md): Optimal lora rank. alpha, number of epochs, batch size & gradient accumulation, QLoRA vs LoRA, target modules and more!
|
||||
- [Tutorial: How to Finetune Llama-3 and Use In Ollama](/get-started/fine-tuning-llms-guide/tutorial-how-to-finetune-llama-3-and-use-in-ollama.md): Beginner's Guide for creating a customized personal assistant (like ChatGPT) to run locally on Ollama
|
||||
- [Reinforcement Learning (RL) Guide](/get-started/reinforcement-learning-rl-guide.md): Learn all about Reinforcement Learning (RL) and how to train your own DeepSeek-R1 reasoning model with Unsloth using GRPO. A complete guide from beginner to advanced.
|
||||
- [Tutorial: Train your own Reasoning model with GRPO](/get-started/reinforcement-learning-rl-guide/tutorial-train-your-own-reasoning-model-with-grpo.md): Beginner's Guide to transforming a model like Llama 3.1 (8B) into a reasoning model by using Unsloth and GRPO.
|
||||
- [Advanced RL Documentation](/get-started/reinforcement-learning-rl-guide/advanced-rl-documentation.md): Advanced documentation settings when using Unsloth with GRPO.
|
||||
- [Memory Efficient RL](/get-started/reinforcement-learning-rl-guide/memory-efficient-rl.md)
|
||||
- [RL Reward Hacking](/get-started/reinforcement-learning-rl-guide/rl-reward-hacking.md): Learn what is Reward Hacking in Reinforcement Learning and how to counter it.
|
||||
- [GSPO Reinforcement Learning](/get-started/reinforcement-learning-rl-guide/gspo-reinforcement-learning.md): Train with GSPO (Group Sequence Policy Optimization) RL in Unsloth.
|
||||
- [Reinforcement Learning - DPO, ORPO & KTO](/get-started/reinforcement-learning-rl-guide/reinforcement-learning-dpo-orpo-and-kto.md): To use the reward modelling functions for DPO, GRPO, ORPO or KTO with Unsloth, follow the steps below:
|
||||
- [DeepSeek-OCR: How to Run & Fine-tune](/new/deepseek-ocr-how-to-run-and-fine-tune.md): Guide on how to run and fine-tune DeepSeek-OCR locally.
|
||||
- [How to Fine-tune LLMs with Unsloth & Docker](/new/how-to-fine-tune-llms-with-unsloth-and-docker.md): Learn how to fine-tune LLMs or do Reinforcement Learning (RL) with Unsloth's Docker image.
|
||||
- [Vision Reinforcement Learning (VLM RL)](/new/vision-reinforcement-learning-vlm-rl.md): Train Vision/multimodal models via GRPO and RL with Unsloth!
|
||||
- [gpt-oss Reinforcement Learning](/new/gpt-oss-reinforcement-learning.md)
|
||||
- [Tutorial: How to Train gpt-oss with RL](/new/gpt-oss-reinforcement-learning/tutorial-how-to-train-gpt-oss-with-rl.md): Learn to train OpenAI gpt-oss with GRPO to autonomously beat 2048 locally or on Colab.
|
||||
- [Unsloth Dynamic GGUFs on Aider Polyglot](/new/unsloth-dynamic-ggufs-on-aider-polyglot.md): Performance of Unsloth Dynamic GGUFs on Aider Polyglot Benchmarks
|
||||
- [Qwen3-VL: How to Run & Fine-tune](/models/qwen3-vl-how-to-run-and-fine-tune.md): Learn to fine-tune and run Qwen3-VL locally with Unsloth.
|
||||
- [gpt-oss: How to Run & Fine-tune](/models/gpt-oss-how-to-run-and-fine-tune.md): Run & fine-tune OpenAI's new open-source models!
|
||||
- [Tutorial: How to Fine-tune gpt-oss](/models/gpt-oss-how-to-run-and-fine-tune/tutorial-how-to-fine-tune-gpt-oss.md): Learn step-by-step how to train OpenAI gpt-oss locally with Unsloth.
|
||||
- [Long Context gpt-oss Training](/models/gpt-oss-how-to-run-and-fine-tune/long-context-gpt-oss-training.md)
|
||||
- [GLM-4.6: How to Run Locally](/models/glm-4.6-how-to-run-locally.md): A guide on how to run Z.ai's new GLM-4.6 model on your own local device!
|
||||
- [IBM Granite 4.0](/models/ibm-granite-4.0.md): How to run IBM Granite-4.0 with Unsloth GGUFs on llama.cpp, Ollama and how to fine-tune!
|
||||
- [DeepSeek-V3.1: How to Run Locally](/models/deepseek-v3.1-how-to-run-locally.md): A guide on how to run DeepSeek-V3.1 and Terminus on your own local device!
|
||||
- [Qwen3-Coder: How to Run Locally](/models/qwen3-coder-how-to-run-locally.md): Run Qwen3-Coder-30B-A3B-Instruct and 480B-A35B locally with Unsloth Dynamic quants.
|
||||
- [Gemma 3: How to Run & Fine-tune](/models/gemma-3-how-to-run-and-fine-tune.md): How to run Gemma 3 effectively with our GGUFs on llama.cpp, Ollama, Open WebUI and how to fine-tune with Unsloth!
|
||||
- [Gemma 3n: How to Run & Fine-tune](/models/gemma-3-how-to-run-and-fine-tune/gemma-3n-how-to-run-and-fine-tune.md): Run Google's new Gemma 3n locally with Dynamic GGUFs on llama.cpp, Ollama, Open WebUI and fine-tune with Unsloth!
|
||||
- [Qwen3: How to Run & Fine-tune](/models/qwen3-how-to-run-and-fine-tune.md): Learn to run & fine-tune Qwen3 locally with Unsloth + our Dynamic 2.0 quants
|
||||
- [Qwen3-2507](/models/qwen3-how-to-run-and-fine-tune/qwen3-2507.md): Run Qwen3-30B-A3B-2507 and 235B-A22B Thinking and Instruct versions locally on your device!
|
||||
- [Tutorials: How To Fine-tune & Run LLMs](/models/tutorials-how-to-fine-tune-and-run-llms.md): Learn how to run and fine-tune models for optimal performance 100% locally with Unsloth.
|
||||
- [DeepSeek-R1-0528: How to Run Locally](/models/tutorials-how-to-fine-tune-and-run-llms/deepseek-r1-0528-how-to-run-locally.md): A guide on how to run DeepSeek-R1-0528 including Qwen3 on your own local device!
|
||||
- [Magistral: How to Run & Fine-tune](/models/tutorials-how-to-fine-tune-and-run-llms/magistral-how-to-run-and-fine-tune.md): Meet Magistral - Mistral's new reasoning models.
|
||||
- [Llama 4: How to Run & Fine-tune](/models/tutorials-how-to-fine-tune-and-run-llms/llama-4-how-to-run-and-fine-tune.md): How to run Llama 4 locally using our dynamic GGUFs which recovers accuracy compared to standard quantization.
|
||||
- [Kimi K2: How to Run Locally](/models/tutorials-how-to-fine-tune-and-run-llms/kimi-k2-how-to-run-locally.md): Guide on running Kimi K2 and Kimi-K2-Instruct-0905 on your own local device!
|
||||
- [Grok 2](/models/tutorials-how-to-fine-tune-and-run-llms/grok-2.md): Run xAI's Grok 2 model locally!
|
||||
- [Devstral: How to Run & Fine-tune](/models/tutorials-how-to-fine-tune-and-run-llms/devstral-how-to-run-and-fine-tune.md): Run and fine-tune Mistral Devstral 1.1, including Small-2507 and 2505.
|
||||
- [DeepSeek-V3-0324: How to Run Locally](/models/tutorials-how-to-fine-tune-and-run-llms/deepseek-v3-0324-how-to-run-locally.md): How to run DeepSeek-V3-0324 locally using our dynamic quants which recovers accuracy
|
||||
- [DeepSeek-R1: How to Run Locally](/models/tutorials-how-to-fine-tune-and-run-llms/deepseek-r1-how-to-run-locally.md): A guide on how you can run our 1.58-bit Dynamic Quants for DeepSeek-R1 using llama.cpp.
|
||||
- [DeepSeek-R1 Dynamic 1.58-bit](/models/tutorials-how-to-fine-tune-and-run-llms/deepseek-r1-how-to-run-locally/deepseek-r1-dynamic-1.58-bit.md): See performance comparison tables for Unsloth's Dynamic GGUF Quants vs Standard IMatrix Quants.
|
||||
- [QwQ-32B: How to Run effectively](/models/tutorials-how-to-fine-tune-and-run-llms/qwq-32b-how-to-run-effectively.md): How to run QwQ-32B effectively with our bug fixes and without endless generations + GGUFs.
|
||||
- [Phi-4 Reasoning: How to Run & Fine-tune](/models/tutorials-how-to-fine-tune-and-run-llms/phi-4-reasoning-how-to-run-and-fine-tune.md): Learn to run & fine-tune Phi-4 reasoning models locally with Unsloth + our Dynamic 2.0 quants
|
||||
- [Running & Saving Models](/basics/running-and-saving-models.md): Learn how to save your finetuned model so you can run it in your favorite inference engine.
|
||||
- [Saving to GGUF](/basics/running-and-saving-models/saving-to-gguf.md): Saving models to 16bit for GGUF so you can use it for Ollama, Jan AI, Open WebUI and more!
|
||||
- [Saving to Ollama](/basics/running-and-saving-models/saving-to-ollama.md)
|
||||
- [Saving to vLLM for deployment](/basics/running-and-saving-models/saving-to-vllm-for-deployment.md): Saving models to 16bit for vLLM deployment and serving
|
||||
- [Saving to SGLang for deployment](/basics/running-and-saving-models/saving-to-sglang-for-deployment.md): Saving models to 16bit for SGLang for deployment and serving
|
||||
- [Unsloth Inference](/basics/running-and-saving-models/unsloth-inference.md): Learn how to run your finetuned model with Unsloth's faster inference.
|
||||
- [Troubleshooting Inference](/basics/running-and-saving-models/troubleshooting-inference.md): If you're experiencing issues when running or saving your model.
|
||||
- [vLLM Engine Arguments](/basics/running-and-saving-models/vllm-engine-arguments.md)
|
||||
- [LoRA Hot Swapping Guide](/basics/running-and-saving-models/lora-hot-swapping-guide.md)
|
||||
- [Text-to-Speech (TTS) Fine-tuning](/basics/text-to-speech-tts-fine-tuning.md): Learn how to fine-tune TTS & STT voice models with Unsloth.
|
||||
- [Unsloth Dynamic 2.0 GGUFs](/basics/unsloth-dynamic-2.0-ggufs.md): A big new upgrade to our Dynamic Quants!
|
||||
- [Vision Fine-tuning](/basics/vision-fine-tuning.md): Learn how to fine-tune vision/multimodal LLMs with Unsloth
|
||||
- [Fine-tuning LLMs with NVIDIA DGX Spark and Unsloth](/basics/fine-tuning-llms-with-nvidia-dgx-spark-and-unsloth.md): Tutorial on how to fine-tune and do reinforcement learning (RL) with OpenAI gpt-oss on NVIDIA DGX Spark.
|
||||
- [Fine-tuning LLMs with Blackwell, RTX 50 series & Unsloth](/basics/fine-tuning-llms-with-blackwell-rtx-50-series-and-unsloth.md): Learn how to fine-tune LLMs on NVIDIA's Blackwell RTX 50 series and B200 GPUs with our step-by-step guide.
|
||||
- [Multi-GPU Training with Unsloth](/basics/multi-gpu-training-with-unsloth.md): Learn how to fine-tune LLMs on multiple GPUs and parallelism with Unsloth.
|
||||
- [Finetuning from Last Checkpoint](/basics/finetuning-from-last-checkpoint.md): Checkpointing allows you to save your finetuning progress so you can pause it and then continue.
|
||||
- [Troubleshooting & FAQs](/basics/troubleshooting-and-faqs.md): Tips to solve issues, and frequently asked questions.
|
||||
- [Chat Templates](/basics/chat-templates.md): Learn the fundamentals and customization options of chat templates, including Conversational, ChatML, ShareGPT, Alpaca formats, and more!
|
||||
- [Quantization-Aware Training (QAT)](/basics/quantization-aware-training-qat.md): Quantize models to 4-bit with Unsloth and PyTorch to recover accuracy.
|
||||
- [Unsloth Environment Flags](/basics/unsloth-environment-flags.md): Advanced flags which might be useful if you see breaking finetunes, or you want to turn stuff off.
|
||||
- [Continued Pretraining](/basics/continued-pretraining.md): AKA as Continued Finetuning. Unsloth allows you to continually pretrain so a model can learn a new language.
|
||||
- [Unsloth Benchmarks](/basics/unsloth-benchmarks.md): Unsloth recorded benchmarks on NVIDIA GPUs.
|
||||
@@ -0,0 +1,3 @@
|
||||
---
|
||||
description: Vector similarity search and embedding databases for RAG, semantic search, and AI application backends.
|
||||
---
|
||||
Reference in New Issue
Block a user