Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
M
MiniCPM-training
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Yutong Wu
MiniCPM-training
Commits
a3a5e312
Commit
a3a5e312
authored
Sep 10, 2024
by
wyt2000
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix: modeling_minicpm.
parent
beaacea5
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
7 deletions
+6
-7
models/MiniCPM-quant/modeling_minicpm.py
+6
-7
No files found.
models/MiniCPM-quant/modeling_minicpm.py
View file @
a3a5e312
...
@@ -403,8 +403,8 @@ class MiniCPMAttention(nn.Module):
...
@@ -403,8 +403,8 @@ class MiniCPMAttention(nn.Module):
value_states
=
self
.
v_proj
(
hidden_states
)
value_states
=
self
.
v_proj
(
hidden_states
)
# kv 4bit quantization
# kv 4bit quantization
key_states
=
activation_quant
(
key_states
,
4
)
key_states
=
key_states
+
(
activation_quant
(
key_states
,
4
)
-
key_states
)
.
detach
(
)
value_states
=
activation_quant
(
value_states
,
4
)
value_states
=
value_states
+
(
activation_quant
(
value_states
,
4
)
-
value_states
)
.
detach
(
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
...
@@ -515,9 +515,8 @@ class MiniCPMFlashAttention2(MiniCPMAttention):
...
@@ -515,9 +515,8 @@ class MiniCPMFlashAttention2(MiniCPMAttention):
value_states
=
self
.
v_proj
(
hidden_states
)
value_states
=
self
.
v_proj
(
hidden_states
)
# kv 4bit quantization
# kv 4bit quantization
key_states
=
activation_quant
(
key_states
,
4
)
key_states
=
key_states
+
(
activation_quant
(
key_states
,
4
)
-
key_states
)
.
detach
()
value_states
=
activation_quant
(
value_states
,
4
)
value_states
=
value_states
+
(
activation_quant
(
value_states
,
4
)
-
value_states
)
.
detach
()
# Flash attention requires the input to have the shape
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# batch_size x seq_length x head_dim x hidden_dim
...
@@ -715,8 +714,8 @@ class MiniCPMSdpaAttention(MiniCPMAttention):
...
@@ -715,8 +714,8 @@ class MiniCPMSdpaAttention(MiniCPMAttention):
value_states
=
self
.
v_proj
(
hidden_states
)
value_states
=
self
.
v_proj
(
hidden_states
)
# kv 4bit quantization
# kv 4bit quantization
key_states
=
activation_quant
(
key_states
,
4
)
key_states
=
key_states
+
(
activation_quant
(
key_states
,
4
)
-
key_states
)
.
detach
(
)
value_states
=
activation_quant
(
value_states
,
4
)
value_states
=
value_states
+
(
activation_quant
(
value_states
,
4
)
-
value_states
)
.
detach
(
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment