代码:
# 浮点结果计算+ppl计算fromtransformersimportAutoModelForCausalLM,AutoTokenizerimportjsonfromtqdmimporttqdmimporttorchfromtorch.nnimportCrossEntropyLoss model_name="/data1/huf/Qwen3-0.6B"json_file="/data1/05_ax/01_eval/ppl/qnn/gsm8k_17d799.json"out_json_file="/data1/05_ax/01_eval/ppl/qnn/gsm8k_17d799_qwen3_0.6B_pred.json"num_samples=100device="cuda:7"# load the tokenizer and the modeltokenizer=AutoTokenizer.from_pretrained(model_name)model=AutoModelForCausalLM.from_pretrained(model_name,torch_dtype=torch.float32,device_map=device).to(device)deffind_subsequence(tensor,subsequence):len_tensor=len(tensor)len_subsequence=len(subsequence)foriinrange(len_tensor-len_subsequence+1):iftorch.equal(tensor[i:i+len_subsequence],subsequence):returnireturn-1# 如果没有找到子序列,返回 -1withopen(json_file,'r')asf:datas=json.loads(f.read())template="<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"loss=0.0num_batches=0forkey,itemintqdm(datas.items()):withtorch.no_grad():prompt=item['origin_prompt'][0]['prompt']prompt=template.format(prompt)inputs=tokenizer(prompt,return_tensors="pt").to(model.device)## 浮点模型输出generated_ids=model.generate(**inputs,max_new_tokens=128)output_ids=generated_ids[0][len(inputs.input_ids[0]):].tolist()try:index=len(output_ids)-output_ids[::-1].index(151668)exceptValueError:index=0thinking_content=tokenizer.decode(output_ids[:index],skip_special_tokens=True).strip("\n")content=tokenizer.decode(output_ids[index:],skip_special_tokens=True).strip("\n")## 将浮点结果写回jsonitem['prediction']=content# 计算浮点模型的pplnew_prompt=prompt+content new_inputs=tokenizer(new_prompt,return_tensors="pt").to(model.device)input_ids=new_inputs['input_ids']index=find_subsequence(input_ids[0],torch.tensor([151644,77091,198]).to(device))index=index+3+4## 将编码写入jsonitem['input_ids']=input_ids[0].cpu().numpy().tolist()item['gen_index']=index item['origin_prompt'][0]['new_prompt']=new_prompt outputs=model(input_ids)shift_logits=outputs.logits[...,index:-1,:].contiguous().to(dtype=torch.float32)shift_labels=input_ids[...,index+1:].contiguous().to(shift_logits.device)loss_fct=CrossEntropyLoss()ce_loss=loss_fct(shift_logits.view(-1,shift_logits.size(-1)),shift_labels.view(-1)).detach()loss+=ce_lossprint(f"ce_loss:{ce_loss:3f}, ppl:{ce_loss.exp():3f}")num_batches+=1ifnum_batches>=num_samples:breakloss=loss/num_batches ppl=loss.exp()print(f"ppl loss:{ppl:3f}")# 保存jsonwithopen(out_json_file,'w')asf:json.dump(datas,f,indent=4)