XSum Text Summarization
Tutorial shows how to train Text Summarization Models using Lightning Flash on the XSum Dataset

Goal

This example covers an abstractive Text Summarization deep learning task
    1.
    What is Text Summarization
    2.
    Training the model using train.py script
    3.
    Loading Grid Weights in Flash and Running Summarization Model
This tutorial uses PyTorch Lightning

Tutorial time

5 minutes

Task: Text Summarization

Text Summarization is the task of generating a consise overview of a text's main point into a short sentence/description. For example, taking a web article and describing the topic in a short sentence.

Dataset: XSum

The Extreme Summarization (XSum) dataset is a dataset for evaluation of abstractive single-document summarization systems consisting of 226,711 news articles collected from BBC (2010 to 2017) accompanied with a one-sentence summary. The articles cover a wide variety of subjects (e.g., News, Politics, Sports, Weather, Business, Technology, Science, Health, Family, Education, Entertainment and Arts)
Source: https://arxiv.org/pdf/1808.08745.pdf

Step 1: Model

PyTorch Lightning Flash enables the quick training, fine tuning, and inferencing of SOTA object detection algorithms such as RetinaNet.
For this demo, we're going to be using the code here
Here's a preview of this code.
1
# 1. Download the data
2
download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", "data/")
3
4
5
# 2. Load the data
6
datamodule = SummarizationData.from_csv(
7
"input",
8
"target",
9
train_file=train_file,
10
test_file=test_file,
11
)
12
13
# 3. Build the model
14
model = SummarizationTask(backbone='google/mt5-small')
15
16
# 4. Create the trainer. Run once on data
17
trainer = Trainer()
18
19
# 5. Fine-tune the model
20
trainer.finetune(model, datamodule=datamodule)
21
22
# 6. Save it!
23
trainer.save_checkpoint("summarization_model_xsum.pt")
Copied!

Step 2: Start a RUN

You can reproduce with this button
Or manually train this model on Grid has 4 simple steps:
    Create a Run.
    Copy and paste the model script.
1
https://github.com/aribornstein/T5-Summarization-Demo/blob/main/train.py
Copied!
    Select 1xT4 (16 GB) $0.68/h (g4dn.xlarge) as the Accelerator
    Provide the run arguments --max_epochs 5 --gpus 1
You can add optional flags to this script:
1
--backbone summarization_model_to_train
2
--train_file /path/to/train.csv
3
--valid_file /path/to/val.csv
4
--test_file /path/to/test.csv
5
--download False
Copied!

Step 3: Use the model for predictions

In this step, we load the Grid weights in Flash and run the model to detect objects.
    1.
    Download Artifacts from Grid Run
    1.
    Load model to our script and inference in 2 lines of code.
1
from flash.text import SummarizationTask
2
3
4
# 1. Load the model from a checkpoint
5
model = SummarizationTask.load_from_checkpoint("./summarization_model_xsum.pt")
6
7
# 2. Summarize an article!
8
print(model.predict([
9
"""
10
Camilla bought a box of mangoes with a Brixton £10 note, introduced last year to try to keep the money of local
11
people within the community.The couple were surrounded by shoppers as they walked along Electric Avenue.
12
They came to Brixton to see work which has started to revitalise the borough.
13
It was Charles' first visit to the area since 1996, when he was accompanied by the former
14
South African president Nelson Mandela.Greengrocer Derek Chong, who has run a stall on Electric Avenue
15
for 20 years, said Camilla had been ""nice and pleasant"" when she purchased the fruit.
16
""She asked me what was nice, what would I recommend, and I said we've got some nice mangoes.
17
She asked me were they ripe and I said yes - they're from the Dominican Republic.""
18
Mr Chong is one of 170 local retailers who accept the Brixton Pound.
19
Customers exchange traditional pound coins for Brixton Pounds and then spend them at the market
20
or in participating shops.
21
"""
22
]))
Copied!
Congratulations you have successfully trained and run inference for your first Text Summarization Model with Grid.

Bonus: CLI equivalent

Here are the equivalent commands for the CLI
First, clone the project
1
git clone https://github.com/aribornstein/T5-Summarization-Demo.git
2
cd T5-Summarization-Demo
Copied!
Start run
1
grid run \
2
--instance_type 1_v100_16gb \
3
--framework lightning \
4
--gpus 1 \
5
train.py \
6
--max_epochs 5 \
7
--gpus 1
Copied!
Last modified 3mo ago