LLM4APR's picture
Update README.md
3f16db6 verified
---
license: bigscience-openrail-m
pipeline_tag: text-generation
tags:
- code
- automated program repair
---
# StarCoder-15B_for_NTR
We fine-tuned [StarCoder-15B](https://huggingface.co/bigcode/starcoder) on [Transfer_dataset](https://drive.google.com/drive/folders/1F1BPfTxHDGX-OCBthudCbu_6Qvcg_fbP?usp=drive_link) under the [NTR](https://sites.google.com/view/neuraltemplaterepair) framework for APR research.
## Model Use
To use this model, please make sure to install transformers, peft, bitsandbytes, and accelerate.
```bash
pip install transformers
pip install peft
pip install bitsandbytes
pip install accelerate
```
Then, please run the following script to merge the adapter into the CodeLlama.
```bash
bash merge.sh
```
Finally, you can load the model to generate patches for buggy code.
```python
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training
import torch
# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained('bigcode/starcoderbase', use_auth_token=True)
model = AutoModelForCausalLM.from_pretrained(
"StarCoder-15B_for_NTR/Epoch_1/-merged",
use_auth_token=True,
use_cache=True,
load_in_8bit=True,
device_map="auto"
)
model = prepare_model_for_int8_training(model)
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules = ["c_proj", "c_attn", "q_attn"]
)
model = get_peft_model(model, lora_config)
# a bug-fix pairs
buggy_code = "
public MultiplePiePlot(CategoryDataset dataset){
super();
// bug_start
this.dataset=dataset;
// bug_end
PiePlot piePlot=new PiePlot(null);
this.pieChart=new JFreeChart(piePlot);
this.pieChart.removeLegend();
this.dataExtractOrder=TableOrder.BY_COLUMN;
this.pieChart.setBackgroundPaint(null);
TextTitle seriesTitle=new TextTitle("Series Title",new Font("SansSerif",Font.BOLD,12));
seriesTitle.setPosition(RectangleEdge.BOTTOM);
this.pieChart.setTitle(seriesTitle);
this.aggregatedItemsKey="Other";
this.aggregatedItemsPaint=Color.lightGray;
this.sectionPaints=new HashMap();
}
"
repair_template = "OtherTemplate"
fixed_code = "
// fix_start
setDataset(dataset);
// fix_end
"
# model inference
input_text = '<commit_before>\n' + buggy_code + '\n<commit_msg>\n' + repair_template + '\n<commit_after>\n'
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(0)
eos_id = tokenizer.convert_tokens_to_ids(tokenizer.eos_token)
generated_ids = model.generate(
input_ids=input_ids,
max_new_tokens=256,
num_beams=10,
num_return_sequences=10,
early_stopping=True,
pad_token_id=eos_id,
eos_token_id=eos_id
)
for generated_id in generated_ids:
generated_text = tokenizer.decode(generated_id, skip_special_tokens=False)
patch = generated_text.split('\n<commit_after>\n')[1]
patch = patch.replace('<|endoftext|>','')
print(patch)
```
## Model Details
The model is licensed under the BigCode OpenRAIL-M v1 license agreement. You can find the full agreement [here](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement).