Callbacks for predicting within AdaptNLP using the fastai framework

class GatherInputsCallback[source]

GatherInputsCallback(after_create=None, before_fit=None, before_epoch=None, before_train=None, before_batch=None, after_pred=None, after_loss=None, before_backward=None, before_step=None, after_cancel_step=None, after_step=None, after_cancel_batch=None, after_batch=None, after_cancel_train=None, after_train=None, before_validate=None, after_cancel_validate=None, after_validate=None, after_cancel_epoch=None, after_epoch=None, after_cancel_fit=None, after_fit=None) :: Callback

Prepares basic input dictionary for HuggingFace Transformers

This Callback generates a very basic dictionary consisting of input_ids, attention_masks, and token_type_ids, and saves it to the attribute self.learn.inputs.

If further data is expected or needed from the batch, the additional Callback(s) should have an order of -2

Parameters:

  • after_create : <class 'NoneType'>, optional

  • before_fit : <class 'NoneType'>, optional

  • before_epoch : <class 'NoneType'>, optional

  • before_train : <class 'NoneType'>, optional

  • before_batch : <class 'NoneType'>, optional

  • after_pred : <class 'NoneType'>, optional

  • after_loss : <class 'NoneType'>, optional

  • before_backward : <class 'NoneType'>, optional

  • before_step : <class 'NoneType'>, optional

  • after_cancel_step : <class 'NoneType'>, optional

  • after_step : <class 'NoneType'>, optional

  • after_cancel_batch : <class 'NoneType'>, optional

  • after_batch : <class 'NoneType'>, optional

  • after_cancel_train : <class 'NoneType'>, optional

  • after_train : <class 'NoneType'>, optional

  • before_validate : <class 'NoneType'>, optional

  • after_cancel_validate : <class 'NoneType'>, optional

  • after_validate : <class 'NoneType'>, optional

  • after_cancel_epoch : <class 'NoneType'>, optional

  • after_epoch : <class 'NoneType'>, optional

  • after_cancel_fit : <class 'NoneType'>, optional

  • after_fit : <class 'NoneType'>, optional

GatherInputsCallback.before_validate[source]

GatherInputsCallback.before_validate()

Sets the number of inputs in self.dls

GatherInputsCallback.before_batch[source]

GatherInputsCallback.before_batch()

Turns self.xb from a tuple to a dictionary of either {"input_ids", "attention_masks", "token_type_ids"}d or {"input_ids", "attention_masks"}

class SetInputsCallback[source]

SetInputsCallback(as_dict=False) :: Callback

Callback which runs after GatherInputsCallback that sets self.learn.xb

Parameters:

  • as_dict : <class 'bool'>, optional

    Whether to leave `self.xb` as a dictionary of values

SetInputsCallback.before_batch[source]

SetInputsCallback.before_batch()

Set self.learn.xb to self.learn.inputs.values()

class GeneratorCallback[source]

GeneratorCallback(num_beams:int, min_length:int, max_length:int, early_stopping:bool, input_ids:Optional[LongTensor]=None, do_sample:Optional[bool]=None, temperature:Optional[float]=None, top_k:Optional[int]=None, top_p:Optional[float]=None, repetition_penalty:Optional[float]=None, bad_words_ids:Optional[Iterable[int]]=None, bos_token_id:Optional[int]=None, pad_token_id:Optional[int]=None, eos_token_id:Optional[int]=None, length_penalty:Optional[float]=None, no_repeat_ngram_size:Optional[int]=None, encoder_no_repeat_ngram_size:Optional[int]=None, num_return_sequences:Optional[int]=None, max_time:Optional[float]=None, max_new_tokens:Optional[int]=None, decoder_start_token_id:Optional[int]=None, use_cache:Optional[bool]=None, num_beam_groups:Optional[int]=None, diversity_penalty:Optional[float]=None, prefix_allowed_tokens_fn:Optional[Callable[int, Tensor, List[int]]]=None, output_attentions:Optional[bool]=None, output_hidden_states:Optional[bool]=None, output_scores:Optional[bool]=None, return_dict_in_generate:Optional[bool]=None, forced_bos_token_id:Optional[int]=None, forced_eos_token_id:Optional[int]=None, remove_invalid_values:Optional[bool]=None, synced_gpus:Optional[bool]=None) :: Callback

Callback used for models that utilize self.model.generate

Parameters:

  • num_beams : <class 'int'>

    Number of beams for beam search

  • min_length : <class 'int'>

    Minimal length of the sequence generated

  • max_length : <class 'int'>

    Maximum length of the sequence generated

  • early_stopping : <class 'bool'>

    Whether to do early stopping

  • input_ids : typing.Union[torch.LongTensor, NoneType], optional

  • do_sample : typing.Union[bool, NoneType], optional

  • temperature : typing.Union[float, NoneType], optional

  • top_k : typing.Union[int, NoneType], optional

  • top_p : typing.Union[float, NoneType], optional

  • repetition_penalty : typing.Union[float, NoneType], optional

  • bad_words_ids : typing.Union[typing.Iterable[int], NoneType], optional

  • bos_token_id : typing.Union[int, NoneType], optional

  • pad_token_id : typing.Union[int, NoneType], optional

  • eos_token_id : typing.Union[int, NoneType], optional

  • length_penalty : typing.Union[float, NoneType], optional

  • no_repeat_ngram_size : typing.Union[int, NoneType], optional

  • encoder_no_repeat_ngram_size : typing.Union[int, NoneType], optional

  • num_return_sequences : typing.Union[int, NoneType], optional

  • max_time : typing.Union[float, NoneType], optional

  • max_new_tokens : typing.Union[int, NoneType], optional

  • decoder_start_token_id : typing.Union[int, NoneType], optional

  • use_cache : typing.Union[bool, NoneType], optional

  • num_beam_groups : typing.Union[int, NoneType], optional

  • diversity_penalty : typing.Union[float, NoneType], optional

  • prefix_allowed_tokens_fn : typing.Union[typing.Callable[[int, torch.Tensor], typing.List[int]], NoneType], optional

  • output_attentions : typing.Union[bool, NoneType], optional

  • output_hidden_states : typing.Union[bool, NoneType], optional

  • output_scores : typing.Union[bool, NoneType], optional

  • return_dict_in_generate : typing.Union[bool, NoneType], optional

  • forced_bos_token_id : typing.Union[int, NoneType], optional

  • forced_eos_token_id : typing.Union[int, NoneType], optional

  • remove_invalid_values : typing.Union[bool, NoneType], optional

  • synced_gpus : typing.Union[bool, NoneType], optional