Skip to content

class langtorch.tt.TextModule

TextModule is a fundamental building block in LangTorch that outputs and takes as input tensors. It inherits from torch.nn.Module and can facilitate operations both on torch.Tensors and TextTensors, chaining text transformations, operations on embeddings and language model inferences.

TextModule: input_class -> output_class is the base class, subclasses of which can operate on:

  • TextModule: TextTensor -> TextTensor, e.g. OpenAI API activation
  • TextModule: TextTensor -> torch.Tensor, e.g. EmbeddingModule
  • TextModule: torch.Tensor -> TextTensor, e.g. a local LLM inference from tokens or a custom retriever

Initialization

TextModule itself can be used to initialized a TextTensor to TextTensor module that formats prompts with the inputs and sends them to an LLM activation. It can be initialized with parameters:

Parameter Type Description
prompt TextTensor A TextTensor of prompt templates. It will be formatted with the input text data via multiplication
activation Activation or str (optional) An Activation module or a string with a OpenAI model name. Learn more here.
key str or TextTensor (optional) The key or keys (of the same shape as the output) that will be automatically assigned to the output TextTensor entries

Tip

Setting a key is useful when chaining TextModules, as prompts in the second module can use "{key}" to route outputs of the first module to a specific place.

Forward Pass

The forward method of TextModule processes the input TextTensor by formatting it with the module's prompt template using TextTensor multiplication. If an activation function is provided, it is then applied to the formatted texts to obtain the output. Finally, if a key is specified, it is assigned to the output TextTensor.

def forward(self, input: TextTensor) -> TextTensor:
    formatted_input = self.prompt * input
    output = self.activation(formatted_input) if self.activation else formatted_input
    if self.key is not None:
        output.set_key_(self.key)
    return output

Subclassing TextModule

TextModule is designed to be subclassed to create custom text processing modules. By subclassing TextModule, you can extend its functionality and nest multiple modules in one bigger architecture. Submodules can be assigned as regular attributes, forming a tree-like structure.

Here's an example TextModule subclassing pattern with sequential calls that you can modify freely:

import torch
from langtorch import TextModule, OpenAI

class TextTransformationModel(TextModule):
    def __init__(self):
        super().__init__()
        self.translate = TextModule("Translate to English: {}", activation="gpt-3.5-turbo", key="text")
        self.summarize = TextModule("Summarize this text: '{text}'", activation="gpt-3.5-turbo")

    def forward(self, input_text):
        translated_text = self.translate(input_text)
        return self.summarize(translated_text)

A useful pattern is also to explicitly chain both formatting and LLM activations with nn.Sequential:

import torch

translate = TextModule("Translate to English: {}")
task = TextModule(some_prompt_template)  

# OpenAI activation is also a Module so we can chain it explicitly here:
task_module_w_CoT = torch.nn.Sequential(  
    task,  
    chain_of_thought,
    OpenAI("gpt-4")
)  

In this example, the TextTransformationModel performs text translation followed by summarization. Each processing step is modularized by a TextModule, specifying the task through its prompt and the language model for activation.

class RAG(TextModule):  
    def __init__(self, documents: TextTensor, *args, **kwargs):  
        super().__init__(*args, **kwargs)  
        self.retriever = Retriever(documents)  

    def forward(self, user_message: TextTensor, k: int = 5):  
        retrieved_context = self.retriever(user_message, k) +"\n"  
        user_message = user_message + "\nCONTEXT:\n" + retrieved_context.sum()  
        return super().forward(user_message)

Parallel Processing

TextModule inherently supports parallel processing of multiple tasks on multiple inputs. Both the prompt and input are TextTensor objects, which can hold multiple entries. The activation function (e.g., language model API calls) is automatically applied in parallel to all entries.

Methods

parameters()

Returns an iterator over the module's parameters. This method is inherited from torch.nn.Module.

named_parameters()

Returns an iterator over the module's named parameters. This method is inherited from torch.nn.Module.

state_dict()

Returns a dictionary containing the module's state. This method is inherited from torch.nn.Module.

load_state_dict(state_dict)

Loads the module's state from a dictionary. This method is inherited from torch.nn.Module.

to(device)

Moves the module's parameters and buffers to the specified device. This method is inherited from torch.nn.Module.

train(mode=True)

Sets the module in training mode. This method is inherited from torch.nn.Module.

eval()

Sets the module in evaluation mode. This method is inherited from torch.nn.Module.

zero_grad()

Sets the gradients of all parameters to zero. This method is inherited from torch.nn.Module.

__call__(*input)

Defines the forward pass of the module. This method is inherited from torch.nn.Module and can be overridden in subclasses to customize the forward pass behavior.

Notes

  • Ensure that the TextTensor inputs match the expected format of the prompt templates accurately. If there is a mismatch between the placeholder name used in the module prompt and the corresponding key in the input tensor, the completion text will be appended at the end of the text instead of being placed at the placeholder.
  • TextModule can be extended through subclassing to implement custom forward methods, offering tailored text processing capabilities. Some currently available or future subclasses are mentioned in langtorch.tt.
  • Parallel (batched) processing of inputs through the LLM activation is handled automatically for all entries regardless of shape.