Module data
Expand source code
from torch.utils.data import Dataset
import os, pickle, json
import logging
logger = logging.getLogger(__name__)
from tqdm import tqdm
class SrcCodeDataset(Dataset):
def __init__(self, file_path, model, cache_path=None):
"""
this dataset class is used to load source code dataset in batch for fine-tuning with GPT2LMModel
:param model: the model that the dataset will be fed to
"""
self.inputs = []
load_cache = False
if cache_path != None:
load_cache = self._load_cache(cache_path)
if not load_cache:
self._build(file_path, model)
if cache_path != None:
self._cache(cache_path)
def __len__(self):
return len(self.inputs)
def __getitem__(self, index):
input_ids = self.inputs[index]["input_ids"]
# input_mask = self.inputs[index]["attention_mask"] we don't need attention_mask for this task
# return {"input_ids": input_ids, "input_mask": input_mask}
return {"input_ids": input_ids}
def _load_cache(self, cache_path):
load_cache = False
if os.path.isdir(cache_path):
if os.path.isfile(os.path.join(cache_path, "inputs.pk")):
with open(os.path.join(cache_path, "inputs.pk"), "rb") as f:
logger.info(
f" load cached token ids of model from {cache_path}")
self.inputs = pickle.load(f)
load_cache = True
return load_cache
def _cache(self, cache_path):
if not os.path.isdir(cache_path):
os.makedirs(cache_path)
with open(os.path.join(cache_path, "inputs.pk"), "wb") as f:
pickle.dump(self.inputs, f)
logger.info(
f" save tokenized ids of samples to: {cache_path}/inputs.pk")
def _build(self, file_path, model):
with open(file_path) as f:
for line in tqdm(f):
example = json.loads(line.strip())
if example["label"].lower() == "python":
encoded_plus = model.tokenizer.encode_plus(
model.tokenize("<python>") + example["token_ids"] + [model.eos_token_id],
max_length=model.max_seq_length)
elif example["label"].lower() == "java":
encoded_plus = model.tokenizer.encode_plus(
model.tokenize("<java>") + example["token_ids"] + [model.eos_token_id],
max_length=model.max_seq_length)
self.inputs.append(encoded_plus.data)
Classes
class SrcCodeDataset (file_path, model, cache_path=None)
-
An abstract class representing a :class:
Dataset
.All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:
__getitem__
, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:__len__
, which is expected to return the size of the dataset by many :class:~torch.utils.data.Sampler
implementations and the default options of :class:~torch.utils.data.DataLoader
.Note
:class:
~torch.utils.data.DataLoader
by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.this dataset class is used to load source code dataset in batch for fine-tuning with GPT2LMModel :param model: the model that the dataset will be fed to
Expand source code
class SrcCodeDataset(Dataset): def __init__(self, file_path, model, cache_path=None): """ this dataset class is used to load source code dataset in batch for fine-tuning with GPT2LMModel :param model: the model that the dataset will be fed to """ self.inputs = [] load_cache = False if cache_path != None: load_cache = self._load_cache(cache_path) if not load_cache: self._build(file_path, model) if cache_path != None: self._cache(cache_path) def __len__(self): return len(self.inputs) def __getitem__(self, index): input_ids = self.inputs[index]["input_ids"] # input_mask = self.inputs[index]["attention_mask"] we don't need attention_mask for this task # return {"input_ids": input_ids, "input_mask": input_mask} return {"input_ids": input_ids} def _load_cache(self, cache_path): load_cache = False if os.path.isdir(cache_path): if os.path.isfile(os.path.join(cache_path, "inputs.pk")): with open(os.path.join(cache_path, "inputs.pk"), "rb") as f: logger.info( f" load cached token ids of model from {cache_path}") self.inputs = pickle.load(f) load_cache = True return load_cache def _cache(self, cache_path): if not os.path.isdir(cache_path): os.makedirs(cache_path) with open(os.path.join(cache_path, "inputs.pk"), "wb") as f: pickle.dump(self.inputs, f) logger.info( f" save tokenized ids of samples to: {cache_path}/inputs.pk") def _build(self, file_path, model): with open(file_path) as f: for line in tqdm(f): example = json.loads(line.strip()) if example["label"].lower() == "python": encoded_plus = model.tokenizer.encode_plus( model.tokenize("<python>") + example["token_ids"] + [model.eos_token_id], max_length=model.max_seq_length) elif example["label"].lower() == "java": encoded_plus = model.tokenizer.encode_plus( model.tokenize("<java>") + example["token_ids"] + [model.eos_token_id], max_length=model.max_seq_length) self.inputs.append(encoded_plus.data)
Ancestors
- torch.utils.data.dataset.Dataset