{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using device: mps\n", "tensor([ 0, 61, 1, 61, 2, 61, 0, 61, 3], device='mps:0')\n" ] }, { "data": { "text/plain": [ "torch.Size([4, 32])" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "\n", "from usta_model import UstaModel\n", "from usta_tokenizer import UstaTokenizer\n", "\n", "device = \"cpu\"\n", "\n", "if torch.cuda.is_available():\n", " device = \"cuda\"\n", "elif torch.backends.mps.is_available():\n", " device = \"mps\"\n", " \n", "\n", "print(f\"Using device: {device}\")\n", "\n", "u_tokenizer = UstaTokenizer(\"tokenizer.json\")\n", "\n", "prompts = [\n", " \"the capital of the united\",\n", " \"madrid is in\",\n", " \"the capital of france is\",\n", " \"the capital of germany is\"\n", "]\n", "\n", "tokens = u_tokenizer.encode(prompts[0])\n", "tokens = tokens.to(device)\n", "print(tokens)\n", "batch_tokens = u_tokenizer.encode_batch(prompts, 32)\n", "batch_tokens = batch_tokens.to(device)\n", "batch_tokens.shape" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.manual_seed(1)\n", "context_length = 32\n", "\n", "u_model = UstaModel(\n", " vocab_size=len(u_tokenizer.vocab),\n", " embedding_dim=12,\n", " num_heads=4,\n", " context_length=context_length,\n", " num_layers=8,\n", " device=device\n", ")\n", "\n", "# load model\n", "u_model.load_state_dict(torch.load(\"../u_model_4000.pth\"))" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([4, 32, 64])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "out = u_model(batch_tokens)\n", "out.shape" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# temperature\n", "# top_k \n", "# top_p\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "top_k = 10" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([17.6884, 14.0799, 9.0104, 8.4548, 7.3207, 7.2960, 6.8096, 6.6073,\n", " 6.6009, 6.3761]),\n", " [61, 60, 35, 58, 9, 38, 59, 4, 18, 49])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sorted_outs = sorted(out[-1][-1].tolist(), reverse=True)\n", "sorted_indexes = []\n", "for so in sorted_outs[:top_k]:\n", " so_index = out[-1][-1].tolist().index(so)\n", " sorted_indexes.append(so_index)\n", "sorted_outs = torch.tensor(sorted_outs[:top_k])\n", "sorted_outs, sorted_indexes\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([17.6884, 14.0799, 9.0104, 8.4548, 7.3207, 7.2960, 6.8096, 6.6073,\n", " 6.6009, 6.3761], device='mps:0', grad_fn=),\n", " tensor([61, 60, 35, 58, 9, 38, 59, 4, 18, 49], device='mps:0'))" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "values, indexes = torch.topk(out[-1][-1], k=10)\n", "values, indexes" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/var/folders/z7/wrd0w0hn7pvb9g97kmdn17640000gn/T/ipykernel_91075/2885985782.py:2: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " adjusted_outs = torch.tensor(sorted_outs) / temperature\n" ] }, { "data": { "text/plain": [ "tensor([1.6830, 1.3397, 0.8573, 0.8045, 0.6965, 0.6942, 0.6479, 0.6287, 0.6281,\n", " 0.6067])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "temperature = 10.51\n", "adjusted_outs = torch.tensor(sorted_outs) / temperature\n", "adjusted_outs" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([0.2128, 0.1509, 0.0932, 0.0884, 0.0793, 0.0791, 0.0756, 0.0741, 0.0741,\n", " 0.0725])" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "probs = torch.softmax(adjusted_outs, dim=-1)\n", "probs" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "top_p = 0.7" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0.5453)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "[0.2128, 0.36, 0.37, 0.38, 0.70, 0.71]\n", "torch.sum(torch.tensor([0.2128, 0.1509, 0.0932, 0.0884]))" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{0: 212, 4: 82, 5: 87, 9: 83, 2: 74, 6: 73, 1: 154, 3: 91, 8: 80, 7: 64}" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sample_count = {}\n", "for _ in range(1000):\n", " sample = torch.multinomial(probs, 1)\n", " sample_count[sample.item()] = sample_count.get(sample.item(), 0) + 1\n", "sample_count" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'the capital of the united.': 3,\n", " 'the capital of the united the ': 22,\n", " 'the capital of the united identity,': 1,\n", " 'the capital of the united capitals': 5,\n", " 'the capital of the united country ': 8,\n", " 'the capital of the united europe ': 26,\n", " 'the capital of the united is ': 7,\n", " 'the capital of the united place ': 4,\n", " 'the capital of the united europe,': 3,\n", " 'the capital of the united united ': 6,\n", " 'the capital of the united for ': 1,\n", " 'the capital of the united spain,': 2,\n", " 'the capital of the united europe.': 1,\n", " 'the capital of the united italy,': 4,\n", " 'the capital of the united art ': 1,\n", " 'the capital of the united of ': 1,\n", " 'the capital of the united united': 1,\n", " 'the capital of the united capitaled': 1,\n", " 'the capital of the united, country': 1,\n", " 'the capital of the united place.': 1,\n", " 'the capital of the united, europe': 1}" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "outs = {}\n", "for _ in range(100):\n", " out = u_model.generate(tokens, max_new_tokens = 3, temperature = 1.7, top_k = 10, top_p = 0.7)\n", " decoded = u_tokenizer.decode(out)\n", " outs[decoded] = outs.get(decoded, 0) + 1\n", "outs" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.3" } }, "nbformat": 4, "nbformat_minor": 2 }