Commit 3fc56504 by wyt2000

update model and inference.

parent 1cda58a7
*.slurm
submit.sh
ret_one
{
"_name_or_path": "openbmb/CPM-2B",
"architectures": [
"MiniCPMForCausalLM"
],
"auto_map": {
"AutoConfig": "configuration_minicpm.MiniCPMConfig",
"AutoModel": "modeling_minicpm.MiniCPMModel",
"AutoModelForCausalLM": "modeling_minicpm.MiniCPMForCausalLM",
"AutoModelForSeq2SeqLM": "modeling_minicpm.MiniCPMForCausalLM",
"AutoModelForSequenceClassification": "modeling_minicpm.MiniCPMForSequenceClassification"
},
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 1536,
"initializer_range": 0.1,
"intermediate_size": 3840,
"max_position_embeddings": 4096,
"num_attention_heads": 24,
"num_hidden_layers": 52,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-05,
"rope_scaling": null,
"torch_dtype": "bfloat16",
"transformers_version": "4.36.0",
"use_cache": true,
"vocab_size": 73440,
"scale_emb": 12,
"dim_model_base": 256,
"scale_depth": 1.4
"_name_or_path": "data/MiniCPM_quant_per_head_fp4_LSQ_after_rope_safesoft_lowrope_prune_fixed",
"architectures": [
"MiniCPMForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"auto_map": {
"AutoConfig": "configuration_minicpm.MiniCPMConfig",
"AutoModel": "modeling_minicpm.MiniCPMModel",
"AutoModelForCausalLM": "modeling_minicpm.MiniCPMForCausalLM",
"AutoModelForSeq2SeqLM": "modeling_minicpm.MiniCPMForCausalLM",
"AutoModelForSequenceClassification": "modeling_minicpm.MiniCPMForSequenceClassification"
},
"bos_token_id": 1,
"dim_model_base": 256,
"eos_token_id": 2,
"head_w_quantbit": 4,
"head_x_quantbit": 8,
"hidden_act": "silu",
"hidden_size": 1536,
"initializer_range": 0.1,
"intermediate_size": 3840,
"kv_cache_quantbit": 4,
"linear_w_quantbit": 4,
"linear_x_quantbit": 8,
"lm_head_rank": 1024,
"max_position_embeddings": 4096,
"model_type": "minicpm",
"num_attention_heads": 24,
"num_hidden_layers": 52,
"num_key_value_heads": 8,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": null,
"rope_theta": 10000.0,
"scale_depth": 1.4,
"scale_emb": 12,
"torch_dtype": "bfloat16",
"transformers_version": "4.41.2",
"use_cache": true,
"vocab_size": 73440
}
......@@ -176,7 +176,8 @@ class MiniCPMConfig(PretrainedConfig):
)
try:
import flash_attn
self._attn_implementation = "flash_attention_2"
# self._attn_implementation = "flash_attention_2"
self._attn_implementation = "eager"
except:
pass
......
{
"do_sample": true,
"top_p": 0.8,
"temperature": 0.8,
"bos_token_id": 1,
"eos_token_id": 2
}
\ No newline at end of file
"bos_token_id": 1,
"do_sample": true,
"eos_token_id": 2,
"temperature": 0.8,
"top_p": 0.8,
"transformers_version": "4.41.2"
}
from transformers import AutoTokenizer, AutoModelForCausalLM
import transformers
import torch
import os
model_path = os.getcwd()
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map='auto',
low_cpu_mem_usage=True,
trust_remote_code=True,
attn_implementation="eager",
torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
tokenizer.pad_token = ''
input_list = ["### Problem: Write a Python program to calculate the 10th prime."]
inputs = tokenizer(input_list, return_tensors="pt", padding=True).to(model.device)
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=2048,
num_return_sequences=1,
do_sample=False
)
print("response:",tokenizer.decode(outputs[0], skip_special_tokens=True))
/home/S/wuyt/lustre/model/aimo-progress-prize-trained-models/Code-Math-QA-WizardLM-deepseekproof-Lean-Workbook-V3-MiniF2F-Valid-Diff-Prompt-minicpm-quant-per-head-fp4-LSQ-after-rope-safesoft-lowrope-rotamul-fixed-1022/model.safetensors
\ No newline at end of file
/lustre/S/huangdi/open_for_out/models/MiniCPM_quant_qilei/pytorch_model.bin
\ No newline at end of file
......@@ -13,6 +13,7 @@
"rstrip": false,
"single_word": false
},
"pad_token": "</s>",
"unk_token": {
"content": "<unk>",
"lstrip": false,
......
/lustre/S/huangdi/open_for_out/models/MiniCPM_quant_qilei/tokenizer.json
\ No newline at end of file
/home/S/wuyt/lustre/model/aimo-progress-prize-trained-models/Code-Math-QA-WizardLM-deepseekproof-Lean-Workbook-V3-MiniF2F-Valid-Diff-Prompt-minicpm-quant-per-head-fp4-LSQ-after-rope-safesoft-lowrope-rotamul-fixed-1022/tokenizer.json
\ No newline at end of file
/lustre/S/huangdi/open_for_out/models/MiniCPM_quant_qilei/tokenizer.model
\ No newline at end of file
/home/S/wuyt/lustre/model/aimo-progress-prize-trained-models/Code-Math-QA-WizardLM-deepseekproof-Lean-Workbook-V3-MiniF2F-Valid-Diff-Prompt-minicpm-quant-per-head-fp4-LSQ-after-rope-safesoft-lowrope-rotamul-fixed-1022/tokenizer.model
\ No newline at end of file
......@@ -28,11 +28,12 @@
}
},
"bos_token": "<s>",
"chat_template": "{% for message in messages %}{% if (message['role'] == 'system')%}{{ '' }}{% elif (message['role'] == 'user')%}{{ message['content'] }}{% elif (message['role'] == 'assistant')%}{{ message['content'] }}{% endif %}{% if loop.last and message['role'] == 'user' and add_generation_prompt %}{{ '' }}{% endif %}{% endfor %}",
"clean_up_tokenization_spaces": false,
"eos_token": "</s>",
"legacy": true,
"model_max_length": 1000000000000000019884624838656,
"pad_token": null,
"model_max_length": 2048,
"pad_token": "</s>",
"sp_model_kwargs": {},
"spaces_between_special_tokens": false,
"tokenizer_class": "LlamaTokenizer",
......
......@@ -2,26 +2,61 @@ import math
import torch
from torch import nn
def grad_scale(x, scale):
y = x
y_grad = x * scale
return (y - y_grad).detach() + y_grad
def weight_quant(weight, num_bits=1):
dtype = weight.dtype
weight = weight.float()
Qn = -2 ** (num_bits - 1)
Qp = 2 ** (num_bits - 1) - 1
s = Qp / weight.abs().mean().clamp(min=1e-5)
result = (weight * s).round().clamp(Qn, Qp) / s
return result.type(dtype)
def round_pass(x):
y = x.round()
y_grad = x
return (y - y_grad).detach() + y_grad
class Quantizer(nn.Module):
def __init__(self, num_bits, seq_len):
super().__init__()
self.thd_neg = - 2 ** (num_bits - 1)
self.thd_pos = 2 ** (num_bits - 1) - 1
self.s = torch.nn.Parameter(torch.ones(seq_len))
def activation_quant(x, num_bits=8):
dtype = x.dtype
x = x.float()
Qn = -2 ** (num_bits - 1)
Qp = 2 ** (num_bits - 1) - 1
s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
result = (x * s).round().clamp(Qn, Qp) / s
return result.type(dtype)
def forward(self, x, input_idx):
s = self.s[input_idx:input_idx + x.shape[1]]
s_scale = s[None, :, None]
x = x * s_scale
x = round_pass(x)
x = torch.clamp(x, self.thd_neg, self.thd_pos)
return s, x
def get_scale_f32(src_amax, dst_max):
scale = dst_max / src_amax.float()
return scale
def round_to_FP4(input):
dst_max=6.0
emax=2
emin=0
p=2
part= (2 - 2**(1-p))
ab= torch.where(torch.isinf(input)+torch.isnan(input), torch.ones_like(input)*dst_max, input)
ab = torch.where(ab>dst_max, torch.ones_like(ab)*dst_max, ab)
ab = torch.where(ab<2.0**(emin) * 2**(-p), torch.zeros_like(ab), ab)
E = torch.where(ab < 2**(emin) , torch.ones_like(ab) * (emin), torch.floor(torch.log2(ab.float())))
P = torch.round(ab * 2**(-E) * 2**(p-1) ) / 2**(p-1)
data = 2**E * P
return data
def quant_fp4(data, num_bits):
sign = torch.sign(data)
abs_data = torch.abs(data).float()
amax, index = torch.max(abs_data, -1, True)
qscale = get_scale_f32(amax, 6.0)
quant_data = round_to_FP4(abs_data * qscale)
quant_data = quant_data * sign
quant_data = data + (quant_data - data).detach()
return qscale, quant_data, data.dtype
def dequant_fp4(qscale, quant_data, target_type):
return (quant_data / qscale).to(target_type)
class CLMLinear(nn.Linear):
......@@ -29,6 +64,7 @@ class CLMLinear(nn.Linear):
*kargs,
weight_bits=1,
input_bits=8,
seq_len=4096,
**kwargs
):
super(CLMLinear, self).__init__(*kargs, **kwargs)
......@@ -37,14 +73,37 @@ class CLMLinear(nn.Linear):
"""
self.weight_bits = weight_bits
self.input_bits = input_bits
self.seq_len = seq_len
self.activation_quant = Quantizer(input_bits, seq_len)
def forward(self, input):
quant_input = input + (activation_quant(input, self.input_bits) - input).detach()
quant_weight = self.weight + (activation_quant(self.weight, self.weight_bits) - self.weight).detach()
if input.shape[1] != 1:
self.input_idx = 0
out = nn.functional.linear(quant_input, quant_weight)
if input.shape[1] + self.input_idx <= self.seq_len:
input_s, tobe_dequant_input = self.activation_quant(input, self.input_idx)
self.input_idx = input.shape[1] + self.input_idx
else:
raise ValueError(f"input.shape[1]: {input.shape[1]}, self.input_idx: {self.input_idx}, self.seq_len: {self.seq_len}")
weight_s, tobe_dequant_weight, _ = quant_fp4(self.weight, self.weight_bits)
out = self.elementwise_multiply_and_div(tobe_dequant_input, tobe_dequant_weight,input_s,weight_s,operate_type = torch.bfloat16)
out = out.type(input.dtype)
if not self.bias is None:
out += self.bias.view(1, -1).expand_as(out)
return out
def elementwise_multiply_and_div(self, A, B, C, D, operate_type=torch.bfloat16):
A = A.type(operate_type)
B = B.type(operate_type)
C = C.type(operate_type)
D = D.type(operate_type)
E = torch.matmul(C[:, None], D.T)
E = torch.clamp(E, min=1e-5)
F = torch.matmul(A, B.T)
result = F / E
return result
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment