Commit a3a5e312 by wyt2000

fix: modeling_minicpm.

parent beaacea5
......@@ -403,8 +403,8 @@ class MiniCPMAttention(nn.Module):
value_states = self.v_proj(hidden_states)
# kv 4bit quantization
key_states = activation_quant(key_states, 4)
value_states = activation_quant(value_states, 4)
key_states = key_states + (activation_quant(key_states, 4) - key_states).detach()
value_states = value_states + (activation_quant(value_states, 4) - value_states).detach()
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
......@@ -515,9 +515,8 @@ class MiniCPMFlashAttention2(MiniCPMAttention):
value_states = self.v_proj(hidden_states)
# kv 4bit quantization
key_states = activation_quant(key_states, 4)
value_states = activation_quant(value_states, 4)
key_states = key_states + (activation_quant(key_states, 4) - key_states).detach()
value_states = value_states + (activation_quant(value_states, 4) - value_states).detach()
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
......@@ -715,8 +714,8 @@ class MiniCPMSdpaAttention(MiniCPMAttention):
value_states = self.v_proj(hidden_states)
# kv 4bit quantization
key_states = activation_quant(key_states, 4)
value_states = activation_quant(value_states, 4)
key_states = key_states + (activation_quant(key_states, 4) - key_states).detach()
value_states = value_states + (activation_quant(value_states, 4) - value_states).detach()
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment