Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
improvement v3-2
 -- sum(log(err)): -840.236110
 -- max(err): 0.005603
  • Loading branch information
Originalimoc committed Jan 10, 2025
commit 117a60a352157265c9ba1226468c83d7329807b5
331 changes: 227 additions & 104 deletions exllamav2/exllamav2_ext/ext_quant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,10 @@ std::tuple<std::vector<std::tuple<uint64_t, float>>, std::vector<int>, float, ui
const int redistribution_iterations = 25;
const float bpw_penalty_scale = 0.01f;
const float min_bpw_limit = 2.0f;
const int opportunistic_iterations = 5000;
const float bpw_transfer_step = 0.0625f; // Amount of BPW to transfer in each step
const int opportunistic_iterations_stage1 = 5000;
const int opportunistic_iterations_stage2 = 10000;
const float initial_opportunistic_temp = 0.01f;
const float min_exp_error_threshold = 0.001f;

// --- Original Simulated Annealing ---
int num_slots = slots.size();
Expand Down Expand Up @@ -334,132 +336,253 @@ std::tuple<std::vector<std::tuple<uint64_t, float>>, std::vector<int>, float, ui
}
}

// --- Opportunistic Optimization ---
// Track the best solution found during opportunistic optimization
std::vector<std::tuple<uint64_t, float>> best_solution_opportunistic = solution;
std::vector<int> best_solution_idx_opportunistic = solution_idx;
float best_sum_log_err_opportunistic = 1e18f;
uint64_t best_cost_opportunistic = current_cost;

for (int i = 0; i < opportunistic_iterations; ++i) {
auto [bpw_mean, bpw_stddev] = calculate_bpw_stats(solution);
float bpw_threshold = std::max(min_bpw_limit, bpw_mean - 0.5f * bpw_stddev);
// --- Opportunistic Optimization (Stage 1: Focus on Sum of Log Errors) ---
float current_sum_log_err = 0;
for (int i = 0; i < num_slots; ++i) {
current_sum_log_err += log(std::get<1>(solution[i]));
}

int slot1 = -1;
// Find a slot with BPW above the threshold
std::vector<int> high_bpw_indices;
for(int j = 0; j < num_slots; j++) {
if(calculate_bpw(solution[j]) > bpw_threshold) {
high_bpw_indices.push_back(j);
}
float best_sum_log_err = current_sum_log_err;
std::vector<std::tuple<uint64_t, float>> best_solution = solution;
std::vector<int> best_solution_idx = solution_idx;
float best_max_exp_error = current_max_exp_error;

float local_temp = initial_opportunistic_temp;
for (int i = 0; i < opportunistic_iterations_stage1; ++i) {
// Select a neighborhood of slots
int center_slot = std::uniform_int_distribution<>(0, num_slots - 1)(gen);
int neighborhood_size = std::min(5, num_slots);
int start_slot = std::max(0, center_slot - neighborhood_size / 2);
int end_slot = std::min(num_slots - 1, center_slot + neighborhood_size / 2);

// Calculate average BPW in the neighborhood
float neighborhood_bpw_sum = 0;
for (int j = start_slot; j <= end_slot; ++j) {
neighborhood_bpw_sum += calculate_bpw(solution[j]);
}
if(high_bpw_indices.empty()) continue;
slot1 = high_bpw_indices[std::uniform_int_distribution<>(0, high_bpw_indices.size() - 1)(gen)];

int slot2 = std::uniform_int_distribution<>(0, num_slots - 1)(gen);
if (slot1 == slot2) continue;

int option1 = solution_idx[slot1];
int option2 = solution_idx[slot2];

// Find a lower BPW option for slot1
int best_option1 = -1;
float best_option1_error = 1e10f;
for (int new_option1 = 0; new_option1 < slots[slot1].size(); new_option1++) {
if (calculate_bpw(slots[slot1][new_option1]) < calculate_bpw(solution[slot1])) {
if (std::get<1>(slots[slot1][new_option1]) < best_option1_error) {
best_option1_error = std::get<1>(slots[slot1][new_option1]);
best_option1 = new_option1;
float neighborhood_bpw_avg = neighborhood_bpw_sum / (end_slot - start_slot + 1);

// Adjust BPWs within the neighborhood, weighted by error
std::vector<std::tuple<uint64_t, float>> new_solution = solution;
std::vector<int> new_solution_idx = solution_idx;
float new_sum_log_err = current_sum_log_err;
uint64_t new_cost = current_cost;

for (int j = start_slot; j <= end_slot; ++j) {
float current_bpw = calculate_bpw(solution[j]);
float target_bpw = neighborhood_bpw_avg;
float error = std::get<1>(solution[j]);
float adjustment = 0.125f;

// Error-weighted adjustment
float error_weight = std::max(0.0f, error - min_exp_error_threshold);

if (current_bpw < target_bpw) {
// Search for a higher BPW option
for (int n = 0; n < slots[j].size(); ++n) {
auto new_option = slots[j][n];
if (calculate_bpw(new_option) > current_bpw && calculate_bpw(new_option) <= current_bpw + adjustment)
{
if (new_cost - std::get<0>(solution[j]) + std::get<0>(new_option) <= max_cost)
{
new_cost = new_cost - std::get<0>(solution[j]) + std::get<0>(new_option);
new_sum_log_err = new_sum_log_err - log(std::get<1>(solution[j])) + log(std::get<1>(new_option));
new_solution[j] = new_option;
new_solution_idx[j] = n;
break;
}
}
}
} else if (current_bpw > target_bpw) {
// Search for a lower BPW option
for (int n = 0; n < slots[j].size(); ++n) {
auto new_option = slots[j][n];
if (calculate_bpw(new_option) < current_bpw && calculate_bpw(new_option) >= current_bpw - adjustment)
{
if (new_cost - std::get<0>(solution[j]) + std::get<0>(new_option) <= max_cost)
{
new_cost = new_cost - std::get<0>(solution[j]) + std::get<0>(new_option);
new_sum_log_err = new_sum_log_err - log(std::get<1>(solution[j])) + log(std::get<1>(new_option));
new_solution[j] = new_option;
new_solution_idx[j] = n;
break;
}
}
}
}
}

// Find a higher BPW option for slot2
int best_option2 = -1;
float best_option2_error = 1e10f;
for (int new_option2 = 0; new_option2 < slots[slot2].size(); new_option2++) {
if (calculate_bpw(slots[slot2][new_option2]) > calculate_bpw(solution[slot2])) {
if (std::get<1>(slots[slot2][new_option2]) < best_option2_error) {
best_option2_error = std::get<1>(slots[slot2][new_option2]);
best_option2 = new_option2;
// Calculate new max exp error
float new_max_exp_error = 0;
for (int j = 0; j < num_slots; ++j) {
new_max_exp_error = std::max(new_max_exp_error, std::get<1>(new_solution[j]));
}

// Acceptance criterion with a small probability of accepting worse solutions
if (new_cost <= max_cost) {
float delta_sum_log_err = new_sum_log_err - current_sum_log_err;
if (delta_sum_log_err < 0 || std::uniform_real_distribution<>(0, 1)(gen) < std::exp(-delta_sum_log_err / local_temp)) {
solution = new_solution;
solution_idx = new_solution_idx;
current_sum_log_err = new_sum_log_err;
current_cost = new_cost;
current_max_exp_error = new_max_exp_error;

if (current_sum_log_err < best_sum_log_err) {
best_sum_log_err = current_sum_log_err;
best_solution = solution;
best_solution_idx = solution_idx;
best_max_exp_error = current_max_exp_error;
}
}
}

if (best_option1 != -1 && best_option2 != -1) {
auto new_option1 = slots[slot1][best_option1];
auto new_option2 = slots[slot2][best_option2];

if (calculate_bpw(new_option2) < min_bpw_limit) continue;

uint64_t new_cost = current_cost - std::get<0>(solution[slot1]) - std::get<0>(solution[slot2]) + std::get<0>(new_option1) + std::get<0>(new_option2);
local_temp *= 0.95f;
}

if (new_cost <= max_cost) {
// Calculate new max exp error
float new_max_exp_error = std::get<1>(new_option2);
for (int j = 0; j < num_slots; j++) {
if (j == slot2) continue;
if (j == slot1) {
new_max_exp_error = std::max(new_max_exp_error, std::get<1>(new_option1));
} else {
new_max_exp_error = std::max(new_max_exp_error, std::get<1>(solution[j]));
// --- Opportunistic Optimization (Stage 2: Focus on Max Error and Min BPW) ---
local_temp = initial_opportunistic_temp * 0.1f; // Lower temperature for Stage 2
for (int i = 0; i < opportunistic_iterations_stage2; ++i) {
// Select a neighborhood of slots
int center_slot = std::uniform_int_distribution<>(0, num_slots - 1)(gen);
int neighborhood_size = std::min(5, num_slots);
int start_slot = std::max(0, center_slot - neighborhood_size / 2);
int end_slot = std::min(num_slots - 1, center_slot + neighborhood_size / 2);

// Calculate average BPW in the neighborhood
float neighborhood_bpw_sum = 0;
for (int j = start_slot; j <= end_slot; ++j) {
neighborhood_bpw_sum += calculate_bpw(solution[j]);
}
float neighborhood_bpw_avg = neighborhood_bpw_sum / (end_slot - start_slot + 1);

// Adjust BPWs within the neighborhood, weighted by error
std::vector<std::tuple<uint64_t, float>> new_solution = solution;
std::vector<int> new_solution_idx = solution_idx;
float new_sum_log_err = current_sum_log_err;
uint64_t new_cost = current_cost;

for (int j = start_slot; j <= end_slot; ++j) {
float current_bpw = calculate_bpw(solution[j]);
float target_bpw = neighborhood_bpw_avg;
float error = std::get<1>(solution[j]);
float adjustment = 0.125f;

// Focus on increasing BPW if below min_bpw_limit
if (current_bpw < min_bpw_limit) {
// Search for a higher BPW option
for (int n = 0; n < slots[j].size(); ++n) {
auto new_option = slots[j][n];
if (calculate_bpw(new_option) > current_bpw && calculate_bpw(new_option) <= current_bpw + adjustment)
{
if (new_cost - std::get<0>(solution[j]) + std::get<0>(new_option) <= max_cost)
{
new_cost = new_cost - std::get<0>(solution[j]) + std::get<0>(new_option);
new_sum_log_err = new_sum_log_err - log(std::get<1>(solution[j])) + log(std::get<1>(new_option));
new_solution[j] = new_option;
new_solution_idx[j] = n;
break;
}
}
}

// Calculate sum of log errors
float new_sum_log_err = 0;
for (int j = 0; j < num_slots; ++j) {
if (j == slot1) {
new_sum_log_err += log(std::get<1>(new_option1));
} else if (j == slot2) {
new_sum_log_err += log(std::get<1>(new_option2));
} else {
new_sum_log_err += log(std::get<1>(solution[j]));
} else {
// Error-weighted adjustment (less aggressive if error is already low)
float error_weight = std::max(0.0f, error - min_exp_error_threshold);

if (current_bpw < target_bpw) {
// Search for a higher BPW option
for (int n = 0; n < slots[j].size(); ++n) {
auto new_option = slots[j][n];
if (calculate_bpw(new_option) > current_bpw && calculate_bpw(new_option) <= current_bpw + adjustment)
{
if (new_cost - std::get<0>(solution[j]) + std::get<0>(new_option) <= max_cost)
{
new_cost = new_cost - std::get<0>(solution[j]) + std::get<0>(new_option);
new_sum_log_err = new_sum_log_err - log(std::get<1>(solution[j])) + log(std::get<1>(new_option));
new_solution[j] = new_option;
new_solution_idx[j] = n;
break;
}
}
}
} else if (current_bpw > target_bpw) {
// Search for a lower BPW option
for (int n = 0; n < slots[j].size(); ++n) {
auto new_option = slots[j][n];
if (calculate_bpw(new_option) < current_bpw && calculate_bpw(new_option) >= current_bpw - adjustment)
{
if (new_cost - std::get<0>(solution[j]) + std::get<0>(new_option) <= max_cost)
{
new_cost = new_cost - std::get<0>(solution[j]) + std::get<0>(new_option);
new_sum_log_err = new_sum_log_err - log(std::get<1>(solution[j])) + log(std::get<1>(new_option));
new_solution[j] = new_option;
new_solution_idx[j] = n;
break;
}
}
}
}
}
}

// Calculate current sum of log errors
float current_sum_log_err = 0;
for (int j = 0; j < num_slots; ++j) {
current_sum_log_err += log(std::get<1>(solution[j]));
}
// Calculate new max exp error
float new_max_exp_error = 0;
for (int j = 0; j < num_slots; ++j) {
new_max_exp_error = std::max(new_max_exp_error, std::get<1>(new_solution[j]));
}

// Accept change if it reduces sum of log errors without increasing max error
if (new_sum_log_err < current_sum_log_err && new_max_exp_error <= current_max_exp_error)
{
solution[slot1] = new_option1;
solution_idx[slot1] = best_option1;
solution[slot2] = new_option2;
solution_idx[slot2] = best_option2;
current_cost = new_cost;
current_max_exp_error = new_max_exp_error;
current_sum_log_err = new_sum_log_err;

// Update best solution found during opportunistic optimization
if (current_sum_log_err < best_sum_log_err_opportunistic) {
best_sum_log_err_opportunistic = current_sum_log_err;
best_cost_opportunistic = current_cost;
best_solution_opportunistic = solution;
best_solution_idx_opportunistic = solution_idx;
}
// Acceptance criterion (more emphasis on max error and min BPW)
auto [new_bpw_mean, new_bpw_stddev] = calculate_bpw_stats(new_solution);
if (new_cost <= max_cost && new_bpw_mean >= min_bpw_limit) {
float delta_sum_log_err = new_sum_log_err - current_sum_log_err;
float delta_max_exp_error = new_max_exp_error - current_max_exp_error;

// Prioritize reducing max error and increasing min BPW
if ((delta_max_exp_error < 0 || (delta_max_exp_error == 0 && delta_sum_log_err < 0)) ||
std::uniform_real_distribution<>(0, 1)(gen) < std::exp(-(delta_sum_log_err + delta_max_exp_error * 100.0f) / local_temp)) {
solution = new_solution;
solution_idx = new_solution_idx;
current_sum_log_err = new_sum_log_err;
current_cost = new_cost;
current_max_exp_error = new_max_exp_error;

if (current_sum_log_err < best_sum_log_err) {
best_sum_log_err = current_sum_log_err;
best_solution = solution;
best_solution_idx = solution_idx;
best_max_exp_error = current_max_exp_error;
}
}
}
}

// Use the best solution found during opportunistic optimization
if (best_sum_log_err_opportunistic < 1e18f) {
solution = best_solution_opportunistic;
solution_idx = best_solution_idx_opportunistic;
current_cost = best_cost_opportunistic;
local_temp *= 0.95f;
}

// --- Final Cost Check and Rollback (if necessary) ---
// --- Final Cost Correction (if needed) ---
if (current_cost > max_cost) {
// Revert to the solution before opportunistic optimization
solution = best_solution_opportunistic;
solution_idx = best_solution_idx_opportunistic;
current_cost = best_cost_opportunistic;
std::vector<std::pair<float, int>> error_indices(num_slots);
for (int i = 0; i < num_slots; ++i) {
error_indices[i] = {std::get<1>(solution[i]), i};
}
std::sort(error_indices.begin(), error_indices.end());

for (const auto& pair : error_indices) {
int i = pair.second;
for (int n = slots[i].size() - 1; n >= 0; --n) {
if (calculate_bpw(slots[i][n]) < calculate_bpw(solution[i]))
{
if (current_cost - std::get<0>(solution[i]) + std::get<0>(slots[i][n]) <= max_cost)
{
uint64_t delta_cost = std::get<0>(slots[i][n]) - std::get<0>(solution[i]);
current_cost += delta_cost;
solution[i] = slots[i][n];
solution_idx[i] = n;
break;
}
}
}
if (current_cost <= max_cost) break;
}
}

// Calculate final max error and sum of log errors
Expand Down