|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include "ggml/ggml.h" |
|
|
|
#include <algorithm> |
|
#include <cmath> |
|
#include <cstdio> |
|
#include <cstring> |
|
#include <ctime> |
|
#include <fstream> |
|
#include <vector> |
|
|
|
#if defined(_MSC_VER) |
|
#pragma warning(disable: 4244 4267) |
|
#endif |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int mnist_eval( |
|
const char * fname_cgraph, |
|
const int n_threads, |
|
std::vector<float> digit) { |
|
|
|
struct ggml_context * ctx_data = NULL; |
|
struct ggml_context * ctx_eval = NULL; |
|
|
|
struct ggml_cgraph gfi = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval); |
|
|
|
|
|
GGML_ASSERT(ggml_graph_get_tensor(&gfi, "fc1_bias")->op_params[0] == int(0xdeadbeef)); |
|
|
|
|
|
|
|
static size_t buf_size = 128ull*1024*1024; |
|
static void * buf = malloc(buf_size); |
|
|
|
struct ggml_init_params params = { |
|
buf_size, |
|
buf, |
|
false, |
|
}; |
|
|
|
struct ggml_context * ctx_work = ggml_init(params); |
|
|
|
struct ggml_tensor * input = ggml_graph_get_tensor(&gfi, "input"); |
|
memcpy(input->data, digit.data(), ggml_nbytes(input)); |
|
|
|
ggml_graph_compute_with_ctx(ctx_work, &gfi, n_threads); |
|
|
|
const float * probs_data = ggml_get_data_f32(ggml_graph_get_tensor(&gfi, "probs")); |
|
|
|
const int prediction = std::max_element(probs_data, probs_data + 10) - probs_data; |
|
|
|
ggml_free(ctx_work); |
|
ggml_free(ctx_data); |
|
ggml_free(ctx_eval); |
|
|
|
return prediction; |
|
} |
|
|
|
int main(int argc, char ** argv) { |
|
srand(time(NULL)); |
|
ggml_time_init(); |
|
|
|
if (argc != 3) { |
|
fprintf(stderr, "Usage: %s models/mnist/mnist.ggml models/mnist/t10k-images.idx3-ubyte\n", argv[0]); |
|
exit(0); |
|
} |
|
|
|
uint8_t buf[784]; |
|
std::vector<float> digit; |
|
|
|
|
|
{ |
|
std::ifstream fin(argv[2], std::ios::binary); |
|
if (!fin) { |
|
fprintf(stderr, "%s: failed to open '%s'\n", __func__, argv[2]); |
|
return 1; |
|
} |
|
|
|
|
|
fin.seekg(16 + 784 * (rand() % 10000)); |
|
fin.read((char *) &buf, sizeof(buf)); |
|
} |
|
|
|
|
|
{ |
|
digit.resize(sizeof(buf)); |
|
|
|
for (int row = 0; row < 28; row++) { |
|
for (int col = 0; col < 28; col++) { |
|
fprintf(stderr, "%c ", (float)buf[row*28 + col] > 230 ? '*' : '_'); |
|
digit[row*28 + col] = ((float)buf[row*28 + col]); |
|
} |
|
|
|
fprintf(stderr, "\n"); |
|
} |
|
|
|
fprintf(stderr, "\n"); |
|
} |
|
|
|
const int prediction = mnist_eval(argv[1], 1, digit); |
|
|
|
fprintf(stdout, "%s: predicted digit is %d\n", __func__, prediction); |
|
|
|
return 0; |
|
} |
|
|