Lab 7: Transformers


NOTE: This is a lab project accompanying the following book [MLF] and it should be used together with the book.

[MLF] H. Jiang, "Machine Learning Fundamentals: A Concise Introduction", Cambridge University Press, 2021. (bibtex)


The purpose of this lab is to study the popular transformer achitecture for a simple language modeling task. First, we show how to implement the multi-head attention module in transformers using pytorch. Next, we extend it to a multi-layer transformer structure for general sequence modeling purpose. In particular, we use the multi-layer transformer model to perform char-level language modeling on a small text corpus, i.e. the tiny-shakespeare corpus. At last, after the transformer is trained, we how that it can be used to generate new text similar to the training samples.

Prerequisites: basic understanding on pytorch.

I. Multi-head Self-Attention Module

Example 7.1:

Use pytorch to implement a multi-head causal self-attention model similar to Figure 8.27 on page 173. The causal attention means that each output vector depends on only all preceding input vectors but not any vectors appearing after it. Compare the running speeds of its forward and backward passes when running on CPUs and GPUs.

All hyper-parameters: (usually we have $d= n \times h$)

  • $d$: embedding dimension
  • $B$: batch size
  • $T$: block size
  • $n$: head number
  • $h$: head size
  1. Given a mini-batch of $B$ input sequences, packed as $\mathbf{X} \in \mathbb{R}^{B \times d \times T}$, and a multi-head transformer consisting of all parameter matrices as $\mathbf{A}^{(j)}, \mathbf{B}^{(j)}, \mathbf{C}^{(j)} \in \mathbb{R}^{h \times d}$ ($j=1,2,\cdots,n$). We first arrange these matrices into three larger parameter matrices, $\mathbf{A}, \mathbf{B}, \mathbf{C} \in \mathbb{R}^{nh \times d}$, and then further pack these three matrices into a single large matrix $\mathbf{M} \in \mathbb{R}^{3nh \times d}$. We can generate query, key and value matrices as follows: $$ \begin{bmatrix} \mathbf{Q} \\ \mathbf{K} \\ \mathbf{V} \end{bmatrix} = \begin{bmatrix}\mathbf{A}\\ \mathbf{B} \\\mathbf{C} \end{bmatrix} \mathbf{X} = \mathbf{M} \, \mathbf{X} $$ where $\mathbf{Q}, \mathbf{K}, \mathbf{V} \in \mathbb{R}^{B \times nh \times T}$.

  2. Re-shape the query, key and value matrices as follows: $$ \mathbf{Q}, \mathbf{K}, \mathbf{V}: \mathbb{R}^{B \times nh \times T} → \mathbb{R}^{B \times n \times h \times T} $$

  3. Apply the point-wise attention as: $$ \mathbf{Q}^\intercal \mathbf{K} \in \mathbb{R}^{B \times n \times T \times T} $$ If necessary, apply an upper triangular masking for causal attetion.

  4. Apply softmax column by column $$ \mathcal{A} = \mathrm{softmax} \Big(\mathbf{Q}^\intercal \mathbf{K} \Big) \;\;\; (\in \mathbb{R}^{B \times n \times T \times T}) $$

  1. Generate the output $$ \mathbf{Z} = \mathbf{V} \, \mathcal{A} \;\;\; (\in \mathbb{R}^{B \times n \times h \times T}) $$
In [2]:
import math
import torch
import torch.nn as nn
from torch.nn import functional as F

class config():
  def __init__(self, batch_size = 10, n_embd = 768, block_size = 128, n_head = 8, causal = True, device='cuda'):
    assert n_embd %  n_head == 0
    self.batch_size = batch_size
    self.n_embd = n_embd
    self.block_size = block_size
    self.n_head = n_head
    self.causal = causal
    self.device = device

class SelfAttentionModule(nn.Module):

    def __init__(self, cfg):
        super().__init__()
        self.batch_size = cfg.batch_size
        self.n_embd = cfg.n_embd
        self.block_size = cfg.block_size
        self.n_head = cfg.n_head
        self.device = cfg.device
        self.causual = cfg.causal
        mask = torch.tril(torch.ones(self.block_size, self.block_size)).transpose(0,1).view(1, 1, self.block_size, self.block_size)
        mask = mask.to(self.device)
        self.register_buffer("mask", mask)

    def forward(self, X, M):
        B, d, T = X.size() # batch size,  embedding dimensionality (n_embd), block size

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        Q, K, V  = (M @ X).split(d, dim=1)
        K = K.view(B, self.n_head, d // self.n_head, T)   # (B, nh, hs, T)
        Q = Q.view(B, self.n_head, d // self.n_head, T)   # (B, nh, hs, T)
        V = V.view(B, self.n_head, d // self.n_head, T)   # (B, nh, hs, T)

        att = (Q.transpose(-2, -1) @ K) * (1.0 / math.sqrt(Q.size(-2)))
        att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-2)   # column-wise softmax
        Z = V @ att   # (B, nh, hs, T) x (B, nh, T, T) -> (B, nh, hs, T)

        Z = Z.contiguous().view(B, d, T)  # re-assemble all head outputs side by side

        return Z
In [3]:
cfg=config()

def loss(model, x, M):
  y = model(x,M)
  return torch.sum(y*y)

X = torch.rand((cfg.batch_size, cfg.n_embd, cfg.block_size), requires_grad=True)
M = (torch.rand((3*cfg.n_embd, cfg.n_embd),requires_grad=True) - 0.5) * (1.0/math.sqrt(cfg.n_embd))

#######################
cfg.device ='cpu'

print(f'batch_size={cfg.batch_size} n_embd={cfg.n_embd} block_size={cfg.block_size} n_head={cfg.n_head} device={cfg.device}')

X = X.to(cfg.device)
M = M.to(cfg.device)

mm1=SelfAttentionModule(cfg)

ls = loss(mm1, X, M)
print(f'pytorch forward pass loss: {ls}')

print(f'forward() pass and pytorch auto-grad backward() running on {cfg.device}:')
%timeit loss(mm1, X, M)
%timeit ls.backward(retain_graph=True)

#######################
cfg.device ='cuda'

print(f'batch_size={cfg.batch_size} n_embd={cfg.n_embd} block_size={cfg.block_size} n_head={cfg.n_head} device={cfg.device}')

X = X.to(cfg.device)
M = M.to(cfg.device)

mm2=SelfAttentionModule(cfg)

ls = loss(mm2, X, M)
print(f'pytorch forward pass loss: {ls}')

print(f'forward() pass and pytorch auto-grad backward() running on {cfg.device}:')
%timeit loss(mm2, X, M)
%timeit ls.backward(retain_graph=True)
batch_size=10 n_embd=768 block_size=128 n_head=8 device=cpu
pytorch forward pass loss: 20090.15234375
forward() pass and pytorch auto-grad backward() running on cpu:
67.3 ms ± 11.3 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
108 ms ± 2.3 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
batch_size=10 n_embd=768 block_size=128 n_head=8 device=cuda
pytorch forward pass loss: 20090.15234375
forward() pass and pytorch auto-grad backward() running on cuda:
2.01 ms ± 11.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
10.2 ms ± 1.81 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

II. Transformers for Language Modelling

Example 7.2:

Use pytorch to implement a multi-layer causal transformer model on page 174, similar to the popular GPT architecture, for char-level language modelling on the tiny-shakespeare corpus. In other words, each char in text is treated as a distinct token in the language model. The transformer is trained to predict next character based on all preceding characters in the transformer input block. After it is trained, use it to generate new text sequences.

In [1]:
# download tiny shakespeare text corpus from Google drive

!gdown --folder https://drive.google.com/drive/folders/1VtY71Iym2uOC4bUUukJ9I-b1JgR5e18o 2> /dev/null
Processing file 1Sl0000nPuCW6RcF3b3SL2n9EPJpI582q input.txt
Building directory structure completed
In [2]:
# load text file as a character string; build char-level vocabulary, encoder and decoder
#  (borrowed from nanoGPT, https://github.com/karpathy/nanoGPT)

train_txt_file = 'tinyshakespeare/input.txt'

with open(train_txt_file, 'r') as f:
    data = f.read()
print(f"length of dataset in characters: {len(data):,}")

# get all the unique characters that occur in this text
chars = sorted(list(set(data)))
vocab_size = len(chars)
print("all the unique characters:", ''.join(chars))
print(f"vocab size: {vocab_size:,}")

# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
def encode(s):
    return [stoi[c] for c in s] # encoder: take a string, output a list of integers
def decode(l):
    return ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# create the train and test splits
n = len(data)
train_data = data[:int(n*0.9)]
val_data = data[int(n*0.9):]

# encode both to integers
train_ids = encode(train_data)
val_ids = encode(val_data)
print(f"train has {len(train_ids):,} tokens")
print(f"val has {len(val_ids):,} tokens")
length of dataset in characters: 1,115,394
all the unique characters: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
vocab size: 65
train has 1,003,854 tokens
val has 111,540 tokens
In [3]:
# Define all parts in a GPT model, such as attention, MLP, LN modules
#  (adapted from nanoGPT, https://github.com/karpathy/nanoGPT)

import math
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F

class config():
  def __init__(self, batch_size=10, n_layer=12, n_head=12, n_embd=768, block_size=1024, vocab_size=50304, causal=True, device='cpu'):
    assert n_embd %  n_head == 0
    self.batch_size = batch_size
    self.n_embd = n_embd
    self.block_size = block_size
    self.n_head = n_head
    self.causal = causal
    self.device = device
    self.n_layer = n_layer
    self.vocab_size = vocab_size

class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)

        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size)).transpose(0,1)
                            .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k ,v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, C // self.n_head, self.n_head).transpose(1, 3) # (B, nh, T, hs)
        q = q.view(B, T, C // self.n_head, self.n_head).transpose(1, 3) # (B, nh, T, hs)
        v = v.view(B, T, C // self.n_head, self.n_head).transpose(1, 3) # (B, nh, T, hs)

        att = (q.transpose(-2, -1) @ k) * (1.0 / math.sqrt(q.size(-2)))
        att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-2)
        y = v @ att  # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 3).contiguous().view(B, T, C) # re-assemble all head outputs side by side
        # output projection
        y = self.c_proj(y)
        return y

class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.c_fc(x)
        x = self.relu(x)
        x = self.c_proj(x)
        return x

class Block(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm([config.n_embd])
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm([config.n_embd])

        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.vocab_size is not None
        assert config.block_size is not None
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm([config.n_embd]),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        # with weight tying when using torch.compile() some warnings get generated:
        self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying

        # init all weights
        self.apply(self._init_weights)
        # apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))

        # report number of parameters
        print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))

    def get_num_params(self, non_embedding=True):
        # Return the number of parameters in the model.
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.transformer.wpe.weight.numel()
        return n_params

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)

        # forward the GPT model itself
        tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
        pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
        x = tok_emb + pos_emb
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)

        if targets is not None:
            # if we are given some desired targets also calculate the loss
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        else:
            # inference-time mini-optimization: only forward the lm_head on the very last position
            logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
            loss = None

        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):

        for _ in range(max_new_tokens):
            # if the sequence context is growing too long we must crop it at block_size
            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
            # forward the model to get the logits for the index in the sequence
            logits, _ = self(idx_cond)
            # pluck the logits at the final step and scale by desired temperature
            logits = logits[:, -1, :] / temperature
            # optionally crop the logits to only the top k options
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            # apply softmax to convert logits to (normalized) probabilities
            probs = F.softmax(logits, dim=-1)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)

            # append sampled index to the running sequence and continue
            idx = torch.cat((idx, idx_next), dim=1)

        return idx
In [4]:
# training script to learn GPT models based on all given hyper-parameters
#  (adapted from nanoGPT, https://github.com/karpathy/nanoGPT)

import os
import time
import math

import numpy as np
import torch

# -----------------------------------------------------------------------------
# default config values
# I/O

######### model size ##########
n_layer = 3      # num of layers
n_head = 8       # num of attn heads per layer
head_size = 96   # size of each attn head
###############################

###### training hyperparameters ######
learning_rate = 6e-4   # max learning rate
max_iters = 10000      # total number of training iterations
batch_size = 12        # mini-batch size
block_size = 128       # block size
#####################################

# adamw optimizer
beta1 = 0.9
beta2 = 0.95

# learning rate decay settings
decay_lr = True # whether to decay the learning rate
warmup_iters = 2000 # how many steps to warm up for
lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla
min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
eval_interval = 1000
eval_iters = 200

# system
device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
dtype = 'bfloat16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler

torch.manual_seed(1337)

n_embd = n_head * head_size

data_tr  = np.array(train_ids, dtype=np.uint16)
data_val = np.array(val_ids, dtype=np.uint16)

# poor man's data loader
def get_batch(split):
    data = data_tr if split == 'train' else data_val

    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

# init a new model from scratch
print("Initializing a new model from scratch")
conf = config(batch_size=batch_size, n_layer=n_layer, n_head=n_head, n_embd=n_embd,\
              block_size=block_size, vocab_size=vocab_size, device=device)
model = GPT(conf)

model = model.to(device)

# initialize a GradScaler. If enabled=False scaler is a no-op
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))

# optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=(beta1, beta2))

# helps estimate an arbitrarily accurate loss over either split using many batches
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

# learning rate decay scheduler (cosine with warmup)
def get_lr(it):
    # 1) linear warmup for warmup_iters steps
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    # 2) if it > lr_decay_iters, return min learning rate
    if it > lr_decay_iters:
        return min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
    return min_lr + coeff * (learning_rate - min_lr)


# training loop
X, Y = get_batch('train') # fetch the very first batch
t0 = time.time()
iter_num = 0

while True:
    # determine and set the learning rate for this iteration
    lr = get_lr(iter_num) if decay_lr else learning_rate

    logits, loss = model(X, Y)
    X, Y = get_batch('train')

    # backward pass, with gradient scaling if training in fp16
    scaler.scale(loss).backward()
    # clip the gradient
    if grad_clip != 0.0:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    # step the optimizer and scaler if training in fp16
    scaler.step(optimizer)
    scaler.update()
    # flush the gradients as soon as we can, no need for this memory anymore
    optimizer.zero_grad(set_to_none=True)

    # evaluate the loss on train/val sets and write checkpoints
    if iter_num % eval_interval == 0:
      # timing and logging
      t1 = time.time()
      dt = t1 - t0
      t0 = t1

      losses = estimate_loss()
      print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f} (time lapses {dt:.4f} seconds)")

    iter_num += 1

    # termination conditions
    if iter_num > max_iters:
        break
Initializing a new model from scratch
number of parameters: 21.29M
step 0: train loss 6.0489, val loss 6.1314 (time lapses 1.3229 seconds)
step 1000: train loss 1.9487, val loss 2.0776 (time lapses 71.9827 seconds)
step 2000: train loss 1.7173, val loss 1.8682 (time lapses 72.5995 seconds)
step 3000: train loss 1.6054, val loss 1.7837 (time lapses 72.4764 seconds)
step 4000: train loss 1.5453, val loss 1.7346 (time lapses 72.9500 seconds)
step 5000: train loss 1.5098, val loss 1.7246 (time lapses 72.6370 seconds)
step 6000: train loss 1.4861, val loss 1.6935 (time lapses 72.3111 seconds)
step 7000: train loss 1.4586, val loss 1.6863 (time lapses 72.0951 seconds)
step 8000: train loss 1.4294, val loss 1.6498 (time lapses 71.9847 seconds)
step 9000: train loss 1.4137, val loss 1.6569 (time lapses 71.8970 seconds)
step 10000: train loss 1.4048, val loss 1.6409 (time lapses 71.9453 seconds)
In [6]:
# sample from a trained model to generate new text
# (borrowed from nanoGPT, https://github.com/karpathy/nanoGPT)

import torch

start = "\n"     # specify a text string as prompt
#start = "FILE:tinyQA/prompts.txt"  #can also specify a file of multiple prompts, use as: "FILE:prompt.txt"

num_samples = 5 # number of samples to draw
max_new_tokens = 500 # number of tokens generated in each sample
temperature = 0.9 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = None #200 # retain only the top_k most likely tokens, clamp others to have 0 probability

seed = 1337
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

model.eval()
model.to(device)

# sampling the model based on one prompt
def SampleOnePrompt(prompt):
    start_ids = encode(prompt)
    x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])

    for k in range(num_samples):
        y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
        print(decode(y[0].tolist()))
        print('---------------')

# encode the beginning of the prompt
if start.startswith('FILE:'):     # for multiple prompts
    with open(start[5:], 'r', encoding='utf-8') as f:
        Lines = f.readlines()
        for prompt in Lines:
            SampleOnePrompt(prompt[:-1])  # chop the trailing newline char
else:
     SampleOnePrompt(start)     # for one prompt

ANGELO:
And comfort is not the aires boast were?
Broach me?
I am such by me uuth, her bardshirts,
Butwick my faction of one heavy sprinfully come my fade
Whomes it ender content Hermions to and quiet.

DUKE VINCERDONE:
For come humb onceive: therefore hew young Bolingbroke?

AUTIO:
'Forld, demate kneel, not what evily, the most
His haply deeds poor gold; the execute,
Turk for treamiss mufficing allowers,
Maderive my of tranking, as to discraffair the Earldom.
Go could all not rel my hand for hi
---------------

Men pay to Rome.

YORK:
Why dead so you myself ancy;
But which they should I have not us. I have my sighs:
I should so, my many more alive,
Or hued the heavy life mouth comfort coureen;
How what most it my Rome, sir, come,
But sunnish'd, let he need out
to accusature but from the way soul of burldness
Whose lack'd this herce advise,
To I stand plaven in the eate,
Though off, if out orning tears more portance whoset,
To the silver were belived so him the word
By the times often their; you contone
---------------

Most and breath. Their fellows are. She, the is, do did
Till of anon thy hoose on thee, master to sea;
Be wereing think of his victor wouldst,
And friends abstime Inname a canity to make the and
inkery dolement tell the leing buwield to-night
each I did the deter your Honest too land;
As there cold consume an the vow
Shrough he it heart some not is conditing torment:
Slad my lord, as belive
To cave but forth this name and who hear the tewn,
To may mine enemy treless fins yet of Rome,
Than art el
---------------

The smools kills; and place, say, as I
Is but of a-father my sun.

POMPEY:
No, sir,
Sir, sir, I well accured my heart he know.

SAMPSON:
If your son mistrumer trial ruin good afe, cousinhim,
And as hour surethere have black'd end,
My fafter good madam, Edward's cold, office likeous,
And away which wither follow'd our true-wrong,
From o'll like punish our own arectator's conver.

GLOUCESTER:
My lord, which ans my and toward's lanch is he parlet.

DUKE VINCERDIO:
What mayou must forth not know'd;

---------------

That latten who cond the tears o' their lives.
Grace not me, so the dead me to thee.

EDBURY:
I love him took a writt of love of this enough,
I have so, bold, my with milk your graves in than ymild,
and is confound the hollow-there countrymand
And bold obsequest allent upon'd:
As I watereford my name,
Our meanishallow made to our his dead.

Clown:
Let's hone mad my face, good of the gentler of pared to be
her one should moving bodies. I would I tree,
And rehem to sleep not teems hronefull man of
---------------

Exercises

Problem 7.1:

Use pytorch to implement the simple recurrent neural network (RNN) structure on Figure 8.25 on page 170, and then extend it to multi-layer RNN models to conduct char-level language modelling on the tiny-shakespeare corpus. At last, sample the trained RNN model to generate new text. Compare the RNN model with the transformer model in Example 7.2 in terms of training speed and the quality of generated text.

Problem 7.2:

Based on the result of Q8.9 (part b) on page 202, use pytorch to explicitly implement the backward pass for the attention model in Example 7.1. Make sure the computed gradients are equal to those obtained from the autograd method in Example 7.1. Also compare the self-implemented backward pass with the pytorch's autograd method in terms of execution time when running on CPUs or GPUs.