Adding a lot of skills for Hermes Gerhard
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
---
|
||||
name: evaluating-llms-harness
|
||||
description: Evaluates LLMs across 60+ academic benchmarks (MMLU, HumanEval, GSM8K, TruthfulQA, HellaSwag). Use when benchmarking model quality, comparing models, reporting academic results, or tracking training progress. Industry standard used by EleutherAI, HuggingFace, and major labs. Supports HuggingFace, vLLM, APIs.
|
||||
description: "lm-eval-harness: benchmark LLMs (MMLU, GSM8K, etc.)."
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
@@ -13,6 +13,10 @@ metadata:
|
||||
|
||||
# lm-evaluation-harness - LLM Benchmarking
|
||||
|
||||
## What's inside
|
||||
|
||||
Evaluates LLMs across 60+ academic benchmarks (MMLU, HumanEval, GSM8K, TruthfulQA, HellaSwag). Use when benchmarking model quality, comparing models, reporting academic results, or tracking training progress. Industry standard used by EleutherAI, HuggingFace, and major labs. Supports HuggingFace, vLLM, APIs.
|
||||
|
||||
## Quick start
|
||||
|
||||
lm-evaluation-harness evaluates LLMs across 60+ academic benchmarks using standardized prompts and metrics.
|
||||
|
||||
@@ -0,0 +1,593 @@
|
||||
---
|
||||
name: weights-and-biases
|
||||
description: Track ML experiments with automatic logging, visualize training in real-time, optimize hyperparameters with sweeps, and manage model registry with W&B - collaborative MLOps platform
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [wandb]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [MLOps, Weights And Biases, WandB, Experiment Tracking, Hyperparameter Tuning, Model Registry, Collaboration, Real-Time Visualization, PyTorch, TensorFlow, HuggingFace]
|
||||
|
||||
---
|
||||
|
||||
# Weights & Biases: ML Experiment Tracking & MLOps
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use Weights & Biases (W&B) when you need to:
|
||||
- **Track ML experiments** with automatic metric logging
|
||||
- **Visualize training** in real-time dashboards
|
||||
- **Compare runs** across hyperparameters and configurations
|
||||
- **Optimize hyperparameters** with automated sweeps
|
||||
- **Manage model registry** with versioning and lineage
|
||||
- **Collaborate on ML projects** with team workspaces
|
||||
- **Track artifacts** (datasets, models, code) with lineage
|
||||
|
||||
**Users**: 200,000+ ML practitioners | **GitHub Stars**: 10.5k+ | **Integrations**: 100+
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Install W&B
|
||||
pip install wandb
|
||||
|
||||
# Login (creates API key)
|
||||
wandb login
|
||||
|
||||
# Or set API key programmatically
|
||||
export WANDB_API_KEY=your_api_key_here
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic Experiment Tracking
|
||||
|
||||
```python
|
||||
import wandb
|
||||
|
||||
# Initialize a run
|
||||
run = wandb.init(
|
||||
project="my-project",
|
||||
config={
|
||||
"learning_rate": 0.001,
|
||||
"epochs": 10,
|
||||
"batch_size": 32,
|
||||
"architecture": "ResNet50"
|
||||
}
|
||||
)
|
||||
|
||||
# Training loop
|
||||
for epoch in range(run.config.epochs):
|
||||
# Your training code
|
||||
train_loss = train_epoch()
|
||||
val_loss = validate()
|
||||
|
||||
# Log metrics
|
||||
wandb.log({
|
||||
"epoch": epoch,
|
||||
"train/loss": train_loss,
|
||||
"val/loss": val_loss,
|
||||
"train/accuracy": train_acc,
|
||||
"val/accuracy": val_acc
|
||||
})
|
||||
|
||||
# Finish the run
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
### With PyTorch
|
||||
|
||||
```python
|
||||
import torch
|
||||
import wandb
|
||||
|
||||
# Initialize
|
||||
wandb.init(project="pytorch-demo", config={
|
||||
"lr": 0.001,
|
||||
"epochs": 10
|
||||
})
|
||||
|
||||
# Access config
|
||||
config = wandb.config
|
||||
|
||||
# Training loop
|
||||
for epoch in range(config.epochs):
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
# Forward pass
|
||||
output = model(data)
|
||||
loss = criterion(output, target)
|
||||
|
||||
# Backward pass
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Log every 100 batches
|
||||
if batch_idx % 100 == 0:
|
||||
wandb.log({
|
||||
"loss": loss.item(),
|
||||
"epoch": epoch,
|
||||
"batch": batch_idx
|
||||
})
|
||||
|
||||
# Save model
|
||||
torch.save(model.state_dict(), "model.pth")
|
||||
wandb.save("model.pth") # Upload to W&B
|
||||
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### 1. Projects and Runs
|
||||
|
||||
**Project**: Collection of related experiments
|
||||
**Run**: Single execution of your training script
|
||||
|
||||
```python
|
||||
# Create/use project
|
||||
run = wandb.init(
|
||||
project="image-classification",
|
||||
name="resnet50-experiment-1", # Optional run name
|
||||
tags=["baseline", "resnet"], # Organize with tags
|
||||
notes="First baseline run" # Add notes
|
||||
)
|
||||
|
||||
# Each run has unique ID
|
||||
print(f"Run ID: {run.id}")
|
||||
print(f"Run URL: {run.url}")
|
||||
```
|
||||
|
||||
### 2. Configuration Tracking
|
||||
|
||||
Track hyperparameters automatically:
|
||||
|
||||
```python
|
||||
config = {
|
||||
# Model architecture
|
||||
"model": "ResNet50",
|
||||
"pretrained": True,
|
||||
|
||||
# Training params
|
||||
"learning_rate": 0.001,
|
||||
"batch_size": 32,
|
||||
"epochs": 50,
|
||||
"optimizer": "Adam",
|
||||
|
||||
# Data params
|
||||
"dataset": "ImageNet",
|
||||
"augmentation": "standard"
|
||||
}
|
||||
|
||||
wandb.init(project="my-project", config=config)
|
||||
|
||||
# Access config during training
|
||||
lr = wandb.config.learning_rate
|
||||
batch_size = wandb.config.batch_size
|
||||
```
|
||||
|
||||
### 3. Metric Logging
|
||||
|
||||
```python
|
||||
# Log scalars
|
||||
wandb.log({"loss": 0.5, "accuracy": 0.92})
|
||||
|
||||
# Log multiple metrics
|
||||
wandb.log({
|
||||
"train/loss": train_loss,
|
||||
"train/accuracy": train_acc,
|
||||
"val/loss": val_loss,
|
||||
"val/accuracy": val_acc,
|
||||
"learning_rate": current_lr,
|
||||
"epoch": epoch
|
||||
})
|
||||
|
||||
# Log with custom x-axis
|
||||
wandb.log({"loss": loss}, step=global_step)
|
||||
|
||||
# Log media (images, audio, video)
|
||||
wandb.log({"examples": [wandb.Image(img) for img in images]})
|
||||
|
||||
# Log histograms
|
||||
wandb.log({"gradients": wandb.Histogram(gradients)})
|
||||
|
||||
# Log tables
|
||||
table = wandb.Table(columns=["id", "prediction", "ground_truth"])
|
||||
wandb.log({"predictions": table})
|
||||
```
|
||||
|
||||
### 4. Model Checkpointing
|
||||
|
||||
```python
|
||||
import torch
|
||||
import wandb
|
||||
|
||||
# Save model checkpoint
|
||||
checkpoint = {
|
||||
'epoch': epoch,
|
||||
'model_state_dict': model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'loss': loss,
|
||||
}
|
||||
|
||||
torch.save(checkpoint, 'checkpoint.pth')
|
||||
|
||||
# Upload to W&B
|
||||
wandb.save('checkpoint.pth')
|
||||
|
||||
# Or use Artifacts (recommended)
|
||||
artifact = wandb.Artifact('model', type='model')
|
||||
artifact.add_file('checkpoint.pth')
|
||||
wandb.log_artifact(artifact)
|
||||
```
|
||||
|
||||
## Hyperparameter Sweeps
|
||||
|
||||
Automatically search for optimal hyperparameters.
|
||||
|
||||
### Define Sweep Configuration
|
||||
|
||||
```python
|
||||
sweep_config = {
|
||||
'method': 'bayes', # or 'grid', 'random'
|
||||
'metric': {
|
||||
'name': 'val/accuracy',
|
||||
'goal': 'maximize'
|
||||
},
|
||||
'parameters': {
|
||||
'learning_rate': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-5,
|
||||
'max': 1e-1
|
||||
},
|
||||
'batch_size': {
|
||||
'values': [16, 32, 64, 128]
|
||||
},
|
||||
'optimizer': {
|
||||
'values': ['adam', 'sgd', 'rmsprop']
|
||||
},
|
||||
'dropout': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.1,
|
||||
'max': 0.5
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Initialize sweep
|
||||
sweep_id = wandb.sweep(sweep_config, project="my-project")
|
||||
```
|
||||
|
||||
### Define Training Function
|
||||
|
||||
```python
|
||||
def train():
|
||||
# Initialize run
|
||||
run = wandb.init()
|
||||
|
||||
# Access sweep parameters
|
||||
lr = wandb.config.learning_rate
|
||||
batch_size = wandb.config.batch_size
|
||||
optimizer_name = wandb.config.optimizer
|
||||
|
||||
# Build model with sweep config
|
||||
model = build_model(wandb.config)
|
||||
optimizer = get_optimizer(optimizer_name, lr)
|
||||
|
||||
# Training loop
|
||||
for epoch in range(NUM_EPOCHS):
|
||||
train_loss = train_epoch(model, optimizer, batch_size)
|
||||
val_acc = validate(model)
|
||||
|
||||
# Log metrics
|
||||
wandb.log({
|
||||
"train/loss": train_loss,
|
||||
"val/accuracy": val_acc
|
||||
})
|
||||
|
||||
# Run sweep
|
||||
wandb.agent(sweep_id, function=train, count=50) # Run 50 trials
|
||||
```
|
||||
|
||||
### Sweep Strategies
|
||||
|
||||
```python
|
||||
# Grid search - exhaustive
|
||||
sweep_config = {
|
||||
'method': 'grid',
|
||||
'parameters': {
|
||||
'lr': {'values': [0.001, 0.01, 0.1]},
|
||||
'batch_size': {'values': [16, 32, 64]}
|
||||
}
|
||||
}
|
||||
|
||||
# Random search
|
||||
sweep_config = {
|
||||
'method': 'random',
|
||||
'parameters': {
|
||||
'lr': {'distribution': 'uniform', 'min': 0.0001, 'max': 0.1},
|
||||
'dropout': {'distribution': 'uniform', 'min': 0.1, 'max': 0.5}
|
||||
}
|
||||
}
|
||||
|
||||
# Bayesian optimization (recommended)
|
||||
sweep_config = {
|
||||
'method': 'bayes',
|
||||
'metric': {'name': 'val/loss', 'goal': 'minimize'},
|
||||
'parameters': {
|
||||
'lr': {'distribution': 'log_uniform', 'min': 1e-5, 'max': 1e-1}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Artifacts
|
||||
|
||||
Track datasets, models, and other files with lineage.
|
||||
|
||||
### Log Artifacts
|
||||
|
||||
```python
|
||||
# Create artifact
|
||||
artifact = wandb.Artifact(
|
||||
name='training-dataset',
|
||||
type='dataset',
|
||||
description='ImageNet training split',
|
||||
metadata={'size': '1.2M images', 'split': 'train'}
|
||||
)
|
||||
|
||||
# Add files
|
||||
artifact.add_file('data/train.csv')
|
||||
artifact.add_dir('data/images/')
|
||||
|
||||
# Log artifact
|
||||
wandb.log_artifact(artifact)
|
||||
```
|
||||
|
||||
### Use Artifacts
|
||||
|
||||
```python
|
||||
# Download and use artifact
|
||||
run = wandb.init(project="my-project")
|
||||
|
||||
# Download artifact
|
||||
artifact = run.use_artifact('training-dataset:latest')
|
||||
artifact_dir = artifact.download()
|
||||
|
||||
# Use the data
|
||||
data = load_data(f"{artifact_dir}/train.csv")
|
||||
```
|
||||
|
||||
### Model Registry
|
||||
|
||||
```python
|
||||
# Log model as artifact
|
||||
model_artifact = wandb.Artifact(
|
||||
name='resnet50-model',
|
||||
type='model',
|
||||
metadata={'architecture': 'ResNet50', 'accuracy': 0.95}
|
||||
)
|
||||
|
||||
model_artifact.add_file('model.pth')
|
||||
wandb.log_artifact(model_artifact, aliases=['best', 'production'])
|
||||
|
||||
# Link to model registry
|
||||
run.link_artifact(model_artifact, 'model-registry/production-models')
|
||||
```
|
||||
|
||||
## Integration Examples
|
||||
|
||||
### HuggingFace Transformers
|
||||
|
||||
```python
|
||||
from transformers import Trainer, TrainingArguments
|
||||
import wandb
|
||||
|
||||
# Initialize W&B
|
||||
wandb.init(project="hf-transformers")
|
||||
|
||||
# Training arguments with W&B
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
report_to="wandb", # Enable W&B logging
|
||||
run_name="bert-finetuning",
|
||||
logging_steps=100,
|
||||
save_steps=500
|
||||
)
|
||||
|
||||
# Trainer automatically logs to W&B
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### PyTorch Lightning
|
||||
|
||||
```python
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
import wandb
|
||||
|
||||
# Create W&B logger
|
||||
wandb_logger = WandbLogger(
|
||||
project="lightning-demo",
|
||||
log_model=True # Log model checkpoints
|
||||
)
|
||||
|
||||
# Use with Trainer
|
||||
trainer = Trainer(
|
||||
logger=wandb_logger,
|
||||
max_epochs=10
|
||||
)
|
||||
|
||||
trainer.fit(model, datamodule=dm)
|
||||
```
|
||||
|
||||
### Keras/TensorFlow
|
||||
|
||||
```python
|
||||
import wandb
|
||||
from wandb.keras import WandbCallback
|
||||
|
||||
# Initialize
|
||||
wandb.init(project="keras-demo")
|
||||
|
||||
# Add callback
|
||||
model.fit(
|
||||
x_train, y_train,
|
||||
validation_data=(x_val, y_val),
|
||||
epochs=10,
|
||||
callbacks=[WandbCallback()] # Auto-logs metrics
|
||||
)
|
||||
```
|
||||
|
||||
## Visualization & Analysis
|
||||
|
||||
### Custom Charts
|
||||
|
||||
```python
|
||||
# Log custom visualizations
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.plot(x, y)
|
||||
wandb.log({"custom_plot": wandb.Image(fig)})
|
||||
|
||||
# Log confusion matrix
|
||||
wandb.log({"conf_mat": wandb.plot.confusion_matrix(
|
||||
probs=None,
|
||||
y_true=ground_truth,
|
||||
preds=predictions,
|
||||
class_names=class_names
|
||||
)})
|
||||
```
|
||||
|
||||
### Reports
|
||||
|
||||
Create shareable reports in W&B UI:
|
||||
- Combine runs, charts, and text
|
||||
- Markdown support
|
||||
- Embeddable visualizations
|
||||
- Team collaboration
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Organize with Tags and Groups
|
||||
|
||||
```python
|
||||
wandb.init(
|
||||
project="my-project",
|
||||
tags=["baseline", "resnet50", "imagenet"],
|
||||
group="resnet-experiments", # Group related runs
|
||||
job_type="train" # Type of job
|
||||
)
|
||||
```
|
||||
|
||||
### 2. Log Everything Relevant
|
||||
|
||||
```python
|
||||
# Log system metrics
|
||||
wandb.log({
|
||||
"gpu/util": gpu_utilization,
|
||||
"gpu/memory": gpu_memory_used,
|
||||
"cpu/util": cpu_utilization
|
||||
})
|
||||
|
||||
# Log code version
|
||||
wandb.log({"git_commit": git_commit_hash})
|
||||
|
||||
# Log data splits
|
||||
wandb.log({
|
||||
"data/train_size": len(train_dataset),
|
||||
"data/val_size": len(val_dataset)
|
||||
})
|
||||
```
|
||||
|
||||
### 3. Use Descriptive Names
|
||||
|
||||
```python
|
||||
# ✅ Good: Descriptive run names
|
||||
wandb.init(
|
||||
project="nlp-classification",
|
||||
name="bert-base-lr0.001-bs32-epoch10"
|
||||
)
|
||||
|
||||
# ❌ Bad: Generic names
|
||||
wandb.init(project="nlp", name="run1")
|
||||
```
|
||||
|
||||
### 4. Save Important Artifacts
|
||||
|
||||
```python
|
||||
# Save final model
|
||||
artifact = wandb.Artifact('final-model', type='model')
|
||||
artifact.add_file('model.pth')
|
||||
wandb.log_artifact(artifact)
|
||||
|
||||
# Save predictions for analysis
|
||||
predictions_table = wandb.Table(
|
||||
columns=["id", "input", "prediction", "ground_truth"],
|
||||
data=predictions_data
|
||||
)
|
||||
wandb.log({"predictions": predictions_table})
|
||||
```
|
||||
|
||||
### 5. Use Offline Mode for Unstable Connections
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
# Enable offline mode
|
||||
os.environ["WANDB_MODE"] = "offline"
|
||||
|
||||
wandb.init(project="my-project")
|
||||
# ... your code ...
|
||||
|
||||
# Sync later
|
||||
# wandb sync <run_directory>
|
||||
```
|
||||
|
||||
## Team Collaboration
|
||||
|
||||
### Share Runs
|
||||
|
||||
```python
|
||||
# Runs are automatically shareable via URL
|
||||
run = wandb.init(project="team-project")
|
||||
print(f"Share this URL: {run.url}")
|
||||
```
|
||||
|
||||
### Team Projects
|
||||
|
||||
- Create team account at wandb.ai
|
||||
- Add team members
|
||||
- Set project visibility (private/public)
|
||||
- Use team-level artifacts and model registry
|
||||
|
||||
## Pricing
|
||||
|
||||
- **Free**: Unlimited public projects, 100GB storage
|
||||
- **Academic**: Free for students/researchers
|
||||
- **Teams**: $50/seat/month, private projects, unlimited storage
|
||||
- **Enterprise**: Custom pricing, on-prem options
|
||||
|
||||
## Resources
|
||||
|
||||
- **Documentation**: https://docs.wandb.ai
|
||||
- **GitHub**: https://github.com/wandb/wandb (10.5k+ stars)
|
||||
- **Examples**: https://github.com/wandb/examples
|
||||
- **Community**: https://wandb.ai/community
|
||||
- **Discord**: https://wandb.me/discord
|
||||
|
||||
## See Also
|
||||
|
||||
- `references/sweeps.md` - Comprehensive hyperparameter optimization guide
|
||||
- `references/artifacts.md` - Data and model versioning patterns
|
||||
- `references/integrations.md` - Framework-specific examples
|
||||
|
||||
|
||||
+584
@@ -0,0 +1,584 @@
|
||||
# Artifacts & Model Registry Guide
|
||||
|
||||
Complete guide to data versioning and model management with W&B Artifacts.
|
||||
|
||||
## Table of Contents
|
||||
- What are Artifacts
|
||||
- Creating Artifacts
|
||||
- Using Artifacts
|
||||
- Model Registry
|
||||
- Versioning & Lineage
|
||||
- Best Practices
|
||||
|
||||
## What are Artifacts
|
||||
|
||||
Artifacts are versioned datasets, models, or files tracked with lineage.
|
||||
|
||||
**Key Features:**
|
||||
- Automatic versioning (v0, v1, v2...)
|
||||
- Lineage tracking (which runs produced/used artifacts)
|
||||
- Efficient storage (deduplication)
|
||||
- Collaboration (team-wide access)
|
||||
- Aliases (latest, best, production)
|
||||
|
||||
**Common Use Cases:**
|
||||
- Dataset versioning
|
||||
- Model checkpoints
|
||||
- Preprocessed data
|
||||
- Evaluation results
|
||||
- Configuration files
|
||||
|
||||
## Creating Artifacts
|
||||
|
||||
### Basic Dataset Artifact
|
||||
|
||||
```python
|
||||
import wandb
|
||||
|
||||
run = wandb.init(project="my-project")
|
||||
|
||||
# Create artifact
|
||||
dataset = wandb.Artifact(
|
||||
name='training-data',
|
||||
type='dataset',
|
||||
description='ImageNet training split with augmentations',
|
||||
metadata={
|
||||
'size': '1.2M images',
|
||||
'format': 'JPEG',
|
||||
'resolution': '224x224'
|
||||
}
|
||||
)
|
||||
|
||||
# Add files
|
||||
dataset.add_file('data/train.csv') # Single file
|
||||
dataset.add_dir('data/images') # Entire directory
|
||||
dataset.add_reference('s3://bucket/data') # Cloud reference
|
||||
|
||||
# Log artifact
|
||||
run.log_artifact(dataset)
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
### Model Artifact
|
||||
|
||||
```python
|
||||
import torch
|
||||
import wandb
|
||||
|
||||
run = wandb.init(project="my-project")
|
||||
|
||||
# Train model
|
||||
model = train_model()
|
||||
|
||||
# Save model
|
||||
torch.save(model.state_dict(), 'model.pth')
|
||||
|
||||
# Create model artifact
|
||||
model_artifact = wandb.Artifact(
|
||||
name='resnet50-classifier',
|
||||
type='model',
|
||||
description='ResNet50 trained on ImageNet',
|
||||
metadata={
|
||||
'architecture': 'ResNet50',
|
||||
'accuracy': 0.95,
|
||||
'loss': 0.15,
|
||||
'epochs': 50,
|
||||
'framework': 'PyTorch'
|
||||
}
|
||||
)
|
||||
|
||||
# Add model file
|
||||
model_artifact.add_file('model.pth')
|
||||
|
||||
# Add config
|
||||
model_artifact.add_file('config.yaml')
|
||||
|
||||
# Log with aliases
|
||||
run.log_artifact(model_artifact, aliases=['latest', 'best'])
|
||||
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
### Preprocessed Data Artifact
|
||||
|
||||
```python
|
||||
import pandas as pd
|
||||
import wandb
|
||||
|
||||
run = wandb.init(project="nlp-project")
|
||||
|
||||
# Preprocess data
|
||||
df = pd.read_csv('raw_data.csv')
|
||||
df_processed = preprocess(df)
|
||||
df_processed.to_csv('processed_data.csv', index=False)
|
||||
|
||||
# Create artifact
|
||||
processed_data = wandb.Artifact(
|
||||
name='processed-text-data',
|
||||
type='dataset',
|
||||
metadata={
|
||||
'rows': len(df_processed),
|
||||
'columns': list(df_processed.columns),
|
||||
'preprocessing_steps': ['lowercase', 'remove_stopwords', 'tokenize']
|
||||
}
|
||||
)
|
||||
|
||||
processed_data.add_file('processed_data.csv')
|
||||
|
||||
# Log artifact
|
||||
run.log_artifact(processed_data)
|
||||
```
|
||||
|
||||
## Using Artifacts
|
||||
|
||||
### Download and Use
|
||||
|
||||
```python
|
||||
import wandb
|
||||
|
||||
run = wandb.init(project="my-project")
|
||||
|
||||
# Download artifact
|
||||
artifact = run.use_artifact('training-data:latest')
|
||||
artifact_dir = artifact.download()
|
||||
|
||||
# Use files
|
||||
import pandas as pd
|
||||
df = pd.read_csv(f'{artifact_dir}/train.csv')
|
||||
|
||||
# Train with artifact data
|
||||
model = train_model(df)
|
||||
```
|
||||
|
||||
### Use Specific Version
|
||||
|
||||
```python
|
||||
# Use specific version
|
||||
artifact_v2 = run.use_artifact('training-data:v2')
|
||||
|
||||
# Use alias
|
||||
artifact_best = run.use_artifact('model:best')
|
||||
artifact_prod = run.use_artifact('model:production')
|
||||
|
||||
# Use from another project
|
||||
artifact = run.use_artifact('team/other-project/model:latest')
|
||||
```
|
||||
|
||||
### Check Artifact Metadata
|
||||
|
||||
```python
|
||||
artifact = run.use_artifact('training-data:latest')
|
||||
|
||||
# Access metadata
|
||||
print(artifact.metadata)
|
||||
print(f"Size: {artifact.metadata['size']}")
|
||||
|
||||
# Access version info
|
||||
print(f"Version: {artifact.version}")
|
||||
print(f"Created at: {artifact.created_at}")
|
||||
print(f"Digest: {artifact.digest}")
|
||||
```
|
||||
|
||||
## Model Registry
|
||||
|
||||
Link models to a central registry for governance and deployment.
|
||||
|
||||
### Create Model Registry
|
||||
|
||||
```python
|
||||
# In W&B UI:
|
||||
# 1. Go to "Registry" tab
|
||||
# 2. Create new registry: "production-models"
|
||||
# 3. Define stages: development, staging, production
|
||||
```
|
||||
|
||||
### Link Model to Registry
|
||||
|
||||
```python
|
||||
import wandb
|
||||
|
||||
run = wandb.init(project="training")
|
||||
|
||||
# Create model artifact
|
||||
model_artifact = wandb.Artifact(
|
||||
name='sentiment-classifier',
|
||||
type='model',
|
||||
metadata={'accuracy': 0.94, 'f1': 0.92}
|
||||
)
|
||||
|
||||
model_artifact.add_file('model.pth')
|
||||
|
||||
# Log artifact
|
||||
run.log_artifact(model_artifact)
|
||||
|
||||
# Link to registry
|
||||
run.link_artifact(
|
||||
model_artifact,
|
||||
'model-registry/production-models',
|
||||
aliases=['staging'] # Deploy to staging
|
||||
)
|
||||
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
### Promote Model in Registry
|
||||
|
||||
```python
|
||||
# Retrieve model from registry
|
||||
api = wandb.Api()
|
||||
artifact = api.artifact('model-registry/production-models/sentiment-classifier:staging')
|
||||
|
||||
# Promote to production
|
||||
artifact.link('model-registry/production-models', aliases=['production'])
|
||||
|
||||
# Demote from production
|
||||
artifact.aliases = ['archived']
|
||||
artifact.save()
|
||||
```
|
||||
|
||||
### Use Model from Registry
|
||||
|
||||
```python
|
||||
import wandb
|
||||
|
||||
run = wandb.init()
|
||||
|
||||
# Download production model
|
||||
model_artifact = run.use_artifact(
|
||||
'model-registry/production-models/sentiment-classifier:production'
|
||||
)
|
||||
|
||||
model_dir = model_artifact.download()
|
||||
|
||||
# Load and use
|
||||
import torch
|
||||
model = torch.load(f'{model_dir}/model.pth')
|
||||
model.eval()
|
||||
```
|
||||
|
||||
## Versioning & Lineage
|
||||
|
||||
### Automatic Versioning
|
||||
|
||||
```python
|
||||
# First log: creates v0
|
||||
run1 = wandb.init(project="my-project")
|
||||
dataset_v0 = wandb.Artifact('my-dataset', type='dataset')
|
||||
dataset_v0.add_file('data_v1.csv')
|
||||
run1.log_artifact(dataset_v0)
|
||||
|
||||
# Second log with same name: creates v1
|
||||
run2 = wandb.init(project="my-project")
|
||||
dataset_v1 = wandb.Artifact('my-dataset', type='dataset')
|
||||
dataset_v1.add_file('data_v2.csv') # Different content
|
||||
run2.log_artifact(dataset_v1)
|
||||
|
||||
# Third log with SAME content as v1: references v1 (no new version)
|
||||
run3 = wandb.init(project="my-project")
|
||||
dataset_v1_again = wandb.Artifact('my-dataset', type='dataset')
|
||||
dataset_v1_again.add_file('data_v2.csv') # Same content as v1
|
||||
run3.log_artifact(dataset_v1_again) # Still v1, no v2 created
|
||||
```
|
||||
|
||||
### Track Lineage
|
||||
|
||||
```python
|
||||
# Training run
|
||||
run = wandb.init(project="my-project")
|
||||
|
||||
# Use dataset (input)
|
||||
dataset = run.use_artifact('training-data:v3')
|
||||
data = load_data(dataset.download())
|
||||
|
||||
# Train model
|
||||
model = train(data)
|
||||
|
||||
# Save model (output)
|
||||
model_artifact = wandb.Artifact('trained-model', type='model')
|
||||
torch.save(model.state_dict(), 'model.pth')
|
||||
model_artifact.add_file('model.pth')
|
||||
run.log_artifact(model_artifact)
|
||||
|
||||
# Lineage automatically tracked:
|
||||
# training-data:v3 --> [run] --> trained-model:v0
|
||||
```
|
||||
|
||||
### View Lineage Graph
|
||||
|
||||
```python
|
||||
# In W&B UI:
|
||||
# Artifacts → Select artifact → Lineage tab
|
||||
# Shows:
|
||||
# - Which runs produced this artifact
|
||||
# - Which runs used this artifact
|
||||
# - Parent/child artifacts
|
||||
```
|
||||
|
||||
## Artifact Types
|
||||
|
||||
### Dataset Artifacts
|
||||
|
||||
```python
|
||||
# Raw data
|
||||
raw_data = wandb.Artifact('raw-data', type='dataset')
|
||||
raw_data.add_dir('raw/')
|
||||
|
||||
# Processed data
|
||||
processed_data = wandb.Artifact('processed-data', type='dataset')
|
||||
processed_data.add_dir('processed/')
|
||||
|
||||
# Train/val/test splits
|
||||
train_split = wandb.Artifact('train-split', type='dataset')
|
||||
train_split.add_file('train.csv')
|
||||
|
||||
val_split = wandb.Artifact('val-split', type='dataset')
|
||||
val_split.add_file('val.csv')
|
||||
```
|
||||
|
||||
### Model Artifacts
|
||||
|
||||
```python
|
||||
# Checkpoint during training
|
||||
checkpoint = wandb.Artifact('checkpoint-epoch-10', type='model')
|
||||
checkpoint.add_file('checkpoint_epoch_10.pth')
|
||||
|
||||
# Final model
|
||||
final_model = wandb.Artifact('final-model', type='model')
|
||||
final_model.add_file('model.pth')
|
||||
final_model.add_file('tokenizer.json')
|
||||
|
||||
# Quantized model
|
||||
quantized = wandb.Artifact('quantized-model', type='model')
|
||||
quantized.add_file('model_int8.onnx')
|
||||
```
|
||||
|
||||
### Result Artifacts
|
||||
|
||||
```python
|
||||
# Predictions
|
||||
predictions = wandb.Artifact('test-predictions', type='predictions')
|
||||
predictions.add_file('predictions.csv')
|
||||
|
||||
# Evaluation metrics
|
||||
eval_results = wandb.Artifact('evaluation', type='evaluation')
|
||||
eval_results.add_file('metrics.json')
|
||||
eval_results.add_file('confusion_matrix.png')
|
||||
```
|
||||
|
||||
## Advanced Patterns
|
||||
|
||||
### Incremental Artifacts
|
||||
|
||||
Add files incrementally without re-uploading.
|
||||
|
||||
```python
|
||||
run = wandb.init(project="my-project")
|
||||
|
||||
# Create artifact
|
||||
dataset = wandb.Artifact('incremental-dataset', type='dataset')
|
||||
|
||||
# Add files incrementally
|
||||
for i in range(100):
|
||||
filename = f'batch_{i}.csv'
|
||||
process_batch(i, filename)
|
||||
dataset.add_file(filename)
|
||||
|
||||
# Log progress
|
||||
if (i + 1) % 10 == 0:
|
||||
print(f"Added {i + 1}/100 batches")
|
||||
|
||||
# Log complete artifact
|
||||
run.log_artifact(dataset)
|
||||
```
|
||||
|
||||
### Artifact Tables
|
||||
|
||||
Track structured data with W&B Tables.
|
||||
|
||||
```python
|
||||
import wandb
|
||||
|
||||
run = wandb.init(project="my-project")
|
||||
|
||||
# Create table
|
||||
table = wandb.Table(columns=["id", "image", "label", "prediction"])
|
||||
|
||||
for idx, (img, label, pred) in enumerate(zip(images, labels, predictions)):
|
||||
table.add_data(
|
||||
idx,
|
||||
wandb.Image(img),
|
||||
label,
|
||||
pred
|
||||
)
|
||||
|
||||
# Log as artifact
|
||||
artifact = wandb.Artifact('predictions-table', type='predictions')
|
||||
artifact.add(table, "predictions")
|
||||
run.log_artifact(artifact)
|
||||
```
|
||||
|
||||
### Artifact References
|
||||
|
||||
Reference external data without copying.
|
||||
|
||||
```python
|
||||
# S3 reference
|
||||
dataset = wandb.Artifact('s3-dataset', type='dataset')
|
||||
dataset.add_reference('s3://my-bucket/data/', name='train')
|
||||
dataset.add_reference('s3://my-bucket/labels/', name='labels')
|
||||
|
||||
# GCS reference
|
||||
dataset.add_reference('gs://my-bucket/data/')
|
||||
|
||||
# HTTP reference
|
||||
dataset.add_reference('https://example.com/data.zip')
|
||||
|
||||
# Local filesystem reference (for shared storage)
|
||||
dataset.add_reference('file:///mnt/shared/data')
|
||||
```
|
||||
|
||||
## Collaboration Patterns
|
||||
|
||||
### Team Dataset Sharing
|
||||
|
||||
```python
|
||||
# Data engineer creates dataset
|
||||
run = wandb.init(project="data-eng", entity="my-team")
|
||||
dataset = wandb.Artifact('shared-dataset', type='dataset')
|
||||
dataset.add_dir('data/')
|
||||
run.log_artifact(dataset, aliases=['latest', 'production'])
|
||||
|
||||
# ML engineer uses dataset
|
||||
run = wandb.init(project="ml-training", entity="my-team")
|
||||
dataset = run.use_artifact('my-team/data-eng/shared-dataset:production')
|
||||
data = load_data(dataset.download())
|
||||
```
|
||||
|
||||
### Model Handoff
|
||||
|
||||
```python
|
||||
# Training team
|
||||
train_run = wandb.init(project="model-training", entity="ml-team")
|
||||
model = train_model()
|
||||
model_artifact = wandb.Artifact('nlp-model', type='model')
|
||||
model_artifact.add_file('model.pth')
|
||||
train_run.log_artifact(model_artifact)
|
||||
train_run.link_artifact(model_artifact, 'model-registry/nlp-models', aliases=['candidate'])
|
||||
|
||||
# Evaluation team
|
||||
eval_run = wandb.init(project="model-eval", entity="ml-team")
|
||||
model_artifact = eval_run.use_artifact('model-registry/nlp-models/nlp-model:candidate')
|
||||
metrics = evaluate_model(model_artifact)
|
||||
|
||||
if metrics['f1'] > 0.9:
|
||||
# Promote to production
|
||||
model_artifact.link('model-registry/nlp-models', aliases=['production'])
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Use Descriptive Names
|
||||
|
||||
```python
|
||||
# ✅ Good: Descriptive names
|
||||
wandb.Artifact('imagenet-train-augmented-v2', type='dataset')
|
||||
wandb.Artifact('bert-base-sentiment-finetuned', type='model')
|
||||
|
||||
# ❌ Bad: Generic names
|
||||
wandb.Artifact('dataset1', type='dataset')
|
||||
wandb.Artifact('model', type='model')
|
||||
```
|
||||
|
||||
### 2. Add Comprehensive Metadata
|
||||
|
||||
```python
|
||||
model_artifact = wandb.Artifact(
|
||||
'production-model',
|
||||
type='model',
|
||||
description='ResNet50 classifier for product categorization',
|
||||
metadata={
|
||||
# Model info
|
||||
'architecture': 'ResNet50',
|
||||
'framework': 'PyTorch 2.0',
|
||||
'pretrained': True,
|
||||
|
||||
# Performance
|
||||
'accuracy': 0.95,
|
||||
'f1_score': 0.93,
|
||||
'inference_time_ms': 15,
|
||||
|
||||
# Training
|
||||
'epochs': 50,
|
||||
'dataset': 'imagenet',
|
||||
'num_samples': 1200000,
|
||||
|
||||
# Business context
|
||||
'use_case': 'e-commerce product classification',
|
||||
'owner': 'ml-team@company.com',
|
||||
'approved_by': 'data-science-lead'
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Use Aliases for Deployment Stages
|
||||
|
||||
```python
|
||||
# Development
|
||||
run.log_artifact(model, aliases=['dev', 'latest'])
|
||||
|
||||
# Staging
|
||||
run.log_artifact(model, aliases=['staging'])
|
||||
|
||||
# Production
|
||||
run.log_artifact(model, aliases=['production', 'v1.2.0'])
|
||||
|
||||
# Archive old versions
|
||||
old_artifact = api.artifact('model:production')
|
||||
old_artifact.aliases = ['archived-v1.1.0']
|
||||
old_artifact.save()
|
||||
```
|
||||
|
||||
### 4. Track Data Lineage
|
||||
|
||||
```python
|
||||
def create_training_pipeline():
|
||||
run = wandb.init(project="pipeline")
|
||||
|
||||
# 1. Load raw data
|
||||
raw_data = run.use_artifact('raw-data:latest')
|
||||
|
||||
# 2. Preprocess
|
||||
processed = preprocess(raw_data)
|
||||
processed_artifact = wandb.Artifact('processed-data', type='dataset')
|
||||
processed_artifact.add_file('processed.csv')
|
||||
run.log_artifact(processed_artifact)
|
||||
|
||||
# 3. Train model
|
||||
model = train(processed)
|
||||
model_artifact = wandb.Artifact('trained-model', type='model')
|
||||
model_artifact.add_file('model.pth')
|
||||
run.log_artifact(model_artifact)
|
||||
|
||||
# Lineage: raw-data → processed-data → trained-model
|
||||
```
|
||||
|
||||
### 5. Efficient Storage
|
||||
|
||||
```python
|
||||
# ✅ Good: Reference large files
|
||||
large_dataset = wandb.Artifact('large-dataset', type='dataset')
|
||||
large_dataset.add_reference('s3://bucket/huge-file.tar.gz')
|
||||
|
||||
# ❌ Bad: Upload giant files
|
||||
# large_dataset.add_file('huge-file.tar.gz') # Don't do this
|
||||
|
||||
# ✅ Good: Upload only metadata
|
||||
metadata_artifact = wandb.Artifact('dataset-metadata', type='dataset')
|
||||
metadata_artifact.add_file('metadata.json') # Small file
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **Artifacts Documentation**: https://docs.wandb.ai/guides/artifacts
|
||||
- **Model Registry**: https://docs.wandb.ai/guides/model-registry
|
||||
- **Best Practices**: https://wandb.ai/site/articles/versioning-data-and-models-in-ml
|
||||
+700
@@ -0,0 +1,700 @@
|
||||
# Framework Integrations Guide
|
||||
|
||||
Complete guide to integrating W&B with popular ML frameworks.
|
||||
|
||||
## Table of Contents
|
||||
- HuggingFace Transformers
|
||||
- PyTorch Lightning
|
||||
- Keras/TensorFlow
|
||||
- Fast.ai
|
||||
- XGBoost/LightGBM
|
||||
- PyTorch Native
|
||||
- Custom Integrations
|
||||
|
||||
## HuggingFace Transformers
|
||||
|
||||
### Automatic Integration
|
||||
|
||||
```python
|
||||
from transformers import Trainer, TrainingArguments
|
||||
import wandb
|
||||
|
||||
# Initialize W&B
|
||||
wandb.init(project="hf-transformers", name="bert-finetuning")
|
||||
|
||||
# Training arguments with W&B
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
report_to="wandb", # Enable W&B logging
|
||||
run_name="bert-base-finetuning",
|
||||
|
||||
# Training params
|
||||
num_train_epochs=3,
|
||||
per_device_train_batch_size=16,
|
||||
per_device_eval_batch_size=64,
|
||||
learning_rate=2e-5,
|
||||
|
||||
# Logging
|
||||
logging_dir="./logs",
|
||||
logging_steps=100,
|
||||
logging_first_step=True,
|
||||
|
||||
# Evaluation
|
||||
evaluation_strategy="steps",
|
||||
eval_steps=500,
|
||||
save_steps=500,
|
||||
|
||||
# Other
|
||||
load_best_model_at_end=True,
|
||||
metric_for_best_model="eval_accuracy"
|
||||
)
|
||||
|
||||
# Trainer automatically logs to W&B
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
compute_metrics=compute_metrics
|
||||
)
|
||||
|
||||
# Train (metrics logged automatically)
|
||||
trainer.train()
|
||||
|
||||
# Finish W&B run
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
### Custom Logging
|
||||
|
||||
```python
|
||||
from transformers import Trainer, TrainingArguments
|
||||
from transformers.integrations import WandbCallback
|
||||
import wandb
|
||||
|
||||
class CustomWandbCallback(WandbCallback):
|
||||
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
|
||||
super().on_evaluate(args, state, control, metrics, **kwargs)
|
||||
|
||||
# Log custom metrics
|
||||
wandb.log({
|
||||
"custom/eval_score": metrics["eval_accuracy"] * 100,
|
||||
"custom/epoch": state.epoch
|
||||
})
|
||||
|
||||
# Use custom callback
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
callbacks=[CustomWandbCallback()]
|
||||
)
|
||||
```
|
||||
|
||||
### Log Model to Registry
|
||||
|
||||
```python
|
||||
from transformers import Trainer, TrainingArguments
|
||||
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
report_to="wandb",
|
||||
load_best_model_at_end=True
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
# Save final model as artifact
|
||||
model_artifact = wandb.Artifact(
|
||||
'hf-bert-model',
|
||||
type='model',
|
||||
description='BERT finetuned on sentiment analysis'
|
||||
)
|
||||
|
||||
# Save model files
|
||||
trainer.save_model("./final_model")
|
||||
model_artifact.add_dir("./final_model")
|
||||
|
||||
# Log artifact
|
||||
wandb.log_artifact(model_artifact, aliases=['best', 'production'])
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
## PyTorch Lightning
|
||||
|
||||
### Basic Integration
|
||||
|
||||
```python
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
import wandb
|
||||
|
||||
# Create W&B logger
|
||||
wandb_logger = WandbLogger(
|
||||
project="lightning-demo",
|
||||
name="resnet50-training",
|
||||
log_model=True, # Log model checkpoints as artifacts
|
||||
save_code=True # Save code as artifact
|
||||
)
|
||||
|
||||
# Lightning module
|
||||
class LitModel(pl.LightningModule):
|
||||
def __init__(self, learning_rate=0.001):
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
self.model = create_model()
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self.model(x)
|
||||
loss = F.cross_entropy(y_hat, y)
|
||||
|
||||
# Log metrics (automatically sent to W&B)
|
||||
self.log('train/loss', loss, on_step=True, on_epoch=True)
|
||||
self.log('train/accuracy', accuracy(y_hat, y), on_epoch=True)
|
||||
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self.model(x)
|
||||
loss = F.cross_entropy(y_hat, y)
|
||||
|
||||
self.log('val/loss', loss, on_step=False, on_epoch=True)
|
||||
self.log('val/accuracy', accuracy(y_hat, y), on_epoch=True)
|
||||
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
|
||||
|
||||
# Trainer with W&B logger
|
||||
trainer = pl.Trainer(
|
||||
logger=wandb_logger,
|
||||
max_epochs=10,
|
||||
accelerator="gpu",
|
||||
devices=1
|
||||
)
|
||||
|
||||
# Train (metrics logged automatically)
|
||||
trainer.fit(model, datamodule=dm)
|
||||
|
||||
# Finish W&B run
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
### Log Media
|
||||
|
||||
```python
|
||||
class LitModel(pl.LightningModule):
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self.model(x)
|
||||
|
||||
# Log images (first batch only)
|
||||
if batch_idx == 0:
|
||||
self.logger.experiment.log({
|
||||
"examples": [wandb.Image(img) for img in x[:8]]
|
||||
})
|
||||
|
||||
return loss
|
||||
|
||||
def on_validation_epoch_end(self):
|
||||
# Log confusion matrix
|
||||
cm = compute_confusion_matrix(self.all_preds, self.all_targets)
|
||||
|
||||
self.logger.experiment.log({
|
||||
"confusion_matrix": wandb.plot.confusion_matrix(
|
||||
probs=None,
|
||||
y_true=self.all_targets,
|
||||
preds=self.all_preds,
|
||||
class_names=self.class_names
|
||||
)
|
||||
})
|
||||
```
|
||||
|
||||
### Hyperparameter Sweeps
|
||||
|
||||
```python
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
import wandb
|
||||
|
||||
# Define sweep
|
||||
sweep_config = {
|
||||
'method': 'bayes',
|
||||
'metric': {'name': 'val/accuracy', 'goal': 'maximize'},
|
||||
'parameters': {
|
||||
'learning_rate': {'min': 1e-5, 'max': 1e-2, 'distribution': 'log_uniform'},
|
||||
'batch_size': {'values': [16, 32, 64]},
|
||||
'hidden_size': {'values': [128, 256, 512]}
|
||||
}
|
||||
}
|
||||
|
||||
sweep_id = wandb.sweep(sweep_config, project="lightning-sweeps")
|
||||
|
||||
def train():
|
||||
# Initialize W&B
|
||||
run = wandb.init()
|
||||
|
||||
# Get hyperparameters
|
||||
config = wandb.config
|
||||
|
||||
# Create logger
|
||||
wandb_logger = WandbLogger()
|
||||
|
||||
# Create model with sweep params
|
||||
model = LitModel(
|
||||
learning_rate=config.learning_rate,
|
||||
hidden_size=config.hidden_size
|
||||
)
|
||||
|
||||
# Create datamodule with sweep batch size
|
||||
dm = DataModule(batch_size=config.batch_size)
|
||||
|
||||
# Train
|
||||
trainer = pl.Trainer(logger=wandb_logger, max_epochs=10)
|
||||
trainer.fit(model, dm)
|
||||
|
||||
# Run sweep
|
||||
wandb.agent(sweep_id, function=train, count=30)
|
||||
```
|
||||
|
||||
## Keras/TensorFlow
|
||||
|
||||
### With Callback
|
||||
|
||||
```python
|
||||
import tensorflow as tf
|
||||
from wandb.keras import WandbCallback
|
||||
import wandb
|
||||
|
||||
# Initialize W&B
|
||||
wandb.init(
|
||||
project="keras-demo",
|
||||
config={
|
||||
"learning_rate": 0.001,
|
||||
"epochs": 10,
|
||||
"batch_size": 32
|
||||
}
|
||||
)
|
||||
|
||||
config = wandb.config
|
||||
|
||||
# Build model
|
||||
model = tf.keras.Sequential([
|
||||
tf.keras.layers.Dense(128, activation='relu'),
|
||||
tf.keras.layers.Dropout(0.2),
|
||||
tf.keras.layers.Dense(10, activation='softmax')
|
||||
])
|
||||
|
||||
model.compile(
|
||||
optimizer=tf.keras.optimizers.Adam(config.learning_rate),
|
||||
loss='sparse_categorical_crossentropy',
|
||||
metrics=['accuracy']
|
||||
)
|
||||
|
||||
# Train with W&B callback
|
||||
history = model.fit(
|
||||
x_train, y_train,
|
||||
validation_data=(x_val, y_val),
|
||||
epochs=config.epochs,
|
||||
batch_size=config.batch_size,
|
||||
callbacks=[
|
||||
WandbCallback(
|
||||
log_weights=True, # Log model weights
|
||||
log_gradients=True, # Log gradients
|
||||
training_data=(x_train, y_train),
|
||||
validation_data=(x_val, y_val),
|
||||
labels=class_names
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Save model as artifact
|
||||
model.save('model.h5')
|
||||
artifact = wandb.Artifact('keras-model', type='model')
|
||||
artifact.add_file('model.h5')
|
||||
wandb.log_artifact(artifact)
|
||||
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
### Custom Training Loop
|
||||
|
||||
```python
|
||||
import tensorflow as tf
|
||||
import wandb
|
||||
|
||||
wandb.init(project="tf-custom-loop")
|
||||
|
||||
# Model, optimizer, loss
|
||||
model = create_model()
|
||||
optimizer = tf.keras.optimizers.Adam(1e-3)
|
||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
|
||||
|
||||
# Metrics
|
||||
train_loss = tf.keras.metrics.Mean(name='train_loss')
|
||||
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
|
||||
|
||||
@tf.function
|
||||
def train_step(x, y):
|
||||
with tf.GradientTape() as tape:
|
||||
predictions = model(x, training=True)
|
||||
loss = loss_fn(y, predictions)
|
||||
|
||||
gradients = tape.gradient(loss, model.trainable_variables)
|
||||
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
|
||||
|
||||
train_loss(loss)
|
||||
train_accuracy(y, predictions)
|
||||
|
||||
# Training loop
|
||||
for epoch in range(EPOCHS):
|
||||
train_loss.reset_states()
|
||||
train_accuracy.reset_states()
|
||||
|
||||
for step, (x, y) in enumerate(train_dataset):
|
||||
train_step(x, y)
|
||||
|
||||
# Log every 100 steps
|
||||
if step % 100 == 0:
|
||||
wandb.log({
|
||||
'train/loss': train_loss.result().numpy(),
|
||||
'train/accuracy': train_accuracy.result().numpy(),
|
||||
'epoch': epoch,
|
||||
'step': step
|
||||
})
|
||||
|
||||
# Log epoch metrics
|
||||
wandb.log({
|
||||
'epoch/train_loss': train_loss.result().numpy(),
|
||||
'epoch/train_accuracy': train_accuracy.result().numpy(),
|
||||
'epoch': epoch
|
||||
})
|
||||
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
## Fast.ai
|
||||
|
||||
### With Callback
|
||||
|
||||
```python
|
||||
from fastai.vision.all import *
|
||||
from fastai.callback.wandb import *
|
||||
import wandb
|
||||
|
||||
# Initialize W&B
|
||||
wandb.init(project="fastai-demo")
|
||||
|
||||
# Create data loaders
|
||||
dls = ImageDataLoaders.from_folder(
|
||||
path,
|
||||
train='train',
|
||||
valid='valid',
|
||||
bs=64
|
||||
)
|
||||
|
||||
# Create learner with W&B callback
|
||||
learn = vision_learner(
|
||||
dls,
|
||||
resnet34,
|
||||
metrics=accuracy,
|
||||
cbs=WandbCallback(
|
||||
log_preds=True, # Log predictions
|
||||
log_model=True, # Log model as artifact
|
||||
log_dataset=True # Log dataset as artifact
|
||||
)
|
||||
)
|
||||
|
||||
# Train (metrics logged automatically)
|
||||
learn.fine_tune(5)
|
||||
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
## XGBoost/LightGBM
|
||||
|
||||
### XGBoost
|
||||
|
||||
```python
|
||||
import xgboost as xgb
|
||||
import wandb
|
||||
|
||||
# Initialize W&B
|
||||
run = wandb.init(project="xgboost-demo", config={
|
||||
"max_depth": 6,
|
||||
"learning_rate": 0.1,
|
||||
"n_estimators": 100
|
||||
})
|
||||
|
||||
config = wandb.config
|
||||
|
||||
# Create DMatrix
|
||||
dtrain = xgb.DMatrix(X_train, label=y_train)
|
||||
dval = xgb.DMatrix(X_val, label=y_val)
|
||||
|
||||
# XGBoost params
|
||||
params = {
|
||||
'max_depth': config.max_depth,
|
||||
'learning_rate': config.learning_rate,
|
||||
'objective': 'binary:logistic',
|
||||
'eval_metric': ['logloss', 'auc']
|
||||
}
|
||||
|
||||
# Custom callback for W&B
|
||||
def wandb_callback(env):
|
||||
"""Log XGBoost metrics to W&B."""
|
||||
for metric_name, metric_value in env.evaluation_result_list:
|
||||
wandb.log({
|
||||
f"{metric_name}": metric_value,
|
||||
"iteration": env.iteration
|
||||
})
|
||||
|
||||
# Train with callback
|
||||
model = xgb.train(
|
||||
params,
|
||||
dtrain,
|
||||
num_boost_round=config.n_estimators,
|
||||
evals=[(dtrain, 'train'), (dval, 'val')],
|
||||
callbacks=[wandb_callback],
|
||||
verbose_eval=10
|
||||
)
|
||||
|
||||
# Save model
|
||||
model.save_model('xgboost_model.json')
|
||||
artifact = wandb.Artifact('xgboost-model', type='model')
|
||||
artifact.add_file('xgboost_model.json')
|
||||
wandb.log_artifact(artifact)
|
||||
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
### LightGBM
|
||||
|
||||
```python
|
||||
import lightgbm as lgb
|
||||
import wandb
|
||||
|
||||
run = wandb.init(project="lgbm-demo")
|
||||
|
||||
# Create datasets
|
||||
train_data = lgb.Dataset(X_train, label=y_train)
|
||||
val_data = lgb.Dataset(X_val, label=y_val, reference=train_data)
|
||||
|
||||
# Parameters
|
||||
params = {
|
||||
'objective': 'binary',
|
||||
'metric': ['binary_logloss', 'auc'],
|
||||
'learning_rate': 0.1,
|
||||
'num_leaves': 31
|
||||
}
|
||||
|
||||
# Custom callback
|
||||
def log_to_wandb(env):
|
||||
"""Log LightGBM metrics to W&B."""
|
||||
for entry in env.evaluation_result_list:
|
||||
dataset_name, metric_name, metric_value, _ = entry
|
||||
wandb.log({
|
||||
f"{dataset_name}/{metric_name}": metric_value,
|
||||
"iteration": env.iteration
|
||||
})
|
||||
|
||||
# Train
|
||||
model = lgb.train(
|
||||
params,
|
||||
train_data,
|
||||
num_boost_round=100,
|
||||
valid_sets=[train_data, val_data],
|
||||
valid_names=['train', 'val'],
|
||||
callbacks=[log_to_wandb]
|
||||
)
|
||||
|
||||
# Save model
|
||||
model.save_model('lgbm_model.txt')
|
||||
artifact = wandb.Artifact('lgbm-model', type='model')
|
||||
artifact.add_file('lgbm_model.txt')
|
||||
wandb.log_artifact(artifact)
|
||||
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
## PyTorch Native
|
||||
|
||||
### Training Loop Integration
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import wandb
|
||||
|
||||
# Initialize W&B
|
||||
wandb.init(project="pytorch-native", config={
|
||||
"learning_rate": 0.001,
|
||||
"epochs": 10,
|
||||
"batch_size": 32
|
||||
})
|
||||
|
||||
config = wandb.config
|
||||
|
||||
# Model, loss, optimizer
|
||||
model = create_model()
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
|
||||
|
||||
# Watch model (logs gradients and parameters)
|
||||
wandb.watch(model, criterion, log="all", log_freq=100)
|
||||
|
||||
# Training loop
|
||||
for epoch in range(config.epochs):
|
||||
model.train()
|
||||
train_loss = 0.0
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
data, target = data.to(device), target.to(device)
|
||||
|
||||
# Forward pass
|
||||
optimizer.zero_grad()
|
||||
output = model(data)
|
||||
loss = criterion(output, target)
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Track metrics
|
||||
train_loss += loss.item()
|
||||
_, predicted = output.max(1)
|
||||
total += target.size(0)
|
||||
correct += predicted.eq(target).sum().item()
|
||||
|
||||
# Log every 100 batches
|
||||
if batch_idx % 100 == 0:
|
||||
wandb.log({
|
||||
'train/loss': loss.item(),
|
||||
'train/batch_accuracy': 100. * correct / total,
|
||||
'epoch': epoch,
|
||||
'batch': batch_idx
|
||||
})
|
||||
|
||||
# Validation
|
||||
model.eval()
|
||||
val_loss = 0.0
|
||||
val_correct = 0
|
||||
val_total = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for data, target in val_loader:
|
||||
data, target = data.to(device), target.to(device)
|
||||
output = model(data)
|
||||
loss = criterion(output, target)
|
||||
|
||||
val_loss += loss.item()
|
||||
_, predicted = output.max(1)
|
||||
val_total += target.size(0)
|
||||
val_correct += predicted.eq(target).sum().item()
|
||||
|
||||
# Log epoch metrics
|
||||
wandb.log({
|
||||
'epoch/train_loss': train_loss / len(train_loader),
|
||||
'epoch/train_accuracy': 100. * correct / total,
|
||||
'epoch/val_loss': val_loss / len(val_loader),
|
||||
'epoch/val_accuracy': 100. * val_correct / val_total,
|
||||
'epoch': epoch
|
||||
})
|
||||
|
||||
# Save final model
|
||||
torch.save(model.state_dict(), 'model.pth')
|
||||
artifact = wandb.Artifact('final-model', type='model')
|
||||
artifact.add_file('model.pth')
|
||||
wandb.log_artifact(artifact)
|
||||
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
## Custom Integrations
|
||||
|
||||
### Generic Framework Integration
|
||||
|
||||
```python
|
||||
import wandb
|
||||
|
||||
class WandbIntegration:
|
||||
"""Generic W&B integration wrapper."""
|
||||
|
||||
def __init__(self, project, config):
|
||||
self.run = wandb.init(project=project, config=config)
|
||||
self.config = wandb.config
|
||||
self.step = 0
|
||||
|
||||
def log_metrics(self, metrics, step=None):
|
||||
"""Log training metrics."""
|
||||
if step is None:
|
||||
step = self.step
|
||||
self.step += 1
|
||||
|
||||
wandb.log(metrics, step=step)
|
||||
|
||||
def log_images(self, images, caption=""):
|
||||
"""Log images."""
|
||||
wandb.log({
|
||||
caption: [wandb.Image(img) for img in images]
|
||||
})
|
||||
|
||||
def log_table(self, data, columns):
|
||||
"""Log tabular data."""
|
||||
table = wandb.Table(columns=columns, data=data)
|
||||
wandb.log({"table": table})
|
||||
|
||||
def save_model(self, model_path, metadata=None):
|
||||
"""Save model as artifact."""
|
||||
artifact = wandb.Artifact(
|
||||
'model',
|
||||
type='model',
|
||||
metadata=metadata or {}
|
||||
)
|
||||
artifact.add_file(model_path)
|
||||
self.run.log_artifact(artifact)
|
||||
|
||||
def finish(self):
|
||||
"""Finish W&B run."""
|
||||
wandb.finish()
|
||||
|
||||
# Usage
|
||||
wb = WandbIntegration(project="my-project", config={"lr": 0.001})
|
||||
|
||||
# Training loop
|
||||
for epoch in range(10):
|
||||
# Your training code
|
||||
loss, accuracy = train_epoch()
|
||||
|
||||
# Log metrics
|
||||
wb.log_metrics({
|
||||
'train/loss': loss,
|
||||
'train/accuracy': accuracy
|
||||
})
|
||||
|
||||
# Save model
|
||||
wb.save_model('model.pth', metadata={'accuracy': 0.95})
|
||||
wb.finish()
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **Integrations Guide**: https://docs.wandb.ai/guides/integrations
|
||||
- **HuggingFace**: https://docs.wandb.ai/guides/integrations/huggingface
|
||||
- **PyTorch Lightning**: https://docs.wandb.ai/guides/integrations/lightning
|
||||
- **Keras**: https://docs.wandb.ai/guides/integrations/keras
|
||||
- **Examples**: https://github.com/wandb/examples
|
||||
+847
@@ -0,0 +1,847 @@
|
||||
# Comprehensive Hyperparameter Sweeps Guide
|
||||
|
||||
Complete guide to hyperparameter optimization with W&B Sweeps.
|
||||
|
||||
## Table of Contents
|
||||
- Sweep Configuration
|
||||
- Search Strategies
|
||||
- Parameter Distributions
|
||||
- Early Termination
|
||||
- Parallel Execution
|
||||
- Advanced Patterns
|
||||
- Real-World Examples
|
||||
|
||||
## Sweep Configuration
|
||||
|
||||
### Basic Sweep Config
|
||||
|
||||
```python
|
||||
sweep_config = {
|
||||
'method': 'bayes', # Search strategy
|
||||
'metric': {
|
||||
'name': 'val/accuracy',
|
||||
'goal': 'maximize' # or 'minimize'
|
||||
},
|
||||
'parameters': {
|
||||
'learning_rate': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-5,
|
||||
'max': 1e-1
|
||||
},
|
||||
'batch_size': {
|
||||
'values': [16, 32, 64, 128]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Initialize sweep
|
||||
sweep_id = wandb.sweep(sweep_config, project="my-project")
|
||||
```
|
||||
|
||||
### Complete Config Example
|
||||
|
||||
```python
|
||||
sweep_config = {
|
||||
# Required: Search method
|
||||
'method': 'bayes',
|
||||
|
||||
# Required: Optimization metric
|
||||
'metric': {
|
||||
'name': 'val/f1_score',
|
||||
'goal': 'maximize'
|
||||
},
|
||||
|
||||
# Required: Parameters to search
|
||||
'parameters': {
|
||||
# Continuous parameter
|
||||
'learning_rate': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-5,
|
||||
'max': 1e-1
|
||||
},
|
||||
|
||||
# Discrete values
|
||||
'batch_size': {
|
||||
'values': [16, 32, 64, 128]
|
||||
},
|
||||
|
||||
# Categorical
|
||||
'optimizer': {
|
||||
'values': ['adam', 'sgd', 'rmsprop', 'adamw']
|
||||
},
|
||||
|
||||
# Uniform distribution
|
||||
'dropout': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.1,
|
||||
'max': 0.5
|
||||
},
|
||||
|
||||
# Integer range
|
||||
'num_layers': {
|
||||
'distribution': 'int_uniform',
|
||||
'min': 2,
|
||||
'max': 10
|
||||
},
|
||||
|
||||
# Fixed value (constant across runs)
|
||||
'epochs': {
|
||||
'value': 50
|
||||
}
|
||||
},
|
||||
|
||||
# Optional: Early termination
|
||||
'early_terminate': {
|
||||
'type': 'hyperband',
|
||||
'min_iter': 5,
|
||||
's': 2,
|
||||
'eta': 3,
|
||||
'max_iter': 27
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Search Strategies
|
||||
|
||||
### 1. Grid Search
|
||||
|
||||
Exhaustively search all combinations.
|
||||
|
||||
```python
|
||||
sweep_config = {
|
||||
'method': 'grid',
|
||||
'parameters': {
|
||||
'learning_rate': {
|
||||
'values': [0.001, 0.01, 0.1]
|
||||
},
|
||||
'batch_size': {
|
||||
'values': [16, 32, 64]
|
||||
},
|
||||
'optimizer': {
|
||||
'values': ['adam', 'sgd']
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Total runs: 3 × 3 × 2 = 18 runs
|
||||
```
|
||||
|
||||
**Pros:**
|
||||
- Comprehensive search
|
||||
- Reproducible results
|
||||
- No randomness
|
||||
|
||||
**Cons:**
|
||||
- Exponential growth with parameters
|
||||
- Inefficient for continuous parameters
|
||||
- Not scalable beyond 3-4 parameters
|
||||
|
||||
**When to use:**
|
||||
- Few parameters (< 4)
|
||||
- All discrete values
|
||||
- Need complete coverage
|
||||
|
||||
### 2. Random Search
|
||||
|
||||
Randomly sample parameter combinations.
|
||||
|
||||
```python
|
||||
sweep_config = {
|
||||
'method': 'random',
|
||||
'parameters': {
|
||||
'learning_rate': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-5,
|
||||
'max': 1e-1
|
||||
},
|
||||
'batch_size': {
|
||||
'values': [16, 32, 64, 128, 256]
|
||||
},
|
||||
'dropout': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.0,
|
||||
'max': 0.5
|
||||
},
|
||||
'num_layers': {
|
||||
'distribution': 'int_uniform',
|
||||
'min': 2,
|
||||
'max': 8
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Run 100 random trials
|
||||
wandb.agent(sweep_id, function=train, count=100)
|
||||
```
|
||||
|
||||
**Pros:**
|
||||
- Scales to many parameters
|
||||
- Can run indefinitely
|
||||
- Often finds good solutions quickly
|
||||
|
||||
**Cons:**
|
||||
- No learning from previous runs
|
||||
- May miss optimal region
|
||||
- Results vary with random seed
|
||||
|
||||
**When to use:**
|
||||
- Many parameters (> 4)
|
||||
- Quick exploration
|
||||
- Limited budget
|
||||
|
||||
### 3. Bayesian Optimization (Recommended)
|
||||
|
||||
Learn from previous trials to sample promising regions.
|
||||
|
||||
```python
|
||||
sweep_config = {
|
||||
'method': 'bayes',
|
||||
'metric': {
|
||||
'name': 'val/loss',
|
||||
'goal': 'minimize'
|
||||
},
|
||||
'parameters': {
|
||||
'learning_rate': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-5,
|
||||
'max': 1e-1
|
||||
},
|
||||
'weight_decay': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-6,
|
||||
'max': 1e-2
|
||||
},
|
||||
'dropout': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.1,
|
||||
'max': 0.5
|
||||
},
|
||||
'num_layers': {
|
||||
'values': [2, 3, 4, 5, 6]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Pros:**
|
||||
- Most sample-efficient
|
||||
- Learns from past trials
|
||||
- Focuses on promising regions
|
||||
|
||||
**Cons:**
|
||||
- Initial random exploration phase
|
||||
- May get stuck in local optima
|
||||
- Slower per iteration
|
||||
|
||||
**When to use:**
|
||||
- Expensive training runs
|
||||
- Need best performance
|
||||
- Limited compute budget
|
||||
|
||||
## Parameter Distributions
|
||||
|
||||
### Continuous Distributions
|
||||
|
||||
```python
|
||||
# Log-uniform: Good for learning rates, regularization
|
||||
'learning_rate': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-6,
|
||||
'max': 1e-1
|
||||
}
|
||||
|
||||
# Uniform: Good for dropout, momentum
|
||||
'dropout': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.0,
|
||||
'max': 0.5
|
||||
}
|
||||
|
||||
# Normal distribution
|
||||
'parameter': {
|
||||
'distribution': 'normal',
|
||||
'mu': 0.5,
|
||||
'sigma': 0.1
|
||||
}
|
||||
|
||||
# Log-normal distribution
|
||||
'parameter': {
|
||||
'distribution': 'log_normal',
|
||||
'mu': 0.0,
|
||||
'sigma': 1.0
|
||||
}
|
||||
```
|
||||
|
||||
### Discrete Distributions
|
||||
|
||||
```python
|
||||
# Fixed values
|
||||
'batch_size': {
|
||||
'values': [16, 32, 64, 128, 256]
|
||||
}
|
||||
|
||||
# Integer uniform
|
||||
'num_layers': {
|
||||
'distribution': 'int_uniform',
|
||||
'min': 2,
|
||||
'max': 10
|
||||
}
|
||||
|
||||
# Quantized uniform (step size)
|
||||
'layer_size': {
|
||||
'distribution': 'q_uniform',
|
||||
'min': 32,
|
||||
'max': 512,
|
||||
'q': 32 # Step by 32: 32, 64, 96, 128...
|
||||
}
|
||||
|
||||
# Quantized log-uniform
|
||||
'hidden_size': {
|
||||
'distribution': 'q_log_uniform',
|
||||
'min': 32,
|
||||
'max': 1024,
|
||||
'q': 32
|
||||
}
|
||||
```
|
||||
|
||||
### Categorical Parameters
|
||||
|
||||
```python
|
||||
# Optimizers
|
||||
'optimizer': {
|
||||
'values': ['adam', 'sgd', 'rmsprop', 'adamw']
|
||||
}
|
||||
|
||||
# Model architectures
|
||||
'model': {
|
||||
'values': ['resnet18', 'resnet34', 'resnet50', 'efficientnet_b0']
|
||||
}
|
||||
|
||||
# Activation functions
|
||||
'activation': {
|
||||
'values': ['relu', 'gelu', 'silu', 'leaky_relu']
|
||||
}
|
||||
```
|
||||
|
||||
## Early Termination
|
||||
|
||||
Stop underperforming runs early to save compute.
|
||||
|
||||
### Hyperband
|
||||
|
||||
```python
|
||||
sweep_config = {
|
||||
'method': 'bayes',
|
||||
'metric': {'name': 'val/accuracy', 'goal': 'maximize'},
|
||||
'parameters': {...},
|
||||
|
||||
# Hyperband early termination
|
||||
'early_terminate': {
|
||||
'type': 'hyperband',
|
||||
'min_iter': 3, # Minimum iterations before termination
|
||||
's': 2, # Bracket count
|
||||
'eta': 3, # Downsampling rate
|
||||
'max_iter': 27 # Maximum iterations
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**How it works:**
|
||||
- Runs trials in brackets
|
||||
- Keeps top 1/eta performers each round
|
||||
- Eliminates bottom performers early
|
||||
|
||||
### Custom Termination
|
||||
|
||||
```python
|
||||
def train():
|
||||
run = wandb.init()
|
||||
|
||||
for epoch in range(MAX_EPOCHS):
|
||||
loss = train_epoch()
|
||||
val_acc = validate()
|
||||
|
||||
wandb.log({'val/accuracy': val_acc, 'epoch': epoch})
|
||||
|
||||
# Custom early stopping
|
||||
if epoch > 5 and val_acc < 0.5:
|
||||
print("Early stop: Poor performance")
|
||||
break
|
||||
|
||||
if epoch > 10 and val_acc > best_acc - 0.01:
|
||||
print("Early stop: No improvement")
|
||||
break
|
||||
```
|
||||
|
||||
## Training Function
|
||||
|
||||
### Basic Template
|
||||
|
||||
```python
|
||||
def train():
|
||||
# Initialize W&B run
|
||||
run = wandb.init()
|
||||
|
||||
# Get hyperparameters
|
||||
config = wandb.config
|
||||
|
||||
# Build model with config
|
||||
model = build_model(
|
||||
hidden_size=config.hidden_size,
|
||||
num_layers=config.num_layers,
|
||||
dropout=config.dropout
|
||||
)
|
||||
|
||||
# Create optimizer
|
||||
optimizer = create_optimizer(
|
||||
model.parameters(),
|
||||
name=config.optimizer,
|
||||
lr=config.learning_rate,
|
||||
weight_decay=config.weight_decay
|
||||
)
|
||||
|
||||
# Training loop
|
||||
for epoch in range(config.epochs):
|
||||
# Train
|
||||
train_loss, train_acc = train_epoch(
|
||||
model, optimizer, train_loader, config.batch_size
|
||||
)
|
||||
|
||||
# Validate
|
||||
val_loss, val_acc = validate(model, val_loader)
|
||||
|
||||
# Log metrics
|
||||
wandb.log({
|
||||
'train/loss': train_loss,
|
||||
'train/accuracy': train_acc,
|
||||
'val/loss': val_loss,
|
||||
'val/accuracy': val_acc,
|
||||
'epoch': epoch
|
||||
})
|
||||
|
||||
# Log final model
|
||||
torch.save(model.state_dict(), 'model.pth')
|
||||
wandb.save('model.pth')
|
||||
|
||||
# Finish run
|
||||
wandb.finish()
|
||||
```
|
||||
|
||||
### With PyTorch
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
import wandb
|
||||
|
||||
def train():
|
||||
run = wandb.init()
|
||||
config = wandb.config
|
||||
|
||||
# Data
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=True
|
||||
)
|
||||
|
||||
# Model
|
||||
model = ResNet(
|
||||
num_classes=config.num_classes,
|
||||
dropout=config.dropout
|
||||
).to(device)
|
||||
|
||||
# Optimizer
|
||||
if config.optimizer == 'adam':
|
||||
optimizer = torch.optim.Adam(
|
||||
model.parameters(),
|
||||
lr=config.learning_rate,
|
||||
weight_decay=config.weight_decay
|
||||
)
|
||||
elif config.optimizer == 'sgd':
|
||||
optimizer = torch.optim.SGD(
|
||||
model.parameters(),
|
||||
lr=config.learning_rate,
|
||||
momentum=config.momentum,
|
||||
weight_decay=config.weight_decay
|
||||
)
|
||||
|
||||
# Scheduler
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
||||
optimizer, T_max=config.epochs
|
||||
)
|
||||
|
||||
# Training
|
||||
for epoch in range(config.epochs):
|
||||
model.train()
|
||||
train_loss = 0.0
|
||||
|
||||
for data, target in train_loader:
|
||||
data, target = data.to(device), target.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
output = model(data)
|
||||
loss = nn.CrossEntropyLoss()(output, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
train_loss += loss.item()
|
||||
|
||||
# Validation
|
||||
model.eval()
|
||||
val_loss, val_acc = validate(model, val_loader)
|
||||
|
||||
# Step scheduler
|
||||
scheduler.step()
|
||||
|
||||
# Log
|
||||
wandb.log({
|
||||
'train/loss': train_loss / len(train_loader),
|
||||
'val/loss': val_loss,
|
||||
'val/accuracy': val_acc,
|
||||
'learning_rate': scheduler.get_last_lr()[0],
|
||||
'epoch': epoch
|
||||
})
|
||||
```
|
||||
|
||||
## Parallel Execution
|
||||
|
||||
### Multiple Agents
|
||||
|
||||
Run sweep agents in parallel to speed up search.
|
||||
|
||||
```python
|
||||
# Initialize sweep once
|
||||
sweep_id = wandb.sweep(sweep_config, project="my-project")
|
||||
|
||||
# Run multiple agents in parallel
|
||||
# Agent 1 (Terminal 1)
|
||||
wandb.agent(sweep_id, function=train, count=20)
|
||||
|
||||
# Agent 2 (Terminal 2)
|
||||
wandb.agent(sweep_id, function=train, count=20)
|
||||
|
||||
# Agent 3 (Terminal 3)
|
||||
wandb.agent(sweep_id, function=train, count=20)
|
||||
|
||||
# Total: 60 runs across 3 agents
|
||||
```
|
||||
|
||||
### Multi-GPU Execution
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
def train():
|
||||
# Get available GPU
|
||||
gpu_id = os.environ.get('CUDA_VISIBLE_DEVICES', '0')
|
||||
|
||||
run = wandb.init()
|
||||
config = wandb.config
|
||||
|
||||
# Train on specific GPU
|
||||
device = torch.device(f'cuda:{gpu_id}')
|
||||
model = model.to(device)
|
||||
|
||||
# ... rest of training ...
|
||||
|
||||
# Run agents on different GPUs
|
||||
# Terminal 1
|
||||
# CUDA_VISIBLE_DEVICES=0 wandb agent sweep_id
|
||||
|
||||
# Terminal 2
|
||||
# CUDA_VISIBLE_DEVICES=1 wandb agent sweep_id
|
||||
|
||||
# Terminal 3
|
||||
# CUDA_VISIBLE_DEVICES=2 wandb agent sweep_id
|
||||
```
|
||||
|
||||
## Advanced Patterns
|
||||
|
||||
### Nested Parameters
|
||||
|
||||
```python
|
||||
sweep_config = {
|
||||
'method': 'bayes',
|
||||
'metric': {'name': 'val/accuracy', 'goal': 'maximize'},
|
||||
'parameters': {
|
||||
'model': {
|
||||
'parameters': {
|
||||
'type': {
|
||||
'values': ['resnet', 'efficientnet']
|
||||
},
|
||||
'size': {
|
||||
'values': ['small', 'medium', 'large']
|
||||
}
|
||||
}
|
||||
},
|
||||
'optimizer': {
|
||||
'parameters': {
|
||||
'type': {
|
||||
'values': ['adam', 'sgd']
|
||||
},
|
||||
'lr': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-5,
|
||||
'max': 1e-1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Access nested config
|
||||
def train():
|
||||
run = wandb.init()
|
||||
model_type = wandb.config.model.type
|
||||
model_size = wandb.config.model.size
|
||||
opt_type = wandb.config.optimizer.type
|
||||
lr = wandb.config.optimizer.lr
|
||||
```
|
||||
|
||||
### Conditional Parameters
|
||||
|
||||
```python
|
||||
sweep_config = {
|
||||
'method': 'bayes',
|
||||
'parameters': {
|
||||
'optimizer': {
|
||||
'values': ['adam', 'sgd']
|
||||
},
|
||||
'learning_rate': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-5,
|
||||
'max': 1e-1
|
||||
},
|
||||
# Only used if optimizer == 'sgd'
|
||||
'momentum': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.5,
|
||||
'max': 0.99
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def train():
|
||||
run = wandb.init()
|
||||
config = wandb.config
|
||||
|
||||
if config.optimizer == 'adam':
|
||||
optimizer = torch.optim.Adam(
|
||||
model.parameters(),
|
||||
lr=config.learning_rate
|
||||
)
|
||||
elif config.optimizer == 'sgd':
|
||||
optimizer = torch.optim.SGD(
|
||||
model.parameters(),
|
||||
lr=config.learning_rate,
|
||||
momentum=config.momentum # Conditional parameter
|
||||
)
|
||||
```
|
||||
|
||||
## Real-World Examples
|
||||
|
||||
### Image Classification
|
||||
|
||||
```python
|
||||
sweep_config = {
|
||||
'method': 'bayes',
|
||||
'metric': {
|
||||
'name': 'val/top1_accuracy',
|
||||
'goal': 'maximize'
|
||||
},
|
||||
'parameters': {
|
||||
# Model
|
||||
'architecture': {
|
||||
'values': ['resnet50', 'resnet101', 'efficientnet_b0', 'efficientnet_b3']
|
||||
},
|
||||
'pretrained': {
|
||||
'values': [True, False]
|
||||
},
|
||||
|
||||
# Training
|
||||
'learning_rate': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-5,
|
||||
'max': 1e-2
|
||||
},
|
||||
'batch_size': {
|
||||
'values': [16, 32, 64, 128]
|
||||
},
|
||||
'optimizer': {
|
||||
'values': ['adam', 'sgd', 'adamw']
|
||||
},
|
||||
'weight_decay': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-6,
|
||||
'max': 1e-2
|
||||
},
|
||||
|
||||
# Regularization
|
||||
'dropout': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.0,
|
||||
'max': 0.5
|
||||
},
|
||||
'label_smoothing': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.0,
|
||||
'max': 0.2
|
||||
},
|
||||
|
||||
# Data augmentation
|
||||
'mixup_alpha': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.0,
|
||||
'max': 1.0
|
||||
},
|
||||
'cutmix_alpha': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.0,
|
||||
'max': 1.0
|
||||
}
|
||||
},
|
||||
'early_terminate': {
|
||||
'type': 'hyperband',
|
||||
'min_iter': 5
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### NLP Fine-Tuning
|
||||
|
||||
```python
|
||||
sweep_config = {
|
||||
'method': 'bayes',
|
||||
'metric': {'name': 'eval/f1', 'goal': 'maximize'},
|
||||
'parameters': {
|
||||
# Model
|
||||
'model_name': {
|
||||
'values': ['bert-base-uncased', 'roberta-base', 'distilbert-base-uncased']
|
||||
},
|
||||
|
||||
# Training
|
||||
'learning_rate': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-6,
|
||||
'max': 1e-4
|
||||
},
|
||||
'per_device_train_batch_size': {
|
||||
'values': [8, 16, 32]
|
||||
},
|
||||
'num_train_epochs': {
|
||||
'values': [3, 4, 5]
|
||||
},
|
||||
'warmup_ratio': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.0,
|
||||
'max': 0.1
|
||||
},
|
||||
'weight_decay': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-4,
|
||||
'max': 1e-1
|
||||
},
|
||||
|
||||
# Optimizer
|
||||
'adam_beta1': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.8,
|
||||
'max': 0.95
|
||||
},
|
||||
'adam_beta2': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.95,
|
||||
'max': 0.999
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Start Small
|
||||
|
||||
```python
|
||||
# Initial exploration: Random search, 20 runs
|
||||
sweep_config_v1 = {
|
||||
'method': 'random',
|
||||
'parameters': {...}
|
||||
}
|
||||
wandb.agent(sweep_id_v1, train, count=20)
|
||||
|
||||
# Refined search: Bayes, narrow ranges
|
||||
sweep_config_v2 = {
|
||||
'method': 'bayes',
|
||||
'parameters': {
|
||||
'learning_rate': {
|
||||
'min': 5e-5, # Narrowed from 1e-6 to 1e-4
|
||||
'max': 1e-4
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Use Log Scales
|
||||
|
||||
```python
|
||||
# ✅ Good: Log scale for learning rate
|
||||
'learning_rate': {
|
||||
'distribution': 'log_uniform',
|
||||
'min': 1e-6,
|
||||
'max': 1e-2
|
||||
}
|
||||
|
||||
# ❌ Bad: Linear scale
|
||||
'learning_rate': {
|
||||
'distribution': 'uniform',
|
||||
'min': 0.000001,
|
||||
'max': 0.01
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Set Reasonable Ranges
|
||||
|
||||
```python
|
||||
# Base ranges on prior knowledge
|
||||
'learning_rate': {'min': 1e-5, 'max': 1e-3}, # Typical for Adam
|
||||
'batch_size': {'values': [16, 32, 64]}, # GPU memory limits
|
||||
'dropout': {'min': 0.1, 'max': 0.5} # Too high hurts training
|
||||
```
|
||||
|
||||
### 4. Monitor Resource Usage
|
||||
|
||||
```python
|
||||
def train():
|
||||
run = wandb.init()
|
||||
|
||||
# Log system metrics
|
||||
wandb.log({
|
||||
'system/gpu_memory_allocated': torch.cuda.memory_allocated(),
|
||||
'system/gpu_memory_reserved': torch.cuda.memory_reserved()
|
||||
})
|
||||
```
|
||||
|
||||
### 5. Save Best Models
|
||||
|
||||
```python
|
||||
def train():
|
||||
run = wandb.init()
|
||||
best_acc = 0.0
|
||||
|
||||
for epoch in range(config.epochs):
|
||||
val_acc = validate(model)
|
||||
|
||||
if val_acc > best_acc:
|
||||
best_acc = val_acc
|
||||
# Save best checkpoint
|
||||
torch.save(model.state_dict(), 'best_model.pth')
|
||||
wandb.save('best_model.pth')
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **Sweeps Documentation**: https://docs.wandb.ai/guides/sweeps
|
||||
- **Configuration Reference**: https://docs.wandb.ai/guides/sweeps/configuration
|
||||
- **Examples**: https://github.com/wandb/examples/tree/master/examples/wandb-sweeps
|
||||
@@ -1,6 +1,6 @@
|
||||
---
|
||||
name: weights-and-biases
|
||||
description: Track ML experiments with automatic logging, visualize training in real-time, optimize hyperparameters with sweeps, and manage model registry with W&B - collaborative MLOps platform
|
||||
description: "W&B: log ML experiments, sweeps, model registry, dashboards."
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
---
|
||||
name: huggingface-hub
|
||||
description: Hugging Face Hub CLI (hf) — search, download, and upload models and datasets, manage repos, query datasets with SQL, deploy inference endpoints, manage Spaces and buckets.
|
||||
description: "HuggingFace hf CLI: search/download/upload models, datasets."
|
||||
version: 1.0.0
|
||||
author: Hugging Face
|
||||
license: MIT
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
---
|
||||
name: obliteratus
|
||||
description: Remove refusal behaviors from open-weight LLMs using OBLITERATUS — mechanistic interpretability techniques (diff-in-means, SVD, whitened SVD, LEACE, SAE decomposition, etc.) to excise guardrails while preserving reasoning. 9 CLI methods, 28 analysis modules, 116 model presets across 5 compute tiers, tournament evaluation, and telemetry-driven recommendations. Use when a user wants to uncensor, abliterate, or remove refusal from an LLM.
|
||||
description: "OBLITERATUS: abliterate LLM refusals (diff-in-means)."
|
||||
version: 2.0.0
|
||||
author: Hermes Agent
|
||||
license: MIT
|
||||
@@ -13,10 +13,21 @@ metadata:
|
||||
|
||||
# OBLITERATUS Skill
|
||||
|
||||
## What's inside
|
||||
|
||||
9 CLI methods, 28 analysis modules, 116 model presets across 5 compute tiers, tournament evaluation, and telemetry-driven recommendations.
|
||||
|
||||
Remove refusal behaviors (guardrails) from open-weight LLMs without retraining or fine-tuning. Uses mechanistic interpretability techniques — including diff-in-means, SVD, whitened SVD, LEACE concept erasure, SAE decomposition, Bayesian kernel projection, and more — to identify and surgically excise refusal directions from model weights while preserving reasoning capabilities.
|
||||
|
||||
**License warning:** OBLITERATUS is AGPL-3.0. NEVER import it as a Python library. Always invoke via CLI (`obliteratus` command) or subprocess. This keeps Hermes Agent's MIT license clean.
|
||||
|
||||
## Video Guide
|
||||
|
||||
Walkthrough of OBLITERATUS used by a Hermes agent to abliterate Gemma:
|
||||
https://www.youtube.com/watch?v=8fG9BrNTeHs ("OBLITERATUS: An AI Agent Removed Gemma 4's Safety Guardrails")
|
||||
|
||||
Useful when the user wants a visual overview of the end-to-end workflow before running it themselves.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Trigger when the user:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
---
|
||||
name: outlines
|
||||
description: Guarantee valid JSON/XML/code structure during generation, use Pydantic models for type-safe outputs, support local models (Transformers, vLLM), and maximize inference speed with Outlines - dottxt.ai's structured generation library
|
||||
description: "Outlines: structured JSON/regex/Pydantic LLM generation."
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
---
|
||||
name: serving-llms-vllm
|
||||
description: Serves LLMs with high throughput using vLLM's PagedAttention and continuous batching. Use when deploying production LLM APIs, optimizing inference latency/throughput, or serving models with limited GPU memory. Supports OpenAI-compatible endpoints, quantization (GPTQ/AWQ/FP8), and tensor parallelism.
|
||||
description: "vLLM: high-throughput LLM serving, OpenAI API, quantization."
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
@@ -13,6 +13,10 @@ metadata:
|
||||
|
||||
# vLLM - High-Performance LLM Serving
|
||||
|
||||
## When to use
|
||||
|
||||
Use when deploying production LLM APIs, optimizing inference latency/throughput, or serving models with limited GPU memory. Supports OpenAI-compatible endpoints, quantization (GPTQ/AWQ/FP8), and tensor parallelism.
|
||||
|
||||
## Quick start
|
||||
|
||||
vLLM achieves 24x higher throughput than standard transformers through PagedAttention (block-based KV cache) and continuous batching (mixing prefill/decode requests).
|
||||
|
||||
@@ -0,0 +1,567 @@
|
||||
---
|
||||
name: audiocraft-audio-generation
|
||||
description: PyTorch library for audio generation including text-to-music (MusicGen) and text-to-sound (AudioGen). Use when you need to generate music from text descriptions, create sound effects, or perform melody-conditioned music generation.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [audiocraft, torch>=2.0.0, transformers>=4.30.0]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Multimodal, Audio Generation, Text-to-Music, Text-to-Audio, MusicGen]
|
||||
|
||||
---
|
||||
|
||||
# AudioCraft: Audio Generation
|
||||
|
||||
Comprehensive guide to using Meta's AudioCraft for text-to-music and text-to-audio generation with MusicGen, AudioGen, and EnCodec.
|
||||
|
||||
## When to use AudioCraft
|
||||
|
||||
**Use AudioCraft when:**
|
||||
- Need to generate music from text descriptions
|
||||
- Creating sound effects and environmental audio
|
||||
- Building music generation applications
|
||||
- Need melody-conditioned music generation
|
||||
- Want stereo audio output
|
||||
- Require controllable music generation with style transfer
|
||||
|
||||
**Key features:**
|
||||
- **MusicGen**: Text-to-music generation with melody conditioning
|
||||
- **AudioGen**: Text-to-sound effects generation
|
||||
- **EnCodec**: High-fidelity neural audio codec
|
||||
- **Multiple model sizes**: Small (300M) to Large (3.3B)
|
||||
- **Stereo support**: Full stereo audio generation
|
||||
- **Style conditioning**: MusicGen-Style for reference-based generation
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **Stable Audio**: For longer commercial music generation
|
||||
- **Bark**: For text-to-speech with music/sound effects
|
||||
- **Riffusion**: For spectogram-based music generation
|
||||
- **OpenAI Jukebox**: For raw audio generation with lyrics
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# From PyPI
|
||||
pip install audiocraft
|
||||
|
||||
# From GitHub (latest)
|
||||
pip install git+https://github.com/facebookresearch/audiocraft.git
|
||||
|
||||
# Or use HuggingFace Transformers
|
||||
pip install transformers torch torchaudio
|
||||
```
|
||||
|
||||
### Basic text-to-music (AudioCraft)
|
||||
|
||||
```python
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Set generation parameters
|
||||
model.set_generation_params(
|
||||
duration=8, # seconds
|
||||
top_k=250,
|
||||
temperature=1.0
|
||||
)
|
||||
|
||||
# Generate from text
|
||||
descriptions = ["happy upbeat electronic dance music with synths"]
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
# Save audio
|
||||
torchaudio.save("output.wav", wav[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Using HuggingFace Transformers
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
import scipy
|
||||
|
||||
# Load model and processor
|
||||
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
|
||||
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
|
||||
model.to("cuda")
|
||||
|
||||
# Generate music
|
||||
inputs = processor(
|
||||
text=["80s pop track with bassy drums and synth"],
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
).to("cuda")
|
||||
|
||||
audio_values = model.generate(
|
||||
**inputs,
|
||||
do_sample=True,
|
||||
guidance_scale=3,
|
||||
max_new_tokens=256
|
||||
)
|
||||
|
||||
# Save
|
||||
sampling_rate = model.config.audio_encoder.sampling_rate
|
||||
scipy.io.wavfile.write("output.wav", rate=sampling_rate, data=audio_values[0, 0].cpu().numpy())
|
||||
```
|
||||
|
||||
### Text-to-sound with AudioGen
|
||||
|
||||
```python
|
||||
from audiocraft.models import AudioGen
|
||||
|
||||
# Load AudioGen
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
|
||||
model.set_generation_params(duration=5)
|
||||
|
||||
# Generate sound effects
|
||||
descriptions = ["dog barking in a park with birds chirping"]
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
torchaudio.save("sound.wav", wav[0].cpu(), sample_rate=16000)
|
||||
```
|
||||
|
||||
## Core concepts
|
||||
|
||||
### Architecture overview
|
||||
|
||||
```
|
||||
AudioCraft Architecture:
|
||||
┌──────────────────────────────────────────────────────────────┐
|
||||
│ Text Encoder (T5) │
|
||||
│ │ │
|
||||
│ Text Embeddings │
|
||||
└────────────────────────┬─────────────────────────────────────┘
|
||||
│
|
||||
┌────────────────────────▼─────────────────────────────────────┐
|
||||
│ Transformer Decoder (LM) │
|
||||
│ Auto-regressively generates audio tokens │
|
||||
│ Using efficient token interleaving patterns │
|
||||
└────────────────────────┬─────────────────────────────────────┘
|
||||
│
|
||||
┌────────────────────────▼─────────────────────────────────────┐
|
||||
│ EnCodec Audio Decoder │
|
||||
│ Converts tokens back to audio waveform │
|
||||
└──────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Model variants
|
||||
|
||||
| Model | Size | Description | Use Case |
|
||||
|-------|------|-------------|----------|
|
||||
| `musicgen-small` | 300M | Text-to-music | Quick generation |
|
||||
| `musicgen-medium` | 1.5B | Text-to-music | Balanced |
|
||||
| `musicgen-large` | 3.3B | Text-to-music | Best quality |
|
||||
| `musicgen-melody` | 1.5B | Text + melody | Melody conditioning |
|
||||
| `musicgen-melody-large` | 3.3B | Text + melody | Best melody |
|
||||
| `musicgen-stereo-*` | Varies | Stereo output | Stereo generation |
|
||||
| `musicgen-style` | 1.5B | Style transfer | Reference-based |
|
||||
| `audiogen-medium` | 1.5B | Text-to-sound | Sound effects |
|
||||
|
||||
### Generation parameters
|
||||
|
||||
| Parameter | Default | Description |
|
||||
|-----------|---------|-------------|
|
||||
| `duration` | 8.0 | Length in seconds (1-120) |
|
||||
| `top_k` | 250 | Top-k sampling |
|
||||
| `top_p` | 0.0 | Nucleus sampling (0 = disabled) |
|
||||
| `temperature` | 1.0 | Sampling temperature |
|
||||
| `cfg_coef` | 3.0 | Classifier-free guidance |
|
||||
|
||||
## MusicGen usage
|
||||
|
||||
### Text-to-music generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
import torchaudio
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-medium')
|
||||
|
||||
# Configure generation
|
||||
model.set_generation_params(
|
||||
duration=30, # Up to 30 seconds
|
||||
top_k=250, # Sampling diversity
|
||||
top_p=0.0, # 0 = use top_k only
|
||||
temperature=1.0, # Creativity (higher = more varied)
|
||||
cfg_coef=3.0 # Text adherence (higher = stricter)
|
||||
)
|
||||
|
||||
# Generate multiple samples
|
||||
descriptions = [
|
||||
"epic orchestral soundtrack with strings and brass",
|
||||
"chill lo-fi hip hop beat with jazzy piano",
|
||||
"energetic rock song with electric guitar"
|
||||
]
|
||||
|
||||
# Generate (returns [batch, channels, samples])
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
# Save each
|
||||
for i, audio in enumerate(wav):
|
||||
torchaudio.save(f"music_{i}.wav", audio.cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Melody-conditioned generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
import torchaudio
|
||||
|
||||
# Load melody model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||||
model.set_generation_params(duration=30)
|
||||
|
||||
# Load melody audio
|
||||
melody, sr = torchaudio.load("melody.wav")
|
||||
|
||||
# Generate with melody conditioning
|
||||
descriptions = ["acoustic guitar folk song"]
|
||||
wav = model.generate_with_chroma(descriptions, melody, sr)
|
||||
|
||||
torchaudio.save("melody_conditioned.wav", wav[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Stereo generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load stereo model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
|
||||
model.set_generation_params(duration=15)
|
||||
|
||||
descriptions = ["ambient electronic music with wide stereo panning"]
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
# wav shape: [batch, 2, samples] for stereo
|
||||
print(f"Stereo shape: {wav.shape}") # [1, 2, 480000]
|
||||
torchaudio.save("stereo.wav", wav[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Audio continuation
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
|
||||
processor = AutoProcessor.from_pretrained("facebook/musicgen-medium")
|
||||
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-medium")
|
||||
|
||||
# Load audio to continue
|
||||
import torchaudio
|
||||
audio, sr = torchaudio.load("intro.wav")
|
||||
|
||||
# Process with text and audio
|
||||
inputs = processor(
|
||||
audio=audio.squeeze().numpy(),
|
||||
sampling_rate=sr,
|
||||
text=["continue with a epic chorus"],
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
# Generate continuation
|
||||
audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=512)
|
||||
```
|
||||
|
||||
## MusicGen-Style usage
|
||||
|
||||
### Style-conditioned generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load style model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-style')
|
||||
|
||||
# Configure generation with style
|
||||
model.set_generation_params(
|
||||
duration=30,
|
||||
cfg_coef=3.0,
|
||||
cfg_coef_beta=5.0 # Style influence
|
||||
)
|
||||
|
||||
# Configure style conditioner
|
||||
model.set_style_conditioner_params(
|
||||
eval_q=3, # RVQ quantizers (1-6)
|
||||
excerpt_length=3.0 # Style excerpt length
|
||||
)
|
||||
|
||||
# Load style reference
|
||||
style_audio, sr = torchaudio.load("reference_style.wav")
|
||||
|
||||
# Generate with text + style
|
||||
descriptions = ["upbeat dance track"]
|
||||
wav = model.generate_with_style(descriptions, style_audio, sr)
|
||||
```
|
||||
|
||||
### Style-only generation (no text)
|
||||
|
||||
```python
|
||||
# Generate matching style without text prompt
|
||||
model.set_generation_params(
|
||||
duration=30,
|
||||
cfg_coef=3.0,
|
||||
cfg_coef_beta=None # Disable double CFG for style-only
|
||||
)
|
||||
|
||||
wav = model.generate_with_style([None], style_audio, sr)
|
||||
```
|
||||
|
||||
## AudioGen usage
|
||||
|
||||
### Sound effect generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import AudioGen
|
||||
import torchaudio
|
||||
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
# Generate various sounds
|
||||
descriptions = [
|
||||
"thunderstorm with heavy rain and lightning",
|
||||
"busy city traffic with car horns",
|
||||
"ocean waves crashing on rocks",
|
||||
"crackling campfire in forest"
|
||||
]
|
||||
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
for i, audio in enumerate(wav):
|
||||
torchaudio.save(f"sound_{i}.wav", audio.cpu(), sample_rate=16000)
|
||||
```
|
||||
|
||||
## EnCodec usage
|
||||
|
||||
### Audio compression
|
||||
|
||||
```python
|
||||
from audiocraft.models import CompressionModel
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
# Load EnCodec
|
||||
model = CompressionModel.get_pretrained('facebook/encodec_32khz')
|
||||
|
||||
# Load audio
|
||||
wav, sr = torchaudio.load("audio.wav")
|
||||
|
||||
# Ensure correct sample rate
|
||||
if sr != 32000:
|
||||
resampler = torchaudio.transforms.Resample(sr, 32000)
|
||||
wav = resampler(wav)
|
||||
|
||||
# Encode to tokens
|
||||
with torch.no_grad():
|
||||
encoded = model.encode(wav.unsqueeze(0))
|
||||
codes = encoded[0] # Audio codes
|
||||
|
||||
# Decode back to audio
|
||||
with torch.no_grad():
|
||||
decoded = model.decode(codes)
|
||||
|
||||
torchaudio.save("reconstructed.wav", decoded[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Music generation pipeline
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
class MusicGenerator:
|
||||
def __init__(self, model_name="facebook/musicgen-medium"):
|
||||
self.model = MusicGen.get_pretrained(model_name)
|
||||
self.sample_rate = 32000
|
||||
|
||||
def generate(self, prompt, duration=30, temperature=1.0, cfg=3.0):
|
||||
self.model.set_generation_params(
|
||||
duration=duration,
|
||||
top_k=250,
|
||||
temperature=temperature,
|
||||
cfg_coef=cfg
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = self.model.generate([prompt])
|
||||
|
||||
return wav[0].cpu()
|
||||
|
||||
def generate_batch(self, prompts, duration=30):
|
||||
self.model.set_generation_params(duration=duration)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = self.model.generate(prompts)
|
||||
|
||||
return wav.cpu()
|
||||
|
||||
def save(self, audio, path):
|
||||
torchaudio.save(path, audio, sample_rate=self.sample_rate)
|
||||
|
||||
# Usage
|
||||
generator = MusicGenerator()
|
||||
audio = generator.generate(
|
||||
"epic cinematic orchestral music",
|
||||
duration=30,
|
||||
temperature=1.0
|
||||
)
|
||||
generator.save(audio, "epic_music.wav")
|
||||
```
|
||||
|
||||
### Workflow 2: Sound design batch processing
|
||||
|
||||
```python
|
||||
import json
|
||||
from pathlib import Path
|
||||
from audiocraft.models import AudioGen
|
||||
import torchaudio
|
||||
|
||||
def batch_generate_sounds(sound_specs, output_dir):
|
||||
"""
|
||||
Generate multiple sounds from specifications.
|
||||
|
||||
Args:
|
||||
sound_specs: list of {"name": str, "description": str, "duration": float}
|
||||
output_dir: output directory path
|
||||
"""
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
results = []
|
||||
|
||||
for spec in sound_specs:
|
||||
model.set_generation_params(duration=spec.get("duration", 5))
|
||||
|
||||
wav = model.generate([spec["description"]])
|
||||
|
||||
output_path = output_dir / f"{spec['name']}.wav"
|
||||
torchaudio.save(str(output_path), wav[0].cpu(), sample_rate=16000)
|
||||
|
||||
results.append({
|
||||
"name": spec["name"],
|
||||
"path": str(output_path),
|
||||
"description": spec["description"]
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
# Usage
|
||||
sounds = [
|
||||
{"name": "explosion", "description": "massive explosion with debris", "duration": 3},
|
||||
{"name": "footsteps", "description": "footsteps on wooden floor", "duration": 5},
|
||||
{"name": "door", "description": "wooden door creaking and closing", "duration": 2}
|
||||
]
|
||||
|
||||
results = batch_generate_sounds(sounds, "sound_effects/")
|
||||
```
|
||||
|
||||
### Workflow 3: Gradio demo
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
def generate_music(prompt, duration, temperature, cfg_coef):
|
||||
model.set_generation_params(
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
cfg_coef=cfg_coef
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = model.generate([prompt])
|
||||
|
||||
# Save to temp file
|
||||
path = "temp_output.wav"
|
||||
torchaudio.save(path, wav[0].cpu(), sample_rate=32000)
|
||||
return path
|
||||
|
||||
demo = gr.Interface(
|
||||
fn=generate_music,
|
||||
inputs=[
|
||||
gr.Textbox(label="Music Description", placeholder="upbeat electronic dance music"),
|
||||
gr.Slider(1, 30, value=8, label="Duration (seconds)"),
|
||||
gr.Slider(0.5, 2.0, value=1.0, label="Temperature"),
|
||||
gr.Slider(1.0, 10.0, value=3.0, label="CFG Coefficient")
|
||||
],
|
||||
outputs=gr.Audio(label="Generated Music"),
|
||||
title="MusicGen Demo"
|
||||
)
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Performance optimization
|
||||
|
||||
### Memory optimization
|
||||
|
||||
```python
|
||||
# Use smaller model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Clear cache between generations
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Generate shorter durations
|
||||
model.set_generation_params(duration=10) # Instead of 30
|
||||
|
||||
# Use half precision
|
||||
model = model.half()
|
||||
```
|
||||
|
||||
### Batch processing efficiency
|
||||
|
||||
```python
|
||||
# Process multiple prompts at once (more efficient)
|
||||
descriptions = ["prompt1", "prompt2", "prompt3", "prompt4"]
|
||||
wav = model.generate(descriptions) # Single batch
|
||||
|
||||
# Instead of
|
||||
for desc in descriptions:
|
||||
wav = model.generate([desc]) # Multiple batches (slower)
|
||||
```
|
||||
|
||||
### GPU memory requirements
|
||||
|
||||
| Model | FP32 VRAM | FP16 VRAM |
|
||||
|-------|-----------|-----------|
|
||||
| musicgen-small | ~4GB | ~2GB |
|
||||
| musicgen-medium | ~8GB | ~4GB |
|
||||
| musicgen-large | ~16GB | ~8GB |
|
||||
|
||||
## Common issues
|
||||
|
||||
| Issue | Solution |
|
||||
|-------|----------|
|
||||
| CUDA OOM | Use smaller model, reduce duration |
|
||||
| Poor quality | Increase cfg_coef, better prompts |
|
||||
| Generation too short | Check max duration setting |
|
||||
| Audio artifacts | Try different temperature |
|
||||
| Stereo not working | Use stereo model variant |
|
||||
|
||||
## References
|
||||
|
||||
- **[Advanced Usage](references/advanced-usage.md)** - Training, fine-tuning, deployment
|
||||
- **[Troubleshooting](references/troubleshooting.md)** - Common issues and solutions
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/facebookresearch/audiocraft
|
||||
- **Paper (MusicGen)**: https://arxiv.org/abs/2306.05284
|
||||
- **Paper (AudioGen)**: https://arxiv.org/abs/2209.15352
|
||||
- **HuggingFace**: https://huggingface.co/facebook/musicgen-small
|
||||
- **Demo**: https://huggingface.co/spaces/facebook/MusicGen
|
||||
@@ -0,0 +1,567 @@
|
||||
---
|
||||
name: audiocraft-audio-generation
|
||||
description: "AudioCraft: MusicGen text-to-music, AudioGen text-to-sound."
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [audiocraft, torch>=2.0.0, transformers>=4.30.0]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Multimodal, Audio Generation, Text-to-Music, Text-to-Audio, MusicGen]
|
||||
|
||||
---
|
||||
|
||||
# AudioCraft: Audio Generation
|
||||
|
||||
Comprehensive guide to using Meta's AudioCraft for text-to-music and text-to-audio generation with MusicGen, AudioGen, and EnCodec.
|
||||
|
||||
## When to use AudioCraft
|
||||
|
||||
**Use AudioCraft when:**
|
||||
- Need to generate music from text descriptions
|
||||
- Creating sound effects and environmental audio
|
||||
- Building music generation applications
|
||||
- Need melody-conditioned music generation
|
||||
- Want stereo audio output
|
||||
- Require controllable music generation with style transfer
|
||||
|
||||
**Key features:**
|
||||
- **MusicGen**: Text-to-music generation with melody conditioning
|
||||
- **AudioGen**: Text-to-sound effects generation
|
||||
- **EnCodec**: High-fidelity neural audio codec
|
||||
- **Multiple model sizes**: Small (300M) to Large (3.3B)
|
||||
- **Stereo support**: Full stereo audio generation
|
||||
- **Style conditioning**: MusicGen-Style for reference-based generation
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **Stable Audio**: For longer commercial music generation
|
||||
- **Bark**: For text-to-speech with music/sound effects
|
||||
- **Riffusion**: For spectogram-based music generation
|
||||
- **OpenAI Jukebox**: For raw audio generation with lyrics
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# From PyPI
|
||||
pip install audiocraft
|
||||
|
||||
# From GitHub (latest)
|
||||
pip install git+https://github.com/facebookresearch/audiocraft.git
|
||||
|
||||
# Or use HuggingFace Transformers
|
||||
pip install transformers torch torchaudio
|
||||
```
|
||||
|
||||
### Basic text-to-music (AudioCraft)
|
||||
|
||||
```python
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Set generation parameters
|
||||
model.set_generation_params(
|
||||
duration=8, # seconds
|
||||
top_k=250,
|
||||
temperature=1.0
|
||||
)
|
||||
|
||||
# Generate from text
|
||||
descriptions = ["happy upbeat electronic dance music with synths"]
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
# Save audio
|
||||
torchaudio.save("output.wav", wav[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Using HuggingFace Transformers
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
import scipy
|
||||
|
||||
# Load model and processor
|
||||
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
|
||||
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
|
||||
model.to("cuda")
|
||||
|
||||
# Generate music
|
||||
inputs = processor(
|
||||
text=["80s pop track with bassy drums and synth"],
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
).to("cuda")
|
||||
|
||||
audio_values = model.generate(
|
||||
**inputs,
|
||||
do_sample=True,
|
||||
guidance_scale=3,
|
||||
max_new_tokens=256
|
||||
)
|
||||
|
||||
# Save
|
||||
sampling_rate = model.config.audio_encoder.sampling_rate
|
||||
scipy.io.wavfile.write("output.wav", rate=sampling_rate, data=audio_values[0, 0].cpu().numpy())
|
||||
```
|
||||
|
||||
### Text-to-sound with AudioGen
|
||||
|
||||
```python
|
||||
from audiocraft.models import AudioGen
|
||||
|
||||
# Load AudioGen
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
|
||||
model.set_generation_params(duration=5)
|
||||
|
||||
# Generate sound effects
|
||||
descriptions = ["dog barking in a park with birds chirping"]
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
torchaudio.save("sound.wav", wav[0].cpu(), sample_rate=16000)
|
||||
```
|
||||
|
||||
## Core concepts
|
||||
|
||||
### Architecture overview
|
||||
|
||||
```
|
||||
AudioCraft Architecture:
|
||||
┌──────────────────────────────────────────────────────────────┐
|
||||
│ Text Encoder (T5) │
|
||||
│ │ │
|
||||
│ Text Embeddings │
|
||||
└────────────────────────┬─────────────────────────────────────┘
|
||||
│
|
||||
┌────────────────────────▼─────────────────────────────────────┐
|
||||
│ Transformer Decoder (LM) │
|
||||
│ Auto-regressively generates audio tokens │
|
||||
│ Using efficient token interleaving patterns │
|
||||
└────────────────────────┬─────────────────────────────────────┘
|
||||
│
|
||||
┌────────────────────────▼─────────────────────────────────────┐
|
||||
│ EnCodec Audio Decoder │
|
||||
│ Converts tokens back to audio waveform │
|
||||
└──────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Model variants
|
||||
|
||||
| Model | Size | Description | Use Case |
|
||||
|-------|------|-------------|----------|
|
||||
| `musicgen-small` | 300M | Text-to-music | Quick generation |
|
||||
| `musicgen-medium` | 1.5B | Text-to-music | Balanced |
|
||||
| `musicgen-large` | 3.3B | Text-to-music | Best quality |
|
||||
| `musicgen-melody` | 1.5B | Text + melody | Melody conditioning |
|
||||
| `musicgen-melody-large` | 3.3B | Text + melody | Best melody |
|
||||
| `musicgen-stereo-*` | Varies | Stereo output | Stereo generation |
|
||||
| `musicgen-style` | 1.5B | Style transfer | Reference-based |
|
||||
| `audiogen-medium` | 1.5B | Text-to-sound | Sound effects |
|
||||
|
||||
### Generation parameters
|
||||
|
||||
| Parameter | Default | Description |
|
||||
|-----------|---------|-------------|
|
||||
| `duration` | 8.0 | Length in seconds (1-120) |
|
||||
| `top_k` | 250 | Top-k sampling |
|
||||
| `top_p` | 0.0 | Nucleus sampling (0 = disabled) |
|
||||
| `temperature` | 1.0 | Sampling temperature |
|
||||
| `cfg_coef` | 3.0 | Classifier-free guidance |
|
||||
|
||||
## MusicGen usage
|
||||
|
||||
### Text-to-music generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
import torchaudio
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-medium')
|
||||
|
||||
# Configure generation
|
||||
model.set_generation_params(
|
||||
duration=30, # Up to 30 seconds
|
||||
top_k=250, # Sampling diversity
|
||||
top_p=0.0, # 0 = use top_k only
|
||||
temperature=1.0, # Creativity (higher = more varied)
|
||||
cfg_coef=3.0 # Text adherence (higher = stricter)
|
||||
)
|
||||
|
||||
# Generate multiple samples
|
||||
descriptions = [
|
||||
"epic orchestral soundtrack with strings and brass",
|
||||
"chill lo-fi hip hop beat with jazzy piano",
|
||||
"energetic rock song with electric guitar"
|
||||
]
|
||||
|
||||
# Generate (returns [batch, channels, samples])
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
# Save each
|
||||
for i, audio in enumerate(wav):
|
||||
torchaudio.save(f"music_{i}.wav", audio.cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Melody-conditioned generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
import torchaudio
|
||||
|
||||
# Load melody model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||||
model.set_generation_params(duration=30)
|
||||
|
||||
# Load melody audio
|
||||
melody, sr = torchaudio.load("melody.wav")
|
||||
|
||||
# Generate with melody conditioning
|
||||
descriptions = ["acoustic guitar folk song"]
|
||||
wav = model.generate_with_chroma(descriptions, melody, sr)
|
||||
|
||||
torchaudio.save("melody_conditioned.wav", wav[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Stereo generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load stereo model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
|
||||
model.set_generation_params(duration=15)
|
||||
|
||||
descriptions = ["ambient electronic music with wide stereo panning"]
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
# wav shape: [batch, 2, samples] for stereo
|
||||
print(f"Stereo shape: {wav.shape}") # [1, 2, 480000]
|
||||
torchaudio.save("stereo.wav", wav[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Audio continuation
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
|
||||
processor = AutoProcessor.from_pretrained("facebook/musicgen-medium")
|
||||
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-medium")
|
||||
|
||||
# Load audio to continue
|
||||
import torchaudio
|
||||
audio, sr = torchaudio.load("intro.wav")
|
||||
|
||||
# Process with text and audio
|
||||
inputs = processor(
|
||||
audio=audio.squeeze().numpy(),
|
||||
sampling_rate=sr,
|
||||
text=["continue with a epic chorus"],
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
# Generate continuation
|
||||
audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=512)
|
||||
```
|
||||
|
||||
## MusicGen-Style usage
|
||||
|
||||
### Style-conditioned generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load style model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-style')
|
||||
|
||||
# Configure generation with style
|
||||
model.set_generation_params(
|
||||
duration=30,
|
||||
cfg_coef=3.0,
|
||||
cfg_coef_beta=5.0 # Style influence
|
||||
)
|
||||
|
||||
# Configure style conditioner
|
||||
model.set_style_conditioner_params(
|
||||
eval_q=3, # RVQ quantizers (1-6)
|
||||
excerpt_length=3.0 # Style excerpt length
|
||||
)
|
||||
|
||||
# Load style reference
|
||||
style_audio, sr = torchaudio.load("reference_style.wav")
|
||||
|
||||
# Generate with text + style
|
||||
descriptions = ["upbeat dance track"]
|
||||
wav = model.generate_with_style(descriptions, style_audio, sr)
|
||||
```
|
||||
|
||||
### Style-only generation (no text)
|
||||
|
||||
```python
|
||||
# Generate matching style without text prompt
|
||||
model.set_generation_params(
|
||||
duration=30,
|
||||
cfg_coef=3.0,
|
||||
cfg_coef_beta=None # Disable double CFG for style-only
|
||||
)
|
||||
|
||||
wav = model.generate_with_style([None], style_audio, sr)
|
||||
```
|
||||
|
||||
## AudioGen usage
|
||||
|
||||
### Sound effect generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import AudioGen
|
||||
import torchaudio
|
||||
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
# Generate various sounds
|
||||
descriptions = [
|
||||
"thunderstorm with heavy rain and lightning",
|
||||
"busy city traffic with car horns",
|
||||
"ocean waves crashing on rocks",
|
||||
"crackling campfire in forest"
|
||||
]
|
||||
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
for i, audio in enumerate(wav):
|
||||
torchaudio.save(f"sound_{i}.wav", audio.cpu(), sample_rate=16000)
|
||||
```
|
||||
|
||||
## EnCodec usage
|
||||
|
||||
### Audio compression
|
||||
|
||||
```python
|
||||
from audiocraft.models import CompressionModel
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
# Load EnCodec
|
||||
model = CompressionModel.get_pretrained('facebook/encodec_32khz')
|
||||
|
||||
# Load audio
|
||||
wav, sr = torchaudio.load("audio.wav")
|
||||
|
||||
# Ensure correct sample rate
|
||||
if sr != 32000:
|
||||
resampler = torchaudio.transforms.Resample(sr, 32000)
|
||||
wav = resampler(wav)
|
||||
|
||||
# Encode to tokens
|
||||
with torch.no_grad():
|
||||
encoded = model.encode(wav.unsqueeze(0))
|
||||
codes = encoded[0] # Audio codes
|
||||
|
||||
# Decode back to audio
|
||||
with torch.no_grad():
|
||||
decoded = model.decode(codes)
|
||||
|
||||
torchaudio.save("reconstructed.wav", decoded[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Music generation pipeline
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
class MusicGenerator:
|
||||
def __init__(self, model_name="facebook/musicgen-medium"):
|
||||
self.model = MusicGen.get_pretrained(model_name)
|
||||
self.sample_rate = 32000
|
||||
|
||||
def generate(self, prompt, duration=30, temperature=1.0, cfg=3.0):
|
||||
self.model.set_generation_params(
|
||||
duration=duration,
|
||||
top_k=250,
|
||||
temperature=temperature,
|
||||
cfg_coef=cfg
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = self.model.generate([prompt])
|
||||
|
||||
return wav[0].cpu()
|
||||
|
||||
def generate_batch(self, prompts, duration=30):
|
||||
self.model.set_generation_params(duration=duration)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = self.model.generate(prompts)
|
||||
|
||||
return wav.cpu()
|
||||
|
||||
def save(self, audio, path):
|
||||
torchaudio.save(path, audio, sample_rate=self.sample_rate)
|
||||
|
||||
# Usage
|
||||
generator = MusicGenerator()
|
||||
audio = generator.generate(
|
||||
"epic cinematic orchestral music",
|
||||
duration=30,
|
||||
temperature=1.0
|
||||
)
|
||||
generator.save(audio, "epic_music.wav")
|
||||
```
|
||||
|
||||
### Workflow 2: Sound design batch processing
|
||||
|
||||
```python
|
||||
import json
|
||||
from pathlib import Path
|
||||
from audiocraft.models import AudioGen
|
||||
import torchaudio
|
||||
|
||||
def batch_generate_sounds(sound_specs, output_dir):
|
||||
"""
|
||||
Generate multiple sounds from specifications.
|
||||
|
||||
Args:
|
||||
sound_specs: list of {"name": str, "description": str, "duration": float}
|
||||
output_dir: output directory path
|
||||
"""
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
results = []
|
||||
|
||||
for spec in sound_specs:
|
||||
model.set_generation_params(duration=spec.get("duration", 5))
|
||||
|
||||
wav = model.generate([spec["description"]])
|
||||
|
||||
output_path = output_dir / f"{spec['name']}.wav"
|
||||
torchaudio.save(str(output_path), wav[0].cpu(), sample_rate=16000)
|
||||
|
||||
results.append({
|
||||
"name": spec["name"],
|
||||
"path": str(output_path),
|
||||
"description": spec["description"]
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
# Usage
|
||||
sounds = [
|
||||
{"name": "explosion", "description": "massive explosion with debris", "duration": 3},
|
||||
{"name": "footsteps", "description": "footsteps on wooden floor", "duration": 5},
|
||||
{"name": "door", "description": "wooden door creaking and closing", "duration": 2}
|
||||
]
|
||||
|
||||
results = batch_generate_sounds(sounds, "sound_effects/")
|
||||
```
|
||||
|
||||
### Workflow 3: Gradio demo
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
def generate_music(prompt, duration, temperature, cfg_coef):
|
||||
model.set_generation_params(
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
cfg_coef=cfg_coef
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = model.generate([prompt])
|
||||
|
||||
# Save to temp file
|
||||
path = "temp_output.wav"
|
||||
torchaudio.save(path, wav[0].cpu(), sample_rate=32000)
|
||||
return path
|
||||
|
||||
demo = gr.Interface(
|
||||
fn=generate_music,
|
||||
inputs=[
|
||||
gr.Textbox(label="Music Description", placeholder="upbeat electronic dance music"),
|
||||
gr.Slider(1, 30, value=8, label="Duration (seconds)"),
|
||||
gr.Slider(0.5, 2.0, value=1.0, label="Temperature"),
|
||||
gr.Slider(1.0, 10.0, value=3.0, label="CFG Coefficient")
|
||||
],
|
||||
outputs=gr.Audio(label="Generated Music"),
|
||||
title="MusicGen Demo"
|
||||
)
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Performance optimization
|
||||
|
||||
### Memory optimization
|
||||
|
||||
```python
|
||||
# Use smaller model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Clear cache between generations
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Generate shorter durations
|
||||
model.set_generation_params(duration=10) # Instead of 30
|
||||
|
||||
# Use half precision
|
||||
model = model.half()
|
||||
```
|
||||
|
||||
### Batch processing efficiency
|
||||
|
||||
```python
|
||||
# Process multiple prompts at once (more efficient)
|
||||
descriptions = ["prompt1", "prompt2", "prompt3", "prompt4"]
|
||||
wav = model.generate(descriptions) # Single batch
|
||||
|
||||
# Instead of
|
||||
for desc in descriptions:
|
||||
wav = model.generate([desc]) # Multiple batches (slower)
|
||||
```
|
||||
|
||||
### GPU memory requirements
|
||||
|
||||
| Model | FP32 VRAM | FP16 VRAM |
|
||||
|-------|-----------|-----------|
|
||||
| musicgen-small | ~4GB | ~2GB |
|
||||
| musicgen-medium | ~8GB | ~4GB |
|
||||
| musicgen-large | ~16GB | ~8GB |
|
||||
|
||||
## Common issues
|
||||
|
||||
| Issue | Solution |
|
||||
|-------|----------|
|
||||
| CUDA OOM | Use smaller model, reduce duration |
|
||||
| Poor quality | Increase cfg_coef, better prompts |
|
||||
| Generation too short | Check max duration setting |
|
||||
| Audio artifacts | Try different temperature |
|
||||
| Stereo not working | Use stereo model variant |
|
||||
|
||||
## References
|
||||
|
||||
- **[Advanced Usage](references/advanced-usage.md)** - Training, fine-tuning, deployment
|
||||
- **[Troubleshooting](references/troubleshooting.md)** - Common issues and solutions
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/facebookresearch/audiocraft
|
||||
- **Paper (MusicGen)**: https://arxiv.org/abs/2306.05284
|
||||
- **Paper (AudioGen)**: https://arxiv.org/abs/2209.15352
|
||||
- **HuggingFace**: https://huggingface.co/facebook/musicgen-small
|
||||
- **Demo**: https://huggingface.co/spaces/facebook/MusicGen
|
||||
+666
@@ -0,0 +1,666 @@
|
||||
# AudioCraft Advanced Usage Guide
|
||||
|
||||
## Fine-tuning MusicGen
|
||||
|
||||
### Custom dataset preparation
|
||||
|
||||
```python
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
import torchaudio
|
||||
|
||||
def prepare_dataset(audio_dir, output_dir, metadata_file):
|
||||
"""
|
||||
Prepare dataset for MusicGen fine-tuning.
|
||||
|
||||
Directory structure:
|
||||
output_dir/
|
||||
├── audio/
|
||||
│ ├── 0001.wav
|
||||
│ ├── 0002.wav
|
||||
│ └── ...
|
||||
└── metadata.json
|
||||
"""
|
||||
output_dir = Path(output_dir)
|
||||
audio_output = output_dir / "audio"
|
||||
audio_output.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load metadata (format: {"path": "...", "description": "..."})
|
||||
with open(metadata_file) as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
processed = []
|
||||
|
||||
for idx, item in enumerate(metadata):
|
||||
audio_path = Path(audio_dir) / item["path"]
|
||||
|
||||
# Load and resample to 32kHz
|
||||
wav, sr = torchaudio.load(str(audio_path))
|
||||
if sr != 32000:
|
||||
resampler = torchaudio.transforms.Resample(sr, 32000)
|
||||
wav = resampler(wav)
|
||||
|
||||
# Convert to mono if stereo
|
||||
if wav.shape[0] > 1:
|
||||
wav = wav.mean(dim=0, keepdim=True)
|
||||
|
||||
# Save processed audio
|
||||
output_path = audio_output / f"{idx:04d}.wav"
|
||||
torchaudio.save(str(output_path), wav, sample_rate=32000)
|
||||
|
||||
processed.append({
|
||||
"path": str(output_path.relative_to(output_dir)),
|
||||
"description": item["description"],
|
||||
"duration": wav.shape[1] / 32000
|
||||
})
|
||||
|
||||
# Save processed metadata
|
||||
with open(output_dir / "metadata.json", "w") as f:
|
||||
json.dump(processed, f, indent=2)
|
||||
|
||||
print(f"Processed {len(processed)} samples")
|
||||
return processed
|
||||
```
|
||||
|
||||
### Fine-tuning with dora
|
||||
|
||||
```bash
|
||||
# AudioCraft uses dora for experiment management
|
||||
# Install dora
|
||||
pip install dora-search
|
||||
|
||||
# Clone AudioCraft
|
||||
git clone https://github.com/facebookresearch/audiocraft.git
|
||||
cd audiocraft
|
||||
|
||||
# Create config for fine-tuning
|
||||
cat > config/solver/musicgen/finetune.yaml << 'EOF'
|
||||
defaults:
|
||||
- musicgen/musicgen_base
|
||||
- /model: lm/musicgen_lm
|
||||
- /conditioner: cond_base
|
||||
|
||||
solver: musicgen
|
||||
autocast: true
|
||||
autocast_dtype: float16
|
||||
|
||||
optim:
|
||||
epochs: 100
|
||||
batch_size: 4
|
||||
lr: 1e-4
|
||||
ema: 0.999
|
||||
optimizer: adamw
|
||||
|
||||
dataset:
|
||||
batch_size: 4
|
||||
num_workers: 4
|
||||
train:
|
||||
- dset: your_dataset
|
||||
root: /path/to/dataset
|
||||
valid:
|
||||
- dset: your_dataset
|
||||
root: /path/to/dataset
|
||||
|
||||
checkpoint:
|
||||
save_every: 10
|
||||
keep_every_states: null
|
||||
EOF
|
||||
|
||||
# Run fine-tuning
|
||||
dora run solver=musicgen/finetune
|
||||
```
|
||||
|
||||
### LoRA fine-tuning
|
||||
|
||||
```python
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from audiocraft.models import MusicGen
|
||||
import torch
|
||||
|
||||
# Load base model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Get the language model component
|
||||
lm = model.lm
|
||||
|
||||
# Configure LoRA
|
||||
lora_config = LoraConfig(
|
||||
r=8,
|
||||
lora_alpha=16,
|
||||
target_modules=["q_proj", "v_proj", "k_proj", "out_proj"],
|
||||
lora_dropout=0.05,
|
||||
bias="none"
|
||||
)
|
||||
|
||||
# Apply LoRA
|
||||
lm = get_peft_model(lm, lora_config)
|
||||
lm.print_trainable_parameters()
|
||||
```
|
||||
|
||||
## Multi-GPU Training
|
||||
|
||||
### DataParallel
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Wrap LM with DataParallel
|
||||
if torch.cuda.device_count() > 1:
|
||||
model.lm = nn.DataParallel(model.lm)
|
||||
|
||||
model.to("cuda")
|
||||
```
|
||||
|
||||
### DistributedDataParallel
|
||||
|
||||
```python
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
def setup(rank, world_size):
|
||||
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
def train(rank, world_size):
|
||||
setup(rank, world_size)
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
model.lm = model.lm.to(rank)
|
||||
model.lm = DDP(model.lm, device_ids=[rank])
|
||||
|
||||
# Training loop
|
||||
# ...
|
||||
|
||||
dist.destroy_process_group()
|
||||
```
|
||||
|
||||
## Custom Conditioning
|
||||
|
||||
### Adding new conditioners
|
||||
|
||||
```python
|
||||
from audiocraft.modules.conditioners import BaseConditioner
|
||||
import torch
|
||||
|
||||
class CustomConditioner(BaseConditioner):
|
||||
"""Custom conditioner for additional control signals."""
|
||||
|
||||
def __init__(self, dim, output_dim):
|
||||
super().__init__(dim, output_dim)
|
||||
self.embed = torch.nn.Linear(dim, output_dim)
|
||||
|
||||
def forward(self, x):
|
||||
return self.embed(x)
|
||||
|
||||
def tokenize(self, x):
|
||||
# Tokenize input for conditioning
|
||||
return x
|
||||
|
||||
# Use with MusicGen
|
||||
from audiocraft.models.builders import get_lm_model
|
||||
|
||||
# Modify model config to include custom conditioner
|
||||
# This requires editing the model configuration
|
||||
```
|
||||
|
||||
### Melody conditioning internals
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
from audiocraft.modules.codebooks_patterns import DelayedPatternProvider
|
||||
import torch
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||||
|
||||
# Access chroma extractor
|
||||
chroma_extractor = model.lm.condition_provider.conditioners.get('chroma')
|
||||
|
||||
# Manual chroma extraction
|
||||
def extract_chroma(audio, sr):
|
||||
"""Extract chroma features from audio."""
|
||||
import librosa
|
||||
|
||||
# Compute chroma
|
||||
chroma = librosa.feature.chroma_cqt(y=audio.numpy(), sr=sr)
|
||||
|
||||
return torch.from_numpy(chroma).float()
|
||||
|
||||
# Use extracted chroma for conditioning
|
||||
chroma = extract_chroma(melody_audio, sample_rate)
|
||||
```
|
||||
|
||||
## EnCodec Deep Dive
|
||||
|
||||
### Custom compression settings
|
||||
|
||||
```python
|
||||
from audiocraft.models import CompressionModel
|
||||
import torch
|
||||
|
||||
# Load EnCodec
|
||||
encodec = CompressionModel.get_pretrained('facebook/encodec_32khz')
|
||||
|
||||
# Access codec parameters
|
||||
print(f"Sample rate: {encodec.sample_rate}")
|
||||
print(f"Channels: {encodec.channels}")
|
||||
print(f"Cardinality: {encodec.cardinality}") # Codebook size
|
||||
print(f"Num codebooks: {encodec.num_codebooks}")
|
||||
print(f"Frame rate: {encodec.frame_rate}")
|
||||
|
||||
# Encode with specific bandwidth
|
||||
# Lower bandwidth = more compression, lower quality
|
||||
encodec.set_target_bandwidth(6.0) # 6 kbps
|
||||
|
||||
audio = torch.randn(1, 1, 32000) # 1 second
|
||||
encoded = encodec.encode(audio)
|
||||
decoded = encodec.decode(encoded[0])
|
||||
```
|
||||
|
||||
### Streaming encoding
|
||||
|
||||
```python
|
||||
import torch
|
||||
from audiocraft.models import CompressionModel
|
||||
|
||||
encodec = CompressionModel.get_pretrained('facebook/encodec_32khz')
|
||||
|
||||
def encode_streaming(audio_stream, chunk_size=32000):
|
||||
"""Encode audio in streaming fashion."""
|
||||
all_codes = []
|
||||
|
||||
for chunk in audio_stream:
|
||||
# Ensure chunk is right shape
|
||||
if chunk.dim() == 1:
|
||||
chunk = chunk.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
codes = encodec.encode(chunk)[0]
|
||||
all_codes.append(codes)
|
||||
|
||||
return torch.cat(all_codes, dim=-1)
|
||||
|
||||
def decode_streaming(codes_stream, output_stream):
|
||||
"""Decode codes in streaming fashion."""
|
||||
for codes in codes_stream:
|
||||
with torch.no_grad():
|
||||
audio = encodec.decode(codes)
|
||||
output_stream.write(audio.cpu().numpy())
|
||||
```
|
||||
|
||||
## MultiBand Diffusion
|
||||
|
||||
### Using MBD for enhanced quality
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen, MultiBandDiffusion
|
||||
|
||||
# Load MusicGen
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-medium')
|
||||
|
||||
# Load MultiBand Diffusion
|
||||
mbd = MultiBandDiffusion.get_mbd_musicgen()
|
||||
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
# Generate with standard decoder
|
||||
descriptions = ["epic orchestral music"]
|
||||
wav_standard = model.generate(descriptions)
|
||||
|
||||
# Generate tokens and use MBD decoder
|
||||
with torch.no_grad():
|
||||
# Get tokens
|
||||
gen_tokens = model.generate_tokens(descriptions)
|
||||
|
||||
# Decode with MBD
|
||||
wav_mbd = mbd.tokens_to_wav(gen_tokens)
|
||||
|
||||
# Compare quality
|
||||
print(f"Standard shape: {wav_standard.shape}")
|
||||
print(f"MBD shape: {wav_mbd.shape}")
|
||||
```
|
||||
|
||||
## API Server Deployment
|
||||
|
||||
### FastAPI server
|
||||
|
||||
```python
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
import io
|
||||
import base64
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Load model at startup
|
||||
model = None
|
||||
|
||||
@app.on_event("startup")
|
||||
async def load_model():
|
||||
global model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
prompt: str
|
||||
duration: float = 10.0
|
||||
temperature: float = 1.0
|
||||
cfg_coef: float = 3.0
|
||||
|
||||
class GenerateResponse(BaseModel):
|
||||
audio_base64: str
|
||||
sample_rate: int
|
||||
duration: float
|
||||
|
||||
@app.post("/generate", response_model=GenerateResponse)
|
||||
async def generate(request: GenerateRequest):
|
||||
if model is None:
|
||||
raise HTTPException(status_code=500, detail="Model not loaded")
|
||||
|
||||
try:
|
||||
model.set_generation_params(
|
||||
duration=min(request.duration, 30),
|
||||
temperature=request.temperature,
|
||||
cfg_coef=request.cfg_coef
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = model.generate([request.prompt])
|
||||
|
||||
# Convert to bytes
|
||||
buffer = io.BytesIO()
|
||||
torchaudio.save(buffer, wav[0].cpu(), sample_rate=32000, format="wav")
|
||||
buffer.seek(0)
|
||||
|
||||
audio_base64 = base64.b64encode(buffer.read()).decode()
|
||||
|
||||
return GenerateResponse(
|
||||
audio_base64=audio_base64,
|
||||
sample_rate=32000,
|
||||
duration=wav.shape[-1] / 32000
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok", "model_loaded": model is not None}
|
||||
|
||||
# Run: uvicorn server:app --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
### Batch processing service
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import torch
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
class MusicGenService:
|
||||
def __init__(self, model_name='facebook/musicgen-small', max_workers=2):
|
||||
self.model = MusicGen.get_pretrained(model_name)
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
async def generate_async(self, prompt, duration=10):
|
||||
"""Async generation with thread pool."""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def _generate():
|
||||
with torch.no_grad():
|
||||
self.model.set_generation_params(duration=duration)
|
||||
return self.model.generate([prompt])
|
||||
|
||||
# Run in thread pool
|
||||
wav = await loop.run_in_executor(self.executor, _generate)
|
||||
return wav[0].cpu()
|
||||
|
||||
async def generate_batch_async(self, prompts, duration=10):
|
||||
"""Process multiple prompts concurrently."""
|
||||
tasks = [self.generate_async(p, duration) for p in prompts]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
# Usage
|
||||
service = MusicGenService()
|
||||
|
||||
async def main():
|
||||
prompts = ["jazz piano", "rock guitar", "electronic beats"]
|
||||
results = await service.generate_batch_async(prompts)
|
||||
return results
|
||||
```
|
||||
|
||||
## Integration Patterns
|
||||
|
||||
### LangChain tool
|
||||
|
||||
```python
|
||||
from langchain.tools import BaseTool
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
import tempfile
|
||||
|
||||
class MusicGeneratorTool(BaseTool):
|
||||
name = "music_generator"
|
||||
description = "Generate music from a text description. Input should be a detailed description of the music style, mood, and instruments."
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
self.model.set_generation_params(duration=15)
|
||||
|
||||
def _run(self, description: str) -> str:
|
||||
with torch.no_grad():
|
||||
wav = self.model.generate([description])
|
||||
|
||||
# Save to temp file
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
||||
torchaudio.save(f.name, wav[0].cpu(), sample_rate=32000)
|
||||
return f"Generated music saved to: {f.name}"
|
||||
|
||||
async def _arun(self, description: str) -> str:
|
||||
return self._run(description)
|
||||
```
|
||||
|
||||
### Gradio with advanced controls
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
models = {}
|
||||
|
||||
def load_model(model_size):
|
||||
if model_size not in models:
|
||||
model_name = f"facebook/musicgen-{model_size}"
|
||||
models[model_size] = MusicGen.get_pretrained(model_name)
|
||||
return models[model_size]
|
||||
|
||||
def generate(prompt, duration, temperature, cfg_coef, top_k, model_size):
|
||||
model = load_model(model_size)
|
||||
|
||||
model.set_generation_params(
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
cfg_coef=cfg_coef,
|
||||
top_k=top_k
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = model.generate([prompt])
|
||||
|
||||
# Save
|
||||
path = "output.wav"
|
||||
torchaudio.save(path, wav[0].cpu(), sample_rate=32000)
|
||||
return path
|
||||
|
||||
demo = gr.Interface(
|
||||
fn=generate,
|
||||
inputs=[
|
||||
gr.Textbox(label="Prompt", lines=3),
|
||||
gr.Slider(1, 30, value=10, label="Duration (s)"),
|
||||
gr.Slider(0.1, 2.0, value=1.0, label="Temperature"),
|
||||
gr.Slider(0.5, 10.0, value=3.0, label="CFG Coefficient"),
|
||||
gr.Slider(50, 500, value=250, step=50, label="Top-K"),
|
||||
gr.Dropdown(["small", "medium", "large"], value="small", label="Model Size")
|
||||
],
|
||||
outputs=gr.Audio(label="Generated Music"),
|
||||
title="MusicGen Advanced",
|
||||
allow_flagging="never"
|
||||
)
|
||||
|
||||
demo.launch(share=True)
|
||||
```
|
||||
|
||||
## Audio Processing Pipeline
|
||||
|
||||
### Post-processing chain
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torchaudio
|
||||
import torchaudio.transforms as T
|
||||
import numpy as np
|
||||
|
||||
class AudioPostProcessor:
|
||||
def __init__(self, sample_rate=32000):
|
||||
self.sample_rate = sample_rate
|
||||
|
||||
def normalize(self, audio, target_db=-14.0):
|
||||
"""Normalize audio to target loudness."""
|
||||
rms = torch.sqrt(torch.mean(audio ** 2))
|
||||
target_rms = 10 ** (target_db / 20)
|
||||
gain = target_rms / (rms + 1e-8)
|
||||
return audio * gain
|
||||
|
||||
def fade_in_out(self, audio, fade_duration=0.1):
|
||||
"""Apply fade in/out."""
|
||||
fade_samples = int(fade_duration * self.sample_rate)
|
||||
|
||||
# Create fade curves
|
||||
fade_in = torch.linspace(0, 1, fade_samples)
|
||||
fade_out = torch.linspace(1, 0, fade_samples)
|
||||
|
||||
# Apply fades
|
||||
audio[..., :fade_samples] *= fade_in
|
||||
audio[..., -fade_samples:] *= fade_out
|
||||
|
||||
return audio
|
||||
|
||||
def apply_reverb(self, audio, decay=0.5):
|
||||
"""Apply simple reverb effect."""
|
||||
impulse = torch.zeros(int(self.sample_rate * 0.5))
|
||||
impulse[0] = 1.0
|
||||
impulse[int(self.sample_rate * 0.1)] = decay * 0.5
|
||||
impulse[int(self.sample_rate * 0.2)] = decay * 0.25
|
||||
|
||||
# Convolve
|
||||
audio = torch.nn.functional.conv1d(
|
||||
audio.unsqueeze(0),
|
||||
impulse.unsqueeze(0).unsqueeze(0),
|
||||
padding=len(impulse) // 2
|
||||
).squeeze(0)
|
||||
|
||||
return audio
|
||||
|
||||
def process(self, audio):
|
||||
"""Full processing pipeline."""
|
||||
audio = self.normalize(audio)
|
||||
audio = self.fade_in_out(audio)
|
||||
return audio
|
||||
|
||||
# Usage with MusicGen
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
wav = model.generate(["chill ambient music"])
|
||||
processor = AudioPostProcessor()
|
||||
wav_processed = processor.process(wav[0].cpu())
|
||||
|
||||
torchaudio.save("processed.wav", wav_processed, sample_rate=32000)
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
|
||||
### Audio quality metrics
|
||||
|
||||
```python
|
||||
import torch
|
||||
from audiocraft.metrics import CLAPTextConsistencyMetric
|
||||
from audiocraft.data.audio import audio_read
|
||||
|
||||
def evaluate_generation(audio_path, text_prompt):
|
||||
"""Evaluate generated audio quality."""
|
||||
# Load audio
|
||||
wav, sr = audio_read(audio_path)
|
||||
|
||||
# CLAP consistency (text-audio alignment)
|
||||
clap_metric = CLAPTextConsistencyMetric()
|
||||
clap_score = clap_metric.compute(wav, [text_prompt])
|
||||
|
||||
return {
|
||||
"clap_score": clap_score,
|
||||
"duration": wav.shape[-1] / sr
|
||||
}
|
||||
|
||||
# Batch evaluation
|
||||
def evaluate_batch(generations):
|
||||
"""Evaluate multiple generations."""
|
||||
results = []
|
||||
for gen in generations:
|
||||
result = evaluate_generation(gen["path"], gen["prompt"])
|
||||
result["prompt"] = gen["prompt"]
|
||||
results.append(result)
|
||||
|
||||
# Aggregate
|
||||
avg_clap = sum(r["clap_score"] for r in results) / len(results)
|
||||
return {
|
||||
"individual": results,
|
||||
"average_clap": avg_clap
|
||||
}
|
||||
```
|
||||
|
||||
## Model Comparison
|
||||
|
||||
### MusicGen variants benchmark
|
||||
|
||||
| Model | CLAP Score | Generation Time (10s) | VRAM |
|
||||
|-------|------------|----------------------|------|
|
||||
| musicgen-small | 0.35 | ~5s | 2GB |
|
||||
| musicgen-medium | 0.42 | ~15s | 4GB |
|
||||
| musicgen-large | 0.48 | ~30s | 8GB |
|
||||
| musicgen-melody | 0.45 | ~15s | 4GB |
|
||||
| musicgen-stereo-medium | 0.41 | ~18s | 5GB |
|
||||
|
||||
### Prompt engineering tips
|
||||
|
||||
```python
|
||||
# Good prompts - specific and descriptive
|
||||
good_prompts = [
|
||||
"upbeat electronic dance music with synthesizer leads and punchy drums at 128 bpm",
|
||||
"melancholic piano ballad with strings, slow tempo, emotional and cinematic",
|
||||
"funky disco groove with slap bass, brass section, and rhythmic guitar"
|
||||
]
|
||||
|
||||
# Bad prompts - too vague
|
||||
bad_prompts = [
|
||||
"nice music",
|
||||
"song",
|
||||
"good beat"
|
||||
]
|
||||
|
||||
# Structure: [mood] [genre] with [instruments] at [tempo/style]
|
||||
```
|
||||
+504
@@ -0,0 +1,504 @@
|
||||
# AudioCraft Troubleshooting Guide
|
||||
|
||||
## Installation Issues
|
||||
|
||||
### Import errors
|
||||
|
||||
**Error**: `ModuleNotFoundError: No module named 'audiocraft'`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Install from PyPI
|
||||
pip install audiocraft
|
||||
|
||||
# Or from GitHub
|
||||
pip install git+https://github.com/facebookresearch/audiocraft.git
|
||||
|
||||
# Verify installation
|
||||
python -c "from audiocraft.models import MusicGen; print('OK')"
|
||||
```
|
||||
|
||||
### FFmpeg not found
|
||||
|
||||
**Error**: `RuntimeError: ffmpeg not found`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Ubuntu/Debian
|
||||
sudo apt-get install ffmpeg
|
||||
|
||||
# macOS
|
||||
brew install ffmpeg
|
||||
|
||||
# Windows (using conda)
|
||||
conda install -c conda-forge ffmpeg
|
||||
|
||||
# Verify
|
||||
ffmpeg -version
|
||||
```
|
||||
|
||||
### PyTorch CUDA mismatch
|
||||
|
||||
**Error**: `RuntimeError: CUDA error: no kernel image is available`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Check CUDA version
|
||||
nvcc --version
|
||||
python -c "import torch; print(torch.version.cuda)"
|
||||
|
||||
# Install matching PyTorch
|
||||
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu121
|
||||
|
||||
# For CUDA 11.8
|
||||
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu118
|
||||
```
|
||||
|
||||
### xformers issues
|
||||
|
||||
**Error**: `ImportError: xformers` related errors
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Install xformers for memory efficiency
|
||||
pip install xformers
|
||||
|
||||
# Or disable xformers
|
||||
export AUDIOCRAFT_USE_XFORMERS=0
|
||||
|
||||
# In Python
|
||||
import os
|
||||
os.environ["AUDIOCRAFT_USE_XFORMERS"] = "0"
|
||||
from audiocraft.models import MusicGen
|
||||
```
|
||||
|
||||
## Model Loading Issues
|
||||
|
||||
### Out of memory during load
|
||||
|
||||
**Error**: `torch.cuda.OutOfMemoryError` during model loading
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use smaller model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Force CPU loading first
|
||||
import torch
|
||||
device = "cpu"
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small', device=device)
|
||||
model = model.to("cuda")
|
||||
|
||||
# Use HuggingFace with device_map
|
||||
from transformers import MusicgenForConditionalGeneration
|
||||
model = MusicgenForConditionalGeneration.from_pretrained(
|
||||
"facebook/musicgen-small",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
### Download failures
|
||||
|
||||
**Error**: Connection errors or incomplete downloads
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Set cache directory
|
||||
import os
|
||||
os.environ["AUDIOCRAFT_CACHE_DIR"] = "/path/to/cache"
|
||||
|
||||
# Or for HuggingFace
|
||||
os.environ["HF_HOME"] = "/path/to/hf_cache"
|
||||
|
||||
# Resume download
|
||||
from huggingface_hub import snapshot_download
|
||||
snapshot_download("facebook/musicgen-small", resume_download=True)
|
||||
|
||||
# Use local files
|
||||
model = MusicGen.get_pretrained('/local/path/to/model')
|
||||
```
|
||||
|
||||
### Wrong model type
|
||||
|
||||
**Error**: Loading wrong model for task
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# For text-to-music: use MusicGen
|
||||
from audiocraft.models import MusicGen
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-medium')
|
||||
|
||||
# For text-to-sound: use AudioGen
|
||||
from audiocraft.models import AudioGen
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
|
||||
# For melody conditioning: use melody variant
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||||
|
||||
# For stereo: use stereo variant
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
|
||||
```
|
||||
|
||||
## Generation Issues
|
||||
|
||||
### Empty or silent output
|
||||
|
||||
**Problem**: Generated audio is silent or very quiet
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Check output
|
||||
wav = model.generate(["upbeat music"])
|
||||
print(f"Shape: {wav.shape}")
|
||||
print(f"Max amplitude: {wav.abs().max().item()}")
|
||||
print(f"Mean amplitude: {wav.abs().mean().item()}")
|
||||
|
||||
# If too quiet, normalize
|
||||
def normalize_audio(audio, target_db=-14.0):
|
||||
rms = torch.sqrt(torch.mean(audio ** 2))
|
||||
target_rms = 10 ** (target_db / 20)
|
||||
gain = target_rms / (rms + 1e-8)
|
||||
return audio * gain
|
||||
|
||||
wav_normalized = normalize_audio(wav)
|
||||
```
|
||||
|
||||
### Poor quality output
|
||||
|
||||
**Problem**: Generated music sounds bad or noisy
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use larger model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-large')
|
||||
|
||||
# Adjust generation parameters
|
||||
model.set_generation_params(
|
||||
duration=15,
|
||||
top_k=250, # Increase for more diversity
|
||||
temperature=0.8, # Lower for more focused output
|
||||
cfg_coef=4.0 # Increase for better text adherence
|
||||
)
|
||||
|
||||
# Use better prompts
|
||||
# Bad: "music"
|
||||
# Good: "upbeat electronic dance music with synthesizers and punchy drums"
|
||||
|
||||
# Try MultiBand Diffusion
|
||||
from audiocraft.models import MultiBandDiffusion
|
||||
mbd = MultiBandDiffusion.get_mbd_musicgen()
|
||||
tokens = model.generate_tokens(["prompt"])
|
||||
wav = mbd.tokens_to_wav(tokens)
|
||||
```
|
||||
|
||||
### Generation too short
|
||||
|
||||
**Problem**: Audio shorter than expected
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Check duration setting
|
||||
model.set_generation_params(duration=30) # Set before generate
|
||||
|
||||
# Verify in generation
|
||||
print(f"Duration setting: {model.generation_params}")
|
||||
|
||||
# Check output shape
|
||||
wav = model.generate(["prompt"])
|
||||
actual_duration = wav.shape[-1] / 32000
|
||||
print(f"Actual duration: {actual_duration}s")
|
||||
|
||||
# Note: max duration is typically 30s
|
||||
```
|
||||
|
||||
### Melody conditioning fails
|
||||
|
||||
**Error**: Issues with melody-conditioned generation
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load melody model (not base model)
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||||
|
||||
# Load and prepare melody
|
||||
melody, sr = torchaudio.load("melody.wav")
|
||||
|
||||
# Resample to model sample rate if needed
|
||||
if sr != 32000:
|
||||
resampler = torchaudio.transforms.Resample(sr, 32000)
|
||||
melody = resampler(melody)
|
||||
|
||||
# Ensure correct shape [batch, channels, samples]
|
||||
if melody.dim() == 1:
|
||||
melody = melody.unsqueeze(0).unsqueeze(0)
|
||||
elif melody.dim() == 2:
|
||||
melody = melody.unsqueeze(0)
|
||||
|
||||
# Convert stereo to mono
|
||||
if melody.shape[1] > 1:
|
||||
melody = melody.mean(dim=1, keepdim=True)
|
||||
|
||||
# Generate with melody
|
||||
model.set_generation_params(duration=min(melody.shape[-1] / 32000, 30))
|
||||
wav = model.generate_with_chroma(["piano cover"], melody, 32000)
|
||||
```
|
||||
|
||||
## Memory Issues
|
||||
|
||||
### CUDA out of memory
|
||||
|
||||
**Error**: `torch.cuda.OutOfMemoryError: CUDA out of memory`
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Clear cache before generation
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Use smaller model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Reduce duration
|
||||
model.set_generation_params(duration=10) # Instead of 30
|
||||
|
||||
# Generate one at a time
|
||||
for prompt in prompts:
|
||||
wav = model.generate([prompt])
|
||||
save_audio(wav)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Use CPU for very large generations
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small', device="cpu")
|
||||
```
|
||||
|
||||
### Memory leak during batch processing
|
||||
|
||||
**Problem**: Memory grows over time
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import gc
|
||||
import torch
|
||||
|
||||
def generate_with_cleanup(model, prompts):
|
||||
results = []
|
||||
|
||||
for prompt in prompts:
|
||||
with torch.no_grad():
|
||||
wav = model.generate([prompt])
|
||||
results.append(wav.cpu())
|
||||
|
||||
# Cleanup
|
||||
del wav
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return results
|
||||
|
||||
# Use context manager
|
||||
with torch.inference_mode():
|
||||
wav = model.generate(["prompt"])
|
||||
```
|
||||
|
||||
## Audio Format Issues
|
||||
|
||||
### Wrong sample rate
|
||||
|
||||
**Problem**: Audio plays at wrong speed
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torchaudio
|
||||
|
||||
# MusicGen outputs at 32kHz
|
||||
sample_rate = 32000
|
||||
|
||||
# AudioGen outputs at 16kHz
|
||||
sample_rate = 16000
|
||||
|
||||
# Always use correct rate when saving
|
||||
torchaudio.save("output.wav", wav[0].cpu(), sample_rate=sample_rate)
|
||||
|
||||
# Resample if needed
|
||||
resampler = torchaudio.transforms.Resample(32000, 44100)
|
||||
wav_resampled = resampler(wav)
|
||||
```
|
||||
|
||||
### Stereo/mono mismatch
|
||||
|
||||
**Problem**: Wrong number of channels
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Check model type
|
||||
print(f"Audio channels: {wav.shape}")
|
||||
# Mono: [batch, 1, samples]
|
||||
# Stereo: [batch, 2, samples]
|
||||
|
||||
# Convert mono to stereo
|
||||
if wav.shape[1] == 1:
|
||||
wav_stereo = wav.repeat(1, 2, 1)
|
||||
|
||||
# Convert stereo to mono
|
||||
if wav.shape[1] == 2:
|
||||
wav_mono = wav.mean(dim=1, keepdim=True)
|
||||
|
||||
# Use stereo model for stereo output
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
|
||||
```
|
||||
|
||||
### Clipping and distortion
|
||||
|
||||
**Problem**: Audio has clipping or distortion
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Check for clipping
|
||||
max_val = wav.abs().max().item()
|
||||
print(f"Max amplitude: {max_val}")
|
||||
|
||||
# Normalize to prevent clipping
|
||||
if max_val > 1.0:
|
||||
wav = wav / max_val
|
||||
|
||||
# Apply soft clipping
|
||||
def soft_clip(x, threshold=0.9):
|
||||
return torch.tanh(x / threshold) * threshold
|
||||
|
||||
wav_clipped = soft_clip(wav)
|
||||
|
||||
# Lower temperature during generation
|
||||
model.set_generation_params(temperature=0.7) # More controlled
|
||||
```
|
||||
|
||||
## HuggingFace Transformers Issues
|
||||
|
||||
### Processor errors
|
||||
|
||||
**Error**: Issues with MusicgenProcessor
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
|
||||
# Load matching processor and model
|
||||
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
|
||||
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
|
||||
|
||||
# Ensure inputs are on same device
|
||||
inputs = processor(
|
||||
text=["prompt"],
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
).to("cuda")
|
||||
|
||||
# Check processor configuration
|
||||
print(processor.tokenizer)
|
||||
print(processor.feature_extractor)
|
||||
```
|
||||
|
||||
### Generation parameter errors
|
||||
|
||||
**Error**: Invalid generation parameters
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# HuggingFace uses different parameter names
|
||||
audio_values = model.generate(
|
||||
**inputs,
|
||||
do_sample=True, # Enable sampling
|
||||
guidance_scale=3.0, # CFG (not cfg_coef)
|
||||
max_new_tokens=256, # Token limit (not duration)
|
||||
temperature=1.0
|
||||
)
|
||||
|
||||
# Calculate tokens from duration
|
||||
# ~50 tokens per second
|
||||
duration_seconds = 10
|
||||
max_tokens = duration_seconds * 50
|
||||
audio_values = model.generate(**inputs, max_new_tokens=max_tokens)
|
||||
```
|
||||
|
||||
## Performance Issues
|
||||
|
||||
### Slow generation
|
||||
|
||||
**Problem**: Generation takes too long
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use smaller model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Reduce duration
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
# Use GPU
|
||||
model.to("cuda")
|
||||
|
||||
# Enable flash attention if available
|
||||
# (requires compatible hardware)
|
||||
|
||||
# Batch multiple prompts
|
||||
prompts = ["prompt1", "prompt2", "prompt3"]
|
||||
wav = model.generate(prompts) # Single batch is faster than loop
|
||||
|
||||
# Use compile (PyTorch 2.0+)
|
||||
model.lm = torch.compile(model.lm)
|
||||
```
|
||||
|
||||
### CPU fallback
|
||||
|
||||
**Problem**: Generation running on CPU instead of GPU
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Check CUDA availability
|
||||
print(f"CUDA available: {torch.cuda.is_available()}")
|
||||
print(f"CUDA device: {torch.cuda.get_device_name(0)}")
|
||||
|
||||
# Explicitly move to GPU
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
model.to("cuda")
|
||||
|
||||
# Verify model device
|
||||
print(f"Model device: {next(model.lm.parameters()).device}")
|
||||
```
|
||||
|
||||
## Common Error Messages
|
||||
|
||||
| Error | Cause | Solution |
|
||||
|-------|-------|----------|
|
||||
| `CUDA out of memory` | Model too large | Use smaller model, reduce duration |
|
||||
| `ffmpeg not found` | FFmpeg not installed | Install FFmpeg |
|
||||
| `No module named 'audiocraft'` | Not installed | `pip install audiocraft` |
|
||||
| `RuntimeError: Expected 3D tensor` | Wrong input shape | Check tensor dimensions |
|
||||
| `KeyError: 'melody'` | Wrong model for melody | Use musicgen-melody |
|
||||
| `Sample rate mismatch` | Wrong audio format | Resample to model rate |
|
||||
|
||||
## Getting Help
|
||||
|
||||
1. **GitHub Issues**: https://github.com/facebookresearch/audiocraft/issues
|
||||
2. **HuggingFace Forums**: https://discuss.huggingface.co
|
||||
3. **Paper**: https://arxiv.org/abs/2306.05284
|
||||
|
||||
### Reporting Issues
|
||||
|
||||
Include:
|
||||
- Python version
|
||||
- PyTorch version
|
||||
- CUDA version
|
||||
- AudioCraft version: `pip show audiocraft`
|
||||
- Full error traceback
|
||||
- Minimal reproducible code
|
||||
- Hardware (GPU model, VRAM)
|
||||
@@ -1,6 +1,6 @@
|
||||
---
|
||||
name: audiocraft-audio-generation
|
||||
description: PyTorch library for audio generation including text-to-music (MusicGen) and text-to-sound (AudioGen). Use when you need to generate music from text descriptions, create sound effects, or perform melody-conditioned music generation.
|
||||
description: "AudioCraft: MusicGen text-to-music, AudioGen text-to-sound."
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
---
|
||||
name: segment-anything-model
|
||||
description: Foundation model for image segmentation with zero-shot transfer. Use when you need to segment any object in images using points, boxes, or masks as prompts, or automatically generate all object masks in an image.
|
||||
description: "SAM: zero-shot image segmentation via points, boxes, masks."
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
---
|
||||
name: dspy
|
||||
description: Build complex AI systems with declarative programming, optimize prompts automatically, create modular RAG systems and agents with DSPy - Stanford NLP's framework for systematic LM programming
|
||||
description: "DSPy: declarative LM programs, auto-optimize prompts, RAG."
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
|
||||
@@ -0,0 +1,161 @@
|
||||
---
|
||||
name: axolotl
|
||||
description: Expert guidance for fine-tuning LLMs with Axolotl - YAML configs, 100+ models, LoRA/QLoRA, DPO/KTO/ORPO/GRPO, multimodal support
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [axolotl, torch, transformers, datasets, peft, accelerate, deepspeed]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Fine-Tuning, Axolotl, LLM, LoRA, QLoRA, DPO, KTO, ORPO, GRPO, YAML, HuggingFace, DeepSpeed, Multimodal]
|
||||
|
||||
---
|
||||
|
||||
# Axolotl Skill
|
||||
|
||||
Comprehensive assistance with axolotl development, generated from official documentation.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
This skill should be triggered when:
|
||||
- Working with axolotl
|
||||
- Asking about axolotl features or APIs
|
||||
- Implementing axolotl solutions
|
||||
- Debugging axolotl code
|
||||
- Learning axolotl best practices
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### Common Patterns
|
||||
|
||||
**Pattern 1:** To validate that acceptable data transfer speeds exist for your training job, running NCCL Tests can help pinpoint bottlenecks, for example:
|
||||
|
||||
```
|
||||
./build/all_reduce_perf -b 8 -e 128M -f 2 -g 3
|
||||
```
|
||||
|
||||
**Pattern 2:** Configure your model to use FSDP in the Axolotl yaml. For example:
|
||||
|
||||
```
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
offload_params: true
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
reshard_after_forward: true
|
||||
```
|
||||
|
||||
**Pattern 3:** The context_parallel_size should be a divisor of the total number of GPUs. For example:
|
||||
|
||||
```
|
||||
context_parallel_size
|
||||
```
|
||||
|
||||
**Pattern 4:** For example: - With 8 GPUs and no sequence parallelism: 8 different batches processed per step - With 8 GPUs and context_parallel_size=4: Only 2 different batches processed per step (each split across 4 GPUs) - If your per-GPU micro_batch_size is 2, the global batch size decreases from 16 to 4
|
||||
|
||||
```
|
||||
context_parallel_size=4
|
||||
```
|
||||
|
||||
**Pattern 5:** Setting save_compressed: true in your configuration enables saving models in a compressed format, which: - Reduces disk space usage by approximately 40% - Maintains compatibility with vLLM for accelerated inference - Maintains compatibility with llmcompressor for further optimization (example: quantization)
|
||||
|
||||
```
|
||||
save_compressed: true
|
||||
```
|
||||
|
||||
**Pattern 6:** Note It is not necessary to place your integration in the integrations folder. It can be in any location, so long as it’s installed in a package in your python env. See this repo for an example: https://github.com/axolotl-ai-cloud/diff-transformer
|
||||
|
||||
```
|
||||
integrations
|
||||
```
|
||||
|
||||
**Pattern 7:** Handle both single-example and batched data. - single example: sample[‘input_ids’] is a list[int] - batched data: sample[‘input_ids’] is a list[list[int]]
|
||||
|
||||
```
|
||||
utils.trainer.drop_long_seq(sample, sequence_len=2048, min_sequence_len=2)
|
||||
```
|
||||
|
||||
### Example Code Patterns
|
||||
|
||||
**Example 1** (python):
|
||||
```python
|
||||
cli.cloud.modal_.ModalCloud(config, app=None)
|
||||
```
|
||||
|
||||
**Example 2** (python):
|
||||
```python
|
||||
cli.cloud.modal_.run_cmd(cmd, run_folder, volumes=None)
|
||||
```
|
||||
|
||||
**Example 3** (python):
|
||||
```python
|
||||
core.trainers.base.AxolotlTrainer(
|
||||
*_args,
|
||||
bench_data_collator=None,
|
||||
eval_data_collator=None,
|
||||
dataset_tags=None,
|
||||
**kwargs,
|
||||
)
|
||||
```
|
||||
|
||||
**Example 4** (python):
|
||||
```python
|
||||
core.trainers.base.AxolotlTrainer.log(logs, start_time=None)
|
||||
```
|
||||
|
||||
**Example 5** (python):
|
||||
```python
|
||||
prompt_strategies.input_output.RawInputOutputPrompter()
|
||||
```
|
||||
|
||||
## Reference Files
|
||||
|
||||
This skill includes comprehensive documentation in `references/`:
|
||||
|
||||
- **api.md** - Api documentation
|
||||
- **dataset-formats.md** - Dataset-Formats documentation
|
||||
- **other.md** - Other documentation
|
||||
|
||||
Use `view` to read specific reference files when detailed information is needed.
|
||||
|
||||
## Working with This Skill
|
||||
|
||||
### For Beginners
|
||||
Start with the getting_started or tutorials reference files for foundational concepts.
|
||||
|
||||
### For Specific Features
|
||||
Use the appropriate category reference file (api, guides, etc.) for detailed information.
|
||||
|
||||
### For Code Examples
|
||||
The quick reference section above contains common patterns extracted from the official docs.
|
||||
|
||||
## Resources
|
||||
|
||||
### references/
|
||||
Organized documentation extracted from official sources. These files contain:
|
||||
- Detailed explanations
|
||||
- Code examples with language annotations
|
||||
- Links to original documentation
|
||||
- Table of contents for quick navigation
|
||||
|
||||
### scripts/
|
||||
Add helper scripts here for common automation tasks.
|
||||
|
||||
### assets/
|
||||
Add templates, boilerplate, or example projects here.
|
||||
|
||||
## Notes
|
||||
|
||||
- This skill was automatically generated from official documentation
|
||||
- Reference files preserve the structure and examples from source docs
|
||||
- Code examples include language detection for better syntax highlighting
|
||||
- Quick reference patterns are extracted from common usage examples in the docs
|
||||
|
||||
## Updating
|
||||
|
||||
To refresh this skill with updated documentation:
|
||||
1. Re-run the scraper with the same configuration
|
||||
2. The skill will be rebuilt with the latest information
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
+1029
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,15 @@
|
||||
# Axolotl Documentation Index
|
||||
|
||||
## Categories
|
||||
|
||||
### Api
|
||||
**File:** `api.md`
|
||||
**Pages:** 150
|
||||
|
||||
### Dataset-Formats
|
||||
**File:** `dataset-formats.md`
|
||||
**Pages:** 9
|
||||
|
||||
### Other
|
||||
**File:** `other.md`
|
||||
**Pages:** 26
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,6 @@
|
||||
---
|
||||
name: axolotl
|
||||
description: Expert guidance for fine-tuning LLMs with Axolotl - YAML configs, 100+ models, LoRA/QLoRA, DPO/KTO/ORPO/GRPO, multimodal support
|
||||
description: "Axolotl: YAML LLM fine-tuning (LoRA, DPO, GRPO)."
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
@@ -13,6 +13,10 @@ metadata:
|
||||
|
||||
# Axolotl Skill
|
||||
|
||||
## What's inside
|
||||
|
||||
Expert guidance for fine-tuning LLMs with Axolotl — YAML configs, 100+ models, LoRA/QLoRA, DPO/KTO/ORPO/GRPO, multimodal support.
|
||||
|
||||
Comprehensive assistance with axolotl development, generated from official documentation.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
---
|
||||
name: fine-tuning-with-trl
|
||||
description: Fine-tune LLMs using reinforcement learning with TRL - SFT for instruction tuning, DPO for preference alignment, PPO/GRPO for reward optimization, and reward model training. Use when need RLHF, align model with preferences, or train from human feedback. Works with HuggingFace Transformers.
|
||||
description: "TRL: SFT, DPO, PPO, GRPO, reward modeling for LLM RLHF."
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
---
|
||||
name: unsloth
|
||||
description: Expert guidance for fast fine-tuning with Unsloth - 2-5x faster training, 50-80% less memory, LoRA/QLoRA optimization
|
||||
description: "Unsloth: 2-5x faster LoRA/QLoRA fine-tuning, less VRAM."
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
|
||||
Reference in New Issue
Block a user