Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
make rms_norm_eps a parameter
  • Loading branch information
slaren committed Jul 24, 2023
commit 9fe47c747f46adbc1c010ac2d5fb33756b09f10c
32 changes: 17 additions & 15 deletions examples/train-text-from-scratch/train-text-from-scratch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#pragma warning(disable: 4244 4267) // possible loss of data
#endif

static const float rms_norm_eps = 1e-6f;

struct random_normal_distribution {
std::mt19937 gen;
std::normal_distribution<float> rd;
Expand Down Expand Up @@ -439,7 +441,7 @@ struct ggml_tensor * forward(
// norm
{
// cur shape [n_embd,N,1,1]
cur = ggml_rms_norm(ctx0, inpL);
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);

// cur = attention_norm*cur
cur = ggml_mul(ctx0,
Expand Down Expand Up @@ -562,7 +564,7 @@ struct ggml_tensor * forward(
// norm
{
// cur shape [n_embd,N,1,1]
cur = ggml_rms_norm(ctx0, inpFF);
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);

// cur = ffn_norm*cur
// cur shape [n_embd,N,1,1]
Expand Down Expand Up @@ -606,7 +608,7 @@ struct ggml_tensor * forward(
{

// inpL shape [n_embd,N,1,1]
inpL = ggml_rms_norm(ctx0, inpL);
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);

// inpL = norm*inpL
// inpL shape [n_embd,N,1,1]
Expand Down Expand Up @@ -694,7 +696,7 @@ struct ggml_tensor * forward_batch(
// norm
{
// cur shape [n_embd,N*n_batch,1,1]
cur = ggml_rms_norm(ctx0, inpL);
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch);

// cur = attention_norm*cur
Expand Down Expand Up @@ -857,7 +859,7 @@ struct ggml_tensor * forward_batch(
// norm
{
// cur shape [n_embd,N*n_batch,1,1]
cur = ggml_rms_norm(ctx0, inpFF);
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch);

// cur = ffn_norm*cur
Expand Down Expand Up @@ -910,7 +912,7 @@ struct ggml_tensor * forward_batch(
{

// inpL shape [n_embd,N*n_batch,1,1]
inpL = ggml_rms_norm(ctx0, inpL);
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(inpL, n_embd, N*n_batch);

// inpL = norm*inpL
Expand Down Expand Up @@ -979,7 +981,7 @@ struct ggml_tensor * forward_batch_wo_cache(
// norm
{
// cur shape [n_embd,N*n_batch,1,1]
cur = ggml_rms_norm(ctx0, inpL);
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch);

// cur = attention_norm*cur
Expand Down Expand Up @@ -1085,7 +1087,7 @@ struct ggml_tensor * forward_batch_wo_cache(
// norm
{
// cur shape [n_embd,N*n_batch,1,1]
cur = ggml_rms_norm(ctx0, inpFF);
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch);

// cur = ffn_norm*cur
Expand Down Expand Up @@ -1138,7 +1140,7 @@ struct ggml_tensor * forward_batch_wo_cache(
{

// inpL shape [n_embd,N*n_batch,1,1]
inpL = ggml_rms_norm(ctx0, inpL);
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(inpL, n_embd, N*n_batch);

// inpL = norm*inpL
Expand Down Expand Up @@ -1203,7 +1205,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(

// norm
{
cur = ggml_rms_norm(ctx0, inpL);
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch);

// cur = attention_norm*cur
Expand Down Expand Up @@ -1267,7 +1269,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
{
// norm
{
cur = ggml_rms_norm(ctx0, inpFF);
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch);

// cur = ffn_norm*cur
Expand Down Expand Up @@ -1311,7 +1313,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
// norm
{

inpL = ggml_rms_norm(ctx0, inpL);
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(inpL, n_embd, N*n_batch);

// inpL = norm*inpL
Expand Down Expand Up @@ -1603,7 +1605,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
struct my_llama_layer & layer = model->layers[il];
// tensors with values necessary for backward pass are in persistent buf(-1)
// other tensors with buf(0) and buf(1) are only temporary needed, and their memory reused after layer is completed.
use_buf(-1); struct ggml_tensor * t02 = expand(gf, ggml_rms_norm (ctx0, cur)); assert_shape_2d(t02, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t02 = expand(gf, ggml_rms_norm (ctx0, cur, rms_norm_eps)); assert_shape_2d(t02, n_embd, N*n_batch);
use_buf( 0); struct ggml_tensor * t03 = expand(gf, ggml_repeat (ctx0, layer.attention_norm, t02)); assert_shape_2d(t03, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t04 = expand(gf, ggml_mul (ctx0, t02, t03)); assert_shape_2d(t04, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t05 = expand(gf, ggml_mul_mat (ctx0, layer.wq, t04)); assert_shape_2d(t05, n_embd, N*n_batch);
Expand All @@ -1623,7 +1625,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
use_buf(-1); struct ggml_tensor * t19 = expand(gf, ggml_reshape_2d (ctx0, t18, n_embd, N*n_batch)); assert_shape_2d(t19, n_embd, N*n_batch);
use_buf( 0); struct ggml_tensor * t20 = expand(gf, ggml_mul_mat (ctx0, layer.wo, t19)); assert_shape_2d(t20, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t21 = expand(gf, ggml_add (ctx0, t20, cur)); assert_shape_2d(t21, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t22 = expand(gf, ggml_rms_norm (ctx0, t21)); assert_shape_2d(t22, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t22 = expand(gf, ggml_rms_norm (ctx0, t21, rms_norm_eps)); assert_shape_2d(t22, n_embd, N*n_batch);
use_buf( 0); struct ggml_tensor * t23 = expand(gf, ggml_repeat (ctx0, layer.ffn_norm, t22)); assert_shape_2d(t23, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t24 = expand(gf, ggml_mul (ctx0, t23, t22)); assert_shape_2d(t24, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t25 = expand(gf, ggml_mul_mat (ctx0, layer.w3, t24)); assert_shape_2d(t25, n_ff, N*n_batch);
Expand Down Expand Up @@ -1666,7 +1668,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
}
clr_buf(0);
use_buf(0);
struct ggml_tensor * t31 = expand(gf, ggml_rms_norm (ctx0, cur)); assert_shape_2d(t31, n_embd, N*n_batch);
struct ggml_tensor * t31 = expand(gf, ggml_rms_norm (ctx0, cur, rms_norm_eps)); assert_shape_2d(t31, n_embd, N*n_batch);
struct ggml_tensor * t32 = expand(gf, ggml_repeat (ctx0, model->norm, t31)); assert_shape_2d(t32, n_embd, N*n_batch);
struct ggml_tensor * t33 = expand(gf, ggml_mul (ctx0, t32, t31)); assert_shape_2d(t33, n_embd, N*n_batch);
use_buf(-1);
Expand Down
13 changes: 7 additions & 6 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -332,12 +332,10 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
}
}

static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols) {
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;

const float eps = 1e-6f;

float tmp = 0.0f; // partial sum for thread in warp

for (int col = tid; col < ncols; col += WARP_SIZE) {
Expand Down Expand Up @@ -2122,10 +2120,10 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i
norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
}

static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
}

static void quantize_row_q8_1_cuda(const float * x, void * vy, const int ndata, const int k, cudaStream_t stream) {
Expand Down Expand Up @@ -2876,8 +2874,11 @@ inline void ggml_cuda_op_rms_norm(
const int64_t ne00 = src0->ne[0];
const int64_t i01_diff = i01_high - i01_low;

float eps;
memcpy(&eps, dst->op_params, sizeof(float));

// compute
rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, eps, cudaStream_main);

(void) src1;
(void) dst;
Expand Down
3 changes: 2 additions & 1 deletion ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,8 @@ void ggml_metal_graph_compute(
encoder = [command_buffer computeCommandEncoder];
}

const float eps = 1e-6f;
float eps;
memcpy(&eps, dst->op_params, sizeof(float));

const int nth = 512;

Expand Down
16 changes: 10 additions & 6 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -5781,6 +5781,7 @@ struct ggml_tensor * ggml_norm_inplace(
static struct ggml_tensor * ggml_rms_norm_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
float eps,
bool inplace) {
bool is_node = false;

Expand All @@ -5790,7 +5791,7 @@ static struct ggml_tensor * ggml_rms_norm_impl(

struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);

// TODO: maybe store epsilon here?
ggml_set_op_params(result, &eps, sizeof(eps));

result->op = GGML_OP_RMS_NORM;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
Expand All @@ -5801,14 +5802,16 @@ static struct ggml_tensor * ggml_rms_norm_impl(

struct ggml_tensor * ggml_rms_norm(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_rms_norm_impl(ctx, a, false);
struct ggml_tensor * a,
float eps) {
return ggml_rms_norm_impl(ctx, a, eps, false);
}

struct ggml_tensor * ggml_rms_norm_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_rms_norm_impl(ctx, a, true);
struct ggml_tensor * a,
float eps) {
return ggml_rms_norm_impl(ctx, a, eps, true);
}

struct ggml_tensor * ggml_rms_norm_back(
Expand Down Expand Up @@ -10131,7 +10134,8 @@ static void ggml_compute_forward_rms_norm_f32(

GGML_TENSOR_UNARY_OP_LOCALS;

const float eps = 1e-6f; // TODO: make this a parameter
float eps;
memcpy(&eps, dst->op_params, sizeof(float));

// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
Expand Down
7 changes: 5 additions & 2 deletions ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -866,14 +866,17 @@ extern "C" {

GGML_API struct ggml_tensor * ggml_rms_norm(
struct ggml_context * ctx,
struct ggml_tensor * a);
struct ggml_tensor * a,
float eps);

GGML_API struct ggml_tensor * ggml_rms_norm_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
struct ggml_tensor * a,
float eps);

// a - x
// b - dy
// TODO: update with configurable eps
GGML_API struct ggml_tensor * ggml_rms_norm_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
Expand Down
10 changes: 7 additions & 3 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1396,11 +1396,15 @@ static bool llama_eval_internal(
const int64_t n_vocab = hparams.n_vocab;
const int64_t n_embd_gqa = hparams.n_embd_gqa();


LLAMA_ASSERT(n_embd_head == hparams.n_rot);

const float freq_base = hparams.rope_freq_base;
const float freq_scale = hparams.rope_freq_scale;

// TODO: read from hparams
const float rms_norm_eps = 1e-6f;

const int n_gpu_layers = model.n_gpu_layers;

auto & mem_per_token = lctx.mem_per_token;
Expand Down Expand Up @@ -1479,7 +1483,7 @@ static bool llama_eval_internal(

// norm
{
cur = ggml_rms_norm(ctx0, inpL);
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
offload_func(cur);
ggml_set_name(cur, "rms_norm_0");

Expand Down Expand Up @@ -1627,7 +1631,7 @@ static bool llama_eval_internal(
{
// norm
{
cur = ggml_rms_norm(ctx0, inpFF);
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
offload_func(cur);
ggml_set_name(cur, "rms_norm_1");

Expand Down Expand Up @@ -1680,7 +1684,7 @@ static bool llama_eval_internal(

// norm
{
cur = ggml_rms_norm(ctx0, inpL);
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
offload_func_nr(cur);
ggml_set_name(cur, "rms_norm_2");

Expand Down