class
langtorch.tt.TextModule
TextModule
is a fundamental building block in LangTorch that outputs and takes as input tensors. It inherits from torch.nn.Module
and can facilitate operations both on torch.Tensors
and TextTensors
, chaining text transformations, operations on embeddings and language model inferences.
TextModule: input_class -> output_class
is the base class, subclasses of which can operate on:
TextModule: TextTensor -> TextTensor
, e.g.OpenAI
API activationTextModule: TextTensor -> torch.Tensor
, e.g.EmbeddingModule
TextModule: torch.Tensor -> TextTensor
, e.g. a local LLM inference from tokens or a custom retriever
Initialization
TextModule
itself can be used to initialized a TextTensor
to TextTensor
module that formats prompts with the inputs and sends them to an LLM activation. It can be initialized with parameters:
Parameter | Type | Description |
---|---|---|
prompt |
TextTensor |
A TextTensor of prompt templates. It will be formatted with the input text data via multiplication |
activation |
Activation or str (optional) |
An Activation module or a string with a OpenAI model name. Learn more here. |
key |
str or TextTensor (optional) |
The key or keys (of the same shape as the output) that will be automatically assigned to the output TextTensor entries |
Tip
Setting a key
is useful when chaining TextModules
, as prompts in the second module can use "{key}"
to route outputs of the first module to a specific place.
Forward Pass
The forward
method of TextModule
processes the input TextTensor
by formatting it with the module's prompt template using TextTensor
multiplication. If an activation function is provided, it is then applied to the formatted texts to obtain the output. Finally, if a key is specified, it is assigned to the output TextTensor
.
def forward(self, input: TextTensor) -> TextTensor:
formatted_input = self.prompt * input
output = self.activation(formatted_input) if self.activation else formatted_input
if self.key is not None:
output.set_key_(self.key)
return output
Subclassing TextModule
TextModule
is designed to be subclassed to create custom text processing modules. By subclassing TextModule
, you can extend its functionality and nest multiple modules in one bigger architecture. Submodules can be assigned as regular attributes, forming a tree-like structure.
Here's an example TextModule
subclassing pattern with sequential calls that you can modify freely:
import torch
from langtorch import TextModule, OpenAI
class TextTransformationModel(TextModule):
def __init__(self):
super().__init__()
self.translate = TextModule("Translate to English: {}", activation="gpt-3.5-turbo", key="text")
self.summarize = TextModule("Summarize this text: '{text}'", activation="gpt-3.5-turbo")
def forward(self, input_text):
translated_text = self.translate(input_text)
return self.summarize(translated_text)
A useful pattern is also to explicitly chain both formatting and LLM activations with nn.Sequential
:
import torch
translate = TextModule("Translate to English: {}")
task = TextModule(some_prompt_template)
# OpenAI activation is also a Module so we can chain it explicitly here:
task_module_w_CoT = torch.nn.Sequential(
task,
chain_of_thought,
OpenAI("gpt-4")
)
In this example, the TextTransformationModel
performs text translation followed by summarization. Each processing step is modularized by a TextModule
, specifying the task through its prompt and the language model for activation.
class RAG(TextModule):
def __init__(self, documents: TextTensor, *args, **kwargs):
super().__init__(*args, **kwargs)
self.retriever = Retriever(documents)
def forward(self, user_message: TextTensor, k: int = 5):
retrieved_context = self.retriever(user_message, k) +"\n"
user_message = user_message + "\nCONTEXT:\n" + retrieved_context.sum()
return super().forward(user_message)
Parallel Processing
TextModule
inherently supports parallel processing of multiple tasks on multiple inputs. Both the prompt and input are TextTensor
objects, which can hold multiple entries. The activation function (e.g., language model API calls) is automatically applied in parallel to all entries.
Methods
parameters()
Returns an iterator over the module's parameters. This method is inherited from torch.nn.Module
.
named_parameters()
Returns an iterator over the module's named parameters. This method is inherited from torch.nn.Module
.
state_dict()
Returns a dictionary containing the module's state. This method is inherited from torch.nn.Module
.
load_state_dict(state_dict)
Loads the module's state from a dictionary. This method is inherited from torch.nn.Module
.
to(device)
Moves the module's parameters and buffers to the specified device. This method is inherited from torch.nn.Module
.
train(mode=True)
Sets the module in training mode. This method is inherited from torch.nn.Module
.
eval()
Sets the module in evaluation mode. This method is inherited from torch.nn.Module
.
zero_grad()
Sets the gradients of all parameters to zero. This method is inherited from torch.nn.Module
.
__call__(*input)
Defines the forward pass of the module. This method is inherited from torch.nn.Module
and can be overridden in subclasses to customize the forward pass behavior.
Notes
- Ensure that the
TextTensor
inputs match the expected format of the prompt templates accurately. If there is a mismatch between the placeholder name used in the module prompt and the corresponding key in the input tensor, the completion text will be appended at the end of the text instead of being placed at the placeholder. TextModule
can be extended through subclassing to implement custom forward methods, offering tailored text processing capabilities. Some currently available or future subclasses are mentioned in langtorch.tt.- Parallel (batched) processing of inputs through the LLM activation is handled automatically for all entries regardless of shape.