Skip to content

Multi-Adapter Cycle-Consistency Training (MACCT)

Multi-Adapter Cycle-Consistency Training (MACCT) is an implementation of the Cycle-Consistency Training (CCT) training method using PEFT LoRA adapters (hu et al. 2022) inplace of full model weights to learn the A->B and B->A translation mappings. MACCT shares the base model weights which are frozen; therefore, we greatly reduce the number of optimizer states, making a significant reduction in memory footprint. So much so, that we can load a frozen base model that is ~7.5x larger than either model in the full fine-tuning case.

[A HELPFUL DIAGRAM WILL GO HERE]

??? "How are these figures calculated?"

  Assuming a restricted case where we just look at static memory, i.e. model weights and optimizer states, the memory savings are significant. We will assume in the dual-model ($DMCCT$) case that both models are the same, likewise in multi-adapter CCT ($MACCT$) we assume both LoRA adapters are initialised the same. In all cases we use the AdamW optimizer as it is the most popular:

  With $\theta$ as the foundational model parameters, $\phi$ as the LoRA adapter parameters, $p$ as the number of bits per parameter ${\theta}_i$, $q$ as the number of bits per parameter ${\phi}_i$, and $r$ as the ratio of base model size $|\theta|$ to LoRA adapter size $|\phi|$, we can derive the memory savings as follows:

  $$
  \begin{aligned}
  DMCCT =& \left[ 2 \text{ models} * |\theta| \text{ params} * p \text{ bits} \right] + \left[ 2 \text{ models} * |\theta| \text{ params} * 2 \text{ states} * (4*8) \text{ bits} \right] \\
        =& 2|\theta|(p + 64)
  \end{aligned} \tag{1}
  $$

  $$
  \begin{aligned}
  MACCT =& [ |\theta| \text{ params} * p \text{ bits} ] + [ 2 \text{ loras} * (|\theta| \text{ params} * r) * { q \text{ bits} + 2 \text{ states} * (4*8) \text{ bits} } ] \\
        =& |\theta|(p + 2r(q + 64))
  \end{aligned} \tag{2}
  $$

  Assuming $p = 16$, $q = 32$ (LoRA are trained in 32 bit by default) and $r = 0.03$ (~3% of base model size), and a 1B parameter model for each translation in $DMCCT$:

  $$
  \begin{aligned}
  DMCCT =& 2 * 1e9 * (16 + 64) \\
        =& 1.6e11 \text{ bits} \\
  \\
  DMCCT =& MACCT \\
  1.6e11 \text{ bits} =& N * (16 + 2 * 0.03 * (32 + 64)) \\
  N =& \frac{1.6e11}{21.76} \\
        \approx& 7.35B \text{ params} \\
  \end{aligned}
  $$