Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
improvement v3-5-1
72B:
 -- sum(log(err)): -884.396164
 -- max(err): 0.017426
vs OGB:
 -- sum(log(err)): -842.360744
 -- max(err): 0.018692
  • Loading branch information
Originalimoc committed Jan 10, 2025
commit 228262169dd33de9f8622d12ec4b4ea54cc8260c
123 changes: 95 additions & 28 deletions exllamav2/exllamav2_ext/ext_quant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,12 +175,36 @@ std::tuple<std::vector<std::tuple<uint64_t, float>>, std::vector<int>, float, ui
{
// --- Internal Parameters ---
const int redistribution_iterations = 25;
const float bpw_penalty_scale = 0.01f;
const float min_bpw_limit = 2.0f;
const float bpw_penalty_scale = 0.05f; // Increased penalty
const float min_bpw_base = 2.8f; // Absolute minimum BPW
const int opportunistic_iterations = 10000;
const float initial_opportunistic_temp = 0.01f;
const float low_error_threshold = 0.0009f;

// --- Dynamic Minimum BPW ---
auto calculate_dynamic_min_bpw = [&](float target_bpw, float temp_ratio) {
float scaled_min_bpw = min_bpw_base + 0.5f * (target_bpw - min_bpw_base);
return min_bpw_base + temp_ratio * (scaled_min_bpw - min_bpw_base);
};

// --- Calculate BPW ---
auto calculate_bpw = [&](const std::tuple<uint64_t, float>& option) {
return 8.0f * std::get<0>(option) / 1024.0f;
};

// --- Calculate BPW stats ---
auto calculate_bpw_stats = [&](const std::vector<std::tuple<uint64_t, float>>& sol) {
int num_slots = sol.size();
std::vector<float> current_bpws(num_slots);
for (int i = 0; i < num_slots; ++i) {
current_bpws[i] = calculate_bpw(sol[i]);
}
float bpw_mean = std::accumulate(current_bpws.begin(), current_bpws.end(), 0.0f) / num_slots;
float bpw_sq_sum = std::inner_product(current_bpws.begin(), current_bpws.end(), current_bpws.begin(), 0.0f);
float bpw_variance = bpw_sq_sum / num_slots - bpw_mean * bpw_mean;
return std::make_pair(bpw_mean, std::sqrt(std::max(0.0f, bpw_variance)));
};

// --- Original Simulated Annealing ---
int num_slots = slots.size();

Expand All @@ -194,6 +218,7 @@ std::tuple<std::vector<std::tuple<uint64_t, float>>, std::vector<int>, float, ui

float temp = initial_temp;
int iterations_outer = static_cast<int>(std::log(min_temp / temp) / std::log(cooling_factor));
float target_bpw = max_cost * 8.0f / 1024.0f / num_slots;

for (int i = 0; i < num_slots; ++i)
{
Expand All @@ -204,6 +229,9 @@ std::tuple<std::vector<std::tuple<uint64_t, float>>, std::vector<int>, float, ui

for (int j = 0; j < iterations_outer; ++j)
{
float temp_ratio = temp / initial_temp;
float min_bpw_limit = calculate_dynamic_min_bpw(target_bpw, temp_ratio);

for (int k = 0; k < iterations; ++k)
{
int i = std::uniform_int_distribution<>(0, num_slots - 1)(gen);
Expand All @@ -224,10 +252,18 @@ std::tuple<std::vector<std::tuple<uint64_t, float>>, std::vector<int>, float, ui
new_max_exp_error = std::max(current_max_exp_error, std::get<1>(new_option));
}

// BPW Penalty (Dynamic and Temperature-Dependent)
float bpw_new = calculate_bpw(new_option);
float bpw_penalty = 0.0f;

if (bpw_new < min_bpw_limit) {
bpw_penalty = (min_bpw_limit - bpw_new) * bpw_penalty_scale * (1 + temp_ratio); // Stronger penalty at higher temp
}

if (current_cost + delta_cost <= max_cost || (delta_cost < 0 && current_cost > max_cost))
{
if (delta_e < 0 ||
std::uniform_real_distribution<>(0, 1)(gen) < std::exp(-delta_e / temp))
if (delta_e + bpw_penalty < 0 ||
std::uniform_real_distribution<>(0, 1)(gen) < std::exp(-(delta_e + bpw_penalty) / temp))
{
solution[i] = new_option;
solution_idx[i] = n;
Expand All @@ -240,22 +276,11 @@ std::tuple<std::vector<std::tuple<uint64_t, float>>, std::vector<int>, float, ui
}

// --- Post-processing: Bit Redistribution ---
auto calculate_bpw = [&](const std::tuple<uint64_t, float>& option) {
return 8.0f * std::get<0>(option) / 1024.0f;
};

auto calculate_bpw_stats = [&](const std::vector<std::tuple<uint64_t, float>>& sol) {
std::vector<float> current_bpws(num_slots);
for (int i = 0; i < num_slots; ++i) {
current_bpws[i] = calculate_bpw(sol[i]);
}
float bpw_mean = std::accumulate(current_bpws.begin(), current_bpws.end(), 0.0f) / num_slots;
float bpw_sq_sum = std::inner_product(current_bpws.begin(), current_bpws.end(), current_bpws.begin(), 0.0f);
float bpw_variance = bpw_sq_sum / num_slots - bpw_mean * bpw_mean;
return std::make_pair(bpw_mean, std::sqrt(std::max(0.0f, bpw_variance)));
};

for (int r = 0; r < redistribution_iterations; ++r) {
float temp_ratio = temp / initial_temp;
float min_bpw_limit = calculate_dynamic_min_bpw(target_bpw, temp_ratio);

// Calculate BPW statistics and dynamic bpw_threshold
auto [bpw_mean, bpw_stddev] = calculate_bpw_stats(solution);
float bpw_threshold = std::max(min_bpw_limit, bpw_mean - 0.5f * bpw_stddev);
Expand Down Expand Up @@ -357,6 +382,9 @@ std::tuple<std::vector<std::tuple<uint64_t, float>>, std::vector<int>, float, ui

float local_temp = initial_opportunistic_temp;
for (int i = 0; i < opportunistic_iterations; ++i) {
float temp_ratio = temp / initial_temp;
float min_bpw_limit = calculate_dynamic_min_bpw(target_bpw, temp_ratio);

// Select a neighborhood of slots
int center_slot = std::uniform_int_distribution<>(0, num_slots - 1)(gen);
int neighborhood_size = std::min(5, num_slots); // Example neighborhood size
Expand All @@ -380,18 +408,18 @@ std::tuple<std::vector<std::tuple<uint64_t, float>>, std::vector<int>, float, ui
float current_bpw = calculate_bpw(solution[j]);
float target_bpw = neighborhood_bpw_avg;

// Error-weighted adjustment
// Error-weighted adjustment with bias towards higher BPW
float avg_error = 0;
for (int k = start_slot; k <= end_slot; ++k) {
avg_error += std::get<1>(solution[k]);
}
avg_error /= (end_slot - start_slot + 1);
float error_ratio = std::get<1>(solution[j]) / avg_error;

float adjustment = 0.125f;
float adjustment = 0.25f + 0.25f * error_ratio; // Increased adjustment with bias

// Adjust BPW towards the target, weighted by error
if (current_bpw < target_bpw) {
// Adjust BPW towards the target, weighted by error, with a bias towards higher BPW
if (current_bpw < target_bpw + adjustment) { // Bias towards higher BPW
// Search for a higher BPW option
for (int n = 0; n < slots[j].size(); ++n) {
auto new_option = slots[j][n];
Expand All @@ -408,7 +436,7 @@ std::tuple<std::vector<std::tuple<uint64_t, float>>, std::vector<int>, float, ui
}
} else if (current_bpw > target_bpw) {
// Search for a lower BPW option
for (int n = 0; n < slots[j].size(); ++n) {
for (int n = slots[j].size() - 1; n >= 0; --n) { // Iterate in reverse order
auto new_option = slots[j][n];
if (calculate_bpw(new_option) < current_bpw && calculate_bpw(new_option) >= current_bpw - adjustment) {
if (new_cost - std::get<0>(solution[j]) + std::get<0>(new_option) <= max_cost)
Expand Down Expand Up @@ -440,13 +468,16 @@ std::tuple<std::vector<std::tuple<uint64_t, float>>, std::vector<int>, float, ui
error_factor = 0.1f; // Reduce the weight of sum_log_err
}

if (new_cost <= max_cost && calculate_bpw(new_solution[i]) >= min_bpw_limit) {
if (new_cost <= max_cost) {
if (delta_sum_log_err * error_factor < 0 || std::uniform_real_distribution<>(0, 1)(gen) < std::exp(-delta_sum_log_err * error_factor / local_temp)) {
accept = true;
}
// Give high priority to the solution that has high minimum BPW
if (calculate_bpw(new_solution[i]) < min_bpw_limit) {
accept = false;
// Further penalize if below min_bpw_limit
for (int j = 0; j < num_slots; ++j) {
if (calculate_bpw(new_solution[j]) < min_bpw_limit) {
accept = false;
break;
}
}
}
}

Expand All @@ -472,6 +503,42 @@ std::tuple<std::vector<std::tuple<uint64_t, float>>, std::vector<int>, float, ui
solution_idx = best_solution_idx;
current_sum_log_err = best_sum_log_err;

// --- BPW Smoothing (Post-processing) ---
for (int i = 1; i < num_slots - 1; ++i) {
float current_bpw = calculate_bpw(solution[i]);
float prev_bpw = calculate_bpw(solution[i - 1]);
float next_bpw = calculate_bpw(solution[i + 1]);
float avg_neighbor_bpw = (prev_bpw + next_bpw) / 2.0f;

if (current_bpw < avg_neighbor_bpw - 0.5f) { // Significant difference
// Find a higher BPW option for the current slot
for (int n = 0; n < slots[i].size(); ++n) {
auto new_option = slots[i][n];
if (calculate_bpw(new_option) > current_bpw && calculate_bpw(new_option) <= avg_neighbor_bpw) {
if (current_cost - std::get<0>(solution[i]) + std::get<0>(new_option) <= max_cost) {
// Check if the new option doesn't significantly increase max_err
float new_max_err = 0;
for (int j = 0; j < num_slots; ++j) {
if (j == i) {
new_max_err = std::max(new_max_err, std::get<1>(new_option));
} else {
new_max_err = std::max(new_max_err, std::get<1>(solution[j]));
}
}

if (new_max_err < current_max_exp_error * 1.1f) { // Allow a small increase in max_err
current_cost = current_cost - std::get<0>(solution[i]) + std::get<0>(new_option);
solution[i] = new_option;
solution_idx[i] = n;
current_max_exp_error = new_max_err;
break;
}
}
}
}
}
}

// --- Final Cost Check and Rollback (if necessary) ---
if (current_cost > max_cost) {
std::vector<std::pair<float, int>> error_indices(num_slots);
Expand Down