From 326b418b59b6d48d854c4461a2303e8ac0a311e6 Mon Sep 17 00:00:00 2001 From: Kawrakow <48489457+ikawrakow@users.noreply.github.com> Date: Fri, 12 Jan 2024 06:59:57 +0100 Subject: [PATCH] Importance Matrix calculation (#4861) * imatrix: 1st version * imatrix: WIP * Cleanup * Update examples/imatrix/imatrix.cpp Co-authored-by: Georgi Gerganov --------- Co-authored-by: Iwan Kawrakow Co-authored-by: Georgi Gerganov --- Makefile | 5 +- examples/CMakeLists.txt | 1 + examples/imatrix/CMakeLists.txt | 5 + examples/imatrix/imatrix.cpp | 380 ++++++++++++++++++++++++++++++++ ggml.c | 14 ++ ggml.h | 6 + 6 files changed, 410 insertions(+), 1 deletion(-) create mode 100644 examples/imatrix/CMakeLists.txt create mode 100644 examples/imatrix/imatrix.cpp diff --git a/Makefile b/Makefile index 4c7e175bf6cb3..05fe9a0f6a0d2 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ # Define the default target now so that it is always the first target BUILD_TARGETS = \ - main quantize quantize-stats perplexity embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \ + main quantize quantize-stats perplexity imatrix embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \ simple batched batched-bench save-load-state server gguf llama-bench libllava.a llava-cli baby-llama beam-search \ speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup passkey tests/test-c.o @@ -614,6 +614,9 @@ quantize-stats: examples/quantize-stats/quantize-stats.cpp build-info.o ggml. perplexity: examples/perplexity/perplexity.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) +imatrix: examples/imatrix/imatrix.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS) + $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) + embedding: examples/embedding/embedding.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 0c71cbdf72a65..fa127a3aa7c9e 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -36,6 +36,7 @@ else() add_subdirectory(lookahead) add_subdirectory(lookup) add_subdirectory(train-text-from-scratch) + add_subdirectory(imatrix) if (LLAMA_METAL) add_subdirectory(metal) endif() diff --git a/examples/imatrix/CMakeLists.txt b/examples/imatrix/CMakeLists.txt new file mode 100644 index 0000000000000..d688a16209049 --- /dev/null +++ b/examples/imatrix/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET imatrix) +add_executable(${TARGET} imatrix.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp new file mode 100644 index 0000000000000..1461bc96376a7 --- /dev/null +++ b/examples/imatrix/imatrix.cpp @@ -0,0 +1,380 @@ +#include "common.h" +#include "llama.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + +struct Stats { + std::vector values; + int ncall = 0; +}; + +struct StatParams { + std::string ofile = "imatrix.dat"; + int n_output_frequency = 10; + int verbosity = 1; + bool collect_output_weight = false; +}; + +class IMatrixCollector { +public: + IMatrixCollector() = default; + void set_parameters(StatParams&& params) { m_params = std::move(params); } + void collect_imatrix(const struct ggml_tensor * src0, const struct ggml_tensor * src1); + void save_imatrix() const; +private: + std::unordered_map m_stats; + StatParams m_params; + std::mutex m_mutex; + int m_last_call = 0; +}; + +void IMatrixCollector::collect_imatrix(const struct ggml_tensor * src0, const struct ggml_tensor * src1) { + if (src1->ne[1] < 16 || src1->type != GGML_TYPE_F32) return; + if (!(strncmp(src0->name, "blk.", 4) == 0 || (m_params.collect_output_weight && strcmp(src0->name, "output.weight") == 0))) return; + std::lock_guard lock(m_mutex); + auto& e = m_stats[src0->name]; + if (e.values.empty()) { + e.values.resize(src1->ne[0], 0); + } + else if (e.values.size() != (size_t)src1->ne[0]) { + fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", src0->name, (int)e.values.size(), (int)src1->ne[0]); + exit(1); //GGML_ASSERT(false); + } + ++e.ncall; + if (m_params.verbosity > 1) { + printf("%s[%d]: %s, %d x %d, %d\n",__func__,m_last_call,src0->name,(int)src1->ne[0],(int)src1->ne[1],(int)src1->type); + } + for (int row = 0; row < (int)src1->ne[1]; ++row) { + const float * x = (const float *)src1->data + row * src1->ne[0]; + for (int j = 0; j < (int)src1->ne[0]; ++j) { + e.values[j] += x[j]*x[j]; + } + } + if (e.ncall > m_last_call) { + m_last_call = e.ncall; + if (m_last_call % m_params.n_output_frequency == 0) { + save_imatrix(); + } + } +} + +void IMatrixCollector::save_imatrix() const { + const char * fname = m_params.ofile.empty() ? "imatrix.dat" : m_params.ofile.c_str(); + std::ofstream out(fname, std::ios::binary); + int n_entries = m_stats.size(); + out.write((const char*)&n_entries, sizeof(n_entries)); + for (auto& p : m_stats) { + int len = p.first.size(); + out.write((const char*)&len, sizeof(len)); + out.write(p.first.c_str(), len); + out.write((const char*)&p.second.ncall, sizeof(p.second.ncall)); + int nval = p.second.values.size(); + out.write((const char*)&nval, sizeof(nval)); + if (nval > 0) out.write((const char*)p.second.values.data(), nval*sizeof(float)); + } + if (m_params.verbosity > 0) { + fprintf(stderr, "\n%s: stored collected data after %d chunks in %s\n",__func__,m_last_call,fname); + } +} + +static IMatrixCollector g_collector; + +static void ik_collect_imatrix(const struct ggml_tensor * src0, const struct ggml_tensor * src1) { + g_collector.collect_imatrix(src0, src1); +} + + +struct results_log_softmax { + double log_softmax; + float logit; + float prob; +}; + +static std::vector softmax(const std::vector& logits) { + std::vector probs(logits.size()); + float max_logit = logits[0]; + for (float v : logits) { + max_logit = std::max(max_logit, v); + } + double sum_exp = 0.0; + for (size_t i = 0; i < logits.size(); i++) { + // Subtract the maximum logit value from the current logit value for numerical stability + const float logit = logits[i] - max_logit; + const float exp_logit = expf(logit); + sum_exp += exp_logit; + probs[i] = exp_logit; + } + for (size_t i = 0; i < probs.size(); i++) { + probs[i] /= sum_exp; + } + return probs; +} + +static results_log_softmax log_softmax(int n_vocab, const float * logits, int tok) { + float max_logit = logits[0]; + for (int i = 1; i < n_vocab; ++i) { + max_logit = std::max(max_logit, logits[i]); + } + double sum_exp = 0.0; + for (int i = 0; i < n_vocab; ++i) { + sum_exp += expf(logits[i] - max_logit); + } + return {logits[tok] - max_logit - log(sum_exp), logits[tok], expf(logits[tok] - max_logit) / (float) sum_exp}; +} + +static void process_logits( + int n_vocab, const float * logits, const int * tokens, int n_token, std::vector & workers, + double & nll, double & nll2, float * logit_history, float * prob_history +) { + std::mutex mutex; + int counter = 0; + auto compute = [&mutex, &counter, &nll, &nll2, logit_history, prob_history, n_vocab, logits, tokens, n_token] () { + double local_nll = 0; + double local_nll2 = 0; + while (true) { + std::unique_lock lock(mutex); + int i = counter++; + if (i >= n_token) { + nll += local_nll; nll2 += local_nll2; + break; + } + lock.unlock(); + const results_log_softmax results = log_softmax(n_vocab, logits + i*n_vocab, tokens[i+1]); + const double v = -results.log_softmax; + local_nll += v; + local_nll2 += v*v; + + logit_history[i] = results.logit; + prob_history[i] = results.prob; + } + }; + for (auto & w : workers) { + w = std::thread(compute); + } + compute(); + for (auto & w : workers) { + w.join(); + } +} + +static bool compute_imatrix(llama_context * ctx, const gpt_params & params) { + + const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); + const int n_ctx = llama_n_ctx(ctx); + + auto tim1 = std::chrono::high_resolution_clock::now(); + fprintf(stderr, "%s: tokenizing the input ..\n", __func__); + + std::vector tokens = ::llama_tokenize(ctx, params.prompt, add_bos); + + auto tim2 = std::chrono::high_resolution_clock::now(); + fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast(tim2-tim1).count()); + + if (int(tokens.size()) < 2*n_ctx) { + fprintf(stderr, "%s: you need at least %d tokens for a context of %d tokens\n",__func__,2*n_ctx, + n_ctx); + fprintf(stderr, "%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size()); + return false; + } + + std::vector logit_history; + logit_history.resize(tokens.size()); + + std::vector prob_history; + prob_history.resize(tokens.size()); + + const int n_chunk_max = tokens.size() / n_ctx; + + const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max); + const int n_vocab = llama_n_vocab(llama_get_model(ctx)); + const int n_batch = params.n_batch; + + int count = 0; + double nll = 0.0; + double nll2 = 0.0; + + fprintf(stderr, "%s: computing over %d chunks with batch_size %d\n", __func__, n_chunk, n_batch); + + std::vector workers(std::thread::hardware_concurrency() - 1); + + for (int i = 0; i < n_chunk; ++i) { + const int start = i * n_ctx; + const int end = start + n_ctx; + + const int num_batches = (n_ctx + n_batch - 1) / n_batch; + + std::vector logits; + + const auto t_start = std::chrono::high_resolution_clock::now(); + + // clear the KV cache + llama_kv_cache_clear(ctx); + + for (int j = 0; j < num_batches; ++j) { + const int batch_start = start + j * n_batch; + const int batch_size = std::min(end - batch_start, n_batch); + + // save original token and restore it after eval + const auto token_org = tokens[batch_start]; + + // add BOS token for the first batch of each chunk + if (add_bos && j == 0) { + tokens[batch_start] = llama_token_bos(llama_get_model(ctx)); + } + + if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) { + fprintf(stderr, "%s : failed to eval\n", __func__); + return false; + } + + // restore the original token in case it was set to BOS + tokens[batch_start] = token_org; + + const auto * batch_logits = llama_get_logits(ctx); + logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); + } + + const auto t_end = std::chrono::high_resolution_clock::now(); + + if (i == 0) { + const float t_total = std::chrono::duration(t_end - t_start).count(); + fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total); + int total_seconds = (int)(t_total * n_chunk); + if (total_seconds >= 60*60) { + fprintf(stderr, "%d hours ", total_seconds / (60*60)); + total_seconds = total_seconds % (60*60); + } + fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0); + } + + const int first = n_ctx/2; + process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first, + workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first); + count += n_ctx - first - 1; + + printf("[%d]%.4lf,", i + 1, std::exp(nll / count)); + fflush(stdout); + } + printf("\n"); + + nll2 /= count; + nll /= count; + const double ppl = exp(nll); + nll2 -= nll * nll; + if (nll2 > 0) { + nll2 = sqrt(nll2/(count-1)); + printf("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl); + } else { + printf("Unexpected negative standard deviation of log(prob)\n"); + } + + return true; +} + +int main(int argc, char ** argv) { + + StatParams sparams; + std::vector args; + args.push_back(argv[0]); + int iarg = 1; + for (; iarg < argc-1; ++iarg) { + std::string arg{argv[iarg]}; + if (arg == "-o" || arg == "--output-file") { + sparams.ofile = argv[++iarg]; + } + else if (arg == "-ofreq" || arg == "--output-frequency") { + sparams.n_output_frequency = std::stoi(argv[++iarg]); + } + else if (arg == "-ow" || arg == "--output-weight") { + sparams.collect_output_weight = std::stoi(argv[++iarg]); + } + else if (arg == "--verbosity") { + sparams.verbosity = std::stoi(argv[++iarg]); + } else { + args.push_back(argv[iarg]); + } + } + if (iarg < argc) { + args.push_back(argv[iarg]); + } + + gpt_params params; + params.n_batch = 512; + if (!gpt_params_parse(args.size(), args.data(), params)) { + return 1; + } + + g_collector.set_parameters(std::move(sparams)); + + ggml_set_imatrix_collection(ik_collect_imatrix); + + params.logits_all = true; + params.n_batch = std::min(params.n_batch, params.n_ctx); + + print_build_info(); + + if (params.seed == LLAMA_DEFAULT_SEED) { + params.seed = time(NULL); + } + + fprintf(stderr, "%s: seed = %u\n", __func__, params.seed); + + std::mt19937 rng(params.seed); + if (params.random_prompt) { + params.prompt = gpt_random_prompt(rng); + } + + llama_backend_init(params.numa); + + llama_model * model; + llama_context * ctx; + + // load the model and apply lora adapter, if any + std::tie(model, ctx) = llama_init_from_gpt_params(params); + if (model == NULL) { + fprintf(stderr, "%s: error: unable to load model\n", __func__); + return 1; + } + + const int n_ctx_train = llama_n_ctx_train(model); + if (params.n_ctx > n_ctx_train) { + fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n", + __func__, n_ctx_train, params.n_ctx); + } + + // print system information + { + fprintf(stderr, "\n"); + fprintf(stderr, "%s\n", get_system_info(params).c_str()); + } + + bool OK = compute_imatrix(ctx, params); + if (!OK) { + return 1; + } + + g_collector.save_imatrix(); + + llama_print_timings(ctx); + + llama_free(ctx); + llama_free_model(model); + + llama_backend_free(); + + return 0; +} diff --git a/ggml.c b/ggml.c index d2a8c0478ab72..f5caeba082ea9 100644 --- a/ggml.c +++ b/ggml.c @@ -394,6 +394,12 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y); static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y); +ggml_collect_imatrix_t g_imatrix_collect = NULL; + +void ggml_set_imatrix_collection(ggml_collect_imatrix_t imatrix_collect) { + g_imatrix_collect = imatrix_collect; +} + static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { [GGML_TYPE_I8] = { .type_name = "i8", @@ -9763,6 +9769,10 @@ static void ggml_compute_forward_mul_mat( const int ith = params->ith; const int nth = params->nth; + if (ith == 1 && g_imatrix_collect) { + g_imatrix_collect(src0, src1); + } + const enum ggml_type type = src0->type; const bool src1_cont = ggml_is_contiguous(src1); @@ -10066,6 +10076,10 @@ static void ggml_compute_forward_mul_mat_id( const struct ggml_tensor * src0_cur = dst->src[cur_a + 2]; + if (ith == 1 && g_imatrix_collect) { + g_imatrix_collect(src0_cur, src1); + } + const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; const size_t row_size = ggml_row_size(vec_dot_type, ne10); diff --git a/ggml.h b/ggml.h index 93b42a27da53d..4c2ff6c661ea3 100644 --- a/ggml.h +++ b/ggml.h @@ -2067,6 +2067,12 @@ extern "C" { GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist); + // + // Importance matrix + // + typedef void(*ggml_collect_imatrix_t)(const struct ggml_tensor * src0, const struct ggml_tensor * src1); + GGML_API void ggml_set_imatrix_collection(ggml_collect_imatrix_t imatrix_collect); + // // gguf //