diff --git a/docs/source/_static/js/custom.js b/docs/source/_static/js/custom.js index 6c574a0763..2481f92677 100644 --- a/docs/source/_static/js/custom.js +++ b/docs/source/_static/js/custom.js @@ -236,9 +236,11 @@ function platformToggle() { const createFrameworkButtons = sample => { const pytorchButton = document.createElement("button"); + pytorchButton.classList.add('pytorch-button') pytorchButton.innerText = "PyTorch"; const tensorflowButton = document.createElement("button"); + tensorflowButton.classList.add('tensorflow-button') tensorflowButton.innerText = "TensorFlow"; const selectorDiv = document.createElement("div"); @@ -253,22 +255,36 @@ function platformToggle() { tensorflowButton.classList.remove("selected"); pytorchButton.addEventListener("click", () => { - sample.element.innerHTML = sample.pytorchSample; - pytorchButton.classList.add("selected"); - tensorflowButton.classList.remove("selected"); + for(const codeBlock of updatedCodeBlocks){ + codeBlock.element.innerHTML = codeBlock.pytorchSample; + } + Array.from(document.getElementsByClassName('pytorch-button')).forEach(button => { + button.classList.add("selected"); + }) + Array.from(document.getElementsByClassName('tensorflow-button')).forEach(button => { + button.classList.remove("selected"); + }) }); tensorflowButton.addEventListener("click", () => { - sample.element.innerHTML = sample.tensorflowSample; - tensorflowButton.classList.add("selected"); - pytorchButton.classList.remove("selected"); + for(const codeBlock of updatedCodeBlocks){ + codeBlock.element.innerHTML = codeBlock.tensorflowSample; + } + Array.from(document.getElementsByClassName('tensorflow-button')).forEach(button => { + button.classList.add("selected"); + }) + Array.from(document.getElementsByClassName('pytorch-button')).forEach(button => { + button.classList.remove("selected"); + }) }); }; - codeBlocks + const updatedCodeBlocks = codeBlocks .map(element => {return {element: element.firstChild, innerText: element.innerText}}) .filter(codeBlock => codeBlock.innerText.includes(pytorchIdentifier) && codeBlock.innerText.includes(tensorflowIdentifier)) .map(getFrameworkSpans) - .forEach(createFrameworkButtons); + + updatedCodeBlocks + .forEach(createFrameworkButtons) }